resharding.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from typing import List, Tuple
  2. from torch.distributed._shard.sharding_spec import (
  3. ShardMetadata,
  4. )
  5. __all__: List[str] = []
  6. def _shards_get_overlap_region_wrt_saved_tensor(
  7. saved_shard: ShardMetadata, current_shard: ShardMetadata
  8. ) -> List[Tuple[int, int, int, int]]:
  9. """
  10. Return the overlapping region between saved_shard and current_shard.
  11. There returned list has the same number of elements as the tensor's dimension.
  12. For each element, we produce a tuple with the following contents:
  13. (dimension, `saved_shard` offset, `current_shard` offset, length)
  14. Offsets are relative to each shard.
  15. """
  16. narrows = []
  17. for dim, (
  18. saved_shard_offset,
  19. current_shard_offset,
  20. saved_shard_size,
  21. current_shard_size,
  22. ) in enumerate(
  23. zip(
  24. saved_shard.shard_offsets,
  25. current_shard.shard_offsets,
  26. saved_shard.shard_sizes,
  27. current_shard.shard_sizes,
  28. )
  29. ):
  30. min_range_end = min(
  31. saved_shard_offset + saved_shard_size,
  32. current_shard_offset + current_shard_size,
  33. )
  34. length = min_range_end - max(current_shard_offset, saved_shard_offset)
  35. if saved_shard_offset > current_shard_offset:
  36. offset_for_saved_tensor = 0
  37. offset_for_current_tensor = (
  38. saved_shard_offset - current_shard_offset
  39. )
  40. else:
  41. offset_for_saved_tensor = current_shard_offset - saved_shard_offset
  42. offset_for_current_tensor = 0
  43. narrows.append(
  44. (dim, offset_for_saved_tensor, offset_for_current_tensor, length)
  45. )
  46. return narrows