123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- from typing import List, Any
- import torch
- from torch.distributed._shard.metadata import ShardMetadata
- from torch.distributed._shard.sharded_tensor import ShardedTensor
- from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
- from torch.distributed._shard.sharded_tensor.shard import Shard
- from torch.distributed._shard.sharding_spec._internals import (
- _check_shard_metadata_pair_overlap,
- )
- from .planner import (
- LoadItemType,
- SavePlan,
- ReadItem,
- WriteItem,
- WriteItemType,
- TensorWriteData,
- )
- from .metadata import (
- BytesStorageMetadata,
- ChunkStorageMetadata,
- TensorStorageMetadata,
- MetadataIndex,
- STATE_DICT_TYPE,
- STORAGE_TYPES,
- )
- from .resharding import _shards_get_overlap_region_wrt_saved_tensor
- __all__: List[str] = []
- def _create_shard_metadata(size: torch.Size) -> ShardMetadata:
- return ShardMetadata(
- shard_offsets=[0] * len(size),
- shard_sizes=list(size),
- )
- def _create_shard_from_tensor(tensor: torch.Tensor) -> Shard:
- return Shard(tensor=tensor, metadata=_create_shard_metadata(tensor.size()))
- def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
- return ChunkStorageMetadata(
- offsets=torch.Size(shard_md.shard_offsets),
- sizes=torch.Size(shard_md.shard_sizes),
- )
- def _sharded_tensor_metadata(
- sharded_tensor: ShardedTensor, shard_md: ShardMetadata
- ) -> TensorWriteData:
- return TensorWriteData(
- chunk=_chunk_for_shard(shard_md),
- properties=sharded_tensor.metadata().tensor_properties,
- size=sharded_tensor.metadata().size,
- )
- def _create_write_item_for_shard(
- fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
- ) -> WriteItem:
- offsets = torch.Size(shard_md.shard_offsets)
- return WriteItem(
- index=MetadataIndex(fqn, offsets),
- type=WriteItemType.SHARD,
- tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
- )
- def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
- offsets = torch.Size([0] * len(tensor.size()))
- return WriteItem(
- index=MetadataIndex(fqn, offsets),
- type=WriteItemType.TENSOR,
- tensor_data=TensorWriteData(
- chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
- properties=TensorProperties.create_from_tensor(tensor),
- size=tensor.size(),
- ),
- )
- def _create_write_item_for_bytesio(fqn: str, bytes: Any):
- return WriteItem(
- index=MetadataIndex(fqn),
- type=WriteItemType.BYTE_IO,
- )
- def _create_read_item_for_byteio(
- dest_index, dest_offset, storage_index, storage_offset, length
- ):
- return ReadItem(
- type=LoadItemType.BYTE_IO,
- dest_index=dest_index,
- dest_offsets=torch.Size((dest_offset,)),
- storage_index=storage_index,
- storage_offsets=torch.Size((storage_offset,)),
- lengths=torch.Size((length,)),
- )
- def _create_read_item_for_tensor(
- dest_index, dest_offsets, storage_index, storage_offsets, lengths
- ):
- return ReadItem(
- type=LoadItemType.TENSOR,
- dest_index=dest_index,
- dest_offsets=torch.Size(dest_offsets),
- storage_index=storage_index,
- storage_offsets=torch.Size(storage_offsets),
- lengths=torch.Size(lengths),
- )
- def _create_sharded_read_items(
- fqn: str,
- checkpoint_md: TensorStorageMetadata,
- local_shards: List[Shard],
- ) -> List[ReadItem]:
- read_items = []
-
- for idx, shard in enumerate(local_shards):
- for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
- shard_md_from_storage = ShardMetadata(
- shard_sizes=list(storage_md.sizes),
- shard_offsets=list(storage_md.offsets),
- )
- if not _check_shard_metadata_pair_overlap(
- shard.metadata, shard_md_from_storage
- ):
- continue
- storage_offsets = []
- dest_offsets = []
- lengths = []
- for (
- dim,
- offset_for_saved_tensor,
- offset_for_current_tensor,
- length,
- ) in _shards_get_overlap_region_wrt_saved_tensor(
- saved_shard=shard_md_from_storage, current_shard=shard.metadata
- ):
- storage_offsets.append(offset_for_saved_tensor)
- dest_offsets.append(offset_for_current_tensor)
- lengths.append(length)
- read_items.append(
- _create_read_item_for_tensor(
- dest_index=MetadataIndex(
- fqn, shard.metadata.shard_offsets, idx
- ),
- dest_offsets=dest_offsets,
- storage_index=MetadataIndex(
- fqn, storage_md.offsets, storage_idx
- ),
- storage_offsets=storage_offsets,
- lengths=lengths,
- )
- )
- return read_items
- def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
- requests = []
- for fqn, obj in state_dict.items():
- if isinstance(obj, ShardedTensor):
- for shard_md in obj.metadata().shards_metadata:
- requests.append(
- _create_write_item_for_shard(fqn, obj, shard_md)
- )
- elif isinstance(obj, torch.Tensor):
- requests.append(_create_write_item_for_tensor(fqn, obj))
- else:
- requests.append(_create_write_item_for_bytesio(fqn, obj))
- return SavePlan(requests)
- def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
- if isinstance(object, ShardedTensor):
- return [
- _create_write_item_for_shard(fqn, object, shard.metadata)
- for shard in object.local_shards()
- ]
- elif isinstance(object, torch.Tensor):
- return [_create_write_item_for_tensor(fqn, object)]
- else:
- return [_create_write_item_for_bytesio(fqn, object)]
- def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]:
- if isinstance(md, BytesStorageMetadata):
- return [
- _create_read_item_for_byteio(
- dest_index=MetadataIndex(fqn),
- dest_offset=0,
- storage_index=MetadataIndex(fqn),
- storage_offset=0,
- length=0,
- )
- ]
- elif isinstance(obj, ShardedTensor):
- local_shards = obj.local_shards()
- elif isinstance(obj, torch.Tensor):
- local_shards = [_create_shard_from_tensor(obj)]
- else:
- raise ValueError(
- f"Invalid checkpoint metadata for {fqn}, "
- + f"expected BytesStorageMetadata but found {type(md)}"
- )
- return _create_sharded_read_items(fqn, md, local_shards)
|