storage.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. import abc
  2. from dataclasses import dataclass
  3. from typing import List, Any
  4. from torch.futures import Future
  5. from .metadata import (
  6. Metadata,
  7. MetadataIndex,
  8. )
  9. from .planner import (
  10. LoadPlan,
  11. SavePlan,
  12. SavePlanner,
  13. LoadPlanner,
  14. )
  15. __all__ = ["WriteResult", "StorageWriter", "StorageReader"]
  16. @dataclass(frozen=True)
  17. class WriteResult:
  18. index: MetadataIndex
  19. size_in_bytes: int
  20. storage_data: Any
  21. class StorageWriter(abc.ABC):
  22. """
  23. Interface used by ``save_state_dict`` to write to storage.
  24. One StorageWriter instance acts as both the coordinator and the follower
  25. in a distributed checkpoint. As part of initialization, each instance
  26. is told its role.
  27. A subclass should expect the following sequence of calls.
  28. 1) (all ranks) set_up_storage_writer()
  29. 2) (all ranks) prepare_local_plan()
  30. 3) (coordinator) prepare_global_plan()
  31. 4) (all ranks) write_data()
  32. 5) (coordinator) finish()
  33. """
  34. @abc.abstractmethod
  35. def set_up_storage_writer(self, is_coordinator: bool) -> None:
  36. """
  37. Initialize this instance.
  38. Args:
  39. is_coordinator (bool): Whether this instance is reponsible for coordinating
  40. the checkpoint.
  41. """
  42. pass
  43. @abc.abstractmethod
  44. def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
  45. """
  46. Perform storage-specific local planning.
  47. While this method can produce a completely different plan, the recomended
  48. way is to store storage specific data in SavePlan::storage_data.
  49. Args:
  50. plan (SavePlan): The local plan from the ``SavePlanner`` in use.
  51. Returns:
  52. A transformed ``SavePlan`` after storage local planning
  53. """
  54. pass
  55. @abc.abstractmethod
  56. def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
  57. """
  58. Perform centralized planning of storage.
  59. This method is only called on the coordinator instance.
  60. While this method can produce a completely different plan, the prefered
  61. way is to store storage specific data in SavePlan::storage_data.
  62. Args:
  63. plans: A list of ``SavePlan`` instances, one for each rank.
  64. Returns:
  65. A list of transformed ``SavePlan`` after storage global planning
  66. """
  67. pass
  68. @abc.abstractmethod
  69. def write_data(
  70. self, plan: SavePlan, planner: SavePlanner
  71. ) -> Future[List[WriteResult]]:
  72. """
  73. Write all items from ``plan`` using ``planner`` to resolve the data.
  74. A subclass should call ``SavePlanner::resolve_data`` on each item
  75. from the plan to get access to the underlying object to write.
  76. Subclasses should lazily call `resolve_data` as it can allocate memory.
  77. In case of tensors, make following assuptions:
  78. - They might be on any device, including not matching the one on ``WriteItem::tensor_data``
  79. - They might be views or not contiguous. Only the projection needs to be saved.
  80. Args:
  81. plan (SavePlan): The save plan to execute.
  82. planner (SavePlanner): Planner object to be used to resolve items to data.
  83. Returns:
  84. A future that completes to a list of WriteResult
  85. """
  86. pass
  87. @abc.abstractmethod
  88. def finish(
  89. self, metadata: Metadata, results: List[List[WriteResult]]
  90. ) -> None:
  91. """
  92. Writes the metadata and marks the current checkpoint as sucessful.
  93. The actual format/schema used for serializing `metadata` is an
  94. implemetation detail. The only requirement is that it's recoverable
  95. in to the same object graph.
  96. Args:
  97. metadata (Metadata): metadata for the new checkpoint
  98. results: A list of WriteResults from all ranks.
  99. Returns:
  100. None
  101. """
  102. pass
  103. class StorageReader(abc.ABC):
  104. """
  105. Interface used by ``load_state_dict`` to read from storage.
  106. One StorageReader instance acts as both the coordinator and the follower
  107. in a distributed checkpoint. As part of initialization, each instance
  108. is told its role.
  109. A subclass should expected the following sequence of calls by ``load_state_dict``:
  110. 1) (all ranks) read_metadata()
  111. 2) (all ranks) set_up_storage_reader()
  112. 3) (all ranks) prepare_local_plan()
  113. 4) (coordinator) prepare_global_plan()
  114. 5) (all ranks) read_data()
  115. """
  116. @abc.abstractmethod
  117. def read_metadata(self) -> Metadata:
  118. """
  119. Reads the checkpoint metadata.
  120. Returns:
  121. The metatada object associated with the checkpoint being loaded.
  122. """
  123. pass
  124. @abc.abstractmethod
  125. def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
  126. """
  127. Initialize this instance.
  128. Args:
  129. metadata (Metadata): The metadata schema to use.
  130. is_coordinator (bool): Whether this instance is reponsible for coordinating
  131. the checkpoint.
  132. """
  133. pass
  134. @abc.abstractmethod
  135. def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
  136. """
  137. Perform storage-specific local planning.
  138. While this method can produce a completely different plan, the recomended
  139. way is to store storage specific data in LoadPlan::storage_data.
  140. Args:
  141. plan (LoadPlan): The local plan from the ``LoadPlan`` in use.
  142. Returns:
  143. A transformed ``LoadPlan`` after storage local planning
  144. """
  145. pass
  146. @abc.abstractmethod
  147. def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
  148. """
  149. Perform centralized planning of storage loading.
  150. This method is only called on the coordinator instance.
  151. While this method can produce a completely different plan, the prefered
  152. way is to store storage specific data in LoadPlan::storage_data.
  153. Args:
  154. plans: A list of ``LoadPlan`` instances, one for each rank.
  155. Returns:
  156. A list of transformed ``LoadPlan`` after storage global planning
  157. """
  158. pass
  159. @abc.abstractmethod
  160. def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
  161. """
  162. Reads all items from ``plan`` using ``planner`` to resolve the data.
  163. A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO
  164. object into the right place.
  165. A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the
  166. tensors that in should load data into.
  167. It's the StorageLayer responsibility to properly schedule any cross device copies
  168. required.
  169. Args:
  170. plan (LoadPlan): The local plan to execute on
  171. planner (LoadPlanner): The planner object to use to resolve items.
  172. Returns:
  173. A future that completes once all reads are finished.
  174. """
  175. pass