_runtime_utils.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377
  1. import functools
  2. from typing import (
  3. Any,
  4. Callable,
  5. Dict,
  6. Iterable,
  7. List,
  8. no_type_check,
  9. Optional,
  10. Set,
  11. Tuple,
  12. )
  13. import torch
  14. import torch.distributed.fsdp._traversal_utils as traversal_utils
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from torch.autograd import Variable
  18. from torch.distributed.algorithms._comm_hooks import default_hooks, LOW_PRECISION_HOOKS
  19. from torch.distributed.fsdp._common_utils import (
  20. _assert_in_training_states,
  21. _FSDPState,
  22. _get_module_fsdp_state,
  23. _get_sharding_strategy,
  24. _is_composable,
  25. TrainingState,
  26. )
  27. from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
  28. from torch.distributed.fsdp._utils import (
  29. _apply_to_tensors,
  30. _no_dispatch_record_stream,
  31. p_assert,
  32. )
  33. from torch.distributed.fsdp.api import BackwardPrefetch
  34. from torch.distributed.fsdp.flat_param import (
  35. _HandlesKey,
  36. FlatParameter,
  37. FlatParamHandle,
  38. HandleShardingStrategy,
  39. HandleTrainingState,
  40. )
  41. from torch.distributed.utils import _to_kwargs
  42. RESHARD_AFTER_FORWARD_STRATEGIES = {
  43. HandleShardingStrategy.FULL_SHARD,
  44. HandleShardingStrategy.HYBRID_SHARD,
  45. }
  46. # Do not include "process_group" to enable hybrid shard and MoE cases
  47. HOMOGENEOUS_ATTR_NAMES = (
  48. "_use_orig_params",
  49. "limit_all_gathers",
  50. )
  51. def _get_fsdp_root_states_with_modules(
  52. module: nn.Module,
  53. ) -> Tuple[List[_FSDPState], List[nn.Module]]:
  54. """
  55. Returns a tuple containing:
  56. 1. A list of the root ``_FSDPState`` instances in the module tree rooted at
  57. ``module`` without any duplicates and following the ``module.modules()``
  58. traversal order (which is assumed to be depth-first).
  59. 2. A corresponding list of the root modules owning the states in the first
  60. list.
  61. This is similar to :func:`_get_fsdp_states_with_modules` except that we
  62. must call :func:`_is_fsdp_root` to force a lazy initialization to determine
  63. the FSDP root in case lazy initialization has not yet happened.
  64. """
  65. fsdp_root_states: List[_FSDPState] = []
  66. fsdp_root_modules: List[nn.Module] = []
  67. visited_fsdp_states: Set[_FSDPState] = set()
  68. # NOTE: This function assumes that `module.modules()` proceeds top-down.
  69. for submodule in module.modules():
  70. optional_state = _get_module_fsdp_state(submodule)
  71. if (
  72. optional_state is not None
  73. and optional_state not in visited_fsdp_states
  74. and _is_fsdp_root(optional_state, submodule)
  75. ):
  76. visited_fsdp_states.add(optional_state)
  77. fsdp_root_states.append(optional_state)
  78. fsdp_root_modules.append(submodule)
  79. return fsdp_root_states, fsdp_root_modules
  80. def _get_fsdp_root_states(module: nn.Module) -> List[_FSDPState]:
  81. """See :func:`_get_fsdp_root_states_with_modules`."""
  82. fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module)
  83. return fsdp_root_states
  84. def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool:
  85. """
  86. Returns if ``state`` corresponds to that of an FSDP root.
  87. For the wrapper code path, ``state`` and ``module`` should be the same. For
  88. the non-wrapper code path, ``state`` should be ``module`` 's state.
  89. """
  90. # Force a lazy initialization to determine the FSDP root
  91. _lazy_init(state, module)
  92. assert state._is_root is not None # mypy
  93. return state._is_root
  94. @no_type_check
  95. def _validate_and_get_hybrid_shard_state(
  96. root_module: nn.Module,
  97. ) -> default_hooks.DefaultState:
  98. """
  99. Precondition: ``root_module`` is a ``FullyShardedDataParallel`` instance.
  100. This checks that all instances using a hybrid sharding strategy have the
  101. same intra- and inter-node process groups.
  102. Returns:
  103. DefaultState: One of the instances' inter-node state (does not
  104. matter which since they will share the same one).
  105. """
  106. intra_node_pgs = set()
  107. inter_node_pgs = set()
  108. inter_node_states = set()
  109. for fsdp_module in traversal_utils._get_fsdp_states(root_module):
  110. # TODO: Change this to handle's sharding strategy if we deprecate
  111. # `ShardingStrategy` internally.
  112. # https://github.com/pytorch/pytorch/issues/90857
  113. if fsdp_module.sharding_strategy in HYBRID_SHARDING_STRATEGIES:
  114. intra_node_pgs.add(fsdp_module.process_group)
  115. inter_node_pgs.add(fsdp_module._inter_node_pg)
  116. inter_node_states.add(fsdp_module._inter_node_state)
  117. if len(intra_node_pgs) == 0 and len(inter_node_pgs) == 0:
  118. # No instances use a hybrid sharding strategy
  119. return None
  120. error_prefix = "At least one instance uses a hybrid sharding strategy but has no "
  121. if len(intra_node_pgs) > 0 and len(inter_node_pgs) == 0:
  122. raise AssertionError(error_prefix + "inter-node proces group set")
  123. if len(intra_node_pgs) == 0 and len(inter_node_pgs) > 0:
  124. raise AssertionError(error_prefix + "intra-node process group set")
  125. error_prefix = "Some instances use a hybrid sharding strategy, but "
  126. if len(intra_node_pgs) != 1:
  127. raise ValueError(error_prefix + "intra-node process groups do not match")
  128. if len(inter_node_pgs) != 1:
  129. raise ValueError(error_prefix + "inter-node process groups do not match")
  130. return next(iter(inter_node_states))
  131. @no_type_check
  132. def _lazy_init(
  133. state: _FSDPState,
  134. root_module: nn.Module,
  135. ) -> _FSDPState:
  136. """
  137. Performs initialization lazily, typically right before the first forward
  138. pass. The laziness is needed to ensure that the parameter device/dtype and
  139. the FSDP hierarchy have finalized. This method's actual logic only runs on
  140. the root FSDP instance, which performs initialization for all non-root FSDP
  141. instances to avoid partial initialization.
  142. For the non-composable code path, ``state`` and ``root_module`` should be
  143. the same, namely the FSDP instance itself.
  144. """
  145. if state._is_root is not None:
  146. return # no-op: already lazily initialized
  147. if not torch.cuda.is_available():
  148. # Allow the FSDP constructor to run even without CUDA but check this
  149. # once we start real execution
  150. raise RuntimeError("FSDP does not support CPU only execution")
  151. # The following logic is only run on the root FSDP instance since it will
  152. # set `_is_root=False` for the non-root instances
  153. state._is_root = True
  154. _assert_in_training_states(state, [TrainingState.IDLE])
  155. _check_flat_params_on_expected_device(state, root_module)
  156. _init_streams(state)
  157. buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module)
  158. _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device)
  159. state._exec_order_data.init(state, root_module, state.process_group)
  160. _share_state_and_init_handle_attrs(state, root_module)
  161. return state
  162. def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module):
  163. """
  164. Checks that all ``FlatParameter``s in ``module`` 's tree managed by
  165. ``state`` are on the expected device for *lazy initialization*.
  166. """
  167. cpu_device = torch.device("cpu")
  168. for handle in traversal_utils._get_fsdp_handles(module):
  169. if (
  170. not handle._offload_params
  171. and handle.flat_param.device != state.compute_device
  172. ):
  173. raise RuntimeError(
  174. "An FSDP-managed module unexpectedly has parameters on "
  175. f"{handle.flat_param.device}. Make sure to move the module to "
  176. f"{state.compute_device} before training."
  177. )
  178. elif handle._offload_params and handle.flat_param.device != cpu_device:
  179. raise RuntimeError(
  180. "An FSDP-managed module with parameter CPU offloading enabled "
  181. f"has parameters on {handle.flat_param.device}. Make sure to "
  182. f"not move the module from CPU when offloading parameters."
  183. )
  184. @no_type_check
  185. def _share_state_and_init_handle_attrs(
  186. root_state: _FSDPState,
  187. root_module: nn.Module,
  188. ) -> None:
  189. """
  190. Shares data structure state from the ``root_state`` to all FSDP states in
  191. ``root_module`` 's module tree, and initializes handle attributes. These
  192. are done together to require a single loop over the states.
  193. """
  194. for handle in root_state._handles:
  195. handle.init_flat_param_attributes()
  196. inter_node_state = _validate_and_get_hybrid_shard_state(root_module)
  197. attr_name_to_values: Dict[str, Set[Any]] = {}
  198. for attr_name in HOMOGENEOUS_ATTR_NAMES:
  199. attr_name_to_values[attr_name] = set()
  200. for fsdp_state in traversal_utils._get_fsdp_states(root_module):
  201. for attr_name in HOMOGENEOUS_ATTR_NAMES:
  202. p_assert(
  203. hasattr(fsdp_state, attr_name),
  204. f"FSDP state missing attribute {attr_name}",
  205. )
  206. attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name))
  207. if fsdp_state is root_state:
  208. continue
  209. handle_sharding_strategy = _get_sharding_strategy(fsdp_state._handles)
  210. if handle_sharding_strategy in (
  211. HandleShardingStrategy.HYBRID_SHARD,
  212. HandleShardingStrategy._HYBRID_SHARD_ZERO2,
  213. ):
  214. # Share the all-reduce state across FSDP units. This is not strictly necessary
  215. # as each one already uses the same process group, but can slightly save memory
  216. # since other FSDP units allreduce state can be garbage collected.
  217. assert inter_node_state is not None, (
  218. "`_validate_and_get_hybrid_shard_state()` should have returned "
  219. "a valid inter-node state if there exists an FSDP instance "
  220. "using a hybrid sharding strategy"
  221. )
  222. fsdp_state._inter_node_state = inter_node_state
  223. # Relax the assert for non-root FSDP instances in case the nested
  224. # initialized module is wrapped again in FSDP later (e.g. after
  225. # training to run inference)
  226. p_assert(
  227. fsdp_state._is_root is None or not fsdp_state._is_root,
  228. "Non-root FSDP instance's `_is_root` should not have been "
  229. "set yet or should have been set to `False`",
  230. )
  231. fsdp_state._is_root = False
  232. fsdp_state._streams = root_state._streams
  233. fsdp_state._stream_to_name = root_state._stream_to_name
  234. fsdp_state._exec_order_data = root_state._exec_order_data
  235. fsdp_state._free_event_queue = root_state._free_event_queue
  236. fsdp_state._handles_prefetched = root_state._handles_prefetched
  237. fsdp_state._needs_pre_backward_unshard = root_state._needs_pre_backward_unshard
  238. for handle in fsdp_state._handles:
  239. handle.init_flat_param_attributes()
  240. for attr_name, attr_values in attr_name_to_values.items():
  241. if len(attr_values) != 1:
  242. raise ValueError(
  243. f"Expects one homogeneous value for {attr_name} but got {attr_values}"
  244. )
  245. @no_type_check
  246. def _init_streams(
  247. state: _FSDPState,
  248. ) -> _FSDPState:
  249. """
  250. Initializes CUDA streams for overlapping communication, computation, and
  251. data transfers. The streams should be shared across FSDP instances.
  252. """
  253. assert state._is_root
  254. assert torch.cuda.is_available()
  255. # Stream for unshard logic, including allocating the all-gather destination
  256. # tensors and the all-gathers themselves.
  257. state._streams["unshard"] = torch.cuda.Stream()
  258. # Stream for overlapping gradient reduction with the backward pass gradient
  259. # computation.
  260. state._streams["post_backward"] = torch.cuda.Stream()
  261. # Stream for pre-unshard logic, namely allocations and writes for CPU
  262. # offloading (H2D copy) and mixed precision (low precision cast).
  263. state._streams["pre_unshard"] = torch.cuda.Stream()
  264. # Default stream for computation
  265. state._streams["default"] = torch.cuda.current_stream()
  266. state._stream_to_name = {
  267. torch.cuda.current_stream(): "default",
  268. state._streams["unshard"]: "unshard",
  269. state._streams["pre_unshard"]: "pre_unshard",
  270. state._streams["post_backward"]: "post_backward",
  271. }
  272. @no_type_check
  273. def _unshard(
  274. state: _FSDPState,
  275. handles: List[FlatParamHandle],
  276. unshard_stream: torch.cuda.Stream,
  277. pre_unshard_stream: torch.cuda.Stream,
  278. ) -> None:
  279. """
  280. Unshards the handles in ``handles``. If the handles are in
  281. :meth:`summon_full_params` and are using mixed precision, then they are
  282. forced to full precision.
  283. Postcondition: Each handle's ``FlatParameter`` 's data is the padded
  284. unsharded flattened parameter on the compute device.
  285. """
  286. if not handles:
  287. return
  288. any_ran_pre_unshard = False
  289. with torch.cuda.stream(pre_unshard_stream):
  290. for handle in handles:
  291. ran_pre_unshard = handle.pre_unshard()
  292. any_ran_pre_unshard = any_ran_pre_unshard or ran_pre_unshard
  293. if any_ran_pre_unshard:
  294. unshard_stream.wait_stream(pre_unshard_stream)
  295. if state.limit_all_gathers:
  296. event = state._free_event_queue.dequeue_if_needed()
  297. if event:
  298. event.synchronize()
  299. with torch.cuda.stream(unshard_stream):
  300. for handle in handles:
  301. handle.unshard()
  302. handle.post_unshard()
  303. @no_type_check
  304. def _reshard(
  305. state: _FSDPState,
  306. handles: List[FlatParamHandle],
  307. free_unsharded_flat_params: List[bool],
  308. ):
  309. """
  310. Reshards the handles in ``handles``. ``free_unsharded_flat_params`` should
  311. have the same length as ``handles``, and each element should give whether
  312. the corresponding handle should free its padded unsharded flattened
  313. parameter.
  314. """
  315. if not handles:
  316. return
  317. p_assert(
  318. len(handles) == len(free_unsharded_flat_params),
  319. "Expects both lists to have equal length but got "
  320. f"{len(handles)} and {len(free_unsharded_flat_params)}",
  321. )
  322. for handle, free_unsharded_flat_param in zip(
  323. handles,
  324. free_unsharded_flat_params,
  325. ):
  326. handle.reshard(free_unsharded_flat_param)
  327. if state.limit_all_gathers and free_unsharded_flat_param:
  328. free_event = torch.cuda.Event()
  329. free_event.record()
  330. state._free_event_queue.enqueue(free_event)
  331. handle.post_reshard()
  332. # Since we prefetch entire handles keys at a time, conservatively mark
  333. # the entire key as no longer prefetched once we free at least one
  334. handles_key = tuple(handles)
  335. if any(free_unsharded_flat_params):
  336. state._handles_prefetched.pop(handles_key, None)
  337. def _unshard_grads(
  338. handles: List[FlatParamHandle],
  339. ) -> None:
  340. for handle in handles:
  341. handle.unshard_grad()
  342. def _reshard_grads(
  343. handles: List[FlatParamHandle],
  344. ) -> None:
  345. for handle in handles:
  346. handle.reshard_grad()
  347. @no_type_check
  348. def _pre_forward(
  349. state: _FSDPState,
  350. handles: List[FlatParamHandle],
  351. unshard_fn: Callable,
  352. module: nn.Module,
  353. args: Tuple[Any, ...],
  354. kwargs: Dict[str, Any],
  355. ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
  356. """
  357. Runs the pre-forward logic. This includes an opportunity to unshard
  358. currently sharded parameters such as those for the current forward and
  359. registering post-backward hooks for these current parameters. This function
  360. also converts forward ``args`` and ``kwargs`` to the given precision.
  361. Args:
  362. handles (List[FlatParamHandle]): Handles giving the parameters used in
  363. the current forward.
  364. unshard_fn (Optional[Callable]): A callable to unshard any currently
  365. sharded parameters or ``None`` to not do any unsharding.
  366. module (nn.Module): Module whose forward this method runs right before;
  367. expected by the hook signature.
  368. args (Tuple[Any, ...]): Module forward ``args``.
  369. kwargs (Dict[str, Any]): Module forward ``kwargs``.
  370. """
  371. state.training_state = TrainingState.FORWARD_BACKWARD
  372. state._exec_order_data.record_pre_forward(handles, module.training)
  373. for handle in handles:
  374. handle._training_state = HandleTrainingState.FORWARD
  375. if unshard_fn is not None:
  376. unshard_fn()
  377. # Register post-backward hooks to reshard the parameters and reduce-scatter
  378. # their gradients. They must be re-registered every forward pass in case
  379. # the `grad_fn` is mutated.
  380. _register_post_backward_hooks(state, handles)
  381. # Recursively convert args and kwargs to specified precision.
  382. input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
  383. if state.mixed_precision.cast_forward_inputs:
  384. args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
  385. return args, kwargs
  386. @no_type_check
  387. def _pre_forward_unshard(
  388. state: _FSDPState,
  389. handles: List[FlatParamHandle],
  390. ) -> None:
  391. """Unshards parameters in the pre-forward."""
  392. if not handles:
  393. return
  394. _unshard(state, handles, state._streams["unshard"], state._streams["pre_unshard"])
  395. handles_key = tuple(handles)
  396. state._needs_pre_forward_unshard[handles_key] = False
  397. torch.cuda.current_stream().wait_stream(state._streams["unshard"])
  398. _prefetch_handles(state, handles_key)
  399. @no_type_check
  400. def _post_forward(
  401. state: _FSDPState,
  402. handles: List[FlatParamHandle],
  403. reshard_fn: Callable,
  404. module: nn.Module,
  405. input: Any,
  406. output: Any,
  407. ) -> Any:
  408. """
  409. Runs the post-forward logic. This includes an opportunity to reshard
  410. currently unsharded parameters such as those used in the current forward
  411. and registering pre-backward hooks on the forward outputs.
  412. Args:
  413. handles (List[FlatParamHandle]): Handles giving the parameters used in
  414. the current forward.
  415. reshard_fn (Optional[Callable]): A callable to reshard any currently
  416. unsharded parameters (e.g. from the current forward) or ``None`` to
  417. not do any resharding.
  418. module (nn.Module): Module whose forward just ran, which should be a
  419. fully sharded module (see [Note: Fully Sharded Module]); expected
  420. by the hook signature.
  421. input (Any): Unused; exepcted by the hook signature.
  422. output (Any): Forward pass output; pre-backward hooks are registered on
  423. the tensors that require gradients in this output.
  424. Postcondition: Each ``FlatParameter`` 's data points to the sharded
  425. flattened parameter.
  426. """
  427. state._exec_order_data.record_post_forward(handles)
  428. if reshard_fn is not None:
  429. reshard_fn()
  430. # Register pre-backward hooks to unshard the flattened parameters
  431. # for the gradient computation (if needed)
  432. output = _register_pre_backward_hooks(state, module, output, handles)
  433. state.training_state = TrainingState.IDLE
  434. for handle in handles:
  435. handle._training_state = HandleTrainingState.IDLE
  436. return output
  437. @no_type_check
  438. def _post_forward_reshard(
  439. state: _FSDPState,
  440. handles: List[FlatParamHandle],
  441. ) -> None:
  442. """Reshards parameters in the post-forward."""
  443. if not handles:
  444. return
  445. # Do not free the root's parameters in the post-forward for `FULL_SHARD`
  446. # with the intention that they are immediately used for backward
  447. # computation (though this may not be true)
  448. free_unsharded_flat_params = [
  449. not state._is_root
  450. and handle._sharding_strategy in RESHARD_AFTER_FORWARD_STRATEGIES
  451. for handle in handles
  452. ]
  453. _reshard(state, handles, free_unsharded_flat_params)
  454. @no_type_check
  455. def _root_pre_forward(
  456. state: _FSDPState,
  457. module: nn.Module,
  458. args,
  459. kwargs,
  460. ) -> None:
  461. """
  462. Runs pre-forward logic specific to the root FSDP instance, which should run
  463. before any individual module's pre-forward. This starts with an attempt at
  464. lazy initialization (which only runs non-vacuously once). Otherwise, if
  465. this is called on a non-root FSDP instance, then it returns directly.
  466. Args:
  467. module (nn.Module): Module for which this logic tries to run. It may or
  468. may not be the root. If not, then this method does not do anything.
  469. """
  470. _lazy_init(state, module)
  471. p_assert(state._is_root is not None, "Expects a root FSDP to have been set")
  472. if not state._is_root:
  473. return args, kwargs
  474. if state.forward_prefetch:
  475. handles_keys = []
  476. if _is_composable(state):
  477. # TODO: This assumes singleton handles keys.
  478. handles_keys = [tuple(handle) for handle in state._handles]
  479. else:
  480. for fsdp_module in traversal_utils._get_fsdp_states(state):
  481. handles_key = tuple(fsdp_module._handles)
  482. handles_keys.append(handles_key)
  483. for handles_key in handles_keys:
  484. state._needs_pre_forward_unshard[handles_key] = True
  485. _wait_for_computation_stream(
  486. torch.cuda.current_stream(),
  487. state._streams["unshard"],
  488. state._streams["pre_unshard"],
  489. )
  490. _clear_grads_if_needed(traversal_utils._get_fsdp_handles(module))
  491. # Prepares the forward inputs by moving them to ``compute_device``
  492. # TODO: Do not use the side stream for tensor copies for now; investigate
  493. # the perf with/without it.
  494. args_tuple, kwargs_tuple = _to_kwargs(
  495. args, kwargs, state.compute_device.index, False
  496. )
  497. args = args_tuple[0]
  498. kwargs = kwargs_tuple[0]
  499. input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
  500. if state.mixed_precision.cast_root_forward_inputs:
  501. args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
  502. return args, kwargs
  503. def _cast_forward_inputs(
  504. input_dtype: Optional[torch.dtype],
  505. *args: Any,
  506. **kwargs: Any,
  507. ) -> Tuple[Any, Any]:
  508. """
  509. Prepares the forward inputs by casting them to ``input_dtype`` if it is not ``None``.
  510. """
  511. # TODO: For mixed precision, cast to reduced-precision in a single `to()` call.
  512. if input_dtype is not None:
  513. args, kwargs = _cast_fp_inputs_to_dtype(input_dtype, *args, **kwargs)
  514. return args, kwargs
  515. def _cast_fp_inputs_to_dtype(
  516. dtype: torch.dtype,
  517. *args: Any,
  518. **kwargs: Any,
  519. ) -> Tuple[Any, Any]:
  520. """
  521. Casts floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
  522. This respects the existing ``requires_grad`` on the tensors.
  523. """
  524. def cast_fn(x: torch.Tensor) -> torch.Tensor:
  525. if not torch.is_floating_point(x) or x.dtype == dtype:
  526. return x
  527. return x.to(dtype)
  528. return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
  529. @no_type_check
  530. def _pre_backward_hook(
  531. state: _FSDPState,
  532. module: nn.Module,
  533. _handles: List[FlatParamHandle],
  534. *unused: Any,
  535. ) -> Any:
  536. """
  537. Prepares ``_handles`` 's ``FlatParameter`` s for gradient computation.
  538. Args:
  539. module (nn.Module): Fully sharded module (see [Note: Fully Sharded
  540. Module]).
  541. """
  542. _handles_key = tuple(_handles) # avoid shadowing `handles_key`
  543. # Only run the pre-backward hook once per group of handles involved in the
  544. # same module forward computation
  545. if _handles_key and state._ran_pre_backward_hook.get(_handles_key, False):
  546. return
  547. with torch.autograd.profiler.record_function(
  548. "FullyShardedDataParallel._pre_backward_hook"
  549. ):
  550. # Queue the post-backward callback once for the root FSDP instance to
  551. # attach it to the outermost backward graph task so that it is called
  552. # after all backward calls complete
  553. if state._is_root and not state._post_backward_callback_queued:
  554. _register_post_backward_final_callback(state, module)
  555. _clear_grads_if_needed(traversal_utils._get_fsdp_handles(module))
  556. elif _handles_key:
  557. allowed_states = [TrainingState.IDLE]
  558. if _is_composable(state):
  559. allowed_states.append(TrainingState.FORWARD_BACKWARD)
  560. _assert_in_training_states(state, allowed_states)
  561. state.training_state = TrainingState.FORWARD_BACKWARD
  562. # Queueing the post-backward callback is the only logic that is not
  563. # per-handle in the pre-backward hook, so we can return early here if
  564. # there are no handles.
  565. if not _handles_key:
  566. return
  567. for handle in _handles:
  568. handle._training_state = HandleTrainingState.BACKWARD_PRE
  569. # If the handles have been prefetched, this `_unshard()` simply
  570. # switches to using the unsharded parameter
  571. _unshard(
  572. state, _handles, state._streams["unshard"], state._streams["pre_unshard"]
  573. )
  574. torch.cuda.current_stream().wait_stream(state._streams["unshard"])
  575. # Set this to `False` to ensure that a mistargeted prefetch does not
  576. # actually unshard these handles
  577. state._needs_pre_backward_unshard[_handles_key] = False
  578. _prefetch_handles(state, _handles_key)
  579. for handle in _handles:
  580. handle.prepare_gradient_for_backward()
  581. state._ran_pre_backward_hook[_handles_key] = True
  582. @no_type_check
  583. @torch.no_grad()
  584. def _post_backward_hook(
  585. state: _FSDPState,
  586. handle: FlatParamHandle,
  587. *unused: Any,
  588. ):
  589. """
  590. Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``.
  591. Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the
  592. unsharded gradient for the local batch.
  593. Postcondition:
  594. - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced
  595. unsharded gradient.
  596. - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded
  597. gradient (accumulating with any existing gradient).
  598. """
  599. flat_param = handle.flat_param
  600. flat_param._post_backward_called = True
  601. with torch.autograd.profiler.record_function(
  602. "FullyShardedDataParallel._post_backward_hook"
  603. ):
  604. _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
  605. # For multiple applications of reentrant AC across submodules sharing
  606. # the same `FlatParameter`, the post-backward hook may run multiple
  607. # times in one backward, in which case we permit the state to already
  608. # be in `BACKWARD_POST`.
  609. p_assert(
  610. handle._training_state
  611. in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST),
  612. f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}",
  613. )
  614. handle._training_state = HandleTrainingState.BACKWARD_POST
  615. if flat_param.grad is None:
  616. return
  617. if flat_param.grad.requires_grad:
  618. raise RuntimeError("FSDP does not support gradients of gradients")
  619. free_unsharded_flat_param = _should_free_in_backward(state, handle)
  620. _reshard(state, [handle], [free_unsharded_flat_param])
  621. # TODO: Post-backward prefetching does not support the multiple handles
  622. # per module case since the post-backward hook runs per handle, not per
  623. # group of handles.
  624. handles_key = (handle,)
  625. _prefetch_handles(state, handles_key)
  626. if not state._sync_gradients:
  627. if handle._use_orig_params:
  628. handle._use_unsharded_grad_views()
  629. return
  630. # Wait for all ops in the current stream (e.g. gradient
  631. # computation) to finish before reduce-scattering the gradient
  632. state._streams["post_backward"].wait_stream(torch.cuda.current_stream())
  633. with torch.cuda.stream(state._streams["post_backward"]):
  634. autograd_computed_grad = flat_param.grad.data
  635. if state._exec_order_data.is_first_iter: # only check once
  636. _check_comm_hook(
  637. state._communication_hook, state._communication_hook_state
  638. )
  639. if (
  640. not _low_precision_hook_enabled(state)
  641. and flat_param.grad.dtype != handle._reduce_dtype
  642. ):
  643. flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype)
  644. if handle.uses_sharded_strategy:
  645. # We clear `.grad` to permit multiple backwards. This avoids a
  646. # race where the second backward pass computation precedes
  647. # ahead of the first backward pass reduction, which is possible
  648. # since the reduction is issued in a separate stream and is
  649. # async and would result in reducing the wrong gradient.
  650. unsharded_grad = flat_param.grad.data
  651. flat_param.grad = None
  652. chunks = list(unsharded_grad.chunk(state.world_size))
  653. numel_to_pad = (
  654. state.world_size * chunks[0].numel() - unsharded_grad.numel()
  655. )
  656. padded_unsharded_grad = (
  657. F.pad(unsharded_grad, [0, numel_to_pad])
  658. if numel_to_pad > 0
  659. else unsharded_grad
  660. )
  661. new_sharded_grad = torch.empty_like(chunks[0]) # padded
  662. state._communication_hook(
  663. state._communication_hook_state,
  664. padded_unsharded_grad,
  665. new_sharded_grad,
  666. )
  667. if handle._sharding_strategy in (
  668. HandleShardingStrategy.HYBRID_SHARD,
  669. HandleShardingStrategy._HYBRID_SHARD_ZERO2,
  670. ):
  671. default_hooks.allreduce_hook(
  672. state=state._inter_node_state,
  673. grad=new_sharded_grad,
  674. )
  675. _cast_grad_to_param_dtype(state, new_sharded_grad, flat_param)
  676. # Save the sharded gradient in `_saved_grad_shard` to support
  677. # gradient accumulation -- for multiple backwards, the gradient
  678. # reductions may happen in arbitrary order
  679. accumulate_grad = hasattr(flat_param, "_saved_grad_shard")
  680. if accumulate_grad:
  681. _check_grad_to_accumulate(
  682. new_sharded_grad, flat_param._saved_grad_shard
  683. )
  684. flat_param._saved_grad_shard += new_sharded_grad
  685. else:
  686. flat_param._saved_grad_shard = new_sharded_grad
  687. grad_to_offload = flat_param._saved_grad_shard
  688. else:
  689. state._communication_hook(
  690. state._communication_hook_state, flat_param.grad
  691. )
  692. # For `NO_SHARD`, we can keep the low precision gradients by
  693. # simply omitting the cast altogether
  694. if not handle._keep_low_precision_grads:
  695. _cast_grad_to_param_dtype(state, flat_param.grad, flat_param)
  696. grad_to_offload = flat_param.grad.data
  697. if handle._offload_params:
  698. # Offload the gradient to CPU to ensure parameters and
  699. # gradients are on the same device as required by the optimizer
  700. # TODO: Investigate why `NO_SHARD` breaks correctness when
  701. # using `non_blocking=True` here.
  702. non_blocking = handle.uses_sharded_strategy
  703. flat_param._cpu_grad.copy_( # type: ignore[attr-defined]
  704. grad_to_offload.detach(), non_blocking=non_blocking
  705. ) # synchronized in the post-backward callback
  706. # Since the gradient being offloaded may have been produced in
  707. # the computation stream and is being consumed here in the
  708. # post-backward stream, inform the caching allocator
  709. _no_dispatch_record_stream(
  710. grad_to_offload.data,
  711. state._streams["post_backward"],
  712. )
  713. # Since the unsharded gradient is produced in the computation
  714. # stream and consumed in the post-backward stream, inform the
  715. # caching allocator (before it goes out of scope)
  716. _no_dispatch_record_stream(
  717. autograd_computed_grad, state._streams["post_backward"]
  718. )
  719. if handle._use_orig_params:
  720. # Since the handle's `FlatParameter` completed its gradient
  721. # computation, we should reset the gradient noneness mask
  722. handle._reset_is_grad_none()
  723. # Delay using sharded gradient views until after the
  724. # reduce-scatter instead of immediately after resharding
  725. handle._use_sharded_grad_views()
  726. @no_type_check
  727. def _should_free_in_backward(
  728. state: _FSDPState,
  729. handle: FlatParamHandle,
  730. ) -> bool:
  731. """
  732. Returns whether FSDP should free the unsharded flattened parameter in the
  733. post-backward or not.
  734. """
  735. # We always free if we are syncing gradients (i.e. not in no_sync) and parameters
  736. # are sharded.
  737. free_unsharded = state._sync_gradients and handle.uses_sharded_strategy
  738. # For NO_SHARD we don't need to free full parameters, for ZeRO-2 strategies, we skip
  739. # freeing in backward.
  740. return free_unsharded or (
  741. handle._sharding_strategy in RESHARD_AFTER_FORWARD_STRATEGIES
  742. )
  743. @no_type_check
  744. def _cast_grad_to_param_dtype(
  745. state: _FSDPState,
  746. sharded_grad: torch.Tensor,
  747. param: FlatParameter,
  748. ):
  749. """
  750. Casts ``sharded_grad`` back to the full parameter dtype so that the
  751. optimizer step runs with that dtype. This performs an actual cast if
  752. 1. parameters were in reduced precision during the forward since then
  753. gradients would be in that reduced precision, or
  754. 2. parameters were not in reduced precision but gradients were in
  755. reduced precision for communication.
  756. However, if a low precision communication hook is registered, then this
  757. dtype cast happens in the hook instead.
  758. """
  759. _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD])
  760. if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype:
  761. low_prec_grad_data = sharded_grad.data
  762. sharded_grad.data = sharded_grad.data.to(dtype=param.dtype)
  763. # Since for `NO_SHARD`, the gradient is produced in the computation
  764. # stream and consumed here in the post-backward stream, inform the
  765. # caching allocator; for the sharded strategies, the gradient is
  766. # produced in the post-backward stream, so this `record_stream()`
  767. # should be a no-op
  768. _no_dispatch_record_stream(low_prec_grad_data, torch.cuda.current_stream())
  769. def _check_comm_hook(
  770. comm_hook: Any,
  771. comm_hook_state: Any,
  772. ) -> None:
  773. p_assert(comm_hook is not None, "Communication hook should not be `None`")
  774. p_assert(
  775. comm_hook_state is not None, "Communication hook state should not be `None`"
  776. )
  777. def _check_grad_to_accumulate(
  778. new_sharded_grad: torch.Tensor,
  779. accumulated_grad: torch.Tensor,
  780. ) -> None:
  781. p_assert(
  782. accumulated_grad.shape == new_sharded_grad.shape,
  783. "Shape mismatch when accumulating gradients: "
  784. f"existing gradient shape={accumulated_grad.shape} "
  785. f"new gradient shape={new_sharded_grad.shape}",
  786. )
  787. p_assert(
  788. accumulated_grad.device == new_sharded_grad.device,
  789. "Device mismatch when accumulating gradients: "
  790. f"existing gradient device={accumulated_grad.device} "
  791. f"new gradient device={new_sharded_grad.device}",
  792. )
  793. @no_type_check
  794. def _low_precision_hook_enabled(state: _FSDPState) -> bool:
  795. return state._communication_hook in LOW_PRECISION_HOOKS
  796. @no_type_check
  797. @torch.no_grad()
  798. def _post_backward_final_callback(
  799. state: _FSDPState,
  800. module: nn.Module,
  801. ):
  802. """
  803. This waits for the post-backward to finish and performs some final cleanup.
  804. This runs at the end of the entire backward pass and should only be called
  805. on the root FSDP instance.
  806. """
  807. p_assert(
  808. state._is_root,
  809. "The post-backward callback should only be called on the root FSDP instance",
  810. )
  811. root_state = state
  812. if root_state._sync_gradients:
  813. torch.cuda.current_stream().wait_stream(root_state._streams["post_backward"])
  814. if root_state.cpu_offload.offload_params:
  815. # Wait for non-blocking GPU -> CPU sharded gradient copies from the
  816. # post-backward hooks to finish explicitly since CPU gradients do
  817. # not automatically synchronize with the GPU
  818. torch.cuda.current_stream().synchronize()
  819. root_state._exec_order_data.next_iter()
  820. for fsdp_state in traversal_utils._get_fsdp_states(module):
  821. _catch_all_reshard(fsdp_state)
  822. _finalize_params(fsdp_state)
  823. fsdp_state._ran_pre_backward_hook.clear()
  824. fsdp_state.training_state = TrainingState.IDLE
  825. for handle in fsdp_state._handles:
  826. handle._training_state = HandleTrainingState.IDLE
  827. fsdp_state._handles_prefetched.clear()
  828. # Reset for cases like one forward and multiple backwards
  829. root_state._post_backward_callback_queued = False
  830. @no_type_check
  831. def _catch_all_reshard(
  832. state: _FSDPState,
  833. ) -> None:
  834. """
  835. Reshards the parameters that may not have been resharded in the
  836. post-backward hook. This can happen when a module's output is used in the
  837. forward pass, meaning that its pre-backward hook runs (unsharding the
  838. parameter), but the post-backward hook does not run because the output was
  839. not jused in the loss computation corresponding to this backward pass.
  840. """
  841. # Wrap with a try-except to provide a more informative traceback if an
  842. # error is raised
  843. try:
  844. free_unsharded_flat_params: List[bool] = []
  845. handles_to_reshard: List[FlatParamHandle] = []
  846. for handle in state._handles:
  847. # TODO: This already-resharded check is brittle:
  848. # https://github.com/pytorch/pytorch/issues/83956
  849. already_resharded = (
  850. handle.flat_param.data_ptr()
  851. == handle.flat_param._local_shard.data_ptr()
  852. )
  853. if already_resharded:
  854. continue
  855. free_unsharded_flat_params.append(_should_free_in_backward(state, handle))
  856. handles_to_reshard.append(handle)
  857. if handles_to_reshard:
  858. _reshard(state, handles_to_reshard, free_unsharded_flat_params)
  859. except Exception as e:
  860. p_assert(
  861. False,
  862. f"Got exception in the catch-all reshard for {state}: {str(e)}",
  863. raise_assertion_error=False,
  864. )
  865. raise e
  866. @no_type_check
  867. def _finalize_params(
  868. state: _FSDPState,
  869. ) -> None:
  870. """Finalizes the parameters before the next iteration."""
  871. for handle in state._handles:
  872. flat_param = handle.flat_param
  873. if flat_param.requires_grad:
  874. if hasattr(flat_param, "_post_backward_hook_state"):
  875. p_assert(
  876. len(flat_param._post_backward_hook_state) == 2,
  877. f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}",
  878. )
  879. flat_param._post_backward_hook_state[1].remove()
  880. delattr(flat_param, "_post_backward_hook_state")
  881. if not state._sync_gradients:
  882. # Preserve the gradient accumulation state if not synchronizing
  883. # gradients: `.grad` remains the unsharded gradient from prior
  884. # `no_sync()` iterations, and `_saved_grad_shard` remains the
  885. # sharded gradient from the last synchronized iteration
  886. continue
  887. handle.prepare_gradient_for_optim()
  888. p_assert(
  889. hasattr(flat_param, "_post_backward_called"),
  890. "Expects `_post_backward_called` to be set on the `FlatParameter`",
  891. )
  892. flat_param._post_backward_called = False
  893. @no_type_check
  894. def _prefetch_handles(
  895. state: _FSDPState,
  896. current_handles_key: _HandlesKey,
  897. ) -> None:
  898. """
  899. Prefetches the next handles if needed (without synchronization). An empty
  900. handles key cannot prefetch.
  901. """
  902. if not current_handles_key:
  903. return
  904. handles_to_prefetch = _get_handles_to_prefetch(state, current_handles_key)
  905. for handles_key in handles_to_prefetch:
  906. # Prefetch the next set of handles without synchronizing to allow
  907. # the sync to happen as late as possible to maximize overlap
  908. _unshard(
  909. state, handles_key, state._streams["unshard"], state._streams["pre_unshard"]
  910. )
  911. state._handles_prefetched[handles_key] = True
  912. @no_type_check
  913. def _get_handles_to_prefetch(
  914. state: _FSDPState,
  915. current_handles_key: _HandlesKey,
  916. ) -> List[_HandlesKey]:
  917. """
  918. Returns a :class:`list` of the handles keys to prefetch for the next
  919. module(s), where ``current_handles_key`` represents the current module.
  920. "Prefetching" refers to running the unshard logic early (without
  921. synchronization), and the "next" modules depend on the recorded execution
  922. order and the current training state.
  923. """
  924. training_state = _get_training_state(current_handles_key)
  925. valid_training_states = (
  926. HandleTrainingState.BACKWARD_PRE,
  927. HandleTrainingState.BACKWARD_POST,
  928. HandleTrainingState.FORWARD,
  929. )
  930. p_assert(
  931. training_state in valid_training_states,
  932. f"Prefetching is only supported in {valid_training_states} but "
  933. f"currently in {training_state}",
  934. )
  935. eod = state._exec_order_data
  936. target_handles_keys: List[_HandlesKey] = []
  937. if (
  938. training_state == HandleTrainingState.BACKWARD_PRE
  939. and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE
  940. ) or (
  941. training_state == HandleTrainingState.BACKWARD_POST
  942. and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST
  943. ):
  944. target_handles_keys = [
  945. target_handles_key
  946. for target_handles_key in eod.get_handles_to_backward_prefetch(
  947. current_handles_key
  948. )
  949. if state._needs_pre_backward_unshard.get(target_handles_key, False)
  950. and not state._handles_prefetched.get(target_handles_key, False)
  951. ]
  952. elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch:
  953. target_handles_keys = [
  954. target_handles_key
  955. for target_handles_key in eod.get_handles_to_forward_prefetch(
  956. current_handles_key
  957. )
  958. if state._needs_pre_forward_unshard.get(target_handles_key, False)
  959. and not state._handles_prefetched.get(target_handles_key, False)
  960. ]
  961. return target_handles_keys
  962. def _get_training_state(
  963. handles_key: _HandlesKey,
  964. ) -> HandleTrainingState:
  965. """Returns the training state of the handles in ``handles_key``."""
  966. p_assert(len(handles_key) > 0, "Expects a non-empty handles key")
  967. training_states = {handle._training_state for handle in handles_key}
  968. p_assert(
  969. len(training_states) == 1,
  970. f"Expects uniform training state but got {training_states}",
  971. )
  972. return next(iter(training_states))
  973. @no_type_check
  974. def _register_pre_forward_hooks(
  975. state: _FSDPState,
  976. modules: Iterable[nn.Module],
  977. ) -> None:
  978. """
  979. Registers pre-forward hooks on all modules in ``modules``. The pre-forward
  980. hooks are partially applied based on the current ``FlatParamHandle``
  981. construction, meaning that they must be re-registered if the construction
  982. changes.
  983. """
  984. for forward_handle in state._pre_forward_handles:
  985. forward_handle.remove()
  986. state._pre_forward_handles.clear()
  987. for module in modules:
  988. module_param_handles = state._fully_sharded_module_to_handles.get(module, [])
  989. if module_param_handles:
  990. unshard_fn = functools.partial(
  991. _pre_forward_unshard,
  992. state,
  993. module_param_handles,
  994. )
  995. hook = functools.partial(
  996. _pre_forward, state, module_param_handles, unshard_fn
  997. )
  998. state._pre_forward_handles.append(
  999. module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
  1000. )
  1001. @no_type_check
  1002. def _register_post_forward_hooks(
  1003. state: _FSDPState,
  1004. modules: Iterable[nn.Module],
  1005. ) -> None:
  1006. """
  1007. Registers post-forward hooks on all modules in ``modules``. The
  1008. post-forward hooks are partially applied based on the current
  1009. ``FlatParamHandle`` construction, meaning that they must be re-registered
  1010. if the construction changes.
  1011. """
  1012. for forward_handle in state._post_forward_handles:
  1013. forward_handle.remove()
  1014. state._post_forward_handles.clear()
  1015. for module in modules:
  1016. module_param_handles = state._fully_sharded_module_to_handles.get(module, [])
  1017. if module_param_handles:
  1018. reshard_fn = functools.partial(
  1019. _post_forward_reshard,
  1020. state,
  1021. module_param_handles,
  1022. )
  1023. hook = functools.partial(
  1024. _post_forward,
  1025. state,
  1026. module_param_handles,
  1027. reshard_fn,
  1028. )
  1029. state._post_forward_handles.append(module.register_forward_hook(hook))
  1030. @no_type_check
  1031. def _register_root_pre_forward_hook(
  1032. state: _FSDPState,
  1033. module: nn.Module,
  1034. ):
  1035. """
  1036. Registers root pre-forward hook on ``module``, which should be the local
  1037. FSDP root.
  1038. NOTE: For the current composable FSDP design, we have each application of
  1039. ``fully_shard()`` to a module to indicate that that module is the local
  1040. FSDP root. We may remove this assumption in the future, in which case we
  1041. will need to register this root pre-forward hook on any candidate module
  1042. that may be the local FSDP root.
  1043. """
  1044. for forward_handle in state._root_pre_forward_handles:
  1045. forward_handle.remove()
  1046. state._root_pre_forward_handles.clear()
  1047. hook = functools.partial(_root_pre_forward, state)
  1048. state._root_pre_forward_handles.append(
  1049. module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True)
  1050. )
  1051. @no_type_check
  1052. def _register_pre_backward_hooks(
  1053. state: _FSDPState,
  1054. module: nn.Module,
  1055. outputs: Any,
  1056. handles: List[FlatParamHandle],
  1057. ) -> None:
  1058. """
  1059. Registers pre-backward hooks on the tensors that require gradients in the
  1060. forward pass outputs ``outputs``, which were computed using the
  1061. ``FlatParameter`` s of ``handles``.
  1062. Args:
  1063. module (nn.Module): Fully sharded module (see [Note: Fully Sharded
  1064. Module]).
  1065. Returns:
  1066. Forward pass outputs with pre-backward hooks registered to tensors that
  1067. require gradients.
  1068. """
  1069. # If there is no gradient computation, then there is no need for
  1070. # pre-backward logic
  1071. if not torch.is_grad_enabled():
  1072. return outputs
  1073. if state._is_root:
  1074. state._post_backward_callback_queued = False # only defined on the root
  1075. handles_key = tuple(handles)
  1076. if handles_key:
  1077. # Since these handles' `FlatParameter`s participated in a forward, we
  1078. # conservatively assume that they will be used in the backward
  1079. state._needs_pre_backward_unshard[handles_key] = False
  1080. state._ran_pre_backward_hook[handles_key] = False
  1081. def _register_hook(t: torch.Tensor) -> torch.Tensor:
  1082. if t.requires_grad:
  1083. t.register_hook(
  1084. functools.partial(_pre_backward_hook, state, module, handles)
  1085. )
  1086. state._needs_pre_backward_unshard[handles_key] = True
  1087. return t
  1088. return _apply_to_tensors(_register_hook, outputs)
  1089. def _register_post_backward_hooks(
  1090. state: _FSDPState,
  1091. handles: List[FlatParamHandle],
  1092. ) -> None:
  1093. """
  1094. Registers post-backward hooks on the ``FlatParameter`` s'
  1095. ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients.
  1096. The ``AccumulateGrad`` object represents the last function that finalizes
  1097. the ``FlatParameter`` 's gradient, so it only runs after its entire
  1098. gradient computation has finished.
  1099. We register the post-backward hook only once in the *first* forward that a
  1100. ``FlatParameter`` participates in. This relies on the ``AccumulateGrad``
  1101. object being preserved through multiple forwards.
  1102. """
  1103. # If there is no gradient computation, then there is no need for
  1104. # post-backward logic
  1105. if not torch.is_grad_enabled():
  1106. return
  1107. for handle in handles:
  1108. flat_param = handle.flat_param
  1109. already_registered = hasattr(flat_param, "_post_backward_hook_state")
  1110. if already_registered or not flat_param.requires_grad:
  1111. continue
  1112. # Get the `AccumulateGrad` object
  1113. temp_flat_param = flat_param.expand_as(flat_param)
  1114. p_assert(
  1115. temp_flat_param.grad_fn is not None,
  1116. "The `grad_fn` is needed to access the `AccumulateGrad` and "
  1117. "register the post-backward hook",
  1118. )
  1119. acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
  1120. assert acc_grad is not None
  1121. hook_handle = acc_grad.register_hook(
  1122. functools.partial(_post_backward_hook, state, handle)
  1123. )
  1124. flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined]
  1125. @no_type_check
  1126. def _register_post_backward_final_callback(
  1127. state: _FSDPState, module: nn.Module
  1128. ) -> None:
  1129. """
  1130. Registers the post-backward final callback that runs at the end of the
  1131. backward pass. This should be called from the root FSDP instance at the
  1132. beginning of the pre-backward.
  1133. """
  1134. p_assert(
  1135. state._is_root,
  1136. "Only the root FSDP instance should register the post-backward callback",
  1137. )
  1138. if state._post_backward_callback_queued:
  1139. return
  1140. _assert_in_training_states(state, [TrainingState.IDLE])
  1141. state._post_backward_callback_queued = True
  1142. Variable._execution_engine.queue_callback(
  1143. functools.partial(_post_backward_final_callback, state, module)
  1144. )
  1145. def _wait_for_computation_stream(
  1146. computation_stream: torch.cuda.Stream,
  1147. unshard_stream: torch.cuda.Stream,
  1148. pre_unshard_stream: torch.cuda.Stream,
  1149. ):
  1150. """
  1151. Has the unshard and pre-unshard streams wait for the computation stream.
  1152. For example, this should be called in the FSDP root's pre-forward to
  1153. respect optimizer step computation.
  1154. """
  1155. unshard_stream.wait_stream(computation_stream)
  1156. # Having the pre-all-gather stream wait for the current stream even if we
  1157. # do not leverage the pre-all-gather stream is tolerable since this only
  1158. # runs once per iteration
  1159. pre_unshard_stream.wait_stream(computation_stream)
  1160. def _clear_grads_if_needed(
  1161. handles: List[FlatParamHandle],
  1162. ):
  1163. """
  1164. Clears the original parameters' gradients if needed. This method's CPU
  1165. overhead is minimal, so we may call it throughout FSDP methods, which serve
  1166. as callsites to free the gradient memory earlier.
  1167. """
  1168. for handle in handles:
  1169. if handle._use_orig_params:
  1170. handle._clear_grads_if_needed()
  1171. @no_type_check
  1172. def _get_buffers_and_dtypes_for_computation(
  1173. state: _FSDPState,
  1174. root_module: nn.Module,
  1175. ) -> Tuple[List[torch.Tensor], List[Optional[torch.dtype]]]:
  1176. """
  1177. Returns all buffers in the module tree rooted at ``root_module`` and a
  1178. corresponding list of the buffer dtypes for computation. Each buffer dtype
  1179. is either ``None`` if buffer mixed precision is not enabled or the buffer
  1180. low precision dtype otherwise.
  1181. """
  1182. p_assert(state._is_root, "Expects the root to cast buffers")
  1183. buffers: List[torch.Tensor] = []
  1184. buffer_dtypes: List[Optional[torch.dtype]] = []
  1185. if _is_composable(state):
  1186. buffers = [
  1187. buffer for module in root_module.modules() for buffer in module.buffers()
  1188. ]
  1189. buffer_dtypes = [
  1190. state.mixed_precision.buffer_dtype for _ in range(len(buffers))
  1191. ]
  1192. else:
  1193. visited_buffers = set()
  1194. # Traverse the FSDP instances bottom-up so that we prefer the owning
  1195. # FSDP instance's mixed precision setting for each buffer
  1196. for fsdp_module in reversed(traversal_utils._get_fsdp_states(root_module)):
  1197. for buffer in fsdp_module.buffers():
  1198. if buffer in visited_buffers:
  1199. continue
  1200. visited_buffers.add(buffer)
  1201. buffers.append(buffer)
  1202. buffer_dtypes.append(fsdp_module.mixed_precision.buffer_dtype)
  1203. assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}"
  1204. return buffers, buffer_dtypes
  1205. @no_type_check
  1206. def _get_buffer_dtypes(
  1207. state: _FSDPState,
  1208. buffer_names: List[str],
  1209. ) -> List[torch.dtype]:
  1210. """
  1211. Returns the original buffer types of the given buffer names.
  1212. """
  1213. buffer_dtypes: List[torch.dtype] = []
  1214. for buffer_name in buffer_names:
  1215. p_assert(
  1216. buffer_name in state._buffer_name_to_orig_dtype,
  1217. f"{buffer_name} is missing from pre-computed dict on rank "
  1218. f"{state.rank}, which only has keys "
  1219. f"{state._buffer_name_to_orig_dtype.keys()}",
  1220. )
  1221. buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name])
  1222. return buffer_dtypes
  1223. def _cast_buffers_to_dtype_and_device(
  1224. buffers: List[torch.Tensor],
  1225. buffer_dtypes: List[Optional[torch.dtype]],
  1226. device: torch.device,
  1227. ) -> None:
  1228. """
  1229. Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them
  1230. to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the
  1231. corresponding buffer is only moved to ``device``.
  1232. """
  1233. p_assert(
  1234. buffer_dtypes is None or len(buffers) == len(buffer_dtypes),
  1235. f"Expects `buffers` and `buffer_dtypes` to have the same length if "
  1236. f"`buffer_dtypes` is specified but got {len(buffers)} and "
  1237. f"{len(buffer_dtypes)}",
  1238. )
  1239. for buffer, buffer_dtype in zip(buffers, buffer_dtypes):
  1240. if not torch.is_floating_point(buffer) or buffer_dtype is None:
  1241. buffer.data = buffer.to(device=device)
  1242. else:
  1243. buffer.data = buffer.to(device=device, dtype=buffer_dtype)