_shard_utils.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import bisect
  2. import itertools
  3. import math
  4. from typing import Any, Dict, List, Optional, Tuple
  5. import torch
  6. import torch.distributed as dist
  7. import torch.nn.functional as F
  8. from torch.distributed import distributed_c10d
  9. from torch.distributed._shard.sharded_tensor import (
  10. Shard,
  11. ShardedTensor,
  12. ShardedTensorMetadata,
  13. TensorProperties,
  14. )
  15. from torch.distributed._shard.sharding_spec import (
  16. ChunkShardingSpec,
  17. EnumerableShardingSpec,
  18. ShardingSpec,
  19. ShardMetadata,
  20. )
  21. def _sharding_spec_to_offsets(
  22. sharding_spec: ShardingSpec, tensor_numel: int, world_size: int
  23. ) -> List[int]:
  24. r"""
  25. Translates the sharding spec to a list of offsets along dim 0. If the
  26. sharding spec is ChunkShardingSpec, only the ``dim`` is used and the
  27. placement is not used.
  28. """
  29. offsets: List[int] = []
  30. if isinstance(sharding_spec, EnumerableShardingSpec):
  31. for shard in sharding_spec.shards:
  32. offsets.append(shard.shard_offsets[0])
  33. elif isinstance(sharding_spec, ChunkShardingSpec):
  34. assert sharding_spec.dim == 0
  35. chunk_size = math.ceil(tensor_numel / world_size)
  36. if chunk_size == 1:
  37. offsets = [
  38. rank if rank < tensor_numel else tensor_numel
  39. for rank in range(world_size)
  40. ]
  41. else:
  42. offsets = [chunk_size if rank > 0 else 0 for rank in range(world_size)]
  43. offsets = list(itertools.accumulate(offsets))
  44. else:
  45. raise ValueError(f"Un-recognized sharding spec type {type(sharding_spec)}.")
  46. return offsets
  47. def _offsets_to_split_sizes(
  48. input_offsets: List[int],
  49. output_offsets: List[int],
  50. tensor_numel: int,
  51. world_size: int,
  52. my_rank: int,
  53. ) -> Tuple[List[int], List[int]]:
  54. r"""
  55. Given the shard offsets for each rank of the input tensor and output tensor,
  56. this API returns the corresponding split sizes that can be passed to
  57. all_to_all_single().
  58. """
  59. def _get_interval(offsets):
  60. if my_rank != world_size - 1:
  61. return offsets[my_rank], offsets[my_rank + 1] - 1
  62. else:
  63. return offsets[my_rank], tensor_numel - 1
  64. def _offsets_to_sizes(offsets, begin, end):
  65. sizes = []
  66. for i, offset in enumerate(offsets):
  67. next_offset = offsets[i + 1] if i < len(offsets) - 1 else end + 1
  68. sizes.append(
  69. (next_offset - offset)
  70. - max(begin - offset, 0)
  71. - max(next_offset - end - 1, 0)
  72. )
  73. return sizes
  74. def _convert(from_offsets, to_offsets, split_sizes):
  75. begin, end = _get_interval(from_offsets)
  76. to_begin_rank = bisect.bisect(to_offsets, begin) - 1
  77. to_end_rank = bisect.bisect(to_offsets, end) - 1
  78. _split_sizes = _offsets_to_sizes(
  79. to_offsets[to_begin_rank : to_end_rank + 1], begin, end
  80. )
  81. split_sizes[to_begin_rank : to_end_rank + 1] = _split_sizes
  82. input_split_sizes = [0 for _ in range(world_size)]
  83. output_split_sizes = [0 for _ in range(world_size)]
  84. _convert(input_offsets, output_offsets, input_split_sizes)
  85. _convert(output_offsets, input_offsets, output_split_sizes)
  86. return input_split_sizes, output_split_sizes
  87. def _reshard_flatten_tensor(
  88. input_tensor: ShardedTensor,
  89. output_spec: ShardingSpec,
  90. world_size: int,
  91. my_rank: int,
  92. device: torch.device,
  93. process_group: Optional[dist.ProcessGroup],
  94. ) -> torch.Tensor:
  95. """
  96. Resharded a sharded flatten tensor, this is used by FSDP to do sharded
  97. state_dict. But the functionaility is not supported by ShardedTensor.
  98. This API is designed to be used for FSDP; therefore this API supports only
  99. 1-D ShardedTensor (hence the naming, reshard_flatten_tensor).
  100. This API uses the ChunkShardingSpec and EnumerableShardingSpec from
  101. torch.distributed.sharding_spec but ignores the placement field in
  102. ChunkShardingSpec, as the placement requires the callees understand the
  103. number of GPUs per node. The API simply uses the semantics of the sharding
  104. specs.
  105. Args:
  106. input_tensor (ShardedTensor): the original ShardedTensor. Must be 1D.
  107. output_spec (ShardingSpec): the sharding spect for the output tensor.
  108. world_size (int): total trainer count.
  109. my_rank (int): the rank for this trainer.
  110. Returns:
  111. The local shard for the new ShardedTensor.
  112. """
  113. input_spec = input_tensor.sharding_spec()
  114. size = input_tensor.size()
  115. if isinstance(size, int):
  116. raise ValueError("The input tensor has no dimensions.")
  117. tensor_numel = size.numel()
  118. input_offsets = _sharding_spec_to_offsets(input_spec, tensor_numel, world_size)
  119. output_offsets = _sharding_spec_to_offsets(output_spec, tensor_numel, world_size)
  120. input_split_sizes, output_split_sizes = _offsets_to_split_sizes(
  121. input_offsets, output_offsets, tensor_numel, world_size, my_rank
  122. )
  123. output_size = sum(output_split_sizes)
  124. local_shard = torch.empty(output_size, dtype=input_tensor.dtype, device=device)
  125. dist.all_to_all_single(
  126. local_shard,
  127. input_tensor.local_shards()[0].tensor,
  128. input_split_sizes=input_split_sizes,
  129. output_split_sizes=output_split_sizes,
  130. group=process_group,
  131. )
  132. return local_shard
  133. def _all_gather_sharded_tensor(
  134. sharded_tensor: ShardedTensor, pg: Optional[dist.ProcessGroup] = None
  135. ) -> torch.Tensor:
  136. if pg is None:
  137. pg = distributed_c10d._get_default_group()
  138. world_size = dist.get_world_size(pg)
  139. shards = sharded_tensor.local_shards()
  140. dim_0_size = sharded_tensor.size()[0] # type: ignore[index]
  141. tensor_numel = sharded_tensor.size().numel() # type: ignore[union-attr]
  142. chunk_size = math.ceil(dim_0_size / world_size) * tensor_numel // dim_0_size
  143. cuda_device = torch.device("cuda", torch.cuda.current_device())
  144. if shards:
  145. local_tensor = shards[0].tensor.flatten()
  146. if not local_tensor.is_cuda:
  147. move_to_cpu = torch.ones(1, device=cuda_device)
  148. local_tensor = local_tensor.cuda()
  149. else:
  150. move_to_cpu = torch.zeros(1, device=cuda_device)
  151. num_padding = chunk_size - local_tensor.numel()
  152. if num_padding > 0:
  153. local_tensor = F.pad(local_tensor, [0, num_padding])
  154. else:
  155. local_tensor = torch.zeros(
  156. chunk_size, dtype=sharded_tensor.dtype, device=cuda_device
  157. )
  158. move_to_cpu = torch.zeros(1, device=cuda_device)
  159. tensor = torch.empty(
  160. chunk_size * world_size,
  161. dtype=local_tensor.dtype,
  162. device=cuda_device,
  163. )
  164. dist._all_gather_base(tensor, local_tensor, group=pg)
  165. tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
  166. return tensor
  167. def _gather_state_dict(
  168. state_dict: Dict[str, Any],
  169. pg: Optional[dist.ProcessGroup] = None,
  170. ) -> Dict[str, Any]:
  171. """
  172. Given a state_dict, this API gathers all the ShardedTensors in the state_dict.
  173. """
  174. new_state_dict = {}
  175. for key, tensor in state_dict.items():
  176. if isinstance(tensor, ShardedTensor):
  177. output_tensor = _all_gather_sharded_tensor(tensor, pg)
  178. if tensor.local_shards() and tensor.local_shards()[0].tensor.is_cuda:
  179. tensor = output_tensor
  180. else:
  181. tensor = output_tensor.cpu()
  182. new_state_dict[key] = tensor
  183. return new_state_dict
  184. def _create_chunk_sharded_tensor(
  185. tensor: torch.Tensor,
  186. rank: int,
  187. world_size: int,
  188. num_devices_per_node: int,
  189. pg: dist.ProcessGroup,
  190. ) -> ShardedTensor:
  191. """
  192. Shard a tensor to chunks along the first dimension. The local rank will gets its
  193. corresponding chunk as the local shard to create a ShardedTensor.
  194. """
  195. chunks = tensor.chunk(world_size, dim=0)
  196. if len(chunks) > rank:
  197. local_shard = chunks[rank].clone()
  198. offsets = [0 for _ in tensor.size()]
  199. offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
  200. local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)]
  201. else:
  202. local_shards = []
  203. # Create a ShardedTensor without invoking communication.
  204. chunk_sizes = [list(chunk.size()) for chunk in chunks]
  205. dim0_offsets = [0] + list(
  206. itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes])
  207. )[:-1]
  208. offsets = [0] * (len(chunk_sizes[0]) - 1)
  209. chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
  210. placements = [
  211. f"rank:{r}/cuda:{r % num_devices_per_node}" for r in range(len(chunk_sizes))
  212. ]
  213. assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
  214. shard_metadata = [
  215. ShardMetadata(offset, size, placement)
  216. for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements)
  217. ]
  218. sharded_tensor_metadata = ShardedTensorMetadata(
  219. shards_metadata=shard_metadata,
  220. size=tensor.size(),
  221. tensor_properties=TensorProperties(
  222. dtype=tensor.dtype,
  223. layout=tensor.layout,
  224. requires_grad=False,
  225. memory_format=torch.contiguous_format,
  226. pin_memory=tensor.is_pinned(),
  227. ),
  228. )
  229. return ShardedTensor._init_from_local_shards_and_global_metadata(
  230. local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg
  231. )