_state_dict_utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761
  1. import functools
  2. import math
  3. import warnings
  4. from typing import Any, Callable, cast, Dict, Iterator, no_type_check, Tuple
  5. import torch
  6. import torch.distributed as dist
  7. import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
  8. import torch.distributed.fsdp._traversal_utils as traversal_utils
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from torch.distributed._shard.sharded_tensor import (
  12. init_from_local_shards,
  13. Shard,
  14. ShardedTensor,
  15. )
  16. from torch.distributed.fsdp._common_utils import (
  17. _FSDPState,
  18. _has_fsdp_params,
  19. _is_composable,
  20. _module_handles,
  21. clean_tensor_name,
  22. FSDP_PREFIX,
  23. FSDP_WRAPPED_MODULE,
  24. )
  25. from torch.distributed.fsdp._runtime_utils import (
  26. _cast_buffers_to_dtype_and_device,
  27. _clear_grads_if_needed,
  28. _get_buffer_dtypes,
  29. _lazy_init,
  30. )
  31. from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
  32. from torch.distributed.utils import _replace_by_prefix
  33. from ._fsdp_extensions import (
  34. _ext_chunk_tensor,
  35. _ext_pre_load_state_dict_transform,
  36. _extensions as _user_extensions,
  37. )
  38. from ._unshard_param_utils import (
  39. _deregister_orig_params,
  40. _register_orig_params,
  41. _unshard_fsdp_state_params,
  42. FLAT_PARAM,
  43. )
  44. from .flat_param import FlatParamHandle
  45. def _convert_to_wrapped_module_name(module_name: str) -> str:
  46. module_name = module_name.replace(f"{FSDP_PREFIX}", "")
  47. module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
  48. if module_name:
  49. module_name = f"{module_name}."
  50. # `CheckpointWrapper` adds a prefix that has to be removed as well.
  51. module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "")
  52. return module_name
  53. def _param_fqns(
  54. module: nn.Module, fsdp_state: _FSDPState
  55. ) -> Iterator[Tuple[str, str, str]]:
  56. if not _has_fsdp_params(fsdp_state, module):
  57. return
  58. for param_name, module_name in _module_handles(fsdp_state, module)[
  59. 0
  60. ].parameter_module_names():
  61. module_name = _convert_to_wrapped_module_name(module_name)
  62. fqn = f"{module_name}{param_name}"
  63. yield fqn, param_name, module_name
  64. def _shared_param_fqns(module: nn.Module, fsdp_state) -> Iterator[Tuple[str, str, str]]:
  65. for param_name, module_name in _module_handles(fsdp_state, module)[
  66. 0
  67. ].shared_parameter_module_names():
  68. module_name = _convert_to_wrapped_module_name(module_name)
  69. fqn = f"{module_name}{param_name}"
  70. yield fqn, param_name, module_name
  71. @no_type_check
  72. def _enter_unshard_params_ctx(
  73. module: nn.Module,
  74. fsdp_state: _FSDPState,
  75. writeback: bool = False,
  76. rank0_only: bool = False,
  77. offload_to_cpu: bool = False,
  78. with_grads: bool = False,
  79. ) -> None:
  80. """
  81. state_dict hooks cannot use the pure context call as the checkpoint flow
  82. requires to enter the context in the pre-hook but leave the context in the
  83. post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
  84. """
  85. assert module not in fsdp_state._unshard_params_ctx, (
  86. "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
  87. "is not None."
  88. )
  89. fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
  90. module,
  91. fsdp_state,
  92. writeback=writeback,
  93. rank0_only=rank0_only,
  94. offload_to_cpu=offload_to_cpu,
  95. with_grads=with_grads,
  96. )
  97. fsdp_state._unshard_params_ctx[module].__enter__()
  98. @no_type_check
  99. def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
  100. """A helper function to exit ``_unshard_fsdp_state_params`` context."""
  101. fsdp_state._unshard_params_ctx[module].__exit__(None, None, None)
  102. fsdp_state._unshard_params_ctx.pop(module)
  103. def _common_pre_state_dict_hook(
  104. module: nn.Module,
  105. fsdp_state: _FSDPState,
  106. ) -> None:
  107. """Performs the pre-state_dict tasks shared by all state_dict types."""
  108. if torch.cuda.is_available():
  109. torch.cuda.synchronize()
  110. # TODO: need to check if this is always correct for composable FSDP.
  111. _lazy_init(fsdp_state, module)
  112. # TODO: change to this call after pre_state_dict_hook is in `nn.Module`.
  113. if fsdp_state._is_root:
  114. _clear_grads_if_needed(traversal_utils._get_fsdp_handles(module))
  115. def _common_unshard_pre_state_dict_hook(
  116. module: nn.Module,
  117. fsdp_state: _FSDPState,
  118. offload_to_cpu: bool,
  119. rank0_only: bool,
  120. ) -> None:
  121. """
  122. Performs the pre-state_dict tasks shared by all state_dict types that require
  123. ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
  124. """
  125. _enter_unshard_params_ctx(
  126. module,
  127. fsdp_state,
  128. writeback=False,
  129. offload_to_cpu=offload_to_cpu,
  130. rank0_only=rank0_only,
  131. )
  132. # TODO: change to the decorator style. See ``_full_pre_state_dict_hook``.
  133. @no_type_check
  134. def _common_unshard_post_state_dict_hook(
  135. module: nn.Module,
  136. fsdp_state: _FSDPState,
  137. state_dict: Dict[str, Any],
  138. prefix: str,
  139. param_hook: Callable,
  140. ) -> Dict[str, Any]:
  141. """
  142. The post-state_dict flow that shared by all state_dict types that require
  143. ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
  144. hook.
  145. """
  146. _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
  147. # Return early for trivial cases
  148. if not state_dict or not _has_fsdp_params(fsdp_state, module):
  149. _exit_unshard_params_ctx(module, fsdp_state)
  150. return state_dict
  151. # If a rank does not have unsharded parameters(when `rank0_only=True`
  152. # and `rank != 0`), then the rank only needed to participate in the
  153. # all-gather and does not need to save the # state dict. We simply check
  154. # rank0_only to ensure this issue.
  155. rank0_only = (
  156. fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT
  157. and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only
  158. )
  159. # no_fsdp_return means the state_dict returned by this rank should contain
  160. # only non-FSDP controlled parameters and buffers.
  161. no_fsdp_return = rank0_only and fsdp_state.rank != 0
  162. if no_fsdp_return and not fsdp_state._use_orig_params:
  163. for clean_key in fsdp_state._buffer_names:
  164. # This is a hack to support activation checkpoint.
  165. clean_key = clean_key.replace(
  166. f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", ""
  167. )
  168. state_dict.pop(f"{prefix}{clean_key}", None)
  169. # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is
  170. # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param
  171. # to appear in state_dict.
  172. state_dict.pop(f"{prefix}{FLAT_PARAM}")
  173. _exit_unshard_params_ctx(module, fsdp_state)
  174. return state_dict
  175. # Loop only the parameters saved in this instance's wrapped module to
  176. # avoid processing buffers.
  177. for fqn, param_name, module_name in _param_fqns(module, fsdp_state):
  178. fqn = f"{prefix}{fqn}"
  179. if no_fsdp_return:
  180. state_dict.pop(fqn)
  181. continue
  182. assert fqn in state_dict, (
  183. f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
  184. f"has {state_dict.keys()}. "
  185. f"prefix={prefix}, module_name={module_name}, "
  186. f"param_name={param_name} rank={fsdp_state.rank}."
  187. )
  188. param_hook(state_dict, prefix, fqn)
  189. _exit_unshard_params_ctx(module, fsdp_state)
  190. cpu_device = torch.device("cpu")
  191. buffer_clean_fqns = []
  192. buffers = []
  193. for clean_key in fsdp_state._buffer_names:
  194. # This is a hack to support activation checkpoint.
  195. clean_key = clean_tensor_name(clean_key)
  196. fqn = f"{prefix}{clean_key}"
  197. if fqn not in state_dict:
  198. # A buffer can be registered as non-persistent.
  199. continue
  200. if no_fsdp_return:
  201. state_dict.pop(fqn)
  202. else:
  203. buffer = state_dict[fqn]
  204. if (
  205. fsdp_state._state_dict_config.offload_to_cpu
  206. and buffer.device != cpu_device
  207. ):
  208. state_dict[fqn] = buffer.to(cpu_device)
  209. # TODO: for composable FSDP, this should be clean_tensor_name(clean_key),
  210. buffer_clean_fqns.append(clean_key)
  211. buffers.append(state_dict[fqn])
  212. if buffers:
  213. mixed_precision_enabled_for_buffers = (
  214. fsdp_state._mixed_precision_enabled_for_buffers()
  215. if not _is_composable(fsdp_state)
  216. else (fsdp_state.mixed_precision.buffer_dtype is not None)
  217. )
  218. if mixed_precision_enabled_for_buffers:
  219. buffer_dtypes = _get_buffer_dtypes(fsdp_state, buffer_clean_fqns)
  220. _cast_buffers_to_dtype_and_device(
  221. buffers, buffer_dtypes, fsdp_state.compute_device
  222. )
  223. for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
  224. fqn = f"{prefix}{clean_fqn}"
  225. state_dict[fqn] = buffer.clone()
  226. return state_dict
  227. @no_type_check
  228. def _full_pre_state_dict_hook(
  229. fsdp_state: _FSDPState,
  230. module: nn.Module,
  231. *args,
  232. **kwargs,
  233. ) -> None:
  234. """
  235. Hook that runs before model.state_dict() is called. pre-state_dict hook is
  236. not actually supported by ``nn.Module``. As a result, this API is called
  237. from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict
  238. is supported in ``nn.Module``, this hook will be registered as a hook in
  239. ``nn.Module``.
  240. TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported
  241. in ``nn.Module``.
  242. """
  243. _common_pre_state_dict_hook(module, fsdp_state)
  244. _common_unshard_pre_state_dict_hook(
  245. module,
  246. fsdp_state,
  247. offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,
  248. rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only,
  249. )
  250. @no_type_check
  251. def _full_post_state_dict_hook(
  252. module: nn.Module,
  253. fsdp_state: _FSDPState,
  254. state_dict: Dict[str, Any],
  255. prefix: str,
  256. ) -> Dict[str, Any]:
  257. """
  258. Hook that runs after model.state_dict() is called before returning result to
  259. user. For FSDP, we may have to clone the tensors in state_dict as params go
  260. back to sharded version after _unshard_fsdp_state_params ends, and also remove
  261. the ``FSDP_WRAPPED_MODULE`` prefix.
  262. """
  263. def param_hook(
  264. state_dict: Dict[str, Any],
  265. prefix: str,
  266. fqn: str,
  267. ) -> None:
  268. clean_key = fqn
  269. clean_prefix = clean_tensor_name(prefix)
  270. # Strip prefix out of key if needed as buffer names and param names
  271. # do not have prefix considered as they are not computed in `state_dict`
  272. # call.
  273. if clean_key.startswith(clean_prefix):
  274. clean_key = clean_key[len(clean_prefix) :]
  275. # Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
  276. if not getattr(state_dict[fqn], "_has_been_cloned", False):
  277. try:
  278. state_dict[fqn] = state_dict[fqn].clone().detach()
  279. state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
  280. except BaseException as e:
  281. warnings.warn(
  282. f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
  283. "This may mean that this state_dict entry could point to invalid "
  284. "memory regions after returning from state_dict() call if this "
  285. "parameter is managed by FSDP. Please check clone "
  286. f"implementation of {fqn}. Error: {str(e)}"
  287. )
  288. return _common_unshard_post_state_dict_hook(
  289. module, fsdp_state, state_dict, prefix, param_hook
  290. )
  291. def _full_pre_load_state_dict_hook(
  292. module: nn.Module,
  293. fsdp_state: _FSDPState,
  294. state_dict: Dict[str, Any],
  295. prefix: str,
  296. ) -> None:
  297. _lazy_init(fsdp_state, module)
  298. _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
  299. # Add FSDP_PREFIX only for wrapper-based FSDP.
  300. if not _is_composable(fsdp_state):
  301. _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
  302. def _full_post_load_state_dict_hook(
  303. module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
  304. ) -> None:
  305. _exit_unshard_params_ctx(module, fsdp_state)
  306. def _local_pre_state_dict_hook(
  307. fsdp_state: _FSDPState,
  308. module: nn.Module,
  309. *args,
  310. **kwargs,
  311. ) -> None:
  312. """
  313. Hook that runs before model.state_dict() is called. Right now, pre-state_dict
  314. hook is not supported by the PyTorch core. So this API is called from
  315. `_local_post_state_dict_hook()` to simulate the case.
  316. """
  317. if (
  318. _has_fsdp_params(fsdp_state, module)
  319. and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy
  320. ):
  321. raise RuntimeError(
  322. "``local_state_dict`` can only be used when parameters are flatten "
  323. "and sharded."
  324. )
  325. _common_pre_state_dict_hook(module, fsdp_state)
  326. @no_type_check
  327. def _local_post_state_dict_hook(
  328. module: nn.Module,
  329. fsdp_state: _FSDPState,
  330. state_dict: Dict[str, Any],
  331. prefix: str,
  332. ) -> Dict[str, Any]:
  333. """
  334. This hook create a ShardedTensor from the local flat_param and replace
  335. the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
  336. will happen. The underlying storage is the same.
  337. """
  338. _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix)
  339. if not _has_fsdp_params(fsdp_state, module):
  340. return state_dict
  341. # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
  342. # value as the flat_param but it is a pure Tensor because
  343. # nn.Module.state_dict() will detach the parameter. Therefore, we need
  344. # to get flat_param to get the metadata.
  345. assert _module_handles(fsdp_state, module), "Should have returned early"
  346. flat_param = _module_handles(fsdp_state, module)[0].flat_param
  347. # Constructs a ShardedTensor from the flat_param "without" padding.
  348. # Removing the padding allows users to change the number of ranks
  349. # when loading the local_state_dict.
  350. full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined]
  351. shard_offset = flat_param.numel() * fsdp_state.rank
  352. valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
  353. if valid_data_size > 0:
  354. # If FlatParameter is returned, FlatParameter._local_shard cause a
  355. # pickling issue (can be torch.save but not torch.load). Since there
  356. # is no benefit for state_dict to return the actual FlatParameter class,
  357. # a view (which is a tensor) of the FlatParameter will be returned.
  358. flat_param = flat_param[:valid_data_size].view(valid_data_size)
  359. local_shards = [
  360. Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank)
  361. ]
  362. else:
  363. local_shards = []
  364. sharded_tensor = init_from_local_shards(
  365. local_shards, full_numel, process_group=fsdp_state.process_group
  366. ) # type: ignore[assignment]
  367. if fsdp_state._state_dict_config.offload_to_cpu:
  368. sharded_tensor = sharded_tensor.cpu()
  369. state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
  370. return state_dict
  371. def _local_post_load_state_dict_hook(
  372. module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
  373. ) -> None:
  374. pass
  375. def _local_pre_load_state_dict_hook(
  376. module: nn.Module,
  377. fsdp_state: _FSDPState,
  378. state_dict: Dict[str, Any],
  379. prefix: str,
  380. ) -> None:
  381. """
  382. This hook finds the local flat_param for this FSDP module from the
  383. state_dict. The flat_param should be a ShardedTensor. This hook converts
  384. the ShardedTensor to a tensor. No copy happen unless padding is required.
  385. """
  386. _lazy_init(fsdp_state, module)
  387. _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}")
  388. fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"
  389. if fqn not in state_dict:
  390. assert not _has_fsdp_params(fsdp_state, module), (
  391. "No `FlatParameter` in `state_dict` for this FSDP instance "
  392. "but it has parameters"
  393. )
  394. return
  395. load_tensor = state_dict[fqn]
  396. assert isinstance(
  397. load_tensor, ShardedTensor
  398. ), "Tensors in local_state_dict should be ShardedTensor."
  399. # Convert the ShardedTensor to a Tensor.
  400. flat_param = _module_handles(fsdp_state, module)[0].flat_param
  401. assert flat_param is not None
  402. valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
  403. shards = load_tensor.local_shards()
  404. if valid_data_size > 0:
  405. assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
  406. load_tensor = shards[0].tensor
  407. # Get the metadata of the flat_param to decide whether to pad the loaded
  408. # tensor.
  409. if flat_param._shard_numel_padded > 0:
  410. assert load_tensor.numel() < flat_param.numel(), (
  411. f"Local shard size = {flat_param.numel()} and the tensor in "
  412. f"the state_dict is {load_tensor.numel()}."
  413. )
  414. load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
  415. else:
  416. load_tensor = flat_param
  417. state_dict[fqn] = load_tensor
  418. def _sharded_pre_state_dict_hook(
  419. fsdp_state: _FSDPState,
  420. module: nn.Module,
  421. *args,
  422. **kwargs,
  423. ) -> None:
  424. """
  425. Hook that runs before model.state_dict() is called. Check
  426. ``_full_pre_load_state_dict_hook`` for the detail.
  427. """
  428. if (
  429. _has_fsdp_params(fsdp_state, module)
  430. and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy
  431. ):
  432. raise RuntimeError(
  433. "``sharded_state_dict`` can only be used when parameters are flatten "
  434. "and sharded."
  435. )
  436. _common_pre_state_dict_hook(module, fsdp_state)
  437. # Setting offload_to_cpu here does not work even if offload_to_cpu is True.
  438. # We have to create ShardedTensor first then move it to CPU.
  439. _common_unshard_pre_state_dict_hook(
  440. module,
  441. fsdp_state,
  442. offload_to_cpu=False,
  443. rank0_only=False,
  444. )
  445. @no_type_check
  446. def _sharded_post_state_dict_hook(
  447. module: nn.Module,
  448. fsdp_state: _FSDPState,
  449. state_dict: Dict[str, Any],
  450. prefix: str,
  451. ) -> Dict[str, Any]:
  452. """
  453. The hook replaces the unflattened, unsharded parameter in the state_dict
  454. with a unflattened, sharded parameter (a ShardedTensor).
  455. """
  456. def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
  457. param = state_dict[fqn]
  458. sharded_tensor = _ext_chunk_tensor(
  459. tensor=param,
  460. rank=fsdp_state.rank,
  461. world_size=fsdp_state.world_size,
  462. num_devices_per_node=torch.cuda.device_count(),
  463. pg=fsdp_state.process_group,
  464. )
  465. if fsdp_state._state_dict_config.offload_to_cpu:
  466. sharded_tensor = sharded_tensor.cpu()
  467. state_dict[fqn] = sharded_tensor
  468. return _common_unshard_post_state_dict_hook(
  469. module, fsdp_state, state_dict, prefix, param_hook
  470. )
  471. @no_type_check
  472. def _sharded_post_load_state_dict_hook(
  473. module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
  474. ) -> None:
  475. if fsdp_state._use_orig_params:
  476. _register_orig_params(module, fsdp_state)
  477. @no_type_check
  478. def _sharded_pre_load_state_dict_hook(
  479. module: nn.Module,
  480. fsdp_state: _FSDPState,
  481. state_dict: Dict[str, Any],
  482. prefix: str,
  483. ) -> None:
  484. """
  485. The hook combines the unflattened, sharded parameters (ShardedTensor) to
  486. a new FlatParameter and shards the new FlatParameter to the local chunk.
  487. """
  488. _lazy_init(fsdp_state, module)
  489. _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
  490. if not _has_fsdp_params(fsdp_state, module):
  491. return
  492. if not _module_handles(fsdp_state, module)[0].uses_sharded_strategy:
  493. raise RuntimeError(
  494. "load_sharded_state_dict can only be called when parameters "
  495. "are flatten and sharded."
  496. )
  497. nonsharded_tensors = []
  498. shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module, fsdp_state)]
  499. loaded_shapes = []
  500. for fqn, _, _ in _param_fqns(module, fsdp_state):
  501. full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}"
  502. param = state_dict.pop(full_fqn)
  503. if fqn in shared_fqns:
  504. continue
  505. # All-gather the param (ShardedTensor)
  506. param, shards = _ext_pre_load_state_dict_transform(param)
  507. loaded_shapes.append(param.size())
  508. assert len(shards) < 2, (
  509. "Expects 0 or 1 shard per rank "
  510. f"but got {len(shards)} shards on rank {fsdp_state.rank}."
  511. )
  512. param_numel = param.size().numel()
  513. dim_0_size = param.size()[0]
  514. chunk_size = (
  515. math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size
  516. )
  517. if len(shards) == 1:
  518. local_tensor = shards[0].tensor.flatten()
  519. if not local_tensor.is_cuda:
  520. local_tensor = local_tensor.cuda()
  521. num_padding = chunk_size - local_tensor.numel()
  522. if num_padding > 0:
  523. local_tensor = F.pad(local_tensor, [0, num_padding])
  524. else:
  525. local_tensor = torch.zeros(chunk_size, dtype=param.dtype).cuda()
  526. tensor = torch.empty(
  527. chunk_size * fsdp_state.world_size, dtype=local_tensor.dtype
  528. ).cuda()
  529. dist.all_gather_into_tensor(
  530. tensor, local_tensor, group=fsdp_state.process_group
  531. )
  532. tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
  533. nonsharded_tensors.append(tensor)
  534. # Create a new flat_param from the loaded, non-sharded tensors.
  535. flat_param = _module_handles(fsdp_state, module)[0].flat_param
  536. loaded_flat_param = FlatParamHandle.flatten_params(
  537. nonsharded_tensors, requires_grad=False
  538. )
  539. # Get the chunk from the loaded flat_param for the local rank.
  540. loaded_flat_tensor, num_to_pad = FlatParamHandle._get_shard(
  541. loaded_flat_param,
  542. fsdp_state.rank,
  543. fsdp_state.world_size,
  544. )
  545. loaded_flat_tensor.to(flat_param.device)
  546. assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), (
  547. f"The original shapes in FSDP are {flat_param._shapes}. "
  548. f"The loaded shapes are {loaded_shapes}. "
  549. f"FSDP extension is {'NOT' if _user_extensions is not None else ''} None."
  550. )
  551. assert flat_param.numel() == loaded_flat_tensor.numel(), (
  552. f"The loaded local chunk has different numel({loaded_flat_tensor.numel()}) "
  553. f"from the local chunk {flat_param.numel()}."
  554. )
  555. assert flat_param._shard_numel_padded == num_to_pad, (
  556. f"The loaded local chunk has different padding({num_to_pad}) "
  557. f"from the local chunk {flat_param._shard_numel_padded}."
  558. )
  559. state_dict[f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"] = loaded_flat_tensor
  560. if fsdp_state._use_orig_params:
  561. _deregister_orig_params(module, fsdp_state)
  562. @no_type_check
  563. @torch.no_grad()
  564. def _post_state_dict_hook(
  565. fsdp_state: _FSDPState,
  566. module: nn.Module,
  567. state_dict: Dict[str, Any],
  568. prefix: str,
  569. *args: Any,
  570. ) -> Dict[str, Any]:
  571. """
  572. _post_state_dict_hook() is called after the state_dict() of this
  573. FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
  574. what postprocessing will be done.
  575. """
  576. _post_state_dict_hook_fn = {
  577. StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
  578. StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
  579. StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
  580. }
  581. processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
  582. module, fsdp_state, state_dict, prefix
  583. )
  584. return processed_state_dict
  585. @no_type_check
  586. @torch.no_grad()
  587. def _pre_state_dict_hook(
  588. fsdp_state: _FSDPState,
  589. module: nn.Module,
  590. *args,
  591. **kwargs,
  592. ) -> None:
  593. """
  594. This is called before the core state dict saving logic of ``module``.
  595. ``fsdp_state._state_dict_type`` is used to decide what postprocessing will
  596. be done.
  597. """
  598. _pre_state_dict_hook_fn = {
  599. StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
  600. StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
  601. StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
  602. }
  603. _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
  604. fsdp_state,
  605. module,
  606. *args,
  607. **kwargs,
  608. )
  609. @no_type_check
  610. @torch.no_grad()
  611. def _pre_load_state_dict_hook(
  612. fsdp_state: _FSDPState,
  613. module: nn.Module,
  614. state_dict: Dict[str, Any],
  615. prefix: str,
  616. *args: Any,
  617. ) -> None:
  618. """
  619. This is called before ``module._load_from_state_dict()``.
  620. ``fsdp_state._state_dict_type`` is used to decide what preprocessing will
  621. be done.
  622. """
  623. _pre_load_state_dict_hook_fn = {
  624. StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
  625. StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
  626. StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
  627. }
  628. # Code that is common for all state_dict impls
  629. if torch.cuda.is_available():
  630. torch.cuda.synchronize()
  631. # Dispatch into state_dict specific implementation of pre-hook.
  632. _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
  633. module, fsdp_state, state_dict, prefix
  634. )
  635. @no_type_check
  636. @torch.no_grad()
  637. def _post_load_state_dict_hook(
  638. fsdp_state: _FSDPState,
  639. module: nn.Module,
  640. *args: Any,
  641. ) -> None:
  642. _post_load_state_dict_hook_fn = {
  643. StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
  644. StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
  645. StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
  646. }
  647. # Code that is common for all state_dict impls
  648. # Dispatch into state_dict type specific implementation of post-hook for
  649. # loading state_dict.
  650. _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
  651. def _register_all_state_dict_hooks(state: _FSDPState):
  652. """
  653. Registers pre-save, post-save, pre-load, and post-load state dict hooks.
  654. """
  655. for hook_registration_fn_str, hook, hook_registration_fn_kwargs in (
  656. ("register_state_dict_pre_hook", _pre_state_dict_hook, {}),
  657. ("_register_state_dict_hook", _post_state_dict_hook, {}),
  658. (
  659. "_register_load_state_dict_pre_hook",
  660. _pre_load_state_dict_hook,
  661. {"with_module": True},
  662. ),
  663. ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}),
  664. ):
  665. _register_state_dict_hooks_base(
  666. state, hook_registration_fn_str, hook, hook_registration_fn_kwargs
  667. )
  668. @no_type_check
  669. def _register_state_dict_hooks_base(
  670. state: _FSDPState,
  671. hook_registration_fn_name: str,
  672. hook: Callable,
  673. hook_registration_fn_kwargs: Dict[str, Any],
  674. ) -> None:
  675. """Registers ``hook`` using ``hook_registration_fn``."""
  676. # TODO: Use `_get_submodule_state(module)` in each hook instead of
  677. # `partial`: https://github.com/pytorch/pytorch/issues/90788
  678. hook_with_state = functools.partial(hook, state)
  679. if not _is_composable(state):
  680. getattr(state, hook_registration_fn_name)(
  681. hook_with_state, **hook_registration_fn_kwargs
  682. )
  683. else:
  684. for handle in state._handles:
  685. getattr(handle._fully_sharded_module, hook_registration_fn_name)(
  686. hook_with_state, **hook_registration_fn_kwargs
  687. )