123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- import collections
- import warnings
- import torch.cuda
- from typing import Optional, Sequence, Union
- __all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']
- SUM = 0 # ncclRedOp_t
- def is_available(tensors):
- if not hasattr(torch._C, '_nccl_all_reduce'):
- warnings.warn('PyTorch is not compiled with NCCL support')
- return False
- devices = set()
- for tensor in tensors:
- if tensor.is_sparse:
- return False
- if not tensor.is_contiguous():
- return False
- if not tensor.is_cuda:
- return False
- device = tensor.get_device()
- if device in devices:
- return False
- devices.add(device)
- return True
- def version():
- ver = torch._C._nccl_version()
- major = ver >> 32
- minor = (ver >> 16) & 65535
- patch = ver & 65535
- return (major, minor, patch)
- def unique_id():
- return torch._C._nccl_unique_id()
- def init_rank(num_ranks, uid, rank):
- return torch._C._nccl_init_rank(num_ranks, uid, rank)
- def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
- if not isinstance(inputs, collections.abc.Container) or isinstance(inputs, torch.Tensor):
- raise TypeError("Inputs should be a collection of tensors")
- def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
- _check_sequence_type(inputs)
- if outputs is None:
- outputs = inputs
- _check_sequence_type(outputs)
- torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
- # `output` used to be `outputs`, taking in a list of tensors. So we have two
- # arguments for BC reasons.
- def reduce(inputs: Sequence[torch.Tensor],
- output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
- root: int = 0,
- op: int = SUM,
- streams: Optional[Sequence[torch.cuda.Stream]] = None,
- comms=None, *,
- outputs: Optional[Sequence[torch.Tensor]] = None) -> None:
- _check_sequence_type(inputs)
- _output: torch.Tensor
- if outputs is not None:
- if output is not None:
- raise ValueError(
- "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
- "favor of 'output', taking in a single output tensor. The signature of reduce is: "
- "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None).")
- else:
- warnings.warn(
- "nccl.reduce with an output tensor list is deprecated. "
- "Please specify a single output tensor with argument 'output' instead instead.")
- _output = outputs[root]
- elif not isinstance(output, torch.Tensor) and isinstance(output, collections.abc.Sequence):
- # User called old API with positional arguments of list of output tensors.
- warnings.warn(
- "nccl.reduce with an output tensor list is deprecated. "
- "Please specify a single output tensor.")
- _output = output[root]
- else:
- _output = inputs[root] if output is None else output
- torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
- def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None:
- _check_sequence_type(inputs)
- torch._C._nccl_broadcast(inputs, root, streams, comms)
- def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None:
- _check_sequence_type(inputs)
- _check_sequence_type(outputs)
- torch._C._nccl_all_gather(inputs, outputs, streams, comms)
- def reduce_scatter(inputs: Sequence[torch.Tensor],
- outputs: Sequence[torch.Tensor],
- op: int = SUM,
- streams=None, comms=None) -> None:
- _check_sequence_type(inputs)
- _check_sequence_type(outputs)
- torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)
|