from typing import List, Tuple from torch.distributed._shard.sharding_spec import ( ShardMetadata, ) __all__: List[str] = [] def _shards_get_overlap_region_wrt_saved_tensor( saved_shard: ShardMetadata, current_shard: ShardMetadata ) -> List[Tuple[int, int, int, int]]: """ Return the overlapping region between saved_shard and current_shard. There returned list has the same number of elements as the tensor's dimension. For each element, we produce a tuple with the following contents: (dimension, `saved_shard` offset, `current_shard` offset, length) Offsets are relative to each shard. """ narrows = [] for dim, ( saved_shard_offset, current_shard_offset, saved_shard_size, current_shard_size, ) in enumerate( zip( saved_shard.shard_offsets, current_shard.shard_offsets, saved_shard.shard_sizes, current_shard.shard_sizes, ) ): min_range_end = min( saved_shard_offset + saved_shard_size, current_shard_offset + current_shard_size, ) length = min_range_end - max(current_shard_offset, saved_shard_offset) if saved_shard_offset > current_shard_offset: offset_for_saved_tensor = 0 offset_for_current_tensor = ( saved_shard_offset - current_shard_offset ) else: offset_for_saved_tensor = current_shard_offset - saved_shard_offset offset_for_current_tensor = 0 narrows.append( (dim, offset_for_saved_tensor, offset_for_current_tensor, length) ) return narrows