_unshard_param_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import contextlib
  2. import warnings
  3. from typing import cast, Generator, List
  4. import torch
  5. import torch.distributed.fsdp._traversal_utils as traversal_utils
  6. import torch.nn as nn
  7. from torch.distributed.fsdp._common_utils import (
  8. _FSDPState,
  9. _has_fsdp_params,
  10. _module_handles,
  11. HandleTrainingState,
  12. TrainingState,
  13. )
  14. from torch.distributed.fsdp._runtime_utils import (
  15. _clear_grads_if_needed,
  16. _get_fsdp_root_states_with_modules,
  17. _lazy_init,
  18. _reshard,
  19. _reshard_grads,
  20. _unshard,
  21. _unshard_grads,
  22. )
  23. from ._utils import p_assert
  24. from .flat_param import FlatParamHandle
  25. FLAT_PARAM = "_flat_param"
  26. @torch.no_grad()
  27. def _writeback_to_local_shard(
  28. handles: List[FlatParamHandle],
  29. writeback_grad: bool,
  30. ):
  31. """
  32. For each handle, writes back the this rank's shard of the unsharded
  33. flattened parameter to the sharded flattened parameter. If
  34. ``writeback_grad=True``, then writes back to the sharded gradient as
  35. well.
  36. Precondition: Each handle's ``FlatParameter`` 's data points to the
  37. padded unsharded flattened parameter.
  38. """
  39. for handle in handles:
  40. def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
  41. if handle.uses_sharded_strategy:
  42. # For sharded strategies, get the *unpadded* shard instead of
  43. # the *padded* shard to persist user changes to the padding
  44. # (though FSDP does not explicitly support this)
  45. shard, _ = FlatParamHandle._get_unpadded_shard(
  46. flat_param_or_grad,
  47. handle.rank,
  48. handle.world_size,
  49. )
  50. return shard
  51. # For `NO_SHARD`, the `flat_param` or its gradient may be modified,
  52. # so we write it back directly
  53. return flat_param_or_grad
  54. param_shard = _get_shard(handle.flat_param)
  55. handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined]
  56. if writeback_grad:
  57. existing_grad = handle.sharded_grad
  58. if existing_grad is not None:
  59. assert handle.flat_param.grad is not None
  60. grad_shard = _get_shard(handle.flat_param.grad)
  61. existing_grad[: grad_shard.numel()].copy_(grad_shard)
  62. def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
  63. """
  64. De-registers the flattened parameter from the wrapped module, hiding it
  65. from ``nn.Module`` methods.
  66. We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
  67. attribute but dynamically change whether it is visible to ``nn.Module``
  68. methods.
  69. """
  70. if _has_fsdp_params(state, module):
  71. # TODO: figure out the case for the composable APIs.
  72. cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
  73. def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
  74. """
  75. Registers the flattened parameter to the wrapped module, making it
  76. visible to ``nn.Module`` methods.
  77. We do not use :meth:`nn.Module.register_parameter` because we want
  78. ``FLAT_PARAM`` to always be an attribute but dynamically change whether
  79. it is visible to ``nn.Module`` methods.
  80. """
  81. handles = _module_handles(state, module)
  82. if _has_fsdp_params(state, module):
  83. # TODO: figure out the case for the composable APIs.
  84. cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handles[0].flat_param
  85. @contextlib.contextmanager
  86. def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
  87. """
  88. Assumes that the flattened parameter is unsharded. When in the context,
  89. de-registers the flattened parameter and unflattens the original
  90. parameters as ``nn.Parameter`` views into the flattened parameter.
  91. After the context, re-registers the flattened parameter and restores
  92. the original parameters as ``Tensor`` views into the flattened
  93. parameter.
  94. """
  95. handles = _module_handles(state, module)
  96. if not handles:
  97. yield
  98. else:
  99. _deregister_flat_param(state, module)
  100. try:
  101. with handles[0].unflatten_as_params():
  102. yield
  103. finally:
  104. if not handles[0]._use_orig_params:
  105. _register_flat_param(state, module)
  106. def _validate_unshard_params_args(
  107. state: _FSDPState,
  108. writeback: bool,
  109. rank0_only: bool,
  110. offload_to_cpu: bool,
  111. with_grads: bool,
  112. ) -> None:
  113. if with_grads and (offload_to_cpu or not state._use_orig_params):
  114. raise NotImplementedError(
  115. f"with_grads={with_grads}, "
  116. f"use_orig_params={state._use_orig_params}, "
  117. f"offload_to_cpu={offload_to_cpu} "
  118. f"is not supported yet"
  119. )
  120. if offload_to_cpu and any(
  121. not handle.uses_sharded_strategy for handle in state._handles
  122. ):
  123. raise NotImplementedError(
  124. "offload_to_cpu=True and NO_SHARD is not supported yet"
  125. )
  126. if writeback and rank0_only:
  127. # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
  128. # persist the changes.
  129. raise NotImplementedError(
  130. "writeback=True and rank0_only=True is not supported yet"
  131. )
  132. if offload_to_cpu and not rank0_only:
  133. warnings.warn(
  134. "offload_to_cpu=True and rank0_only=False may result in the"
  135. "unsharded parameters being redundantly copied to CPU memory for "
  136. "GPUs sharing the same CPU memory, which risks CPU OOM. We "
  137. "recommend using offload_to_cpu=True with rank0_only=True."
  138. )
  139. @contextlib.contextmanager
  140. def _unshard_fsdp_state_params(
  141. module: nn.Module,
  142. state: _FSDPState,
  143. writeback: bool,
  144. rank0_only: bool,
  145. offload_to_cpu: bool,
  146. with_grads: bool,
  147. ):
  148. """
  149. This unshards the parameters for a single FSDP state ``state`` that
  150. corresponds to ``module``.
  151. """
  152. _validate_unshard_params_args(
  153. state, writeback, rank0_only, offload_to_cpu, with_grads
  154. )
  155. torch.cuda.synchronize()
  156. # If handles are shared by other module(s), the handle may be already unsharded.
  157. handles = [
  158. handle
  159. for handle in _module_handles(state, module)
  160. if handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
  161. ]
  162. if not handles:
  163. yield
  164. return
  165. for handle in handles:
  166. assert (
  167. handle._training_state == HandleTrainingState.IDLE
  168. ), f"Expects the handle training to be IDLE but got {handle._training_state}"
  169. for handle in handles:
  170. handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
  171. _clear_grads_if_needed(handles)
  172. free_unsharded_flat_params = [handle.needs_unshard() for handle in handles]
  173. # No need to call `wait_stream()` since we unshard in the computation
  174. # stream directly
  175. computation_stream = torch.cuda.current_stream()
  176. _unshard(state, handles, computation_stream, computation_stream)
  177. if with_grads:
  178. _unshard_grads(handles)
  179. if rank0_only and state.rank != 0:
  180. # Free the unsharded flattened parameter early
  181. _reshard(state, handles, free_unsharded_flat_params)
  182. if with_grads:
  183. _reshard_grads(handles)
  184. try:
  185. yield
  186. finally:
  187. for handle in handles:
  188. handle._training_state = HandleTrainingState.IDLE
  189. else:
  190. # Unflatten the unsharded flattened parameters
  191. with contextlib.ExitStack() as stack:
  192. # Invariant: rank == 0 or !rank0_only
  193. for handle in handles:
  194. if offload_to_cpu and handle.uses_sharded_strategy:
  195. stack.enter_context(handle.to_cpu())
  196. # NOTE: Since PyTorch enforces that a parameter and its
  197. # gradients need to match metadata (e.g. device), we must
  198. # move gradients to CPU *after* we move parameters.
  199. # NOTE: This assumes 1 `FlatParameter`
  200. if not state._use_orig_params:
  201. stack.enter_context(_unflatten_as_params(state, module))
  202. try:
  203. yield
  204. finally:
  205. stack.close()
  206. if writeback:
  207. _writeback_to_local_shard(handles, with_grads)
  208. _reshard(state, handles, free_unsharded_flat_params)
  209. if with_grads:
  210. _reshard_grads(handles)
  211. for handle in handles:
  212. handle._training_state = HandleTrainingState.IDLE
  213. @contextlib.contextmanager
  214. def _unshard_params_recurse(
  215. module: nn.Module,
  216. state: _FSDPState,
  217. recurse: bool,
  218. writeback: bool,
  219. rank0_only: bool,
  220. offload_to_cpu: bool,
  221. with_grads: bool,
  222. ):
  223. """
  224. This is a helper for :func:`_unshard_params` that recursively calls
  225. :func:`_unshard_fsdp_state_params` on FSDP states if ``recurse=True``.
  226. NOTE: This runs lazy initialization.
  227. """
  228. _validate_unshard_params_args(
  229. state, writeback, rank0_only, offload_to_cpu, with_grads
  230. )
  231. if recurse:
  232. with contextlib.ExitStack() as stack:
  233. # TODO (awgu): The traversal function does not traverse through
  234. # incompatible composable APIs. Verify if this is the desired
  235. # behavior for this function.
  236. for state, fsdp_module in zip(
  237. *traversal_utils._get_fsdp_states_with_modules(module)
  238. ):
  239. stack.enter_context(
  240. _unshard_params_recurse(
  241. module=fsdp_module,
  242. state=state,
  243. recurse=False,
  244. writeback=writeback,
  245. rank0_only=rank0_only,
  246. offload_to_cpu=offload_to_cpu,
  247. with_grads=with_grads,
  248. )
  249. )
  250. yield
  251. return
  252. _lazy_init(state, module)
  253. if state.training_state == TrainingState.FORWARD_BACKWARD:
  254. raise AssertionError(
  255. "Cannot manually unshard parameters during forward/backward"
  256. )
  257. elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
  258. raise AssertionError(
  259. "Cannot manually unshard parameters when already unsharding parameters"
  260. )
  261. with _unshard_fsdp_state_params(
  262. module=module,
  263. state=state,
  264. writeback=writeback,
  265. rank0_only=rank0_only,
  266. offload_to_cpu=offload_to_cpu,
  267. with_grads=with_grads,
  268. ):
  269. try:
  270. state.training_state = TrainingState.SUMMON_FULL_PARAMS
  271. yield
  272. finally:
  273. state.training_state = TrainingState.IDLE
  274. @contextlib.contextmanager
  275. def _unshard_params(
  276. module: nn.Module,
  277. recurse: bool,
  278. writeback: bool,
  279. rank0_only: bool,
  280. offload_to_cpu: bool,
  281. with_grads: bool,
  282. ):
  283. """
  284. This unshards FSDP-managed parameters for all modules with FSDP applied in
  285. the module tree rooted at ``module``.
  286. """
  287. root_fsdp_states, root_fsdp_modules = _get_fsdp_root_states_with_modules(module)
  288. with contextlib.ExitStack() as stack:
  289. for root_fsdp_state, root_fsdp_module in zip(
  290. root_fsdp_states, root_fsdp_modules
  291. ):
  292. stack.enter_context(
  293. _unshard_params_recurse(
  294. module=root_fsdp_module,
  295. state=root_fsdp_state,
  296. recurse=recurse,
  297. writeback=writeback,
  298. rank0_only=rank0_only,
  299. offload_to_cpu=offload_to_cpu,
  300. with_grads=with_grads,
  301. )
  302. )
  303. yield
  304. return
  305. def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
  306. """
  307. Deregisters the original parameters; registers the ``FlatParameter``.
  308. """
  309. handles = _module_handles(state, module)
  310. p_assert(
  311. len(handles) <= 1,
  312. "Expects <=1 handle per FSDP instance; needs to be refactored "
  313. "for >1 handle (e.g. non-recursive wrapping)",
  314. )
  315. if not handles:
  316. return
  317. handle = handles[0]
  318. p_assert(
  319. handle._use_orig_params,
  320. f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
  321. f"handle: {handle._use_orig_params}",
  322. )
  323. handle._deregister_orig_params()
  324. _register_flat_param(state, module)
  325. def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
  326. """
  327. Deregisters the ``FlatParameter``; registers the original parameters.
  328. """
  329. handles = _module_handles(state, module)
  330. if not handles:
  331. return
  332. handle = handles[0]
  333. _deregister_flat_param(state, module)
  334. if handle.is_sharded(handle.flat_param):
  335. handle._use_sharded_views()
  336. handle._use_sharded_grad_views()
  337. else:
  338. handle._use_unsharded_views(as_params=True)