planner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. import abc
  2. from dataclasses import dataclass
  3. import io
  4. from typing import List, Tuple, Any, Union, Optional
  5. from enum import Enum, auto
  6. import torch
  7. from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
  8. from .metadata import (
  9. ChunkStorageMetadata,
  10. MetadataIndex,
  11. Metadata,
  12. STATE_DICT_TYPE,
  13. )
  14. __all__ = [
  15. "WriteItemType",
  16. "LoadItemType",
  17. "TensorWriteData",
  18. "WriteItem",
  19. "ReadItem",
  20. "SavePlan",
  21. "LoadPlan",
  22. "SavePlanner",
  23. "LoadPlanner",
  24. ]
  25. class WriteItemType(Enum):
  26. TENSOR = auto()
  27. SHARD = auto()
  28. BYTE_IO = auto()
  29. class LoadItemType(Enum):
  30. TENSOR = auto()
  31. BYTE_IO = auto()
  32. @dataclass(frozen=True)
  33. class TensorWriteData:
  34. chunk: ChunkStorageMetadata
  35. properties: TensorProperties
  36. size: torch.Size
  37. @dataclass(frozen=True)
  38. class WriteItem:
  39. index: MetadataIndex
  40. type: WriteItemType
  41. # Value present if it's a tensor write
  42. tensor_data: Optional[TensorWriteData] = None
  43. @dataclass(frozen=True)
  44. class ReadItem:
  45. # Read Item
  46. type: LoadItemType
  47. # Index into the state_dict
  48. dest_index: MetadataIndex
  49. # Offsets into destination tensor
  50. dest_offsets: torch.Size
  51. # Index into the checkpoint
  52. storage_index: MetadataIndex
  53. # Offset into the checkpoint data
  54. storage_offsets: torch.Size
  55. # Size of the hypercube to copy
  56. lengths: torch.Size
  57. @dataclass(frozen=True)
  58. class SavePlan:
  59. items: List[WriteItem]
  60. storage_data: Any = None
  61. planner_data: Any = None
  62. @dataclass
  63. class LoadPlan:
  64. items: List[ReadItem]
  65. storage_data: Any = None
  66. planner_data: Any = None
  67. class SavePlanner(abc.ABC):
  68. """
  69. Abstract class defining the protocol used by save_state_dict to plan the save process.
  70. SavePlanners are stateful objects that can be used to customize the whole save process.
  71. SavePlanner acts as an access proxy to the state_dict, so any transfomation done to it
  72. will be visible to the whole process.
  73. A planner subclass can expect the following sequence of calls during save_state_dict:
  74. 1) set_up_planner - called on all ranks.
  75. Signals the start of a checkpoint save.
  76. 2) create_local_plan - called on all ranks.
  77. Process the state_dict and produces a `SavePlan` that will be sent for global planning.
  78. 3) create_global_plan - called on the coordinator rank only.
  79. Takes the SavePlan from all ranks and make any global decision.
  80. 4) finish_plan - called on all ranks.
  81. This gives each rank a chance to adjust to global planning decisions.
  82. 5) resolve_data - called multiple times on each rank
  83. Lookups a value on the `state_dict` for the storage layer to write.
  84. Users are recomended to extend DefaultSavePlanner instead of this interface directly as
  85. most changes can be expressed by changes in a single method.
  86. There are 3 usual patterns of extension:
  87. Rewriting state_dict. This is the simplest way to extend the save process as it
  88. doesn't requite understanding the intrincacies of how SavePlan works:
  89. >>> # xdoctest: +SKIP("undefined vars")
  90. >>> class RenamePlanner(DefaultSavePlanner):
  91. >>> def set_up_planner(self, state_dict, is_coordinator):
  92. >>> # prefix all keys with `foo_``
  93. >>> super().set_up_planner(self, {"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
  94. Modifying local plan and lookup in tandem. This is useful when fine control of how data is persisted
  95. >>> # xdoctest: +SKIP("undefined vars")
  96. >>> class FP16Planner(DefaultSavePlanner):
  97. >>> def create_local_plan(self):
  98. >>> plan = super().create_local_plan()
  99. >>> for p in plan:
  100. >>> if p.tensor_data is not None:
  101. >>> p.tensor_data.properties.dtype = torch.float16
  102. >>>
  103. >>> def resolve_data(self, write_item):
  104. >>> item = super().resolve_data(write_item)
  105. >>> return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)
  106. Using the global planning step to make central decisions that can't be made individually by each rank
  107. >>> # xdoctest: +SKIP("undefined vars")
  108. >>> from itertools import islice
  109. >>> from dataclasses import replace
  110. >>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
  111. >>> # This uses the default local plan behavior of having all non-sharded writes in rank 0
  112. >>> # This sample doesn't handle ShardedTensors
  113. >>> def create_global_plan(self, all_plans):
  114. >>> def chunk(it, size):
  115. >>> it = iter(it)
  116. >>> return list(iter(lambda: tuple(islice(it, size)), ()))
  117. >>> all_plans = [
  118. >>> replace(plan, items=items) for plan, items in
  119. >>> zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
  120. >>> ]
  121. >>> return super().create_global_plan(all_plans)
  122. Finally, some planners need to save additional metadata in the checkpoint, this is
  123. accomplished by having each rank contribute their data items in the local plan and
  124. the global planner aggregate them:
  125. >>> # xdoctest: +SKIP("undefined vars")
  126. >>> class SaveExtraDataPlanner(DefaultSavePlanner):
  127. >>> def create_local_plan(self) -> SavePlan:
  128. >>> plan = super().create_local_plan()
  129. >>> return replace(plan, planner_data="per-rank-data")
  130. >>>
  131. >>> def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
  132. >>> global_plan, metadata = super().create_global_plan(all_plans)
  133. >>> merged_data = [p.planner_data for p in global_plan]
  134. >>> metadata = replace(metadata, planner_data=merged_data)
  135. >>> return global_plan, metadata
  136. """
  137. @abc.abstractmethod
  138. def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
  139. """
  140. Intialize this planner to save ``state_dict``.
  141. Implementations should save those values as they won't be provided lated in the save process.
  142. This is called on all ranks.
  143. """
  144. pass
  145. @abc.abstractmethod
  146. def create_local_plan(self) -> SavePlan:
  147. """
  148. Compute the save plan for the current rank.
  149. This will be aggregated and passed to create_global_plan.
  150. Planner specific data can be passed through SavePlan::planner_data.
  151. This is called on all ranks.
  152. """
  153. pass
  154. @abc.abstractmethod
  155. def create_global_plan(
  156. self, all_plans: List[SavePlan]
  157. ) -> Tuple[List[SavePlan], Metadata]:
  158. """
  159. Compute the global checkpoint plan and return the local plan of each rank.
  160. This is called on the coordinator rank only.
  161. """
  162. pass
  163. @abc.abstractmethod
  164. def finish_plan(self, new_plan: SavePlan) -> SavePlan:
  165. """
  166. Merge the plan created by `create_local_plan` and the result of `create_global_plan`.
  167. This is called on all ranks.
  168. """
  169. pass
  170. @abc.abstractmethod
  171. def resolve_data(
  172. self, write_item: WriteItem
  173. ) -> Union[torch.Tensor, io.BytesIO]:
  174. """
  175. Lookup the object associated with ``write_item`` in ``state_dict`` and apply any
  176. transformation (such as serialization) prior to the storage layer consuming it.
  177. Called on each rank multiple times, at least once per WriteItem in the final SavePlan.
  178. This method should be idepotent and thread-save. StorageWriter implementations
  179. are free to call it as frequently as they need.
  180. Any transformation that allocates memory should be lazily done when his method
  181. is called in order to reduce peak memory required by checkpointing.
  182. When returning tensors, they can be on any device or format, they can be views too.
  183. It's the storage layer responsibility to figure out how to save them.
  184. """
  185. pass
  186. class LoadPlanner:
  187. """
  188. Abstract class defining the protocol used by load_state_dict to plan the load process.
  189. LoadPlanner are stateful objects that can be used to customize the whole load process.
  190. LoadPlanner acts as an access proxy to the state_dict, so any transfomation done to it
  191. will be visible to the whole process.
  192. A planner subclass can expect the following sequence of calls during load_state_dict:
  193. 1) set_up_planner - called on all ranks.
  194. Signals the start of loading a checkpoint.
  195. 2) create_local_plan - called on all ranks.
  196. Process the state_dict and produces a `LoadPlan` that will be sent for global planning.
  197. 3) create_global_plan - called on the coordinator rank only.
  198. Takes the LoadPlan from all ranks and make any global decision.
  199. 4) load_bytes - called multiple times on each rank
  200. This is called once per non-tensor value in state_dict.
  201. 5) resolve_tensor and commit_tensor - called multiple times on each rank
  202. They are called in pair for each Tensor value in state_dict.
  203. Users are recomended to extend DefaultLoadPlanner instead of this interface directly as
  204. most changes can be expressed by changes in a single method.
  205. There are two usual patterns of extension:
  206. Rewriting state_dict. This is the simplest way to extend the load process as it
  207. doesn't requite understanding the intrincacies of how LoadPlan works. We need
  208. to keep a reference to the original state_dict as load happens in place so
  209. we need to be able to perform it in place
  210. >>> # xdoctest: +SKIP("undefined vars")
  211. >>> class RenamePlanner(DefaultLoadPlanner):
  212. >>> def set_up_planner(self, state_dict, metadata, is_coordinator):
  213. >>> self.original_state_dict = state_dict
  214. >>> super().set_up_planner(self, {"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)
  215. >>>
  216. >>> def load_bytes(self, read_item, value):
  217. >>> # Remove the "foo_" prefix
  218. >>> self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)
  219. Modifying resolve_tensor and commit_tensor to handle load time transformation.
  220. >>> # xdoctest: +SKIP("undefined vars")
  221. >>> class MetaModelMaterialize(DefaultSavePlanner):
  222. >>> def resolve_tensor(self, read_item):
  223. >>> tensor = super().resolve_tensor(read_item)
  224. >>> return torch.empty_like(tensor, device="cpu")
  225. >>>
  226. >>> def commit_tensor(self, read_item, tensor):
  227. >>> self.state_dict[read_item.dest_index.fqn] = tensor
  228. """
  229. @abc.abstractmethod
  230. def set_up_planner(
  231. self,
  232. state_dict: STATE_DICT_TYPE,
  233. metadata: Metadata,
  234. is_coordinator: bool,
  235. ) -> None:
  236. """
  237. Initialize this instance to load data into ``state_dict``
  238. . N.B. This is called on every rank.
  239. """
  240. pass
  241. @abc.abstractmethod
  242. def create_local_plan(self) -> LoadPlan:
  243. """
  244. Create a LoadPlan based on state_dict and metadata provided by set_up_planner.
  245. . N.B. This is called on every rank.
  246. """
  247. pass
  248. @abc.abstractmethod
  249. def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
  250. """
  251. Compute the global load plan and return plans for each rank.
  252. . N.B. This is called on the coordinator rank only
  253. """
  254. pass
  255. @abc.abstractmethod
  256. def finish_plan(self, central_plan: LoadPlan) -> LoadPlan:
  257. """
  258. Accept the plan from coordinator and return final LoadPlan.
  259. """
  260. pass
  261. @abc.abstractmethod
  262. def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
  263. """
  264. Load the item described by ``read_item``and ``value``.
  265. This method is expected to modify in-place the underlying state_dict.
  266. The contents of ``value`` are defined by the SavePlanner used to produce
  267. the checkpoint being loaded.
  268. """
  269. pass
  270. @abc.abstractmethod
  271. def resolve_tensor(self, read_item: ReadItem) -> torch.Tensor:
  272. """
  273. Return the tensor described by ``read_item`` to be used by the StorageReader to load `read_item`.
  274. The tensor should alias with one on the underlying state_dict as StorageReader will replace its contents.
  275. If, for any reason, that's not possible, the planner can use the ``commit_tensor`` method to copy the data
  276. back to the one in state_dict.
  277. """
  278. pass
  279. @abc.abstractmethod
  280. def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
  281. """
  282. This method is called once the StorageReader finished loading data into ``tensor``.
  283. The provided tensor is the same one returned by the call to ``resolve_tensor``.
  284. This method is only needed if this LoadPlanner needs to post process ``tensor`` prior to
  285. copying it back to the one in the state_dict.
  286. The contents of tensor will follow its device synchronization model.
  287. """
  288. pass