metadata.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from dataclasses import dataclass, field
  2. from typing import Dict, List, Union, Optional, Sequence, Any
  3. from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
  4. import torch
  5. from torch.distributed._shard.sharded_tensor import (
  6. ShardedTensor,
  7. )
  8. __all__ = [
  9. "ChunkStorageMetadata",
  10. "TensorStorageMetadata",
  11. "BytesStorageMetadata",
  12. "Metadata",
  13. "MetadataIndex",
  14. ]
  15. @dataclass
  16. class ChunkStorageMetadata:
  17. """
  18. Each chunk is expected to have the same properties of the TensorStorageMetadata that includes it.
  19. """
  20. offsets: torch.Size
  21. sizes: torch.Size
  22. @dataclass
  23. class TensorStorageMetadata:
  24. properties: TensorProperties
  25. size: torch.Size
  26. chunks: List[ChunkStorageMetadata]
  27. @dataclass
  28. class BytesStorageMetadata:
  29. pass
  30. TENSOR_TYPE = Union[torch.Tensor, ShardedTensor]
  31. STORAGE_TYPES = Union[TensorStorageMetadata, BytesStorageMetadata]
  32. STATE_DICT_TYPE = Dict[str, Any]
  33. @dataclass
  34. class Metadata:
  35. # Keys are the same from the `state_dict` used.
  36. state_dict_metadata: Dict[str, STORAGE_TYPES]
  37. planner_data: Any = None
  38. storage_data: Any = None
  39. @dataclass(frozen=True)
  40. class MetadataIndex:
  41. """
  42. This class represents a lookup key for items in a state dict or Metadata.
  43. """
  44. fqn: str
  45. """Fully Qualified Name of the object"""
  46. offset: Optional[torch.Size] = None
  47. """If the object is a tensor, offset into the tensor we're looking for"""
  48. index: Optional[int] = field(hash=False, compare=False, default=None)
  49. """
  50. Index hint when searching for tensor chunk to speedup lookups (optional)
  51. A common representation of a sharded tensor is as a list of chunks so to
  52. find the index in such a list you need to linear search it.
  53. When constructing an instance of MetadataIndex that points to that list,
  54. one can provide the index as a hint and it will be probed first before
  55. the linear search and thus making it significantly faster.
  56. """
  57. def __init__(
  58. self,
  59. fqn: str,
  60. offset: Optional[Sequence[int]] = None,
  61. index: Optional[int] = None,
  62. ):
  63. # We must use object.__setattr__ due to frozen=True
  64. object.__setattr__(self, "fqn", fqn)
  65. object.__setattr__(self, "index", index)
  66. if offset is not None:
  67. object.__setattr__(self, "offset", torch.Size(offset))