replicated_tensor.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import torch
  2. import torch.distributed as dist
  3. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  4. from torch.distributed import distributed_c10d
  5. from torch.overrides import get_default_nowrap_functions
  6. _REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [
  7. # List of ops where if parameters are a combination of ReplicatedTensors
  8. # and non-tensors, we can still return a ReplicatedTensor as the result.
  9. torch.unsqueeze,
  10. torch.Tensor.unsqueeze,
  11. torch.Tensor.__getitem__,
  12. ]
  13. class ReplicatedTensor(torch.Tensor):
  14. """
  15. ReplicatedTensor represents a tensor which is replicated across the `world_size` and
  16. has the same value on each rank.
  17. ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together
  18. with ShardedTensor/Tensor together to express different types of computation. The
  19. inter-op rules defined as (using torch.add as an example op):
  20. ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
  21. ReplicatedTensor + torch.Tensor = torch.Tensor
  22. ReplicatedTensor + ShardedTensor = ShardedTensor
  23. ReplicatedTensor + other type (i.e. Scalar) = other type
  24. NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
  25. construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
  26. stays the same, there's no enforcement on it (i.e. if you manually modify content on
  27. some ranks, the modified value will not automatically get synced to other nodes). If
  28. you wish to manually validate tensors are the same across ranks, use `validate()`.
  29. """
  30. _process_group: distributed_c10d.ProcessGroup
  31. __slots__ = ["_process_group"]
  32. def __new__(cls, data=None, process_group=None):
  33. if data is None:
  34. data = torch.empty(0)
  35. r = torch.Tensor._make_subclass(cls, data, data.requires_grad) # type: ignore[arg-type]
  36. r._process_group = ( # type: ignore[attr-defined]
  37. process_group
  38. if process_group is not None
  39. else distributed_c10d._get_default_group()
  40. )
  41. return r
  42. def __deepcopy__(self, memo):
  43. if id(self) in memo:
  44. return memo[id(self)]
  45. else:
  46. result = type(self)(self.data.clone(memory_format=torch.preserve_format), self._process_group)
  47. memo[id(self)] = result
  48. return result
  49. def __repr__(self):
  50. return f"ReplicatedTensor({super().__repr__()})"
  51. @classmethod
  52. def __torch_function__(cls, func, types, args=(), kwargs=None):
  53. if kwargs is None:
  54. kwargs = {}
  55. # We will re-dispatch the execution to ShardedTensor __torch_function__
  56. # if we find there're ShardedTensor operands. We will also check if args/kwargs
  57. # are all replicated tensor operands, we have to do this to ensure we do not
  58. # converting results back to ReplicatedTensor if not all operands are replicated.
  59. all_replicated = True
  60. replicated_with_non_tensor = True
  61. replicated_pg = None
  62. def dispatch_arg(arg):
  63. # This function returns a tuple, first element represents whether the op been
  64. # executed, the second element represents the result of the execution
  65. nonlocal replicated_pg, all_replicated, replicated_with_non_tensor
  66. if isinstance(arg, ShardedTensor):
  67. # redispatch to ShardedTensor
  68. # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
  69. return True, arg.__torch_function__(func, types, args, kwargs)
  70. if isinstance(arg, ReplicatedTensor):
  71. if replicated_pg is None:
  72. replicated_pg = arg._process_group
  73. elif replicated_pg != arg._process_group:
  74. raise RuntimeError(
  75. f"ReplicatedTensor operands must be in the same process group "
  76. f"in torch function '{func.__name__}', but found at least two "
  77. f"ReplicatedTensor operands in different process groups! ")
  78. elif isinstance(arg, torch.Tensor):
  79. replicated_with_non_tensor = False
  80. all_replicated = False
  81. else:
  82. all_replicated = False
  83. return False, None
  84. for arg in args:
  85. redispatched, res = dispatch_arg(arg)
  86. if redispatched:
  87. return res
  88. if kwargs is not None:
  89. for k, v in kwargs.items():
  90. redispatched, res = dispatch_arg(v)
  91. if redispatched:
  92. return res
  93. # We cann't do super().__torch_function__() as it implicitly convert the result
  94. # back to tensor subclasses, where in our case, we need to control the output type
  95. # base on the inter-op rules we defined.
  96. with torch._C.DisableTorchFunctionSubclass():
  97. rs = func(*args, **kwargs)
  98. if func in get_default_nowrap_functions():
  99. return rs
  100. result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor)
  101. should_convert_to_replicated = all_replicated or (
  102. replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST
  103. )
  104. if result_not_replicated and should_convert_to_replicated:
  105. # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
  106. # __torch_function__, result is a torch.Tensor, then we convert and return a
  107. # ReplicatedTensor according to our inter-op rule
  108. rs = rs.as_subclass(ReplicatedTensor) # type: ignore[arg-type]
  109. # propagate the process_group field to result
  110. rs._process_group = replicated_pg # type: ignore[attr-defined]
  111. return rs
  112. def validate(self) -> bool:
  113. """
  114. Validate the ReplicatedTensor is legit by all gathering tensors on all ranks
  115. and check to make sure they are the same.
  116. If there's some ranks with different values, a ValueError will be raised.
  117. Keyword args:
  118. process_group (ProcessGroup, optional): The process group to work on. If None,
  119. the default process group will be used.
  120. Returns:
  121. True if validation succeed.
  122. """
  123. world_size = dist.get_world_size(self._process_group)
  124. current_rank = dist.get_rank(self._process_group)
  125. tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)]
  126. dist.all_gather(tensors_on_rank, self, group=self._process_group)
  127. # validate and check if all tensors are equal
  128. for rank, tensor in enumerate(tensors_on_rank):
  129. if not torch.allclose(self, tensor):
  130. raise ValueError(
  131. f"ReplicatedTensor have different values on rank {current_rank} and {rank}")
  132. return True
  133. def __setstate__(self, state):
  134. with torch._C.DisableTorchFunctionSubclass():
  135. self.data = state
  136. self.requires_grad = state.requires_grad
  137. from torch.distributed._shard.api import _get_current_process_group
  138. self._process_group = _get_current_process_group()
  139. def __getstate__(self):
  140. return self.data