from typing import List, Any import torch from torch.distributed._shard.metadata import ShardMetadata from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._shard.sharded_tensor.metadata import TensorProperties from torch.distributed._shard.sharded_tensor.shard import Shard from torch.distributed._shard.sharding_spec._internals import ( _check_shard_metadata_pair_overlap, ) from .planner import ( LoadItemType, SavePlan, ReadItem, WriteItem, WriteItemType, TensorWriteData, ) from .metadata import ( BytesStorageMetadata, ChunkStorageMetadata, TensorStorageMetadata, MetadataIndex, STATE_DICT_TYPE, STORAGE_TYPES, ) from .resharding import _shards_get_overlap_region_wrt_saved_tensor __all__: List[str] = [] def _create_shard_metadata(size: torch.Size) -> ShardMetadata: return ShardMetadata( shard_offsets=[0] * len(size), shard_sizes=list(size), ) def _create_shard_from_tensor(tensor: torch.Tensor) -> Shard: return Shard(tensor=tensor, metadata=_create_shard_metadata(tensor.size())) def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: return ChunkStorageMetadata( offsets=torch.Size(shard_md.shard_offsets), sizes=torch.Size(shard_md.shard_sizes), ) def _sharded_tensor_metadata( sharded_tensor: ShardedTensor, shard_md: ShardMetadata ) -> TensorWriteData: return TensorWriteData( chunk=_chunk_for_shard(shard_md), properties=sharded_tensor.metadata().tensor_properties, size=sharded_tensor.metadata().size, ) def _create_write_item_for_shard( fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata ) -> WriteItem: offsets = torch.Size(shard_md.shard_offsets) return WriteItem( index=MetadataIndex(fqn, offsets), type=WriteItemType.SHARD, tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), ) def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: offsets = torch.Size([0] * len(tensor.size())) return WriteItem( index=MetadataIndex(fqn, offsets), type=WriteItemType.TENSOR, tensor_data=TensorWriteData( chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), properties=TensorProperties.create_from_tensor(tensor), size=tensor.size(), ), ) def _create_write_item_for_bytesio(fqn: str, bytes: Any): return WriteItem( index=MetadataIndex(fqn), type=WriteItemType.BYTE_IO, ) def _create_read_item_for_byteio( dest_index, dest_offset, storage_index, storage_offset, length ): return ReadItem( type=LoadItemType.BYTE_IO, dest_index=dest_index, dest_offsets=torch.Size((dest_offset,)), storage_index=storage_index, storage_offsets=torch.Size((storage_offset,)), lengths=torch.Size((length,)), ) def _create_read_item_for_tensor( dest_index, dest_offsets, storage_index, storage_offsets, lengths ): return ReadItem( type=LoadItemType.TENSOR, dest_index=dest_index, dest_offsets=torch.Size(dest_offsets), storage_index=storage_index, storage_offsets=torch.Size(storage_offsets), lengths=torch.Size(lengths), ) def _create_sharded_read_items( fqn: str, checkpoint_md: TensorStorageMetadata, local_shards: List[Shard], ) -> List[ReadItem]: read_items = [] # this is a naive quadratic algo that can be optimized later for idx, shard in enumerate(local_shards): for storage_idx, storage_md in enumerate(checkpoint_md.chunks): shard_md_from_storage = ShardMetadata( shard_sizes=list(storage_md.sizes), shard_offsets=list(storage_md.offsets), ) if not _check_shard_metadata_pair_overlap( shard.metadata, shard_md_from_storage ): continue storage_offsets = [] dest_offsets = [] lengths = [] for ( dim, offset_for_saved_tensor, offset_for_current_tensor, length, ) in _shards_get_overlap_region_wrt_saved_tensor( saved_shard=shard_md_from_storage, current_shard=shard.metadata ): storage_offsets.append(offset_for_saved_tensor) dest_offsets.append(offset_for_current_tensor) lengths.append(length) read_items.append( _create_read_item_for_tensor( dest_index=MetadataIndex( fqn, shard.metadata.shard_offsets, idx ), dest_offsets=dest_offsets, storage_index=MetadataIndex( fqn, storage_md.offsets, storage_idx ), storage_offsets=storage_offsets, lengths=lengths, ) ) return read_items def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: requests = [] for fqn, obj in state_dict.items(): if isinstance(obj, ShardedTensor): for shard_md in obj.metadata().shards_metadata: requests.append( _create_write_item_for_shard(fqn, obj, shard_md) ) elif isinstance(obj, torch.Tensor): requests.append(_create_write_item_for_tensor(fqn, obj)) else: requests.append(_create_write_item_for_bytesio(fqn, obj)) return SavePlan(requests) def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: if isinstance(object, ShardedTensor): return [ _create_write_item_for_shard(fqn, object, shard.metadata) for shard in object.local_shards() ] elif isinstance(object, torch.Tensor): return [_create_write_item_for_tensor(fqn, object)] else: return [_create_write_item_for_bytesio(fqn, object)] def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: if isinstance(md, BytesStorageMetadata): return [ _create_read_item_for_byteio( dest_index=MetadataIndex(fqn), dest_offset=0, storage_index=MetadataIndex(fqn), storage_offset=0, length=0, ) ] elif isinstance(obj, ShardedTensor): local_shards = obj.local_shards() elif isinstance(obj, torch.Tensor): local_shards = [_create_shard_from_tensor(obj)] else: raise ValueError( f"Invalid checkpoint metadata for {fqn}, " + f"expected BytesStorageMetadata but found {type(md)}" ) return _create_sharded_read_items(fqn, md, local_shards)