12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- from dataclasses import dataclass
- from typing import List, Union, Optional
- from functools import reduce
- from torch.distributed.remote_device import _remote_device
- @dataclass
- class ShardMetadata:
- """
- Represents a shard of the overall Tensor including its
- offsets, lengths and device placement.
- Args:
- shard_offsets(List[int]): Offsets in the original tensor indicating
- the start offsets for this shard. Should have the same rank as
- the original tensor.
- shard_sizes(List[int]): Integers indicating the size of each
- dimension for this shard. Should have the same rank as the
- original tensor.
- placement(:class:`torch.distributed._remote_device`):
- Specifies the placement of this shard.
- """
- __slots__ = ['shard_offsets', 'shard_sizes', 'placement']
- shard_offsets: List[int]
- shard_sizes: List[int]
- placement: Optional[_remote_device]
- def __init__(
- self,
- shard_offsets: List[int],
- shard_sizes: List[int],
- placement: Optional[Union[str, _remote_device]] = None
- ):
- self.shard_offsets = shard_offsets
- self.shard_sizes = shard_sizes
- if isinstance(placement, str):
- self.placement = _remote_device(placement)
- else:
- self.placement = placement
- if len(self.shard_offsets) != len(self.shard_sizes):
- raise ValueError(
- f'shard_offsets and shard_sizes should have '
- f'the same number of elements, found {len(self.shard_offsets)} '
- f'and {self.shard_sizes} respectively')
- for i in range(len(self.shard_offsets)):
- if self.shard_offsets[i] < 0:
- raise ValueError('shard_offsets should be >=0')
- if self.shard_sizes[i] < 0:
- raise ValueError('shard_sizes should be >= 0')
- def __hash__(self):
- def _hash_reduce(a, b):
- return (a << 8) + hash(b)
- res = reduce(_hash_reduce, self.shard_offsets, 37)
- res = reduce(_hash_reduce, self.shard_sizes, res)
- res = _hash_reduce(res, self.placement)
- return res
|