nccl.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import collections
  2. import warnings
  3. import torch.cuda
  4. from typing import Optional, Sequence, Union
  5. __all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']
  6. SUM = 0 # ncclRedOp_t
  7. def is_available(tensors):
  8. if not hasattr(torch._C, '_nccl_all_reduce'):
  9. warnings.warn('PyTorch is not compiled with NCCL support')
  10. return False
  11. devices = set()
  12. for tensor in tensors:
  13. if tensor.is_sparse:
  14. return False
  15. if not tensor.is_contiguous():
  16. return False
  17. if not tensor.is_cuda:
  18. return False
  19. device = tensor.get_device()
  20. if device in devices:
  21. return False
  22. devices.add(device)
  23. return True
  24. def version():
  25. ver = torch._C._nccl_version()
  26. major = ver >> 32
  27. minor = (ver >> 16) & 65535
  28. patch = ver & 65535
  29. return (major, minor, patch)
  30. def unique_id():
  31. return torch._C._nccl_unique_id()
  32. def init_rank(num_ranks, uid, rank):
  33. return torch._C._nccl_init_rank(num_ranks, uid, rank)
  34. def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
  35. if not isinstance(inputs, collections.abc.Container) or isinstance(inputs, torch.Tensor):
  36. raise TypeError("Inputs should be a collection of tensors")
  37. def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
  38. _check_sequence_type(inputs)
  39. if outputs is None:
  40. outputs = inputs
  41. _check_sequence_type(outputs)
  42. torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms)
  43. # `output` used to be `outputs`, taking in a list of tensors. So we have two
  44. # arguments for BC reasons.
  45. def reduce(inputs: Sequence[torch.Tensor],
  46. output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
  47. root: int = 0,
  48. op: int = SUM,
  49. streams: Optional[Sequence[torch.cuda.Stream]] = None,
  50. comms=None, *,
  51. outputs: Optional[Sequence[torch.Tensor]] = None) -> None:
  52. _check_sequence_type(inputs)
  53. _output: torch.Tensor
  54. if outputs is not None:
  55. if output is not None:
  56. raise ValueError(
  57. "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in "
  58. "favor of 'output', taking in a single output tensor. The signature of reduce is: "
  59. "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None).")
  60. else:
  61. warnings.warn(
  62. "nccl.reduce with an output tensor list is deprecated. "
  63. "Please specify a single output tensor with argument 'output' instead instead.")
  64. _output = outputs[root]
  65. elif not isinstance(output, torch.Tensor) and isinstance(output, collections.abc.Sequence):
  66. # User called old API with positional arguments of list of output tensors.
  67. warnings.warn(
  68. "nccl.reduce with an output tensor list is deprecated. "
  69. "Please specify a single output tensor.")
  70. _output = output[root]
  71. else:
  72. _output = inputs[root] if output is None else output
  73. torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
  74. def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None:
  75. _check_sequence_type(inputs)
  76. torch._C._nccl_broadcast(inputs, root, streams, comms)
  77. def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None:
  78. _check_sequence_type(inputs)
  79. _check_sequence_type(outputs)
  80. torch._C._nccl_all_gather(inputs, outputs, streams, comms)
  81. def reduce_scatter(inputs: Sequence[torch.Tensor],
  82. outputs: Sequence[torch.Tensor],
  83. op: int = SUM,
  84. streams=None, comms=None) -> None:
  85. _check_sequence_type(inputs)
  86. _check_sequence_type(outputs)
  87. torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)