planner_helpers.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from typing import List, Any
  2. import torch
  3. from torch.distributed._shard.metadata import ShardMetadata
  4. from torch.distributed._shard.sharded_tensor import ShardedTensor
  5. from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
  6. from torch.distributed._shard.sharded_tensor.shard import Shard
  7. from torch.distributed._shard.sharding_spec._internals import (
  8. _check_shard_metadata_pair_overlap,
  9. )
  10. from .planner import (
  11. LoadItemType,
  12. SavePlan,
  13. ReadItem,
  14. WriteItem,
  15. WriteItemType,
  16. TensorWriteData,
  17. )
  18. from .metadata import (
  19. BytesStorageMetadata,
  20. ChunkStorageMetadata,
  21. TensorStorageMetadata,
  22. MetadataIndex,
  23. STATE_DICT_TYPE,
  24. STORAGE_TYPES,
  25. )
  26. from .resharding import _shards_get_overlap_region_wrt_saved_tensor
  27. __all__: List[str] = []
  28. def _create_shard_metadata(size: torch.Size) -> ShardMetadata:
  29. return ShardMetadata(
  30. shard_offsets=[0] * len(size),
  31. shard_sizes=list(size),
  32. )
  33. def _create_shard_from_tensor(tensor: torch.Tensor) -> Shard:
  34. return Shard(tensor=tensor, metadata=_create_shard_metadata(tensor.size()))
  35. def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
  36. return ChunkStorageMetadata(
  37. offsets=torch.Size(shard_md.shard_offsets),
  38. sizes=torch.Size(shard_md.shard_sizes),
  39. )
  40. def _sharded_tensor_metadata(
  41. sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  42. ) -> TensorWriteData:
  43. return TensorWriteData(
  44. chunk=_chunk_for_shard(shard_md),
  45. properties=sharded_tensor.metadata().tensor_properties,
  46. size=sharded_tensor.metadata().size,
  47. )
  48. def _create_write_item_for_shard(
  49. fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  50. ) -> WriteItem:
  51. offsets = torch.Size(shard_md.shard_offsets)
  52. return WriteItem(
  53. index=MetadataIndex(fqn, offsets),
  54. type=WriteItemType.SHARD,
  55. tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
  56. )
  57. def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
  58. offsets = torch.Size([0] * len(tensor.size()))
  59. return WriteItem(
  60. index=MetadataIndex(fqn, offsets),
  61. type=WriteItemType.TENSOR,
  62. tensor_data=TensorWriteData(
  63. chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
  64. properties=TensorProperties.create_from_tensor(tensor),
  65. size=tensor.size(),
  66. ),
  67. )
  68. def _create_write_item_for_bytesio(fqn: str, bytes: Any):
  69. return WriteItem(
  70. index=MetadataIndex(fqn),
  71. type=WriteItemType.BYTE_IO,
  72. )
  73. def _create_read_item_for_byteio(
  74. dest_index, dest_offset, storage_index, storage_offset, length
  75. ):
  76. return ReadItem(
  77. type=LoadItemType.BYTE_IO,
  78. dest_index=dest_index,
  79. dest_offsets=torch.Size((dest_offset,)),
  80. storage_index=storage_index,
  81. storage_offsets=torch.Size((storage_offset,)),
  82. lengths=torch.Size((length,)),
  83. )
  84. def _create_read_item_for_tensor(
  85. dest_index, dest_offsets, storage_index, storage_offsets, lengths
  86. ):
  87. return ReadItem(
  88. type=LoadItemType.TENSOR,
  89. dest_index=dest_index,
  90. dest_offsets=torch.Size(dest_offsets),
  91. storage_index=storage_index,
  92. storage_offsets=torch.Size(storage_offsets),
  93. lengths=torch.Size(lengths),
  94. )
  95. def _create_sharded_read_items(
  96. fqn: str,
  97. checkpoint_md: TensorStorageMetadata,
  98. local_shards: List[Shard],
  99. ) -> List[ReadItem]:
  100. read_items = []
  101. # this is a naive quadratic algo that can be optimized later
  102. for idx, shard in enumerate(local_shards):
  103. for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
  104. shard_md_from_storage = ShardMetadata(
  105. shard_sizes=list(storage_md.sizes),
  106. shard_offsets=list(storage_md.offsets),
  107. )
  108. if not _check_shard_metadata_pair_overlap(
  109. shard.metadata, shard_md_from_storage
  110. ):
  111. continue
  112. storage_offsets = []
  113. dest_offsets = []
  114. lengths = []
  115. for (
  116. dim,
  117. offset_for_saved_tensor,
  118. offset_for_current_tensor,
  119. length,
  120. ) in _shards_get_overlap_region_wrt_saved_tensor(
  121. saved_shard=shard_md_from_storage, current_shard=shard.metadata
  122. ):
  123. storage_offsets.append(offset_for_saved_tensor)
  124. dest_offsets.append(offset_for_current_tensor)
  125. lengths.append(length)
  126. read_items.append(
  127. _create_read_item_for_tensor(
  128. dest_index=MetadataIndex(
  129. fqn, shard.metadata.shard_offsets, idx
  130. ),
  131. dest_offsets=dest_offsets,
  132. storage_index=MetadataIndex(
  133. fqn, storage_md.offsets, storage_idx
  134. ),
  135. storage_offsets=storage_offsets,
  136. lengths=lengths,
  137. )
  138. )
  139. return read_items
  140. def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
  141. requests = []
  142. for fqn, obj in state_dict.items():
  143. if isinstance(obj, ShardedTensor):
  144. for shard_md in obj.metadata().shards_metadata:
  145. requests.append(
  146. _create_write_item_for_shard(fqn, obj, shard_md)
  147. )
  148. elif isinstance(obj, torch.Tensor):
  149. requests.append(_create_write_item_for_tensor(fqn, obj))
  150. else:
  151. requests.append(_create_write_item_for_bytesio(fqn, obj))
  152. return SavePlan(requests)
  153. def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
  154. if isinstance(object, ShardedTensor):
  155. return [
  156. _create_write_item_for_shard(fqn, object, shard.metadata)
  157. for shard in object.local_shards()
  158. ]
  159. elif isinstance(object, torch.Tensor):
  160. return [_create_write_item_for_tensor(fqn, object)]
  161. else:
  162. return [_create_write_item_for_bytesio(fqn, object)]
  163. def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
  164. if isinstance(md, BytesStorageMetadata):
  165. return [
  166. _create_read_item_for_byteio(
  167. dest_index=MetadataIndex(fqn),
  168. dest_offset=0,
  169. storage_index=MetadataIndex(fqn),
  170. storage_offset=0,
  171. length=0,
  172. )
  173. ]
  174. elif isinstance(obj, ShardedTensor):
  175. local_shards = obj.local_shards()
  176. elif isinstance(obj, torch.Tensor):
  177. local_shards = [_create_shard_from_tensor(obj)]
  178. else:
  179. raise ValueError(
  180. f"Invalid checkpoint metadata for {fqn}, "
  181. + f"expected BytesStorageMetadata but found {type(md)}"
  182. )
  183. return _create_sharded_read_items(fqn, md, local_shards)