_utils.py 1020 B

1234567891011121314151617181920212223242526
  1. import torch
  2. from torch.distributed._shard.metadata import ShardMetadata
  3. from typing import Sequence
  4. def narrow_tensor_by_index(tensor: torch.Tensor, offsets: Sequence[int], sizes: Sequence[int]) -> torch.Tensor:
  5. """
  6. Narrow the tensor according to ``offsets`` and ``sizes``.
  7. """
  8. narrowed_tensor = tensor
  9. for idx, (offset, size) in enumerate(zip(offsets, sizes)):
  10. if size < tensor.size(idx):
  11. # Reshape to get shard for this rank and we don't want autograd
  12. # recording here for the narrow op and 'local_shard' should be a
  13. # leaf variable in the autograd graph.
  14. narrowed_tensor = narrowed_tensor.narrow(
  15. idx,
  16. offset,
  17. size
  18. )
  19. return narrowed_tensor
  20. def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor:
  21. """
  22. Narrow the tensor according to the metadata
  23. """
  24. return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes)