redistribute.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from typing import cast, Dict, List, Sequence, Tuple
  3. import torch
  4. import torch.distributed._tensor.api as dtensor
  5. from torch.distributed._tensor.device_mesh import DeviceMesh
  6. from torch.distributed._tensor.placement_types import (
  7. _Partial,
  8. Placement,
  9. Replicate,
  10. Shard,
  11. )
  12. _PlacementItem = Tuple[int, Tuple[Placement, Placement]]
  13. def _replicate_then_shard(val: _PlacementItem) -> int:
  14. """
  15. Replicate from inner to outer dimension.
  16. Shard from outer to inner dimension.
  17. """
  18. i, (current, target) = val
  19. if (target.is_replicate() or target.is_partial()) and current.is_shard():
  20. return -i
  21. elif (current.is_replicate() or current.is_partial()) and target.is_shard():
  22. return i
  23. else:
  24. return 0
  25. def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]:
  26. """
  27. Decompose Si -> Sj into Si -> R -> Sj
  28. There's 2 ways a shardings can differ within a mesh dimension:
  29. 1) sharding on different tensor dimensions, e.g. Shard(0) -> Shard(1)
  30. 2) different sub-shards of a repeated shard ("mis-aligned sharding")
  31. (Shard(0), Shard(0)) -> (Replicate(), Shard(0))
  32. Here the Shard(0) -> Shard(0) for mesh dimension 2 is actually
  33. a reshard, because in the first case it's a sub-sharding of an already tensor dimension 0,
  34. and in the second case, it's the first sharding on tensor dimesnion 0.
  35. """
  36. # detect mis-aligned repeated shardings
  37. from collections import defaultdict
  38. repeat_dim_current: Dict[int, int] = defaultdict(int)
  39. repeat_dim_target: Dict[int, int] = defaultdict(int)
  40. output: List[_PlacementItem] = []
  41. for i, (current, target) in val:
  42. # detect mis-aligned sharding
  43. if current.is_shard():
  44. repeat_dim_current[cast(Shard, current).dim] += 1
  45. if target.is_shard():
  46. repeat_dim_target[cast(Shard, target).dim] += 1
  47. if (
  48. isinstance(current, Shard)
  49. and isinstance(target, Shard)
  50. and (
  51. current.dim != target.dim
  52. or repeat_dim_current[current.dim] != repeat_dim_target[target.dim]
  53. )
  54. ):
  55. # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j)
  56. output.append((i, (current, Replicate())))
  57. output.append((i, (Replicate(), target)))
  58. else:
  59. output.append((i, (current, target)))
  60. return output
  61. # Intentionally expose this API to trace ops on local tensors
  62. def _redistribute_with_local_tensor(
  63. local_tensor: torch.Tensor,
  64. size: torch.Size,
  65. device_mesh: DeviceMesh,
  66. current_placements: Sequence[Placement],
  67. target_placements: Sequence[Placement],
  68. ) -> torch.Tensor:
  69. new_local_tensor = None
  70. sorted_placements = list(enumerate(zip(current_placements, target_placements)))
  71. sorted_placements = _decompose_reshard(sorted_placements)
  72. sorted_placements.sort(key=_replicate_then_shard)
  73. for i, (current, target) in sorted_placements:
  74. my_coordinate = device_mesh.get_coordinate_on_dim(i)
  75. num_chunks = device_mesh.size(dim=i)
  76. # TODO: what should happen if rank is not in the mesh?
  77. # see issue https://github.com/pytorch/tau/pull/492
  78. assert (
  79. my_coordinate is not None
  80. ), "Rank if not part of mesh" # TODO: figure out behavior here
  81. if current == target:
  82. # short cut, just use the original local tensor
  83. new_local_tensor = local_tensor
  84. continue
  85. if target.is_replicate():
  86. # Case 1: target is Replicate
  87. if current.is_partial():
  88. partial_spec = cast(_Partial, current)
  89. new_local_tensor = partial_spec._to_replicate(
  90. local_tensor, device_mesh, i
  91. )
  92. elif current.is_shard():
  93. current_placement = cast(Shard, current)
  94. new_local_tensor = current_placement._to_replicate_tensor(
  95. local_tensor, size, device_mesh, i
  96. )
  97. else:
  98. raise RuntimeError(
  99. f"redistribute from {current_placements} to {target_placements} not supported yet"
  100. )
  101. elif target.is_shard():
  102. # Case 2: target is Shard
  103. target_placement = cast(Shard, target)
  104. if current.is_partial():
  105. partial_spec = cast(_Partial, current)
  106. new_local_tensor = partial_spec._to_shard(
  107. local_tensor, device_mesh, i, target_placement
  108. )
  109. elif current.is_replicate():
  110. # split the tensor and return the corresponding cloned local shard
  111. shards, _ = target_placement._split_tensor(
  112. local_tensor,
  113. num_chunks,
  114. with_padding=False,
  115. contiguous=False,
  116. )
  117. new_local_tensor = shards[my_coordinate].clone()
  118. else:
  119. # NOTE: this case shouldn't hit _decompose_sharding, decompose sharding should
  120. # decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1)
  121. assert (
  122. current.is_shard()
  123. ), f"Current placement should be shard but found {current}"
  124. shard_spec = cast(Shard, current)
  125. if shard_spec.dim != target_placement.dim:
  126. # TODO: enable this with all_to_all
  127. raise NotImplementedError(
  128. "Changing sharding dim is not supported yet!"
  129. )
  130. elif target.is_partial():
  131. if current.is_replicate():
  132. # For replicate -> partial, we zero out all other ranks of the current mesh dim
  133. # and leave only 1 rank have the data, to perform a "zero cost" reshard.
  134. if my_coordinate is not None and my_coordinate != 0:
  135. new_local_tensor = local_tensor.zero_()
  136. else:
  137. new_local_tensor = local_tensor
  138. else:
  139. raise RuntimeError(
  140. f"redistribute from {current_placements} to {target_placements} not supported yet"
  141. )
  142. assert new_local_tensor is not None
  143. local_tensor = new_local_tensor
  144. assert new_local_tensor is not None, "redistribute failed!"
  145. return new_local_tensor
  146. def redistribute_dtensor(
  147. input: "dtensor.DTensor",
  148. device_mesh: DeviceMesh,
  149. placements: Sequence[Placement],
  150. ) -> "dtensor.DTensor":
  151. if input.device_mesh != device_mesh:
  152. # TODO: alltoall reshuffling to change device_mesh if they are not the same
  153. raise NotImplementedError("Cross device mesh comm not supported yet!")
  154. local_tensor = input._local_tensor
  155. new_local_tensor = _redistribute_with_local_tensor(
  156. local_tensor,
  157. input.size(),
  158. device_mesh,
  159. input.placements,
  160. placements,
  161. )
  162. return dtensor.DTensor(
  163. new_local_tensor,
  164. device_mesh,
  165. placements,
  166. size=input.size(),
  167. requires_grad=local_tensor.requires_grad,
  168. )
  169. class Redistribute(torch.autograd.Function):
  170. @staticmethod
  171. def forward( # type: ignore[override]
  172. # pyre-fixme[2]: Parameter must be annotated.
  173. ctx,
  174. input: "dtensor.DTensor",
  175. device_mesh: DeviceMesh,
  176. placements: List[Placement],
  177. ):
  178. ctx.previous_placement = input.placements
  179. ctx.previous_device_mesh = input.device_mesh
  180. return redistribute_dtensor(input, device_mesh, placements)
  181. @staticmethod
  182. def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
  183. previous_placement = ctx.previous_placement
  184. previous_device_mesh = ctx.previous_device_mesh
  185. # When we run backward pass of redistribute (i.e. manual redistribute from
  186. # user code instead of torch_dispatch), we scan first and see if we need
  187. # to change the target placement for one special case:
  188. # replicate -> partial.
  189. # In this case we keep the grad as replicate, this is because we don't
  190. # want to convert the replicated gradients back to partial, although
  191. # that's logically conform with the same layout, converting the gradients
  192. # back to partial is acutally useless as you would have to do reduce later
  193. # which would be more expensive than keeping it replicate! For this reason,
  194. # we keep the replicate grad here.
  195. # TODO: see if this make sense for all cases.
  196. target_placements: List[Placement] = []
  197. for current, target in zip(grad_output.placements, previous_placement):
  198. if current.is_replicate() and target.is_partial():
  199. # keep target placement to replicate instead of partial in this case
  200. target_placements.append(current)
  201. else:
  202. target_placements.append(target)
  203. return (
  204. redistribute_dtensor(grad_output, previous_device_mesh, target_placements),
  205. None,
  206. None,
  207. )