default_planner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import dataclasses
  3. import io
  4. import logging
  5. import operator
  6. from collections import ChainMap
  7. from functools import reduce
  8. from typing import List, Tuple, Dict, Any, Union, cast
  9. import torch
  10. from torch.distributed._shard._utils import narrow_tensor_by_index
  11. from torch.distributed._shard.sharded_tensor import ShardedTensor
  12. from torch.distributed.checkpoint.planner import (
  13. SavePlanner,
  14. LoadPlanner,
  15. SavePlan,
  16. LoadPlan,
  17. ReadItem,
  18. WriteItem,
  19. WriteItemType,
  20. )
  21. from torch.distributed.checkpoint.metadata import (
  22. BytesStorageMetadata,
  23. ChunkStorageMetadata,
  24. TensorStorageMetadata,
  25. MetadataIndex,
  26. Metadata,
  27. STATE_DICT_TYPE,
  28. STORAGE_TYPES,
  29. )
  30. from torch.distributed.checkpoint.planner_helpers import (
  31. _create_read_items,
  32. _create_write_items,
  33. _create_default_metadata_only_plan,
  34. )
  35. from torch.distributed.checkpoint._nested_dict import (
  36. FLATTEN_MAPPING,
  37. flatten_state_dict,
  38. )
  39. from torch.distributed.checkpoint._sharded_tensor_utils import (
  40. _flatten_sharded_tensors,
  41. )
  42. from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
  43. from torch.distributed.checkpoint.utils import find_state_dict_object
  44. from torch.distributed.checkpoint._traverse import set_element
  45. logger: logging.Logger = logging.getLogger(__file__)
  46. __all__ = [
  47. "DefaultSavePlanner",
  48. "DefaultLoadPlanner",
  49. "create_default_local_load_plan",
  50. "create_default_global_load_plan",
  51. "create_default_local_save_plan",
  52. "create_default_global_save_plan",
  53. ]
  54. # TODO: Update docstrings for default_planner.py
  55. class DefaultSavePlanner(SavePlanner):
  56. mappings: FLATTEN_MAPPING
  57. def __init__(
  58. self,
  59. flatten_state_dict: bool = True,
  60. flatten_sharded_tensors: bool = True,
  61. dedup_replicated_tensors: bool = True,
  62. ) -> None:
  63. self.flatten_state_dict = flatten_state_dict
  64. self.flatten_sharded_tensors = flatten_sharded_tensors
  65. self.dedup_replicated_tensors = dedup_replicated_tensors
  66. self.mappings = {}
  67. def set_up_planner(
  68. self, state_dict: STATE_DICT_TYPE, is_coordinator: bool
  69. ) -> None:
  70. if self.flatten_state_dict:
  71. state_dict, self.mappings = flatten_state_dict(state_dict)
  72. if self.flatten_sharded_tensors:
  73. state_dict = _flatten_sharded_tensors(state_dict)
  74. self.state_dict = state_dict
  75. self.is_coordinator = is_coordinator
  76. def create_local_plan(self) -> SavePlan:
  77. plan = create_default_local_save_plan(
  78. self.state_dict, self.is_coordinator
  79. )
  80. if self.flatten_state_dict:
  81. plan = dataclasses.replace(plan, planner_data=self.mappings)
  82. self.plan = plan
  83. return self.plan
  84. def create_global_plan(
  85. self, all_plans: List[SavePlan]
  86. ) -> Tuple[List[SavePlan], Metadata]:
  87. if self.dedup_replicated_tensors:
  88. all_plans = dedup_tensors(all_plans)
  89. global_plan, metadata = create_default_global_save_plan(all_plans)
  90. if self.flatten_state_dict:
  91. # | does not work for Python 3.8 or older version.
  92. # merged_mappings = reduce(
  93. # lambda x, y: x | y, (p.planner_data for p in global_plan)
  94. # )
  95. planner_data_dict = [p.planner_data for p in global_plan]
  96. merged_mappings = dict(ChainMap(*planner_data_dict))
  97. metadata = dataclasses.replace(
  98. metadata, planner_data=merged_mappings
  99. )
  100. if not _validate_global_plan(global_plan, metadata):
  101. raise ValueError("Failed to validate global plan")
  102. self.global_plan = global_plan
  103. self.metadata = metadata
  104. return self.global_plan, self.metadata
  105. def finish_plan(self, new_plan: SavePlan) -> SavePlan:
  106. self.plan = new_plan
  107. return new_plan
  108. def resolve_data(
  109. self, write_item: WriteItem
  110. ) -> Union[torch.Tensor, io.BytesIO]:
  111. object = self.lookup_object(write_item.index)
  112. return self.transform_object(write_item, object)
  113. def lookup_object(self, index: MetadataIndex) -> Any:
  114. """
  115. This is an extension from the planner interface to make it easy to extend the default planner
  116. """
  117. return find_state_dict_object(self.state_dict, index)
  118. def transform_object(self, write_item: WriteItem, object: Any):
  119. """
  120. This is an extension from the planner interface to make it easy to extend the default planner
  121. """
  122. if write_item.type == WriteItemType.BYTE_IO:
  123. bytes = io.BytesIO()
  124. torch.save(object, bytes)
  125. object = bytes
  126. return object
  127. class DefaultLoadPlanner(LoadPlanner):
  128. """
  129. DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
  130. In particular it adds the following:
  131. flatten_state_dict: Handle state_dict with nested dicts
  132. flatten_sharded_tensors: For FSDP in 2D parallel mode
  133. """
  134. original_state_dict: STATE_DICT_TYPE
  135. mappings: FLATTEN_MAPPING
  136. def __init__(
  137. self,
  138. flatten_state_dict: bool = True,
  139. flatten_sharded_tensors: bool = True,
  140. ) -> None:
  141. self.flatten_state_dict = flatten_state_dict
  142. self.flatten_sharded_tensors = flatten_sharded_tensors
  143. self.original_state_dict = {}
  144. self.mappings = {}
  145. def set_up_planner(
  146. self,
  147. state_dict: STATE_DICT_TYPE,
  148. metadata: Metadata,
  149. is_coordinator: bool,
  150. ) -> None:
  151. self.original_state_dict = state_dict
  152. if self.flatten_sharded_tensors:
  153. state_dict = _flatten_sharded_tensors(state_dict)
  154. if self.flatten_state_dict:
  155. state_dict, self.mappings = flatten_state_dict(state_dict)
  156. self.state_dict = state_dict
  157. self.metadata = metadata
  158. self.is_coordinator = is_coordinator
  159. def create_local_plan(self) -> LoadPlan:
  160. return create_default_local_load_plan(self.state_dict, self.metadata)
  161. def create_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
  162. return create_default_global_load_plan(global_plan)
  163. def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
  164. return new_plan
  165. def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
  166. if self.flatten_state_dict:
  167. set_element(
  168. self.original_state_dict,
  169. self.mappings[read_item.dest_index.fqn],
  170. torch.load(value),
  171. )
  172. else:
  173. self.state_dict[read_item.dest_index.fqn] = torch.load(value)
  174. def resolve_tensor(self, read_item: ReadItem):
  175. tensor = self.lookup_tensor(read_item.dest_index)
  176. return self.transform_tensor(read_item, tensor)
  177. def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
  178. pass
  179. def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
  180. """
  181. This is an extension from the planner interface to make it easy to extend the default planner
  182. """
  183. return find_state_dict_object(self.state_dict, index)
  184. def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
  185. """
  186. This is an extension from the planner interface to make it easy to extend the default planner
  187. """
  188. return narrow_tensor_by_index(
  189. tensor, read_item.dest_offsets, read_item.lengths
  190. )
  191. def create_default_local_load_plan(
  192. state_dict: Dict[str, Any],
  193. metadata: Metadata,
  194. ) -> LoadPlan:
  195. requests = []
  196. """
  197. Create the ``LoadPlan`` used by DefaultLoadPlanner.
  198. It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
  199. The default behavior is to match key exactly between state_dict and metadata.
  200. It handles resharding by issuing multiple read requests against storage in order to match
  201. load requirements.
  202. """
  203. for fqn, obj in state_dict.items():
  204. md = metadata.state_dict_metadata[fqn]
  205. requests += _create_read_items(fqn, md, obj)
  206. return LoadPlan(requests)
  207. def create_default_global_load_plan(
  208. all_plans: List[LoadPlan],
  209. ) -> List[LoadPlan]:
  210. """
  211. Create global load plan used by DefaultLoadPlanner.
  212. The default load behavior involved no global coordination and this function
  213. currently doesn't change the local plans.
  214. """
  215. return all_plans
  216. def create_default_local_save_plan(
  217. state_dict: Dict[str, Any], is_coordinator: bool
  218. ) -> SavePlan:
  219. """
  220. Create the ``SavePlan`` used by DefaultSavePlanner.
  221. On non-coordinator ranks, this function ignores tensors and non-tensor objects,
  222. only producing writes for ShardedTensor objects.
  223. On the coordinator rank, produce writes for all values.
  224. """
  225. requests = []
  226. for fqn, obj in state_dict.items():
  227. if isinstance(obj, ShardedTensor) or is_coordinator:
  228. requests += _create_write_items(fqn, obj)
  229. return SavePlan(requests)
  230. def create_default_global_save_plan(
  231. all_plans: List[SavePlan],
  232. ) -> Tuple[List[SavePlan], Metadata]:
  233. """
  234. Create the global plan and metadata used by DefaultSavePlanner.
  235. Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
  236. The only global planning change is to update index hints in all ``MetadataIndex`` objects.
  237. """
  238. md: Dict[str, STORAGE_TYPES] = {}
  239. new_plans = []
  240. for plan in all_plans:
  241. new_items = []
  242. for item in plan.items:
  243. if not item.type == WriteItemType.SHARD:
  244. assert item.index.fqn not in md
  245. if item.type == WriteItemType.BYTE_IO:
  246. md[item.index.fqn] = BytesStorageMetadata()
  247. new_items.append(item)
  248. else:
  249. assert item.tensor_data is not None
  250. tensor_md = cast(
  251. TensorStorageMetadata,
  252. md.setdefault(
  253. item.index.fqn,
  254. TensorStorageMetadata(
  255. properties=item.tensor_data.properties,
  256. size=item.tensor_data.size,
  257. chunks=[],
  258. ),
  259. ),
  260. )
  261. new_index = dataclasses.replace(
  262. item.index, index=len(tensor_md.chunks)
  263. )
  264. new_item = dataclasses.replace(item, index=new_index)
  265. new_items.append(new_item)
  266. assert (
  267. item.tensor_data.chunk is not None
  268. ), f"""
  269. Cannot create MD for tensor without bounds.
  270. FQN: {item.index.fqn}
  271. """
  272. tensor_md.chunks.append(item.tensor_data.chunk)
  273. new_plans.append(dataclasses.replace(plan, items=new_items))
  274. return (new_plans, Metadata(md))
  275. def _create_default_local_metadata(state_dict: STATE_DICT_TYPE) -> Metadata:
  276. """
  277. Return the ``Metadata`` if DefaultSavePlanner was used to checkpoint ``state_dict``.
  278. """
  279. plan = _create_default_metadata_only_plan(state_dict)
  280. _, md = create_default_global_save_plan([plan])
  281. return md
  282. def _check_box_overlap(
  283. box0: ChunkStorageMetadata, box1: ChunkStorageMetadata
  284. ) -> bool:
  285. """
  286. Checks if two boxes overlap. Tuples are (offset, lengths)
  287. """
  288. # For each dim of each shard, check if one shard resides on the other
  289. # end of second shard with respect to that dim. As an example for a 2D
  290. # shard, we would check if one shard is above or on the left of the
  291. # other shard.
  292. ndims = len(box0.offsets)
  293. for i in range(ndims):
  294. if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
  295. return False
  296. if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
  297. return False
  298. return True
  299. def _check_box_bounds(
  300. outer_box_size: torch.Size, inner_box: ChunkStorageMetadata
  301. ) -> bool:
  302. for i in range(len(outer_box_size)):
  303. if inner_box.offsets[i] < 0:
  304. return False
  305. if inner_box.sizes[i] < 0:
  306. return False
  307. if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
  308. return False
  309. return True
  310. def _validate_global_plan(
  311. global_plan: List[SavePlan], metadata: Metadata
  312. ) -> bool:
  313. all_good = True
  314. for key, value in metadata.state_dict_metadata.items():
  315. if isinstance(value, BytesStorageMetadata):
  316. continue
  317. if len(value.size) == 0:
  318. continue
  319. chunks_volume = 0
  320. for chunk_idx, chunk0 in enumerate(value.chunks):
  321. if not _check_box_bounds(value.size, chunk0):
  322. logger.warning(
  323. f"""
  324. key:{key} has out of bounds chunk:
  325. tensor-size:{value.size} chunk: {chunk0}
  326. """
  327. )
  328. all_good = False
  329. chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
  330. for chunk1 in value.chunks[chunk_idx + 1 :]:
  331. if _check_box_overlap(chunk0, chunk1):
  332. logger.warning(
  333. f"key:{key} has overlapping chunks: {chunk0} {chunk1}"
  334. )
  335. all_good = False
  336. tensor_volume = reduce(operator.mul, value.size, 1)
  337. if chunks_volume != tensor_volume:
  338. logger.warning(
  339. f"""
  340. key:{key} invalid fill tensor-volume:
  341. {tensor_volume} chunks-volume: {chunks_volume}
  342. """
  343. )
  344. all_good = False
  345. return all_good