metadata.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from dataclasses import dataclass
  2. from typing import List, Union, Optional
  3. from functools import reduce
  4. from torch.distributed.remote_device import _remote_device
  5. @dataclass
  6. class ShardMetadata:
  7. """
  8. Represents a shard of the overall Tensor including its
  9. offsets, lengths and device placement.
  10. Args:
  11. shard_offsets(List[int]): Offsets in the original tensor indicating
  12. the start offsets for this shard. Should have the same rank as
  13. the original tensor.
  14. shard_sizes(List[int]): Integers indicating the size of each
  15. dimension for this shard. Should have the same rank as the
  16. original tensor.
  17. placement(:class:`torch.distributed._remote_device`):
  18. Specifies the placement of this shard.
  19. """
  20. __slots__ = ['shard_offsets', 'shard_sizes', 'placement']
  21. shard_offsets: List[int]
  22. shard_sizes: List[int]
  23. placement: Optional[_remote_device]
  24. def __init__(
  25. self,
  26. shard_offsets: List[int],
  27. shard_sizes: List[int],
  28. placement: Optional[Union[str, _remote_device]] = None
  29. ):
  30. self.shard_offsets = shard_offsets
  31. self.shard_sizes = shard_sizes
  32. if isinstance(placement, str):
  33. self.placement = _remote_device(placement)
  34. else:
  35. self.placement = placement
  36. if len(self.shard_offsets) != len(self.shard_sizes):
  37. raise ValueError(
  38. f'shard_offsets and shard_sizes should have '
  39. f'the same number of elements, found {len(self.shard_offsets)} '
  40. f'and {self.shard_sizes} respectively')
  41. for i in range(len(self.shard_offsets)):
  42. if self.shard_offsets[i] < 0:
  43. raise ValueError('shard_offsets should be >=0')
  44. if self.shard_sizes[i] < 0:
  45. raise ValueError('shard_sizes should be >= 0')
  46. def __hash__(self):
  47. def _hash_reduce(a, b):
  48. return (a << 8) + hash(b)
  49. res = reduce(_hash_reduce, self.shard_offsets, 37)
  50. res = reduce(_hash_reduce, self.shard_sizes, res)
  51. res = _hash_reduce(res, self.placement)
  52. return res