api.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from abc import ABC, abstractmethod
  2. from dataclasses import dataclass
  3. import functools
  4. from typing import Callable, Dict, List, TYPE_CHECKING
  5. import torch
  6. from ._internals import (
  7. check_tensor,
  8. get_chunked_dim_size,
  9. get_split_size,
  10. validate_non_overlapping_shards_metadata
  11. )
  12. from torch.distributed._shard.metadata import ShardMetadata
  13. import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
  14. from torch.distributed._shard.op_registry_utils import _decorator_func
  15. if TYPE_CHECKING:
  16. # Only include ShardedTensor when do type checking, exclude it
  17. # from run-time to resolve circular dependency.
  18. from torch.distributed._shard.sharded_tensor import ShardedTensor
  19. class PlacementSpec(ABC):
  20. """
  21. Base class representing the placement of an entity. Subclasses of this
  22. class can be used to specify customized placements which might not be
  23. covered by existing APIs.
  24. """
  25. pass
  26. @dataclass
  27. class DevicePlacementSpec(PlacementSpec):
  28. """
  29. Associates placement of an entity with a single device.
  30. Args:
  31. device(:class:`torch.distributed._remote_device`): The device to place the entity on.
  32. """
  33. device: torch.distributed._remote_device
  34. def __post_init__(self):
  35. if not isinstance(self.device, torch.distributed._remote_device):
  36. self.device = torch.distributed._remote_device(self.device)
  37. class ShardingSpec(ABC):
  38. """
  39. Base class representing sharding specifications.
  40. """
  41. @abstractmethod
  42. def build_metadata(self,
  43. tensor_sizes: torch.Size,
  44. tensor_properties: sharded_tensor_meta.TensorProperties,
  45. ) -> sharded_tensor_meta.ShardedTensorMetadata:
  46. """
  47. Given a global tensor size, define how to shard a tensor like this shape
  48. across ranks, return ShardedTensorMetadata
  49. Args:
  50. tensor_sizes (:class:`torch.Size`):
  51. The tensor shape to shard on, a `torch.Size` object that represents the
  52. tensor shape to be sharded according to the ShardingSpec.
  53. tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties):
  54. Tensor properties used to create a ShardedTensor.
  55. Returns:
  56. A :class:`ShardedTensorMetadata` object that encodes the information about
  57. the layout of the ShardedTensor and its properties.
  58. """
  59. @abstractmethod
  60. def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
  61. """
  62. Given a global tensor on src_rank, shard this tensor
  63. across ranks within the process group, return a ShardedTensor.
  64. Args:
  65. tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
  66. Keyword args:
  67. src_rank (int, optional): The source rank which is used as the ground truth of
  68. the data for the parameter that would be sharded and scattered
  69. across the rest of the ranks.
  70. Default: 0.
  71. process_group (ProcessGroup, optional): The process group to work on. If None,
  72. the default process group will be used.
  73. Returns:
  74. A :class:`ShardedTensor` sharded from the given tensor.
  75. """
  76. # Ops customized for a particular ShardingSpec.
  77. _CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
  78. def _has_custom_op(sharding_spec, op):
  79. """
  80. Returns whether or not the ShardingSpec has a custom op implementation.
  81. """
  82. class_name = type(sharding_spec).__qualname__
  83. return class_name in _CUSTOM_SHARDING_SPEC_OPS and op in _CUSTOM_SHARDING_SPEC_OPS[class_name]
  84. def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs, process_group):
  85. """
  86. Calls the custom op for this ShardingSpec if it exists.
  87. """
  88. class_name = type(sharding_spec).__qualname__
  89. if not _has_custom_op(sharding_spec, op):
  90. raise RuntimeError(f'Custom op: {op} not registered for {class_name}')
  91. func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op]
  92. return func(types, args, kwargs, process_group)
  93. def custom_sharding_spec_op(sharding_spec_class, func):
  94. """
  95. Decorator to allow custom registration of ops.
  96. Args:
  97. sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
  98. func(Callable): The op to override (ex: torch.bmm)
  99. """
  100. class_name = sharding_spec_class.__qualname__
  101. if class_name not in _CUSTOM_SHARDING_SPEC_OPS:
  102. _CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
  103. return functools.partial(
  104. _decorator_func,
  105. op=func,
  106. op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name]
  107. )
  108. @dataclass
  109. class EnumerableShardingSpec(ShardingSpec):
  110. """
  111. This is a type of PlacementSpec that allows users to specify a generic
  112. sharding scheme by enumerating exactly how each shard is laid out.
  113. Args:
  114. shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
  115. each shard. Note that none of the shards should overlap.
  116. """
  117. shards: List[ShardMetadata]
  118. def __post_init__(self):
  119. if len(self.shards) == 0:
  120. raise ValueError(f'Empty shard list provided: {self.shards}')
  121. # Validate each shard has same rank.
  122. rank = -1
  123. for shard in self.shards:
  124. if rank != -1 and rank != len(shard.shard_offsets):
  125. raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}')
  126. rank = len(shard.shard_offsets)
  127. validate_non_overlapping_shards_metadata(self.shards)
  128. def build_metadata(self,
  129. tensor_sizes: torch.Size,
  130. tensor_properties: sharded_tensor_meta.TensorProperties,
  131. ) -> sharded_tensor_meta.ShardedTensorMetadata:
  132. # check if shards form a valid tensor
  133. check_tensor(self.shards, tensor_sizes)
  134. return sharded_tensor_meta.ShardedTensorMetadata(
  135. self.shards,
  136. tensor_sizes,
  137. tensor_properties
  138. )
  139. def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
  140. # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec
  141. raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!")
  142. def _infer_sharding_spec_from_shards_metadata(shards_metadata):
  143. """
  144. Infer the sharding spec from the metadata of each shard of a ShardedTensor.
  145. If the tensor is sharded only on one dimension, we can then verify whether it's
  146. a ChunkShardingSpec or not. The way to verify it is to first get the total length
  147. and perform a chunk sharding with the given placements to see if we can have the
  148. same chunk size as the given shards_metadata. If not, we assume it's enum sharded.
  149. Args:
  150. shards_metadata (List[ShardMetadata]): List of Metadata of local shards.
  151. Returns:
  152. A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding
  153. spec for one sharded tensor.
  154. """
  155. placements = []
  156. chunk_sharding_dim = None
  157. chunk_offset_list = []
  158. shard_size_list = []
  159. # collect local shard metadatas from the global sharded_tensor_metadata
  160. for shard_metadata in shards_metadata: # type: ignore[attr-defined]
  161. placements.append(shard_metadata.placement)
  162. local_offsets = shard_metadata.shard_offsets
  163. chunk_offset_list.append(sum(local_offsets))
  164. shard_size_list.append(shard_metadata.shard_sizes)
  165. shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0]
  166. # If the offset is [0, 0, ..., 0] (all zeros),
  167. # we cannot decide whether how the tensor is sharded.
  168. if len(shard_dims) == 0:
  169. continue
  170. # If the offset is [0, N, .,0, M, 0, .., 0],
  171. # we are sure it's sharded by more than one dimension.
  172. if len(shard_dims) != 1:
  173. chunk_sharding_dim = None
  174. break
  175. # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just
  176. # one dimension, we need to make sure all ranks share the same dimension.
  177. if not chunk_sharding_dim:
  178. chunk_sharding_dim = shard_dims[0]
  179. elif chunk_sharding_dim != shard_dims[0]:
  180. chunk_sharding_dim = None
  181. break
  182. if chunk_sharding_dim is not None:
  183. # Ensure we infer the correct placement order from offsets
  184. placements = [
  185. x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
  186. ]
  187. from .chunk_sharding_spec import ChunkShardingSpec
  188. chunk_spec = ChunkShardingSpec(
  189. dim=chunk_sharding_dim,
  190. placements=placements,
  191. )
  192. shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list])
  193. shard_total_length = sum(shard_sizes)
  194. chunks = len(placements)
  195. split_size = get_split_size(shard_total_length, chunks)
  196. chunk_shard_sizes = sorted(
  197. [
  198. get_chunked_dim_size(shard_total_length, split_size, idx)
  199. for idx in range(len(placements))
  200. ]
  201. )
  202. if shard_sizes == chunk_shard_sizes:
  203. return chunk_spec
  204. return EnumerableShardingSpec(shards_metadata)