fully_sharded_data_parallel.py 92 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037
  1. import contextlib
  2. import copy
  3. import functools
  4. import math
  5. import traceback
  6. import warnings
  7. from contextlib import contextmanager
  8. from enum import auto, Enum
  9. from typing import (
  10. Any,
  11. Callable,
  12. Dict,
  13. Generator,
  14. Iterable,
  15. Iterator,
  16. List,
  17. Optional,
  18. Tuple,
  19. Union,
  20. )
  21. import torch
  22. import torch.distributed as dist
  23. import torch.distributed.fsdp._traversal_utils as traversal_utils
  24. import torch.nn as nn
  25. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  26. _CHECKPOINT_WRAPPED_MODULE,
  27. ActivationWrapper,
  28. )
  29. from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS
  30. from torch.distributed.fsdp._common_utils import (
  31. _FSDPState,
  32. _get_param_to_fqns,
  33. FSDP_PREFIX,
  34. FSDP_WRAPPED_MODULE,
  35. TrainingState,
  36. )
  37. from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo
  38. from torch.distributed.fsdp._init_utils import (
  39. _check_orig_params_flattened,
  40. _get_default_comm_hook,
  41. _init_buffer_state,
  42. _init_core_state,
  43. _init_ignored_module_states,
  44. _init_param_handle_from_module,
  45. _init_prefetching_state,
  46. _init_process_group_state,
  47. _init_runtime_state,
  48. _init_state_dict_state,
  49. HYBRID_SHARDING_STRATEGIES,
  50. ProcessGroupType,
  51. )
  52. from torch.distributed.fsdp._runtime_utils import (
  53. _get_fsdp_root_states,
  54. _is_fsdp_root,
  55. _lazy_init,
  56. _post_forward,
  57. _post_forward_reshard,
  58. _pre_forward,
  59. _pre_forward_unshard,
  60. _root_pre_forward,
  61. )
  62. from torch.distributed.fsdp._wrap_utils import _auto_wrap
  63. from torch.distributed.fsdp.api import (
  64. BackwardPrefetch,
  65. CPUOffload,
  66. FullOptimStateDictConfig,
  67. FullStateDictConfig,
  68. LocalOptimStateDictConfig,
  69. LocalStateDictConfig,
  70. MixedPrecision,
  71. OptimStateDictConfig,
  72. ShardedOptimStateDictConfig,
  73. ShardedStateDictConfig,
  74. ShardingStrategy,
  75. StateDictConfig,
  76. StateDictSettings,
  77. StateDictType,
  78. )
  79. from ._optim_utils import (
  80. _broadcast_pos_dim_tensor_states,
  81. _broadcast_processed_optim_state_dict,
  82. _flatten_optim_state_dict,
  83. _get_param_id_to_param_from_optim_input,
  84. _get_param_key_to_param,
  85. _get_param_to_param_id_from_optim_input,
  86. _get_param_to_param_key,
  87. _optim_state_dict,
  88. _process_pos_dim_tensor_state,
  89. _rekey_sharded_optim_state_dict,
  90. )
  91. from ._state_dict_utils import _register_all_state_dict_hooks
  92. from ._unshard_param_utils import (
  93. _deregister_orig_params,
  94. _register_flat_param,
  95. _register_orig_params,
  96. _unshard_params,
  97. _unshard_params_recurse,
  98. )
  99. from ._utils import p_assert
  100. from .flat_param import FlatParameter
  101. from .wrap import _FSDPPolicy
  102. __all__ = [
  103. "FullyShardedDataParallel",
  104. "OptimStateKeyType",
  105. ]
  106. FLAT_PARAM = "_flat_param"
  107. class OptimStateKeyType(Enum):
  108. PARAM_NAME = auto()
  109. PARAM_ID = auto()
  110. class FullyShardedDataParallel(nn.Module, _FSDPState):
  111. """
  112. A wrapper for sharding Module parameters across data parallel workers. This
  113. is inspired by `Xu et al.`_ as well as the ZeRO Stage 3 from DeepSpeed_.
  114. FullyShardedDataParallel is commonly shortened to FSDP.
  115. .. _`Xu et al.`: https://arxiv.org/abs/2004.13336
  116. .. _DeepSpeed: https://www.deepspeed.ai/
  117. Example::
  118. >>> # xdoctest: +SKIP("undefined variables")
  119. >>> import torch
  120. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  121. >>> torch.cuda.set_device(device_id)
  122. >>> sharded_module = FSDP(my_module)
  123. >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
  124. >>> x = sharded_module(x, y=3, z=torch.Tensor([1]))
  125. >>> loss = x.sum()
  126. >>> loss.backward()
  127. >>> optim.step()
  128. .. warning::
  129. The optimizer must be initialized *after* the module has been wrapped,
  130. since FSDP will shard parameters in-place and this will break any
  131. previously initialized optimizers.
  132. .. warning::
  133. If the destination CUDA device has ID ``dev_id``, either (1)
  134. ``module`` should already be placed on that device, (2) the device
  135. should be set using ``torch.cuda.set_device(dev_id)``, or (3)
  136. ``dev_id`` should be passed into the ``device_id`` constructor
  137. argument. This FSDP instance's compute device will be that destination
  138. device. For (1) and (3), the FSDP initialization always occurs on GPU.
  139. For (2), the FSDP initialization happens on ``module`` 's current
  140. device, which may be CPU.
  141. .. warning::
  142. FSDP currently does not support gradient accumulation outside
  143. ``no_sync()`` when using CPU offloading. Trying to do so yields
  144. incorrect results since FSDP will use the newly-reduced gradient
  145. instead of accumulating with any existing gradient.
  146. .. warning::
  147. Changing the original parameter variable names after construction will
  148. lead to undefined behavior.
  149. .. warning::
  150. Passing in `sync_module_states=True` flag requires module to be put
  151. on GPU, or to use ``device_id`` argument to specify a CUDA device that
  152. FSDP will move module to. This is because ``sync_module_states=True``
  153. requires GPU communication.
  154. .. warning::
  155. As of PyTorch 1.12, FSDP only offers limited support for shared parameters
  156. (for example, setting one ``Linear`` layer's weight to another's). In
  157. particular, modules that share parameters must be wrapped as part of the
  158. same FSDP unit. If enhanced shared parameter support is needed for your
  159. use case, please ping https://github.com/pytorch/pytorch/issues/77724
  160. .. note:
  161. Attempting to run the forward pass of a submodule that is contained in an
  162. FSDP instance is not supported and will result in errors. This is because the
  163. submodule's parameters will be sharded, but it itself is not an FSDP instance,
  164. so its forward pass will not all-gather the full parameters appropriately.
  165. This could potentially happen when attempting to run only the encoder of a
  166. encoder-decoder model, and the encoder is not wrapped in its own FSDP instance. To
  167. resolve this, please wrap the submodule in its own FSDP unit.
  168. .. note::
  169. Inputs into FSDP ``forward`` function will be moved to compute device
  170. (same device FSDP module is on) before running ``forward``, so user does
  171. not have to manually move inputs from CPU -> GPU.
  172. Args:
  173. module (nn.Module):
  174. This is the module to be wrapped with FSDP.
  175. process_group: Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]
  176. This is the process group used for collective communications and
  177. the one over which the model is sharded. For hybrid sharding strategies such as
  178. ``ShardingStrategy.HYBRID_SHARD`` users can
  179. pass in a tuple of process groups representing the groups to shard and replicate across,
  180. respectively.
  181. sharding_strategy (Optional[ShardingStrategy]):
  182. This configures the sharding strategy used by FSDP, which may trade
  183. off memory saving and communication overhead. See
  184. :class:`ShardingStrategy` for details. (Default: ``FULL_SHARD``)
  185. cpu_offload (Optional[CPUOffload]):
  186. This configures CPU offloading. If this is set to ``None``, then
  187. no CPU offloading happens. See :class:`CPUOffload` for details.
  188. (Default: ``None``)
  189. auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], _FSDPPolicy]]):
  190. This is either ``None``, an ``_FSDPPolicy``, or a callable of
  191. a fixed signature. If it is ``None``, then ``module`` is wrapped
  192. with only a top-level FSDP instance without any nested wrapping. If
  193. it is an ``_FSDPPolicy``, then the wrapping follows the given
  194. policy. ``ModuleWrapPolicy`` in ``torch.distributed.fsdp.wrap.py``
  195. is an example. If it is a callable, then it should take in three
  196. arguments ``module: nn.Module``, ``recurse: bool``, and
  197. ``nonwrapped_numel: int`` and should return a ``bool`` specifying
  198. whether the passed-in ``module`` should be wrapped if
  199. ``recurse=False`` or if the traversal should continue down the
  200. subtree if ``recurse=True``. Additional custom arguments may be
  201. added to the callable. The ``size_based_auto_wrap_policy`` in
  202. ``torch.distributed.fsdp.wrap.py`` gives an example callable that
  203. wraps a module if the parameters in its subtree exceed 100M numel.
  204. A good practice is to print the model after wrapping and adjust as
  205. needed.
  206. Example::
  207. >>> def custom_auto_wrap_policy(
  208. >>> module: nn.Module,
  209. >>> recurse: bool,
  210. >>> nonwrapped_numel: int,
  211. >>> # Additional custom arguments
  212. >>> min_num_params: int = int(1e8),
  213. >>> ) -> bool:
  214. >>> return nonwrapped_numel >= min_num_params
  215. >>> # Configure a custom `min_num_params`
  216. >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
  217. backward_prefetch (Optional[BackwardPrefetch]):
  218. This configures explicit backward prefetching of all-gathers. See
  219. :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``)
  220. mixed_precision (Optional[MixedPrecision]):
  221. This configures native mixed precision for FSDP. If this is set to
  222. ``None``, then no mixed precision is used. Otherwise, parameter,
  223. buffer, and gradient reduction dtypes can be set. See
  224. :class:`MixedPrecision` for details. (Default: ``None``)
  225. ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose
  226. own parameters and child modules' parameters and buffers are
  227. ignored by this instance. None of the modules directly in
  228. ``ignored_modules`` should be :class:`FullyShardedDataParallel`
  229. instances, and any child modules that are already-constructed
  230. :class:`FullyShardedDataParallel` instances will not be ignored if
  231. they are nested under this instance. This argument may be used to
  232. avoid sharding specific parameters at module granularity when using an
  233. ``auto_wrap_policy`` or if parameters' sharding is not managed by
  234. FSDP. (Default: ``None``)
  235. param_init_fn (Optional[Callable[[nn.Module], None]]):
  236. A ``Callable[torch.nn.Module] -> None`` that
  237. specifies how modules that are currently on the meta device should be initialized
  238. onto an actual device. Note that as of v1.12, we detect modules on the meta
  239. device via ``is_meta`` check and apply a default initialization that calls
  240. ``reset_parameters`` method on the passed in ``nn.Module`` if ``param_init_fn``
  241. is not specified, otherwise we run ``param_init_fn`` to initialize the passed
  242. in ``nn.Module``. In particular, this means that if ``is_meta=True`` for any
  243. module parameters for modules that will be wrapped with FSDP and ``param_init_fn``
  244. is not specified, we assume your module properly implements a ``reset_parameters()``
  245. and will throw errors if not. Note that additionally, we offer support for modules
  246. initialized with torchdistX's (https://github.com/pytorch/torchdistX)
  247. ``deferred_init`` API. In this case, deferred modules would be initialized
  248. by a default initialization function that calls torchdistX's
  249. ``materialize_module``, or the passed in ``param_init_fn``, if it is not
  250. ``None``. The same ``Callable`` is applied to initialize all meta modules.
  251. Note that this initialization function is applied before doing any FSDP sharding
  252. logic.
  253. Example::
  254. >>> # xdoctest: +SKIP("undefined variables")
  255. >>> module = MyModule(device="meta")
  256. >>> def my_init_fn(module):
  257. >>> # responsible for initializing a module, such as with reset_parameters
  258. >>> ...
  259. >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy)
  260. >>> print(next(fsdp_model.parameters()).device) # current CUDA device
  261. >>> # With torchdistX
  262. >>> module = deferred_init.deferred_init(MyModule, device="cuda")
  263. >>> # Will initialize via deferred_init.materialize_module().
  264. >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
  265. device_id (Optional[Union[int, torch.device]]): An ``int`` or ``torch.device``
  266. describing the CUDA device the FSDP module should be moved to determining where
  267. initialization such as sharding takes place. If this argument is not specified
  268. and ``module`` is on CPU, we issue a warning mentioning that this argument can
  269. be specified for faster initialization. If specified, resulting FSDP instances
  270. will reside on this device, including moving ignored modules' parameters if
  271. needed. Note that if ``device_id`` is specified but ``module`` is already on a
  272. different CUDA device, an error will be thrown. (Default: ``None``)
  273. sync_module_states (bool): If ``True``, each individually wrapped FSDP unit will broadcast
  274. module parameters from rank 0 to ensure they are the same across all ranks after
  275. initialization. This helps ensure model parameters are the same across ranks
  276. before starting training, but adds communication overhead to ``__init__``, as at least
  277. one broadcast is triggered per individually wrapped FSDP unit.
  278. This can also help load checkpoints taken by ``state_dict`` and to be loaded by
  279. ``load_state_dict`` in a memory efficient way. See documentation for
  280. :class:`FullStateDictConfig` for an example of this. (Default: ``False``)
  281. forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches
  282. the next upcoming all-gather while executing in the forward pass.
  283. This may improve communication and computation overlap for CPU
  284. bound workloads. This should only be used for static graph models
  285. since the forward order is fixed based on the first iteration's
  286. execution. (Default: ``False``)
  287. limit_all_gathers (bool): If ``False``, then FSDP allows the CPU
  288. thread to schedule all-gathers without any extra synchronization.
  289. If ``True``, then FSDP explicitly synchronizes the CPU thread to
  290. prevent too many in-flight all-gathers. This ``bool`` only affects
  291. the sharded strategies that schedule all-gathers. Enabling this can
  292. help lower the number of CUDA malloc retries.
  293. ignored_parameters (Optional[Iterable[torch.nn.Parameter]]): Ignored
  294. parameters will not be managed by this FSDP instance,
  295. that means these parameters will not be flattened and sharded by FSDP,
  296. their gradients will not be synchronized as well. With this newly added
  297. argument, ``ignored_modules`` could be deprecated soon. For backward compatibility,
  298. both ``ignored_parameters`` and ``ignored_modules`` are kept for now,
  299. but FSDP only allows one of them to be specified as not ``None``.
  300. """
  301. def __init__(
  302. self,
  303. module: nn.Module,
  304. process_group: ProcessGroupType = None,
  305. sharding_strategy: Optional[ShardingStrategy] = None,
  306. cpu_offload: Optional[CPUOffload] = None,
  307. auto_wrap_policy: Optional[Union[Callable, _FSDPPolicy]] = None,
  308. backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE,
  309. mixed_precision: Optional[MixedPrecision] = None,
  310. ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
  311. param_init_fn: Optional[Callable[[nn.Module], None]] = None,
  312. device_id: Optional[Union[int, torch.device]] = None,
  313. sync_module_states: bool = False,
  314. forward_prefetch: bool = False,
  315. limit_all_gathers: bool = False,
  316. use_orig_params: bool = False,
  317. ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
  318. ):
  319. torch._C._log_api_usage_once("torch.distributed.fsdp")
  320. super().__init__()
  321. _init_ignored_module_states(self, module, ignored_modules, ignored_parameters)
  322. # Add module annotations for Dynamo support (see function for details)
  323. _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params)
  324. # Initializes self.process_group, along with rank and world size. This will
  325. # also set another attribute, _inter_node_pg, to control the process group
  326. # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}.
  327. # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up
  328. # the same process group state as the root FSDP module.
  329. _init_process_group_state(
  330. self, process_group, sharding_strategy, auto_wrap_policy
  331. )
  332. if auto_wrap_policy is not None:
  333. auto_wrap_kwargs = {
  334. "module": module,
  335. "auto_wrap_policy": auto_wrap_policy,
  336. "wrapper_cls": FullyShardedDataParallel,
  337. "ignored_modules": self._ignored_modules,
  338. "ignored_params": self._ignored_params,
  339. "only_wrap_children": True, # avoid double wrapping the root
  340. }
  341. fsdp_kwargs = {
  342. "process_group": process_group,
  343. "sharding_strategy": sharding_strategy,
  344. "cpu_offload": cpu_offload,
  345. "backward_prefetch": backward_prefetch,
  346. "mixed_precision": mixed_precision,
  347. "param_init_fn": param_init_fn,
  348. "device_id": device_id,
  349. "sync_module_states": sync_module_states,
  350. "forward_prefetch": forward_prefetch,
  351. "limit_all_gathers": limit_all_gathers,
  352. "use_orig_params": use_orig_params,
  353. }
  354. if sharding_strategy in HYBRID_SHARDING_STRATEGIES:
  355. # Share root process groups with children to maintain
  356. # the invariant that all FSDP modules will have the same
  357. # process groups.
  358. fsdp_kwargs["process_group"] = (self.process_group, self._inter_node_pg)
  359. _auto_wrap(auto_wrap_kwargs, fsdp_kwargs, FullyShardedDataParallel)
  360. backward_prefetch_limit = 1
  361. forward_prefetch_limit = 1
  362. _init_core_state(
  363. self,
  364. sharding_strategy,
  365. mixed_precision,
  366. cpu_offload,
  367. limit_all_gathers,
  368. use_orig_params,
  369. backward_prefetch_limit,
  370. forward_prefetch_limit,
  371. )
  372. _init_runtime_state(self)
  373. _init_prefetching_state(self, backward_prefetch, forward_prefetch)
  374. _init_buffer_state(self, module)
  375. _init_param_handle_from_module(
  376. self,
  377. module,
  378. device_id,
  379. param_init_fn,
  380. sync_module_states,
  381. FullyShardedDataParallel,
  382. )
  383. self._fsdp_wrapped_module = module
  384. if not use_orig_params:
  385. _check_orig_params_flattened(self, self._ignored_params)
  386. _register_flat_param(self, self)
  387. # `_state_dict_type` controls the `state_dict()` behavior, which is
  388. # implemented using post-save and pre-load hooks
  389. _init_state_dict_state(self)
  390. _register_all_state_dict_hooks(self)
  391. @property
  392. def module(self) -> nn.Module:
  393. """
  394. Returns the wrapped module (like :class:`DistributedDataParallel`).
  395. """
  396. # FSDP's `.module` must refer to the innermost wrapped module when
  397. # composing with other module wrappers in order for state dict to work
  398. if isinstance(self._fsdp_wrapped_module, ActivationWrapper):
  399. return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE)
  400. return self._fsdp_wrapped_module
  401. @property
  402. def _has_params(self) -> bool:
  403. """Returns whether this FSDP instance manages any parameters."""
  404. return hasattr(self, "_handles") and len(self._handles) > 0
  405. @property
  406. def _flat_param(self) -> Optional[FlatParameter]:
  407. return self._handles[0].flat_param if self._handles else None
  408. def __getattr__(self, name: str) -> Any:
  409. """Forward missing attributes to the wrapped module."""
  410. try:
  411. return super().__getattr__(name) # defer to nn.Module's logic
  412. except AttributeError:
  413. return getattr(self._fsdp_wrapped_module, name)
  414. def __getitem__(self, key: int) -> Any:
  415. """Forward indexing calls in case the module is an ``nn.Sequential``."""
  416. if hasattr(self, FSDP_WRAPPED_MODULE):
  417. return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator]
  418. return super().__getitem__(key)
  419. def check_is_root(self) -> bool:
  420. return _is_fsdp_root(self, self)
  421. @staticmethod
  422. def fsdp_modules(
  423. module: nn.Module,
  424. root_only: bool = False,
  425. ) -> List["FullyShardedDataParallel"]:
  426. """
  427. Returns all nested FSDP instances, possibly including ``module`` itself
  428. and only including FSDP root modules if ``root_only=True``.
  429. Args:
  430. module (torch.nn.Module): Root module, which may or may not be an
  431. ``FSDP`` module.
  432. root_only (bool): Whether to return only FSDP root modules.
  433. (Default: ``False``)
  434. Returns:
  435. List[FullyShardedDataParallel]: FSDP modules that are nested in
  436. the input ``module``.
  437. """
  438. if root_only:
  439. return _get_fsdp_root_states(module)
  440. return traversal_utils._get_fsdp_states(module)
  441. def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel":
  442. r"""Applies ``fn`` recursively to every submodule (as returned by ``.children()``)
  443. as well as self. Typical use includes initializing the parameters of a model
  444. (see also :ref:`nn-init-doc`).
  445. Compared to ``torch.nn.Module.apply``, this version additionally gathers
  446. the full parameters before applying ``fn``. It should not be called from
  447. within another ``summon_full_params`` context.
  448. Args:
  449. fn (:class:`Module` -> None): function to be applied to each submodule
  450. Returns:
  451. Module: self
  452. """
  453. uninitialized = self._is_root is None
  454. self._assert_state(TrainingState.IDLE)
  455. # Use `_unshard_params_recurse()` with `recurse=False` instead of
  456. # `_unshard_fsdp_state_params()` directly to perform lazy
  457. # initialization, which is needed to initialize `FlatParameter`
  458. # parameter attributes as required by the unshard logic
  459. with _unshard_params_recurse(
  460. self,
  461. self,
  462. recurse=False,
  463. writeback=True,
  464. rank0_only=False,
  465. offload_to_cpu=False,
  466. with_grads=False,
  467. ):
  468. ret = super().apply(fn)
  469. # Reset lazy init called in `_unshard_params_recurse()` since `apply()`
  470. # may have been called on FSDP instance that is not truly a root, in
  471. # which case it will be incorrectly marked as one.
  472. if uninitialized and self._is_root:
  473. for module in traversal_utils._get_fsdp_states(self):
  474. module._reset_lazy_init()
  475. return ret
  476. def _mixed_precision_enabled_for_buffers(self) -> bool:
  477. """
  478. Returns if the user explicitly enabled buffer mixed precision.
  479. NOTE: Unlike parameters and gradient reduction, buffer mixed precision
  480. is applied at the FSDP instance level, not the ``FlatParameter`` level,
  481. which may be different for the composable code path.
  482. """
  483. return self.mixed_precision.buffer_dtype is not None
  484. def _low_precision_hook_enabled(self) -> bool:
  485. """
  486. Wether a low precision hook is registered or not.
  487. """
  488. return (
  489. self._communication_hook is not None
  490. and self._communication_hook in LOW_PRECISION_HOOKS
  491. )
  492. def _reset_lazy_init(self) -> None:
  493. """
  494. Reset instance so :func:`_lazy_init` will run on the next forward.
  495. """
  496. self._is_root: Optional[bool] = None
  497. @staticmethod
  498. def set_state_dict_type(
  499. module: nn.Module,
  500. state_dict_type: StateDictType,
  501. state_dict_config: Optional[StateDictConfig] = None,
  502. optim_state_dict_config: Optional[OptimStateDictConfig] = None,
  503. ) -> StateDictSettings:
  504. """
  505. Set the ``state_dict_type`` and the corresponding (optional)
  506. configurations of all the descendant FSDP modules of the target module.
  507. The target module does not have to be a FSDP module. If the target
  508. module is a FSDP module, its ``state_dict_type`` will also be changed.
  509. .. note:: This API should be called for only the top-level (root)
  510. module.
  511. .. note:: This API enables users to transparently use the conventional
  512. ``state_dict`` API to take model checkpoints in cases where the
  513. root FSDP module is wrapped by another ``nn.Module``. For example,
  514. the following will ensure ``state_dict`` is called on all non-FSDP
  515. instances, while dispatching into `sharded_state_dict` implementation
  516. for FSDP:
  517. Example::
  518. >>> # xdoctest: +SKIP("undefined variables")
  519. >>> model = DDP(FSDP(...))
  520. >>> FSDP.set_state_dict_type(
  521. >>> model,
  522. >>> StateDictType.SHARDED_STATE_DICT,
  523. >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True),
  524. >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True),
  525. >>> )
  526. >>> param_state_dict = model.state_dict()
  527. >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
  528. Args:
  529. module (torch.nn.Module): Root module.
  530. state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
  531. state_dict_config (Optional[StateDictConfig]): the configuration for the
  532. target ``state_dict_type``.
  533. Returns:
  534. A StateDictSettings that include the previous state_dict type and
  535. configuration for the module.
  536. """
  537. _state_dict_type_to_config = {
  538. StateDictType.FULL_STATE_DICT: FullStateDictConfig,
  539. StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
  540. StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
  541. }
  542. _optim_state_dict_type_to_config = {
  543. StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig,
  544. StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig,
  545. StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig,
  546. }
  547. # Use the default config if a state_dict config is not set.
  548. state_dict_config_type = _state_dict_type_to_config[state_dict_type]
  549. optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type]
  550. if state_dict_config is None:
  551. state_dict_config = state_dict_config_type()
  552. if optim_state_dict_config is None:
  553. optim_state_dict_config = optim_state_dict_config_type()
  554. if state_dict_config_type != type(state_dict_config):
  555. raise RuntimeError(
  556. f"Expected state_dict_config of type {state_dict_config_type} "
  557. f"but got {type(state_dict_config)}"
  558. )
  559. if optim_state_dict_config_type != type(optim_state_dict_config):
  560. raise RuntimeError(
  561. f"Expected optim_state_dict_config of type {optim_state_dict_config_type} "
  562. f"but got {type(optim_state_dict_config)}"
  563. )
  564. # Set the state_dict type and configurations.
  565. prev_state_dict_type = None
  566. prev_state_dict_config = None
  567. prev_optim_state_dict_config = None
  568. for submodule in traversal_utils._get_fsdp_states(module):
  569. if prev_state_dict_type is None:
  570. prev_state_dict_type = submodule._state_dict_type
  571. else:
  572. assert (
  573. prev_state_dict_type == submodule._state_dict_type
  574. ), "All FSDP modules should have the same state_dict_type."
  575. if prev_state_dict_config is None:
  576. prev_state_dict_config = submodule._state_dict_config
  577. else:
  578. assert isinstance(
  579. submodule._state_dict_config, type(prev_state_dict_config)
  580. ), "All FSDP modules must have the same type of state_dict_config."
  581. if prev_optim_state_dict_config is None:
  582. prev_optim_state_dict_config = submodule._optim_state_dict_config
  583. else:
  584. assert isinstance(
  585. submodule._optim_state_dict_config,
  586. type(prev_optim_state_dict_config),
  587. ), "All FSDP modules must have the same type of optim_state_dict_config."
  588. submodule._state_dict_type = state_dict_type
  589. submodule._state_dict_config = state_dict_config
  590. submodule._optimstate_dict_config = optim_state_dict_config
  591. return StateDictSettings(
  592. prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config
  593. )
  594. @staticmethod
  595. def get_state_dict_type(module: nn.Module) -> StateDictSettings:
  596. state_dict_settings: Optional[StateDictSettings] = None
  597. for submodule in FullyShardedDataParallel.fsdp_modules(module):
  598. if state_dict_settings is None:
  599. state_dict_settings = StateDictSettings(
  600. state_dict_type=submodule._state_dict_type,
  601. state_dict_config=submodule._state_dict_config,
  602. optim_state_dict_config=submodule._optim_state_dict_config,
  603. )
  604. else:
  605. submodule_settings = StateDictSettings(
  606. submodule._state_dict_type,
  607. submodule._state_dict_config,
  608. submodule._optim_state_dict_config,
  609. )
  610. assert state_dict_settings == submodule_settings, (
  611. "All FSDP modules must have the same state dict settings."
  612. f"Got {submodule_settings} and {state_dict_settings}."
  613. )
  614. return state_dict_settings
  615. @staticmethod
  616. @contextlib.contextmanager
  617. def state_dict_type(
  618. module: nn.Module,
  619. state_dict_type: StateDictType,
  620. state_dict_config: Optional[StateDictConfig] = None,
  621. optim_state_dict_config: Optional[OptimStateDictConfig] = None,
  622. ) -> Generator:
  623. """
  624. A context manager to set the ``state_dict_type`` of all the descendant
  625. FSDP modules of the target module. This context manager has the same
  626. functions as :meth:`set_state_dict_type`. Read the document of
  627. :meth:`set_state_dict_type` for the detail.
  628. Example::
  629. >>> # xdoctest: +SKIP("undefined variables")
  630. >>> model = DDP(FSDP(...))
  631. >>> with FSDP.state_dict_type(
  632. >>> model,
  633. >>> StateDictType.SHARDED_STATE_DICT,
  634. >>> ):
  635. >>> checkpoint = model.state_dict()
  636. Args:
  637. module (torch.nn.Module): Root module.
  638. state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
  639. state_dict_config (Optional[StateDictConfig]): the configuration for the
  640. target ``state_dict_type``.
  641. """
  642. try:
  643. prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
  644. module,
  645. state_dict_type,
  646. state_dict_config,
  647. optim_state_dict_config,
  648. )
  649. yield
  650. except Exception as e:
  651. raise e
  652. FullyShardedDataParallel.set_state_dict_type(
  653. module,
  654. prev_state_dict_settings.state_dict_type,
  655. prev_state_dict_settings.state_dict_config,
  656. prev_state_dict_settings.optim_state_dict_config,
  657. )
  658. def forward(self, *args: Any, **kwargs: Any) -> Any:
  659. """
  660. Runs the forward pass for the wrapped module, inserting FSDP-specific
  661. pre- and post-forward sharding logic.
  662. """
  663. with torch.autograd.profiler.record_function(
  664. "FullyShardedDataParallel.forward"
  665. ):
  666. args, kwargs = _root_pre_forward(self, self, args, kwargs)
  667. unused = None
  668. unshard_fn = functools.partial(_pre_forward_unshard, self, self._handles)
  669. reshard_fn = functools.partial(_post_forward_reshard, self, self._handles)
  670. args, kwargs = _pre_forward(
  671. self, self._handles, unshard_fn, self._fsdp_wrapped_module, args, kwargs
  672. )
  673. for handle in self._handles:
  674. p_assert(
  675. handle.flat_param.device == self.compute_device,
  676. "Expected `FlatParameter` to be on the compute device "
  677. f"{self.compute_device} but got {handle.flat_param.device}",
  678. )
  679. output = self._fsdp_wrapped_module(*args, **kwargs)
  680. return _post_forward(self, self._handles, reshard_fn, self, unused, output)
  681. @staticmethod
  682. @contextlib.contextmanager
  683. def summon_full_params(
  684. module: nn.Module,
  685. recurse: bool = True,
  686. writeback: bool = True,
  687. rank0_only: bool = False,
  688. offload_to_cpu: bool = False,
  689. with_grads: bool = False,
  690. ) -> Generator:
  691. r"""A context manager to expose full params for FSDP instances.
  692. Can be useful *after* forward/backward for a model to get
  693. the params for additional processing or checking. It can take a non-FSDP
  694. module and will summon full params for all contained FSDP modules as
  695. well as their children, depending on the ``recurse`` argument.
  696. .. note:: This can be used on inner FSDPs.
  697. .. note:: This can *not* be used within a forward or backward pass. Nor
  698. can forward and backward be started from within this context.
  699. .. note:: Parameters will revert to their local shards after the context
  700. manager exits, storage behavior is the same as forward.
  701. .. note:: The full parameters can be modified, but only the portion
  702. corresponding to the local param shard will persist after the
  703. context manager exits (unless ``writeback=False``, in which case
  704. changes will be discarded). In the case where FSDP does not shard
  705. the parameters, currently only when ``world_size == 1``, or ``NO_SHARD``
  706. config, the modification is persisted regardless of ``writeback``.
  707. .. note:: This method works on modules which are not FSDP themselves but
  708. may contain multiple independent FSDP units. In that case, the given
  709. arguments will apply to all contained FSDP units.
  710. .. warning:: Note that ``rank0_only=True`` in conjunction with
  711. ``writeback=True`` is not currently supported and will raise an
  712. error. This is because model parameter shapes would be different
  713. across ranks within the context, and writing to them can lead to
  714. inconsistency across ranks when the context is exited.
  715. .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will
  716. result in full parameters being redundantly copied to CPU memory for
  717. GPUs that reside on the same machine, which may incur the risk of
  718. CPU OOM. It is recommended to use ``offload_to_cpu`` with
  719. ``rank0_only=True``.
  720. Args:
  721. recurse (bool, Optional): recursively summon all params for nested
  722. FSDP instances (default: True).
  723. writeback (bool, Optional): if ``False``, modifications to params are
  724. discarded after the context manager exits;
  725. disabling this can be slightly more efficient (default: True)
  726. rank0_only (bool, Optional): if ``True``, full parameters are
  727. materialized on only global rank 0. This means that within the
  728. context, only rank 0 will have full parameters and the other
  729. ranks will have sharded parameters. Note that setting
  730. ``rank0_only=True`` with ``writeback=True`` is not supported,
  731. as model parameter shapes will be different across ranks
  732. within the context, and writing to them can lead to
  733. inconsistency across ranks when the context is exited.
  734. offload_to_cpu (bool, Optional): If ``True``, full parameters are
  735. offloaded to CPU. Note that this offloading currently only
  736. occurs if the parameter is sharded (which is only not the case
  737. for world_size = 1 or ``NO_SHARD`` config). It is recommended
  738. to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid
  739. redundant copies of model parameters being offloaded to the same CPU memory.
  740. with_grads (bool, Optional): If ``True``, gradients are also
  741. unsharded with the parameters. Currently, this is only
  742. supported when passing ``use_orig_params=True`` to the FSDP
  743. constructor and ``offload_to_cpu=False`` to this method.
  744. (Default: ``False``)
  745. """
  746. with _unshard_params(
  747. module, recurse, writeback, rank0_only, offload_to_cpu, with_grads
  748. ):
  749. yield
  750. @contextlib.contextmanager
  751. def _deregister_orig_params_ctx(self):
  752. """
  753. This deregisters the original parameters and exposes the
  754. :class:`FlatParameter` s. If a :class:`FlatParameter` is sharded, then
  755. this refreshes the sharded views before exiting. This method shouuld
  756. only be called when using the original parameters.
  757. """
  758. p_assert(
  759. self._use_orig_params,
  760. "`_deregister_orig_params_ctx()` should only be called when "
  761. "`_use_orig_params=True`",
  762. )
  763. for fsdp_module in traversal_utils._get_fsdp_states(self):
  764. _deregister_orig_params(fsdp_module, fsdp_module)
  765. try:
  766. yield
  767. finally:
  768. for fsdp_module in traversal_utils._get_fsdp_states(self):
  769. _register_orig_params(fsdp_module, fsdp_module)
  770. def _apply(self, *args, **kwargs):
  771. """
  772. When using the original parameters, this deregisters the original
  773. parameters and exposes the :class:`FlatParameter` s before calling
  774. ``_apply()``.
  775. """
  776. # When using the original parameters: Since (1) the `FlatParameter`s
  777. # own the storage and (2) `_apply()` is the subroutine underlying the
  778. # most common storage-changing ops like `to()` and `cuda()`, we
  779. # override `_apply()` to have the storage change directly performed on
  780. # the `FlatParameter`s instead of applying to the original parameters
  781. # and then writing back to the `FlatParameter`s.
  782. context = (
  783. self._deregister_orig_params_ctx()
  784. if self._use_orig_params
  785. else contextlib.suppress()
  786. )
  787. with context:
  788. return super()._apply(*args, **kwargs)
  789. def named_buffers(
  790. self,
  791. *args,
  792. **kwargs,
  793. ) -> Iterator[Tuple[str, torch.Tensor]]:
  794. """
  795. Overrides :meth:`named_buffers()` to intercept buffer names and
  796. remove all occurrences of the FSDP-specific flattened buffer prefix
  797. when inside the :meth:`summon_full_params` context manager.
  798. """
  799. should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
  800. for buffer_name, buffer in super().named_buffers(*args, **kwargs):
  801. if should_clean_name:
  802. # Remove any instances of the FSDP-specific prefix; there can
  803. # be multiple in the case of nested FSDP modules
  804. buffer_name = buffer_name.replace(FSDP_PREFIX, "")
  805. yield (buffer_name, buffer)
  806. def named_parameters(
  807. self,
  808. *args,
  809. **kwargs,
  810. ) -> Iterator[Tuple[str, torch.nn.Parameter]]:
  811. """
  812. Overrides :meth:`named_parameters()` to intercept parameter names and
  813. remove all occurrences of the FSDP-specific flattened parameter prefix
  814. when inside the :meth:`summon_full_params` context manager.
  815. """
  816. should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS
  817. for param_name, param in super().named_parameters(*args, **kwargs):
  818. if should_clean_name:
  819. # Remove any instances of the FSDP-specific prefix; there can
  820. # be multiple in the case of nested FSDP modules
  821. param_name = param_name.replace(FSDP_PREFIX, "")
  822. yield (param_name, param)
  823. def _assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None:
  824. """Assert we are in the given state."""
  825. # Since assert can be turned off and this error checking
  826. # is really important, we use explicit error checking
  827. # and raise a ValueError if needed.
  828. if isinstance(state, TrainingState):
  829. state = [state]
  830. if self.training_state not in state:
  831. msg = (
  832. f"expected to be in states {state} but current state "
  833. f"is {self.training_state}"
  834. )
  835. # In case we are failing in the context of autograd hook, asserting
  836. # may not generate useful msg. So, let's print it to be sure.
  837. if self.rank == 0:
  838. print(f"Asserting FSDP instance is: {self}")
  839. print(f"ERROR: {msg}")
  840. traceback.print_stack()
  841. raise ValueError(msg)
  842. @contextmanager
  843. def no_sync(self) -> Generator:
  844. """
  845. A context manager to disable gradient synchronizations across FSDP
  846. instances. Within this context, gradients will be accumulated in module
  847. variables, which will later be synchronized in the first
  848. forward-backward pass after exiting the context. This should only be
  849. used on the root FSDP instance and will recursively apply to all
  850. children FSDP instances.
  851. .. note:: This likely results in higher memory usage because FSDP will
  852. accumulate the full model gradients (instead of gradient shards)
  853. until the eventual sync.
  854. .. note:: When used with CPU offloading, the gradients will not be
  855. offloaded to CPU when inside the context manager. Instead, they
  856. will only be offloaded right after the eventual sync.
  857. """
  858. _lazy_init(self, self)
  859. if not self._is_root:
  860. raise RuntimeError(
  861. "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module."
  862. )
  863. self._assert_state(TrainingState.IDLE)
  864. old_flags = []
  865. for m in self.modules():
  866. if isinstance(m, FullyShardedDataParallel):
  867. old_flags.append((m, m._sync_gradients))
  868. m._sync_gradients = False
  869. try:
  870. yield
  871. finally:
  872. for m, old_flag in old_flags:
  873. assert not m._sync_gradients, (
  874. "`_sync_gradients` was incorrectly set to "
  875. "`True` while in the `no_sync()` context manager"
  876. )
  877. m._sync_gradients = old_flag
  878. @torch.no_grad()
  879. def clip_grad_norm_(
  880. self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0
  881. ) -> torch.Tensor:
  882. """
  883. Clips the gradient norm of all parameters. The norm is computed over
  884. all parameters' gradients as viewed as a single vector, and the
  885. gradients are modified in-place.
  886. Args:
  887. max_norm (float or int): max norm of the gradients
  888. norm_type (float or int): type of the used p-norm. Can be ``'inf'``
  889. for infinity norm.
  890. Returns:
  891. Total norm of the parameters (viewed as a single vector).
  892. .. note:: If every FSDP instance uses ``NO_SHARD``, meaning that no
  893. gradients are sharded across ranks, then you may directly use
  894. :func:`torch.nn.utils.clip_grad_norm_`.
  895. .. note:: If at least some FSDP instance uses a sharded strategy (i.e.
  896. one other than ``NO_SHARD``), then you should use this method
  897. instead of :func:`torch.nn.utils.clip_grad_norm_` since this method
  898. handles the fact that gradients are sharded across ranks.
  899. .. note:: The total norm returned will have the "largest" dtype across
  900. all parameters/gradients as defined by PyTorch's type promotion
  901. semantics. For example, if *all* parameters/gradients use a low
  902. precision dtype, then the returned norm's dtype will be that low
  903. precision dtype, but if there exists at least one parameter/
  904. gradient using FP32, then the returned norm's dtype will be FP32.
  905. .. warning:: This needs to be called on all ranks since it uses
  906. collective communications.
  907. """
  908. _lazy_init(self, self)
  909. if not self._is_root:
  910. raise RuntimeError(
  911. "`clip_grad_norm_()` should only be called on the root FSDP instance"
  912. )
  913. self._assert_state(TrainingState.IDLE)
  914. # If every FSDP instance uses `NO_SHARD`, then we can directly use
  915. # the normal `nn.utils` one targeting local gradients
  916. all_no_shard = all(
  917. not handle.uses_sharded_strategy
  918. for handle in traversal_utils._get_fsdp_handles(self)
  919. )
  920. if all_no_shard:
  921. return torch.nn.utils.clip_grad_norm_(
  922. self.parameters(), max_norm, norm_type
  923. )
  924. # Otherwise, there exists some FSDP instance using a sharded strategy,
  925. # where sharded and non-sharded parameters must be handled separately
  926. max_norm = float(max_norm)
  927. norm_type = float(norm_type)
  928. sharded_params = set()
  929. nonsharded_params = set() # `NO_SHARD` or not FSDP-managed
  930. grads: List[torch.Tensor] = []
  931. for handle in traversal_utils._get_fsdp_handles(self):
  932. target_set = (
  933. sharded_params if handle.uses_sharded_strategy else nonsharded_params
  934. )
  935. if handle._use_orig_params:
  936. for param in handle.flat_param._params:
  937. target_set.add(param)
  938. if param.grad is not None:
  939. grads.append(param.grad)
  940. else:
  941. target_set.add(handle.flat_param)
  942. if handle.flat_param.grad is not None:
  943. grads.append(handle.flat_param.grad)
  944. for param in self.parameters():
  945. not_fsdp_managed = (
  946. param not in sharded_params and param not in nonsharded_params
  947. )
  948. if not_fsdp_managed:
  949. nonsharded_params.add(param)
  950. if param.grad is not None:
  951. grads.append(param.grad)
  952. # Compute local norms (forced to be in FP32)
  953. local_sharded_norm = _get_grad_norm(sharded_params, norm_type).to(
  954. self.compute_device
  955. )
  956. local_nonsharded_norm = _get_grad_norm(nonsharded_params, norm_type).to(
  957. self.compute_device
  958. )
  959. # Reconstruct the total gradient norm depending on the norm type
  960. if norm_type == math.inf:
  961. total_norm = torch.maximum(local_sharded_norm, local_nonsharded_norm)
  962. dist.all_reduce(
  963. total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group
  964. )
  965. else:
  966. total_norm = local_sharded_norm**norm_type
  967. dist.all_reduce(total_norm, group=self.process_group)
  968. # All-reducing the local non-sharded norm would count it an extra
  969. # world-size-many times
  970. total_norm += local_nonsharded_norm**norm_type
  971. total_norm = total_norm ** (1.0 / norm_type)
  972. if self.cpu_offload.offload_params:
  973. total_norm = total_norm.cpu()
  974. clip_coef = max_norm / (total_norm + 1e-6)
  975. # Multiplying by the clamped coefficient is meaningless when it is
  976. # equal to 1, but it avoids the host-device sync that would result from
  977. # `if clip_coef < 1`
  978. clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
  979. for grad in grads:
  980. grad.detach().mul_(clip_coef_clamped.to(grad.device, grad.dtype))
  981. # Use the "largest" dtype by type promotion semantics to use the same
  982. # dtype as if we did not force local norm computation to be in FP32
  983. if len(grads) == 0:
  984. # If this rank has no gradients, then we must default to FP32
  985. # unless we use additional communication, which we prefer to avoid
  986. # since `clip_grad_norm_()` is called in the training loop
  987. warnings.warn(
  988. f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no "
  989. "gradients -- returning the total norm in the default dtype "
  990. f"{total_norm.dtype}"
  991. ) # warn since this is generally unexpected
  992. return total_norm
  993. total_norm_dtype = functools.reduce(
  994. lambda dtype1, dtype2: torch.promote_types(dtype1, dtype2),
  995. [grad.dtype for grad in grads],
  996. )
  997. return total_norm.to(total_norm_dtype)
  998. @staticmethod
  999. def _warn_optim_input(optim_input):
  1000. if optim_input is not None:
  1001. warnings.warn(
  1002. "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. You may remove it "
  1003. "from your code without changing its functionality."
  1004. )
  1005. @staticmethod
  1006. def _is_using_optim_input(optim_input, optim) -> bool:
  1007. if optim_input is None and optim is None:
  1008. # Use the default behavior of `optim_input``
  1009. return True
  1010. if optim_input is not None:
  1011. # Use the `optim_input` code path
  1012. return True
  1013. # Use the `optim` code path
  1014. return False
  1015. @staticmethod
  1016. def _warn_legacy_optim_state_dict(curr: str, new: str):
  1017. warnings.warn(
  1018. f"``FullyShardedDataParallel.{curr}``is being deprecated and is "
  1019. f"replaced by ``FullyShardedDataParallel.{new}``. "
  1020. f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2."
  1021. )
  1022. @staticmethod
  1023. def _optim_state_dict_impl(
  1024. model: torch.nn.Module,
  1025. optim: torch.optim.Optimizer,
  1026. optim_state_dict: Dict[str, Any],
  1027. optim_input: Optional[
  1028. Union[
  1029. List[Dict[str, Any]],
  1030. Iterable[torch.nn.Parameter],
  1031. ]
  1032. ] = None,
  1033. rank0_only: bool = True,
  1034. full_state_dict: bool = True,
  1035. group: Optional[dist.ProcessGroup] = None,
  1036. ) -> Dict[str, Any]:
  1037. """
  1038. The internal API that is used by all the optim_state_dict implementations.
  1039. Given model, optim, the original optim_state_dict, this API removes the
  1040. FSDP internal information and internal sharding from the optim_state_dict.
  1041. """
  1042. if full_state_dict:
  1043. FullyShardedDataParallel._warn_optim_input(optim_input)
  1044. using_optim_input = FullyShardedDataParallel._is_using_optim_input(
  1045. optim_input,
  1046. optim,
  1047. )
  1048. else:
  1049. using_optim_input = False
  1050. assert optim_input is None and not rank0_only
  1051. use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
  1052. 0
  1053. ]._use_orig_params
  1054. assert all(
  1055. use_orig_params == m._use_orig_params
  1056. for m in FullyShardedDataParallel.fsdp_modules(model)
  1057. ), "Not all FSDP modules have the same _use_orig_params value"
  1058. return _optim_state_dict(
  1059. model=model,
  1060. optim=optim,
  1061. optim_state_dict=optim_state_dict,
  1062. optim_input=optim_input,
  1063. rank0_only=rank0_only,
  1064. shard_state=not full_state_dict,
  1065. group=group,
  1066. using_optim_input=using_optim_input,
  1067. use_orig_params=use_orig_params,
  1068. )
  1069. @staticmethod
  1070. def _optim_state_dict_to_load_impl(
  1071. optim_state_dict: Dict[str, Any],
  1072. model: torch.nn.Module,
  1073. optim_input: Optional[
  1074. Union[
  1075. List[Dict[str, Any]],
  1076. Iterable[torch.nn.Parameter],
  1077. ]
  1078. ] = None,
  1079. optim: Optional[torch.optim.Optimizer] = None,
  1080. full_state_dict: bool = True,
  1081. rank0_only: bool = False,
  1082. is_named_optimizer: bool = False,
  1083. group: Optional[dist.ProcessGroup] = None,
  1084. ) -> Dict[str, Any]:
  1085. """
  1086. The internal API that is used by all the load optim_state_dict
  1087. implementations except for loading optim_state_dict with rank0_only is
  1088. True option.
  1089. Given model, optim, the saved optim_state_dict, this API adds the
  1090. FSDP internal information and internal sharding to the optim_state_dict.
  1091. """
  1092. FullyShardedDataParallel._warn_optim_input(optim_input)
  1093. using_optim_input = FullyShardedDataParallel._is_using_optim_input(
  1094. optim_input,
  1095. optim,
  1096. )
  1097. use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[
  1098. 0
  1099. ]._use_orig_params
  1100. assert all(
  1101. use_orig_params == m._use_orig_params
  1102. for m in FullyShardedDataParallel.fsdp_modules(model)
  1103. ), "Not all FSDP modules have the same _use_orig_params value"
  1104. if rank0_only:
  1105. rank = dist.get_rank(group)
  1106. world_size = dist.get_world_size(group)
  1107. # Flatten the optimizer state dict and construct a copy with the
  1108. # positive-dimension tensors' shapes in place of the tensors themselves
  1109. # since those tensors will be broadcast separately to avoid copying
  1110. if rank == 0:
  1111. flat_osd = _flatten_optim_state_dict(
  1112. optim_state_dict,
  1113. model=model,
  1114. shard_state=False,
  1115. use_orig_params=use_orig_params,
  1116. optim=(optim if is_named_optimizer else None),
  1117. )
  1118. processed_osd = _process_pos_dim_tensor_state(flat_osd, world_size)
  1119. # Broadcast the optim state dict without positive-dimension tensor
  1120. # state and the FSDP parameter IDs from rank 0 to all ranks
  1121. processed_osd = _broadcast_processed_optim_state_dict(
  1122. processed_osd if rank == 0 else None,
  1123. rank,
  1124. group,
  1125. )
  1126. # Broadcast positive-dimension tensor state (both sharded tensors for
  1127. # FSDP parameters and unsharded tensors for non-FSDP parameters)
  1128. broadcast_device = (
  1129. torch.device("cuda")
  1130. if torch.cuda.is_available()
  1131. else torch.device("cpu")
  1132. )
  1133. sharded_osd = _broadcast_pos_dim_tensor_states(
  1134. processed_osd,
  1135. flat_osd if rank == 0 else None,
  1136. rank,
  1137. world_size,
  1138. group,
  1139. broadcast_device,
  1140. )
  1141. # Rekey the optimizer state dict to use parameter IDs according to this
  1142. # rank's `optim`
  1143. ret_state_dict = _rekey_sharded_optim_state_dict(
  1144. sharded_osd,
  1145. model=model,
  1146. optim=optim,
  1147. optim_input=optim_input,
  1148. using_optim_input=using_optim_input,
  1149. is_named_optimizer=is_named_optimizer,
  1150. )
  1151. else:
  1152. sharded_osd = _flatten_optim_state_dict(
  1153. optim_state_dict,
  1154. model=model,
  1155. shard_state=True,
  1156. use_orig_params=use_orig_params,
  1157. optim=(optim if is_named_optimizer else None),
  1158. )
  1159. ret_state_dict = _rekey_sharded_optim_state_dict(
  1160. sharded_osd,
  1161. model=model,
  1162. optim=optim,
  1163. optim_input=optim_input,
  1164. using_optim_input=using_optim_input,
  1165. is_named_optimizer=is_named_optimizer,
  1166. )
  1167. return ret_state_dict
  1168. @staticmethod
  1169. def full_optim_state_dict(
  1170. model: torch.nn.Module,
  1171. optim: torch.optim.Optimizer,
  1172. optim_input: Optional[
  1173. Union[
  1174. List[Dict[str, Any]],
  1175. Iterable[torch.nn.Parameter],
  1176. ]
  1177. ] = None,
  1178. rank0_only: bool = True,
  1179. group: Optional[dist.ProcessGroup] = None,
  1180. ) -> Dict[str, Any]:
  1181. """
  1182. Consolidates the full optimizer state on rank 0 and returns it
  1183. as a :class:`dict` following the convention of
  1184. :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"``
  1185. and ``"param_groups"``. The flattened parameters in ``FSDP`` modules
  1186. contained in ``model`` are mapped back to their unflattened parameters.
  1187. .. warning:: This needs to be called on all ranks since it uses
  1188. collective communications. However, if ``rank0_only=True``, then
  1189. the state dict is only populated on rank 0, and all other ranks
  1190. return an empty :class:`dict`.
  1191. .. warning:: Unlike ``torch.optim.Optimizer.state_dict()``, this method
  1192. uses full parameter names as keys instead of parameter IDs.
  1193. .. note:: Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors
  1194. contained in the optimizer state dict are not cloned, so there may
  1195. be aliasing surprises. For best practices, consider saving the
  1196. returned optimizer state dict immediately, e.g. using
  1197. ``torch.save()``.
  1198. Args:
  1199. model (torch.nn.Module): Root module (which may or may not be a
  1200. :class:`FullyShardedDataParallel` instance) whose parameters
  1201. were passed into the optimizer ``optim``.
  1202. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1203. parameters.
  1204. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
  1205. Input passed into the optimizer ``optim`` representing either a
  1206. :class:`list` of parameter groups or an iterable of parameters;
  1207. if ``None``, then this method assumes the input was
  1208. ``model.parameters()``. This argument is deprecated, and there
  1209. is no need to pass it in anymore. (Default: ``None``)
  1210. rank0_only (bool): If ``True``, saves the populated :class:`dict`
  1211. only on rank 0; if ``False``, saves it on all ranks. (Default:
  1212. ``True``)
  1213. group (dist.ProcessGroup): Model's process group or ``None`` if using
  1214. the default process group. (Default: ``None``)
  1215. Returns:
  1216. Dict[str, Any]: A :class:`dict` containing the optimizer state for
  1217. ``model`` 's original unflattened parameters and including keys
  1218. "state" and "param_groups" following the convention of
  1219. :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``,
  1220. then nonzero ranks return an empty :class:`dict`.
  1221. """
  1222. FullyShardedDataParallel._warn_legacy_optim_state_dict(
  1223. "full_optim_state_dict", "optim_state_dict"
  1224. )
  1225. return FullyShardedDataParallel._optim_state_dict_impl(
  1226. model=model,
  1227. optim=optim,
  1228. optim_state_dict=optim.state_dict(),
  1229. optim_input=optim_input,
  1230. rank0_only=rank0_only,
  1231. group=group,
  1232. full_state_dict=True,
  1233. )
  1234. @staticmethod
  1235. def sharded_optim_state_dict(
  1236. model: torch.nn.Module,
  1237. optim: torch.optim.Optimizer,
  1238. group: Optional[dist.ProcessGroup] = None,
  1239. ) -> Dict[str, Any]:
  1240. """
  1241. The API is similar to :meth:`full_optim_state_dict` but this API chunks
  1242. all non-zero-dimension states to :class:`ShardedTensor` to save memory.
  1243. This API should only be used when the model ``state_dict`` is derived
  1244. with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``.
  1245. For the detailed usage, refer to :meth:`full_optim_state_dict`.
  1246. .. warning:: The returned state dict contains ``ShardedTensor`` and
  1247. cannot be directly used by the regular ``optim.load_state_dict``.
  1248. """
  1249. FullyShardedDataParallel._warn_legacy_optim_state_dict(
  1250. "sharded_optim_state_dict", "optim_state_dict"
  1251. )
  1252. return FullyShardedDataParallel._optim_state_dict_impl(
  1253. model=model,
  1254. optim=optim,
  1255. optim_state_dict=optim.state_dict(),
  1256. optim_input=None,
  1257. rank0_only=False,
  1258. full_state_dict=False,
  1259. group=group,
  1260. )
  1261. @staticmethod
  1262. def shard_full_optim_state_dict(
  1263. full_optim_state_dict: Dict[str, Any],
  1264. model: torch.nn.Module,
  1265. optim_input: Optional[
  1266. Union[
  1267. List[Dict[str, Any]],
  1268. Iterable[torch.nn.Parameter],
  1269. ]
  1270. ] = None,
  1271. optim: Optional[torch.optim.Optimizer] = None,
  1272. ) -> Dict[str, Any]:
  1273. """
  1274. Shards the full optimizer state dict ``full_optim_state_dict`` by
  1275. remapping the state to flattened parameters instead of unflattened
  1276. parameters and restricting to only this rank's part of the optimizer
  1277. state. The first argument should be the return value of
  1278. :meth:`full_optim_state_dict`.
  1279. Example::
  1280. >>> # xdoctest: +SKIP("undefined variables")
  1281. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1282. >>> model, optim = ...
  1283. >>> full_osd = FSDP.full_optim_state_dict(model, optim)
  1284. >>> torch.save(full_osd, PATH)
  1285. >>> # Define new model with possibly different world size
  1286. >>> new_model, new_optim = ...
  1287. >>> full_osd = torch.load(PATH)
  1288. >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model)
  1289. >>> new_optim.load_state_dict(sharded_osd)
  1290. .. note:: Both :meth:`shard_full_optim_state_dict` and
  1291. :meth:`scatter_full_optim_state_dict` may be used to get the
  1292. sharded optimizer state dict to load. Assuming that the full
  1293. optimizer state dict resides in CPU memory, the former requires
  1294. each rank to have the full dict in CPU memory, where each rank
  1295. individually shards the dict without any communication, while the
  1296. latter requires only rank 0 to have the full dict in CPU memory,
  1297. where rank 0 moves each shard to GPU memory (for NCCL) and
  1298. communicates it to ranks appropriately. Hence, the former has
  1299. higher aggregate CPU memory cost, while the latter has higher
  1300. communication cost.
  1301. Args:
  1302. full_optim_state_dict (Dict[str, Any]): Optimizer state dict
  1303. corresponding to the unflattened parameters and holding the
  1304. full non-sharded optimizer state.
  1305. model (torch.nn.Module): Root module (which may or may not be a
  1306. :class:`FullyShardedDataParallel` instance) whose parameters
  1307. correspond to the optimizer state in ``full_optim_state_dict``.
  1308. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
  1309. Input passed into the optimizer representing either a
  1310. :class:`list` of parameter groups or an iterable of parameters;
  1311. if ``None``, then this method assumes the input was
  1312. ``model.parameters()``. This argument is deprecated, and there
  1313. is no need to pass it in anymore. (Default: ``None``)
  1314. optim (Optional[torch.optim.Optimizer]): Optimizer that will load
  1315. the state dict returned by this method. This is the preferred
  1316. argument to use over ``optim_input``. (Default: ``None``)
  1317. Returns:
  1318. Dict[str, Any]: The full optimizer state dict now remapped to
  1319. flattened parameters instead of unflattened parameters and
  1320. restricted to only include this rank's part of the optimizer state.
  1321. """
  1322. FullyShardedDataParallel._warn_legacy_optim_state_dict(
  1323. "shard_full_optim_state_dict", "optim_state_dict_to_load"
  1324. )
  1325. return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  1326. optim_state_dict=full_optim_state_dict,
  1327. model=model,
  1328. optim_input=optim_input,
  1329. optim=optim,
  1330. full_state_dict=True,
  1331. is_named_optimizer=False,
  1332. )
  1333. @staticmethod
  1334. def flatten_sharded_optim_state_dict(
  1335. sharded_optim_state_dict: Dict[str, Any],
  1336. model: torch.nn.Module,
  1337. optim: torch.optim.Optimizer,
  1338. ) -> Dict[str, Any]:
  1339. """
  1340. The API is similar to :meth:`shard_full_optim_state_dict`. The only
  1341. difference is that the input ``sharded_optim_state_dict`` should be
  1342. returned from :meth:`sharded_optim_state_dict`. Therefore, there will
  1343. be all-gather calls on each rank to gather ``ShardedTensor`` s.
  1344. Args:
  1345. sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict
  1346. corresponding to the unflattened parameters and holding the
  1347. sharded optimizer state.
  1348. model (torch.nn.Module):
  1349. Refer to :meth:``shard_full_optim_state_dict``.
  1350. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1351. parameters.
  1352. Returns:
  1353. Refer to :meth:`shard_full_optim_state_dict`.
  1354. """
  1355. FullyShardedDataParallel._warn_legacy_optim_state_dict(
  1356. "flatten_sharded_optim_state_dict", "optim_state_dict_to_load"
  1357. )
  1358. return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  1359. optim_state_dict=sharded_optim_state_dict,
  1360. model=model,
  1361. optim_input=None,
  1362. optim=optim,
  1363. full_state_dict=False,
  1364. is_named_optimizer=False,
  1365. )
  1366. @staticmethod
  1367. def scatter_full_optim_state_dict(
  1368. full_optim_state_dict: Optional[Dict[str, Any]],
  1369. model: torch.nn.Module,
  1370. optim_input: Optional[
  1371. Union[
  1372. List[Dict[str, Any]],
  1373. Iterable[torch.nn.Parameter],
  1374. ]
  1375. ] = None,
  1376. optim: Optional[torch.optim.Optimizer] = None,
  1377. group: Optional[Any] = None,
  1378. ) -> Dict[str, Any]:
  1379. """
  1380. Scatters the full optimizer state dict from rank 0 to all other ranks,
  1381. returning the sharded optimizer state dict on each rank. The return
  1382. value is the same as :meth:`shard_full_optim_state_dict`, and on rank
  1383. 0, the first argument should be the return value of
  1384. :meth:`full_optim_state_dict`.
  1385. Example::
  1386. >>> # xdoctest: +SKIP("undefined variables")
  1387. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1388. >>> model, optim = ...
  1389. >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0
  1390. >>> # Define new model with possibly different world size
  1391. >>> new_model, new_optim, new_group = ...
  1392. >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group)
  1393. >>> new_optim.load_state_dict(sharded_osd)
  1394. .. note:: Both :meth:`shard_full_optim_state_dict` and
  1395. :meth:`scatter_full_optim_state_dict` may be used to get the
  1396. sharded optimizer state dict to load. Assuming that the full
  1397. optimizer state dict resides in CPU memory, the former requires
  1398. each rank to have the full dict in CPU memory, where each rank
  1399. individually shards the dict without any communication, while the
  1400. latter requires only rank 0 to have the full dict in CPU memory,
  1401. where rank 0 moves each shard to GPU memory (for NCCL) and
  1402. communicates it to ranks appropriately. Hence, the former has
  1403. higher aggregate CPU memory cost, while the latter has higher
  1404. communication cost.
  1405. Args:
  1406. full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state
  1407. dict corresponding to the unflattened parameters and holding
  1408. the full non-sharded optimizer state if on rank 0; the argument
  1409. is ignored on nonzero ranks.
  1410. model (torch.nn.Module): Root module (which may or may not be a
  1411. :class:`FullyShardedDataParallel` instance) whose parameters
  1412. correspond to the optimizer state in ``full_optim_state_dict``.
  1413. optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]):
  1414. Input passed into the optimizer representing either a
  1415. :class:`list` of parameter groups or an iterable of parameters;
  1416. if ``None``, then this method assumes the input was
  1417. ``model.parameters()``. This argument is deprecated, and there
  1418. is no need to pass it in anymore. (Default: ``None``)
  1419. optim (Optional[torch.optim.Optimizer]): Optimizer that will load
  1420. the state dict returned by this method. This is the preferred
  1421. argument to use over ``optim_input``. (Default: ``None``)
  1422. group (dist.ProcessGroup): Model's process group or ``None`` if
  1423. using the default process group. (Default: ``None``)
  1424. Returns:
  1425. Dict[str, Any]: The full optimizer state dict now remapped to
  1426. flattened parameters instead of unflattened parameters and
  1427. restricted to only include this rank's part of the optimizer state.
  1428. """
  1429. FullyShardedDataParallel._warn_legacy_optim_state_dict(
  1430. "scatter_full_optim_state_dict", "optim_state_dict_to_load"
  1431. )
  1432. return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  1433. optim_state_dict=full_optim_state_dict,
  1434. model=model,
  1435. optim_input=optim_input,
  1436. optim=optim,
  1437. full_state_dict=True,
  1438. rank0_only=True,
  1439. is_named_optimizer=False,
  1440. group=group,
  1441. )
  1442. @staticmethod
  1443. def rekey_optim_state_dict(
  1444. optim_state_dict: Dict[str, Any],
  1445. optim_state_key_type: OptimStateKeyType,
  1446. model: torch.nn.Module,
  1447. optim_input: Optional[
  1448. Union[
  1449. List[Dict[str, Any]],
  1450. Iterable[torch.nn.Parameter],
  1451. ]
  1452. ] = None,
  1453. optim: Optional[torch.optim.Optimizer] = None,
  1454. ) -> Dict[str, Any]:
  1455. """
  1456. Re-keys the optimizer state dict ``optim_state_dict`` to use the key
  1457. type ``optim_state_key_type``. This can be used to achieve
  1458. compatibility between optimizer state dicts from models with FSDP
  1459. instances and ones without.
  1460. To re-key an FSDP full optimizer state dict (i.e. from
  1461. :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to
  1462. a non-wrapped model::
  1463. >>> # xdoctest: +SKIP("undefined variables")
  1464. >>> wrapped_model, wrapped_optim = ...
  1465. >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
  1466. >>> nonwrapped_model, nonwrapped_optim = ...
  1467. >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
  1468. >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
  1469. To re-key a normal optimizer state dict from a non-wrapped model to be
  1470. loadable to a wrapped model::
  1471. >>> # xdoctest: +SKIP("undefined variables")
  1472. >>> nonwrapped_model, nonwrapped_optim = ...
  1473. >>> osd = nonwrapped_optim.state_dict()
  1474. >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
  1475. >>> wrapped_model, wrapped_optim = ...
  1476. >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
  1477. >>> wrapped_optim.load_state_dict(sharded_osd)
  1478. Returns:
  1479. Dict[str, Any]: The optimizer state dict re-keyed using the
  1480. parameter keys specified by ``optim_state_key_type``.
  1481. """
  1482. FullyShardedDataParallel._warn_optim_input(optim_input)
  1483. using_optim_input = FullyShardedDataParallel._is_using_optim_input(
  1484. optim_input,
  1485. optim,
  1486. )
  1487. assert optim_state_key_type in (
  1488. OptimStateKeyType.PARAM_NAME,
  1489. OptimStateKeyType.PARAM_ID,
  1490. )
  1491. osd = optim_state_dict # alias
  1492. # Validate that the existing parameter keys are uniformly typed
  1493. uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]]
  1494. uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]]
  1495. if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or (
  1496. any(uses_param_id_mask) and not all(uses_param_id_mask)
  1497. ):
  1498. error_msg = f"Invalid parameter keys: {osd['state'].keys()}"
  1499. raise ValueError(error_msg)
  1500. # Return directly if the existing key type matches the target key type
  1501. if (
  1502. optim_state_key_type == OptimStateKeyType.PARAM_NAME
  1503. and all(uses_param_name_mask)
  1504. ) or (
  1505. optim_state_key_type == OptimStateKeyType.PARAM_ID
  1506. and all(uses_param_id_mask)
  1507. ):
  1508. return osd
  1509. # Otherwise, actually perform the re-keying
  1510. new_osd = {}
  1511. if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name
  1512. param_id_to_param = (
  1513. _get_param_id_to_param_from_optim_input(model, optim_input)
  1514. if using_optim_input
  1515. else _get_param_key_to_param(optim)
  1516. )
  1517. param_to_param_name = _get_param_to_fqn(model)
  1518. param_id_to_param_name: List[str] = [
  1519. param_to_param_name[param] for param in param_id_to_param.values()
  1520. ]
  1521. new_osd["state"] = {
  1522. param_id_to_param_name[param_id]: param_state
  1523. for param_id, param_state in osd["state"].items()
  1524. }
  1525. new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
  1526. for param_group in new_osd["param_groups"]:
  1527. param_group["params"] = sorted(
  1528. [
  1529. param_id_to_param_name[param_id]
  1530. for param_id in param_group["params"]
  1531. ]
  1532. )
  1533. return new_osd
  1534. elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID
  1535. param_name_to_param = _get_fqn_to_param(model)
  1536. param_to_param_id = (
  1537. _get_param_to_param_id_from_optim_input(model, optim_input)
  1538. if using_optim_input
  1539. else _get_param_to_param_key(optim)
  1540. )
  1541. # Because not all model parameters may be passed as the optimizer
  1542. # input, we may need to drop some parameters from this mapping
  1543. param_name_to_param_id = {
  1544. param_name: param_to_param_id[param]
  1545. for param_name, param in param_name_to_param.items()
  1546. if param in param_to_param_id
  1547. }
  1548. new_osd["state"] = {
  1549. param_name_to_param_id[param_name]: param_state
  1550. for param_name, param_state in osd["state"].items()
  1551. }
  1552. new_osd["param_groups"] = copy.deepcopy(osd["param_groups"])
  1553. for param_group in new_osd["param_groups"]:
  1554. param_group["params"] = sorted(
  1555. [
  1556. param_name_to_param_id[param_name]
  1557. for param_name in param_group["params"]
  1558. ]
  1559. )
  1560. return new_osd
  1561. return new_osd # should never reach here
  1562. @staticmethod
  1563. def optim_state_dict(
  1564. model: torch.nn.Module,
  1565. optim: torch.optim.Optimizer,
  1566. group: Optional[dist.ProcessGroup] = None,
  1567. ) -> Dict[str, Any]:
  1568. """
  1569. Returns the state dict of ``optim`` for the ``model`` that is (partially)
  1570. sharded by FSDP. The state may be sharded, consolidated, or consolidated
  1571. on rank 0 only depending on the ``state_dict_type`` set by
  1572. :meth:`set_state_dict_type` or :meth:`state_dict_type`.
  1573. Example::
  1574. >>> # xdoctest: +SKIP("undefined variables")
  1575. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1576. >>> from torch.distributed.fsdp import StateDictType
  1577. >>> from torch.distributed.fsdp import FullStateDictConfig
  1578. >>> from torch.distributed.fsdp import FullOptimStateDictConfig
  1579. >>> # Save a checkpoint
  1580. >>> model, optim = ...
  1581. >>> FSDP.set_state_dict_type(
  1582. >>> model,
  1583. >>> StateDictType.FULL_STATE_DICT,
  1584. >>> FullStateDictConfig(rank0_only=False),
  1585. >>> FullOptimStateDictConfig(rank0_only=False),
  1586. >>> )
  1587. >>> state_dict = model.state_dict()
  1588. >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
  1589. >>> save_a_checkpoint(state_dict, optim_state_dict)
  1590. >>> # Load a checkpoint
  1591. >>> model, optim = ...
  1592. >>> state_dict, optim_state_dict = load_a_checkponit()
  1593. >>> FSDP.set_state_dict_type(
  1594. >>> model,
  1595. >>> StateDictType.FULL_STATE_DICT,
  1596. >>> FullStateDictConfig(rank0_only=False),
  1597. >>> FullOptimStateDictConfig(rank0_only=False),
  1598. >>> )
  1599. >>> model.load_state_dict(state_dict)
  1600. >>> optim_state_dict = FSDP.optim_state_dict_to_load(
  1601. >>> optim_state_dict, model, optim
  1602. >>> )
  1603. >>> optim.load_state_dict(optim_state_dict)
  1604. Args:
  1605. model (torch.nn.Module): Root module (which may or may not be a
  1606. :class:`FullyShardedDataParallel` instance) whose parameters
  1607. were passed into the optimizer ``optim``.
  1608. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1609. parameters.
  1610. group (dist.ProcessGroup): Model's process group across which parameters
  1611. are sharded or ``None`` if using the default process group. (
  1612. Default: ``None``)
  1613. Returns:
  1614. Dict[str, Any]: A :class:`dict` containing the optimizer state for
  1615. ``model``. The sharding of the optimizer state is based on
  1616. ``state_dict_type``.
  1617. """
  1618. state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
  1619. return FullyShardedDataParallel._optim_state_dict_impl(
  1620. model=model,
  1621. optim=optim,
  1622. optim_state_dict=optim.state_dict(),
  1623. optim_input=None,
  1624. rank0_only=getattr(state_dict_settings, "rank0_only", False),
  1625. full_state_dict=state_dict_settings.state_dict_type
  1626. == StateDictType.FULL_STATE_DICT,
  1627. group=group,
  1628. )
  1629. @staticmethod
  1630. def optim_state_dict_post_hook(
  1631. model: torch.nn.Module,
  1632. optim: torch.optim.Optimizer,
  1633. optim_state_dict: Dict[str, Any],
  1634. group: Optional[dist.ProcessGroup] = None,
  1635. ) -> Dict[str, Any]:
  1636. """
  1637. This hook is intended be used by ``torch.distributed.NamedOptimizer``.
  1638. The functionaility is identical to ``:meth:optim_state_dict`` except
  1639. for the different arguments.
  1640. Args:
  1641. model (torch.nn.Module): Root module (which may or may not be a
  1642. :class:`FullyShardedDataParallel` instance) whose parameters
  1643. were passed into the optimizer ``optim``.
  1644. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1645. parameters.
  1646. optim (Dict[str, Any]: the optim_state_dict to be coverted. The value
  1647. is typically returned by ``NamedOptimizer.state_dict()``.
  1648. group (dist.ProcessGroup): Model's process group across which parameters
  1649. are sharded or ``None`` if using the default process group. (
  1650. Default: ``None``)
  1651. Returns:
  1652. Dict[str, Any]: A :class:`dict` containing the optimizer state for
  1653. ``model``. The sharding of the optimizer state is based on
  1654. ``state_dict_type``.
  1655. """
  1656. state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
  1657. return FullyShardedDataParallel._optim_state_dict_impl(
  1658. model=model,
  1659. optim=optim,
  1660. optim_state_dict=optim_state_dict,
  1661. optim_input=None,
  1662. rank0_only=getattr(state_dict_settings, "rank0_only", False),
  1663. full_state_dict=state_dict_settings.state_dict_type
  1664. == StateDictType.FULL_STATE_DICT,
  1665. group=None,
  1666. )
  1667. @staticmethod
  1668. def optim_state_dict_to_load(
  1669. optim_state_dict: Dict[str, Any],
  1670. model: torch.nn.Module,
  1671. optim: torch.optim.Optimizer,
  1672. is_named_optimizer: bool = False,
  1673. group: Optional[dist.ProcessGroup] = None,
  1674. ) -> Dict[str, Any]:
  1675. """
  1676. Given a saved ``optim_state_dict``, converts it to the optimizer state_dict
  1677. that can be loaded to ``optim`` which is the optimizer for ``model``.
  1678. ``model`` is (partially) sharded by FullyShardedDataParallel.
  1679. >>> # xdoctest: +SKIP("undefined variables")
  1680. >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  1681. >>> from torch.distributed.fsdp import StateDictType
  1682. >>> from torch.distributed.fsdp import FullStateDictConfig
  1683. >>> from torch.distributed.fsdp import FullOptimStateDictConfig
  1684. >>> # Save a checkpoint
  1685. >>> model, optim = ...
  1686. >>> FSDP.set_state_dict_type(
  1687. >>> model,
  1688. >>> StateDictType.FULL_STATE_DICT,
  1689. >>> FullStateDictConfig(rank0_only=False),
  1690. >>> FullOptimStateDictConfig(rank0_only=False),
  1691. >>> )
  1692. >>> state_dict = model.state_dict()
  1693. >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
  1694. >>> save_a_checkpoint(state_dict, optim_state_dict)
  1695. >>> # Load a checkpoint
  1696. >>> model, optim = ...
  1697. >>> state_dict, optim_state_dict = load_a_checkponit()
  1698. >>> FSDP.set_state_dict_type(
  1699. >>> model,
  1700. >>> StateDictType.FULL_STATE_DICT,
  1701. >>> FullStateDictConfig(rank0_only=False),
  1702. >>> FullOptimStateDictConfig(rank0_only=False),
  1703. >>> )
  1704. >>> model.load_state_dict(state_dict)
  1705. >>> optim_state_dict = FSDP.optim_state_dict_to_load(
  1706. >>> optim_state_dict, model, optim
  1707. >>> )
  1708. >>> optim.load_state_dict(optim_state_dict)
  1709. Args:
  1710. optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
  1711. model (torch.nn.Module): Root module (which may or may not be a
  1712. :class:`FullyShardedDataParallel` instance) whose parameters
  1713. were passed into the optimizer ``optim``.
  1714. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1715. parameters.
  1716. is_named_optimizer (bool): Is this optimizer a NamedOptimizer or
  1717. KeyedOptimizer. Only set to True if ``optim`` is TorchRec's
  1718. KeyedOptimizer or torch.distributed's NamedOptimizer.
  1719. group (dist.ProcessGroup): Model's process group across which parameters
  1720. are sharded or ``None`` if using the default process group. (
  1721. Default: ``None``)
  1722. """
  1723. state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
  1724. return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  1725. optim_state_dict=optim_state_dict,
  1726. model=model,
  1727. optim_input=None,
  1728. optim=optim,
  1729. full_state_dict=(
  1730. state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT
  1731. ),
  1732. rank0_only=getattr(state_dict_settings, "rank0_only", False),
  1733. is_named_optimizer=is_named_optimizer,
  1734. group=group,
  1735. )
  1736. @staticmethod
  1737. def load_optim_state_dict_pre_hook(
  1738. model: torch.nn.Module,
  1739. optim: torch.optim.Optimizer,
  1740. optim_state_dict: Dict[str, Any],
  1741. group: Optional[dist.ProcessGroup] = None,
  1742. ) -> Dict[str, Any]:
  1743. """
  1744. This hook is intended be used by ``torch.distributed.NamedOptimizer``.
  1745. The functionaility is identical to ``:meth:optim_state_dict_to_load``
  1746. except for the different arguments.
  1747. Args:
  1748. model (torch.nn.Module): Root module (which may or may not be a
  1749. :class:`FullyShardedDataParallel` instance) whose parameters
  1750. were passed into the optimizer ``optim``.
  1751. optim (torch.optim.Optimizer): Optimizer for ``model`` 's
  1752. parameters.
  1753. optim_state_dict (Dict[str, Any]): The optimizer states to be loaded.
  1754. group (dist.ProcessGroup): Model's process group across which parameters
  1755. are sharded or ``None`` if using the default process group. (
  1756. Default: ``None``)
  1757. """
  1758. state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model)
  1759. return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  1760. optim_state_dict=optim_state_dict,
  1761. model=model,
  1762. optim_input=None,
  1763. optim=optim,
  1764. full_state_dict=state_dict_settings.state_dict_type
  1765. == StateDictType.FULL_STATE_DICT,
  1766. is_named_optimizer=True,
  1767. group=group,
  1768. )
  1769. def register_comm_hook(self, state: object, hook: callable):
  1770. """
  1771. Registers a communication hook which is an enhancement that provides a
  1772. flexible hook to users where they can specify how FSDP aggregates gradients
  1773. across multiple workers.
  1774. This hook can be used to implement several algorithms like
  1775. `GossipGrad <https://arxiv.org/abs/1803.05880>`_ and gradient compression
  1776. which involve different communication strategies for
  1777. parameter syncs while training with :class:`FullyShardedDataParallel`.
  1778. .. warning ::
  1779. FSDP communication hook should be registered before running an initial forward pass
  1780. and only once.
  1781. Args:
  1782. state (object): Passed to the hook to maintain any state information during the training process.
  1783. Examples include error feedback in gradient compression,
  1784. peers to communicate with next in `GossipGrad <https://arxiv.org/abs/1803.05880>`_, etc.
  1785. It is locally stored by each worker
  1786. and shared by all the gradient tensors on the worker.
  1787. hook (Callable): Callable, which has one of the following signatures:
  1788. 1) ``hook: Callable[torch.Tensor] -> None``:
  1789. This function takes in a Python tensor, which represents
  1790. the full, flattened, unsharded gradient with respect to all variables
  1791. corresponding to the model this FSDP unit is wrapping
  1792. (that are not wrapped by other FSDP sub-units).
  1793. It then performs all necessary processing and returns ``None``;
  1794. 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``:
  1795. This function takes in two Python tensors, the first one represents
  1796. the full, flattened, unsharded gradient with respect to all variables
  1797. corresponding to the model this FSDP unit is wrapping
  1798. (that are not wrapped by other FSDP sub-units). The latter
  1799. represents a pre-sized tensor to store a chunk of a sharded gradient after
  1800. reduction.
  1801. In both cases, callable performs all necessary processing and returns ``None``.
  1802. Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case.
  1803. Callables with signature 2 are expected to handle gradient communication for sharded cases.
  1804. """
  1805. if not self.check_is_root():
  1806. raise AssertionError(
  1807. "register_comm_hook can only be called on a root instance."
  1808. )
  1809. for submodule in traversal_utils._get_fsdp_states(self):
  1810. assert (
  1811. not submodule._hook_registered
  1812. ), "communication hook can be only registered once"
  1813. submodule._hook_registered = True
  1814. assert submodule._communication_hook == _get_default_comm_hook(
  1815. self.sharding_strategy
  1816. ), f"communication hook should be default, but it is {submodule._communication_hook.__name__} instead"
  1817. submodule._communication_hook_state = state
  1818. submodule._communication_hook = hook
  1819. def _get_grad_norm(
  1820. params: Iterable[nn.Parameter],
  1821. norm_type: float,
  1822. ) -> torch.Tensor:
  1823. """
  1824. Returns the gradient norm of parameters ``param`` s, where the gradients
  1825. are viewed as a single vector. The returned norm is in FP32 even if
  1826. parameters/gradients are in a low precision. This is because the downstream
  1827. use of this return value is a reduction across ranks.
  1828. """
  1829. params_with_grad = [param for param in params if param.grad is not None]
  1830. if len(params_with_grad) == 0:
  1831. return torch.tensor(0.0)
  1832. grads = [param.grad for param in params_with_grad]
  1833. grad_dtypes = {grad.dtype for grad in grads}
  1834. if len(grad_dtypes) != 1:
  1835. raise ValueError(
  1836. f"Requires uniform dtype across all gradients but got {grad_dtypes}"
  1837. )
  1838. # Compute the gradient norm in FP32, where we treat the gradients as a
  1839. # single vector
  1840. grad_norm = torch.linalg.vector_norm(
  1841. torch.stack(
  1842. [
  1843. torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32)
  1844. for grad in grads
  1845. ],
  1846. ),
  1847. norm_type,
  1848. dtype=torch.float32,
  1849. )
  1850. return grad_norm
  1851. def _get_param_to_fqn(
  1852. model: torch.nn.Module,
  1853. ) -> Dict[torch.nn.Parameter, str]:
  1854. """
  1855. Constructs a mapping from parameters to their parameter names. ``model``
  1856. should not contain any :class:`FullyShardedDataParallel` instances, which
  1857. means that none of the parameters should be ``FlatParameter`` s. As a
  1858. result, compared to :meth:`_get_param_to_fqns`, the mapped
  1859. values may be flattened from singleton :class:`list` s to the contained
  1860. names themselves.
  1861. Args:
  1862. model (torch.nn.Module): Root module, which should not contain any
  1863. :class:`FullyShardedDataParallel` instances.
  1864. """
  1865. param_to_param_names = _get_param_to_fqns(model)
  1866. for param_names in param_to_param_names.values():
  1867. assert len(param_names) > 0, (
  1868. "`_get_param_to_fqns()` " "should not construct empty lists"
  1869. )
  1870. if len(param_names) > 1:
  1871. raise RuntimeError(
  1872. "Each parameter should only map to one parameter name but got "
  1873. f"{len(param_names)}: {param_names}"
  1874. )
  1875. param_to_param_name = {
  1876. param: param_names[0] for param, param_names in param_to_param_names.items()
  1877. }
  1878. return param_to_param_name
  1879. def _get_fqn_to_param(
  1880. model: torch.nn.Module,
  1881. ) -> Dict[str, torch.nn.Parameter]:
  1882. """Constructs the inverse mapping of :meth:`_get_param_to_fqn`."""
  1883. param_to_param_name = _get_param_to_fqn(model)
  1884. return dict(zip(param_to_param_name.values(), param_to_param_name.keys()))