partial_tensor.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. import functools
  2. from typing import Callable, Dict, TYPE_CHECKING
  3. import torch
  4. import torch.distributed as dist
  5. import torch.distributed._shard.sharding_spec as shard_spec
  6. from torch.distributed import distributed_c10d
  7. from torch.distributed.nn.functional import (
  8. reduce_scatter,
  9. )
  10. from torch.distributed._shard.common_op_utils import _register_default_op
  11. from torch.distributed._shard.op_registry_utils import _decorator_func
  12. from torch.utils._pytree import tree_map
  13. if TYPE_CHECKING:
  14. # Only include ShardedTensor when do type checking, exclude it
  15. # from run-time to resolve circular dependency.
  16. from torch.distributed._shard.sharded_tensor import ShardedTensor
  17. # Custom PartialTensor ops
  18. _PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}
  19. def _custom_partial_tensor_op(func):
  20. """
  21. Decorate for custom partial tensor op
  22. Args:
  23. func(Callable): Torch function for which we want to provide a PartialTensor
  24. implementation (ex: torch.nn.functional.linear)
  25. """
  26. return functools.partial(
  27. _decorator_func,
  28. op=func,
  29. op_table=_PARTIAL_TENSOR_OPS
  30. )
  31. class _PartialTensor(torch.Tensor):
  32. """
  33. PartialTensor is an abstraction to represent Tensors that need
  34. aggregation across multiple devices and multiple processes.
  35. PartialTensor is initialized in an SPMD like fashion where each rank
  36. initializes the PartialTensor. The PartialTensor object on each rank
  37. then only stores the local partial shard, process group and the
  38. aggregation way to get a full tensor.
  39. PartialTensor doesn't provide any Tensor like operations but is a
  40. wrapper providing the Tensor representing the local partial shard.
  41. We assume the size of each local tensor to be exactly the same.
  42. Users can apply custom distributed sharded computations on top of
  43. this primitive.
  44. Args:
  45. local_partial_shard (Tensor): Partial result stored across ranks.
  46. process_group (ProcessGroup): The process group to aggregate on.
  47. reduce_op (distributed_c10d.ReduceOp): Way to aggregate the partial result.
  48. Default: ``distributed_c10d.ReduceOp.SUM``
  49. Examples:
  50. >>> # All tensors below are of torch.int64 type.
  51. >>> # We have 2 process groups, 2 ranks.
  52. >>> # xdoctest: +SKIP
  53. >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
  54. >>> tensor = torch.cat([tensor, tensor + 2])
  55. >>> tensor
  56. tensor([1, 2, 3, 4]) # Rank 0
  57. tensor([3, 4, 5, 6]) # Rank 1
  58. >>> partial_tensor = _PartialTensor(tensor, distributed_c10d.ReduceOp.MAX)
  59. >>> sharding_dim = 0
  60. >>> collect_spec = shard_spec.ChunkShardingSpec(
  61. dim=sharding_dim,
  62. placements=[
  63. "rank:0/cuda:0",
  64. "rank:1/cuda:1",
  65. ],
  66. )
  67. >>> complete_tensor = partial_tensor.reshard(collect_spec)
  68. >>> complete_tensor
  69. ShardedTensor(
  70. ShardedTensorMetadata(
  71. shards_metadata=[
  72. ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0),
  73. ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)],
  74. size=torch.Size([4])
  75. )
  76. >>> complete_tensor.local_tensor()
  77. tensor([3, 4]) # Rank 0
  78. tensor([5, 6]) # Rank 1
  79. >>> # All tensors below are of torch.cfloat type.
  80. >>> # We have 2 process groups, 2 ranks.
  81. >>> tensor = torch.tensor([1, 2]) + 2 * rank
  82. >>> tensor = torch.cat([tensor, tensor + 2])
  83. >>> tensor
  84. tensor([1, 2, 3, 4]) # Rank 0
  85. tensor([3, 4, 5, 6]) # Rank 1
  86. >>> partial_tensor = _PartialTensor(tensor)
  87. >>> complete_tensor = partial_tensor.reshard(collect_spec)
  88. >>> complete_tensor
  89. ShardedTensor(
  90. ShardedTensorMetadata(
  91. shards_metadata=[
  92. ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0),
  93. ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1)],
  94. size=torch.Size([4])
  95. )
  96. >>> complete_tensor.local_tensor()
  97. tensor([4, 6]) # Rank 0
  98. tensor([8, 10]) # Rank 1
  99. """
  100. _process_group: distributed_c10d.ProcessGroup
  101. _local_shard: torch.Tensor
  102. _reduce_op: distributed_c10d.ReduceOp
  103. __slots__ = ["_process_group", "_local_shard", "_reduce_op"]
  104. def __new__(cls, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM):
  105. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  106. cls,
  107. local_shard.size(),
  108. dtype=local_shard.dtype,
  109. layout=local_shard.layout,
  110. pin_memory=local_shard.is_pinned(),
  111. requires_grad=local_shard.requires_grad) # type: ignore[arg-type]
  112. r._process_group = ( # type: ignore[attr-defined]
  113. process_group
  114. if process_group is not None
  115. else distributed_c10d._get_default_group()
  116. )
  117. r._reduce_op = reduce_op
  118. r._local_shard = local_shard
  119. return r
  120. def __post_init__(self):
  121. if not isinstance(self._reduce_op, distributed_c10d.ReduceOp):
  122. raise ValueError(
  123. "reduce_op needs to be a member of distributed_c10d.ReduceOp."
  124. )
  125. def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> "ShardedTensor":
  126. """
  127. The reshard happens in two steps logically:
  128. 1. Aggregate all the shards of the partial tensor.
  129. 2. Shard this tensor according to the provided spec.
  130. In reality, for the sake of performance, we consolidate all partial tensors
  131. across multiple ranks and covert to a sharded tensor in one step.
  132. Args:
  133. resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
  134. The specification describing how we reshard the aggregated local result.
  135. Returns:
  136. A :class:`ShardedTensor` filled with local aggregated result.
  137. """
  138. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  139. if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
  140. raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
  141. if self._local_shard.is_complex():
  142. raise NotImplementedError("Only real partial tensor supported for reshard.")
  143. sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined]
  144. chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size()
  145. local_shard = self._local_shard
  146. # Add padding when the size is not divisible by the world size.
  147. if chunk_mode_res != 0:
  148. padding = [0] * (local_shard.dim() * 2)
  149. padding[-1] = self._process_group.size() - chunk_mode_res
  150. local_shard = torch.nn.functional.pad(
  151. local_shard,
  152. tuple(padding),
  153. "constant",
  154. 0,
  155. )
  156. current_rank = dist.get_rank(self._process_group) # type: ignore[attr-defined]
  157. rank_idx = None
  158. rearrange_local_shards = False
  159. indices = [0] * self._process_group.size()
  160. for idx, placement in enumerate(resharding_spec.placements): # type: ignore[attr-defined]
  161. if placement.rank() == current_rank: # type: ignore[index, union-attr]
  162. rank_idx = idx # type: ignore[attr-defined]
  163. if placement.rank() != idx: # type: ignore[index, union-attr]
  164. rearrange_local_shards = True
  165. indices[placement.rank()] = idx # type: ignore[index, union-attr]
  166. local_shards = local_shard.chunk(self._process_group.size(), dim=sharding_dim)
  167. if rearrange_local_shards:
  168. # Need to re-arrange original shard_dim of output_tensor_list.
  169. local_shards = [local_shards[idx] for idx in indices] # type: ignore[call-overload]
  170. local_result = reduce_scatter(
  171. torch.empty_like(local_shards[0]),
  172. list(local_shards),
  173. op=self._reduce_op,
  174. group=self._process_group,
  175. )
  176. sharded_tensor_size = self._local_shard.size()
  177. # Remove padding when the size is not divisible by the world size.
  178. if chunk_mode_res != 0:
  179. uneven_local_shards = self._local_shard.chunk(
  180. self._process_group.size(), dim=sharding_dim
  181. )
  182. expected_size = uneven_local_shards[rank_idx].size() # type: ignore[index]
  183. if local_result.size() != expected_size:
  184. local_result = local_result.narrow(
  185. sharding_dim,
  186. 0,
  187. expected_size[sharding_dim],
  188. )
  189. return ShardedTensor._init_from_local_tensor(
  190. local_result,
  191. resharding_spec,
  192. sharded_tensor_size,
  193. process_group=self._process_group,
  194. )
  195. @classmethod
  196. def __torch_function__(cls, func, types, args=(), kwargs=None):
  197. # Find process_group
  198. process_group = None
  199. def find_process_group(e):
  200. nonlocal process_group
  201. if process_group is None and isinstance(e, _PartialTensor):
  202. process_group = e._process_group
  203. tree_map(find_process_group, args)
  204. tree_map(find_process_group, kwargs)
  205. if func in _PARTIAL_TENSOR_OPS:
  206. return _PARTIAL_TENSOR_OPS[func](types, args, kwargs, process_group)
  207. # Need to disable all dispatch to print args and kwargs appropriately.
  208. guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
  209. try:
  210. with torch._C.DisableTorchFunctionSubclass():
  211. raise RuntimeError(
  212. f"torch function '{func.__name__}', with args: {args} and "
  213. f"kwargs: {kwargs} not supported for PartialTensor!")
  214. finally:
  215. del guard
  216. @classmethod
  217. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  218. raise RuntimeError(
  219. f"A {cls.__name__} object is being used from c++ "
  220. f"while calling {func.__module__}.{func.__name__} "
  221. "but the there is no custom __torch_dispatch__ implementation for it."
  222. )
  223. def __repr__(self):
  224. return f"PartialTensor({super().__repr__()})"
  225. def _transpose_impl(types, args=(), kwargs=None, process_group=None):
  226. partial_tensor = args[0]
  227. input = partial_tensor._local_shard
  228. dim0 = args[1]
  229. dim1 = args[2]
  230. return _PartialTensor(
  231. torch.transpose(input, dim0, dim1),
  232. process_group,
  233. partial_tensor._reduce_op
  234. )
  235. @_custom_partial_tensor_op(torch.Tensor.transpose)
  236. def partial_transpose(types, args=(), kwargs=None, process_group=None):
  237. return _transpose_impl(types, args, kwargs, process_group)
  238. @_custom_partial_tensor_op(torch.transpose)
  239. def partial_torch_transpose(types, args=(), kwargs=None, process_group=None):
  240. return _transpose_impl(types, args, kwargs, process_group)
  241. @_custom_partial_tensor_op(torch.cat)
  242. def partial_cat(types, args=(), kwargs=None, process_group=None):
  243. input_list = args[0]
  244. if len(input_list) == 0:
  245. raise RuntimeError('Empty list of tensors to torch.cat!')
  246. local_shards = []
  247. for idx, input in enumerate(input_list):
  248. if not isinstance(input, _PartialTensor):
  249. raise RuntimeError('All inputs need to be an instance of _PartialTensor')
  250. if idx == 0:
  251. reduce_op = input._reduce_op
  252. elif reduce_op != input._reduce_op:
  253. raise RuntimeError(
  254. 'All _PartialTensor reduce_ops need to be the same, found: '
  255. '{reduce_op} and {input._reduce_op}'
  256. )
  257. local_shards.append(input._local_shard)
  258. if kwargs is None:
  259. dim = 0
  260. else:
  261. if 'out' in kwargs:
  262. raise RuntimeError('"out" kwarg is not supported!')
  263. dim = kwargs['dim'] if 'dim' in kwargs else 0
  264. return _PartialTensor(torch.cat(local_shards, dim), process_group, input._reduce_op)
  265. # Tensor properties access
  266. _register_default_op(torch.Tensor.requires_grad.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
  267. _register_default_op(torch.Tensor.shape.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
  268. _register_default_op(torch.Tensor.dtype.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
  269. _register_default_op(torch.Tensor.layout.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
  270. _register_default_op(torch.Tensor.size, _custom_partial_tensor_op)
  271. _register_default_op(torch.Tensor.dim, _custom_partial_tensor_op)
  272. _register_default_op(torch.Tensor.ndim.__get__, _custom_partial_tensor_op) # type: ignore[attr-defined]
  273. _register_default_op(torch.Tensor.is_contiguous, _custom_partial_tensor_op)
  274. _register_default_op(torch.Tensor.contiguous, _custom_partial_tensor_op)