123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- import contextlib
- import warnings
- from typing import cast, Generator, List
- import torch
- import torch.distributed.fsdp._traversal_utils as traversal_utils
- import torch.nn as nn
- from torch.distributed.fsdp._common_utils import (
- _FSDPState,
- _has_fsdp_params,
- _module_handles,
- HandleTrainingState,
- TrainingState,
- )
- from torch.distributed.fsdp._runtime_utils import (
- _clear_grads_if_needed,
- _get_fsdp_root_states_with_modules,
- _lazy_init,
- _reshard,
- _reshard_grads,
- _unshard,
- _unshard_grads,
- )
- from ._utils import p_assert
- from .flat_param import FlatParamHandle
- FLAT_PARAM = "_flat_param"
- @torch.no_grad()
- def _writeback_to_local_shard(
- handles: List[FlatParamHandle],
- writeback_grad: bool,
- ):
- """
- For each handle, writes back the this rank's shard of the unsharded
- flattened parameter to the sharded flattened parameter. If
- ``writeback_grad=True``, then writes back to the sharded gradient as
- well.
- Precondition: Each handle's ``FlatParameter`` 's data points to the
- padded unsharded flattened parameter.
- """
- for handle in handles:
- def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor:
- if handle.uses_sharded_strategy:
- # For sharded strategies, get the *unpadded* shard instead of
- # the *padded* shard to persist user changes to the padding
- # (though FSDP does not explicitly support this)
- shard, _ = FlatParamHandle._get_unpadded_shard(
- flat_param_or_grad,
- handle.rank,
- handle.world_size,
- )
- return shard
- # For `NO_SHARD`, the `flat_param` or its gradient may be modified,
- # so we write it back directly
- return flat_param_or_grad
- param_shard = _get_shard(handle.flat_param)
- handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined]
- if writeback_grad:
- existing_grad = handle.sharded_grad
- if existing_grad is not None:
- assert handle.flat_param.grad is not None
- grad_shard = _get_shard(handle.flat_param.grad)
- existing_grad[: grad_shard.numel()].copy_(grad_shard)
- def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None:
- """
- De-registers the flattened parameter from the wrapped module, hiding it
- from ``nn.Module`` methods.
- We do not use ``del`` because we want ``FLAT_PARAM`` to always be an
- attribute but dynamically change whether it is visible to ``nn.Module``
- methods.
- """
- if _has_fsdp_params(state, module):
- # TODO: figure out the case for the composable APIs.
- cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None)
- def _register_flat_param(state: _FSDPState, module: nn.Module) -> None:
- """
- Registers the flattened parameter to the wrapped module, making it
- visible to ``nn.Module`` methods.
- We do not use :meth:`nn.Module.register_parameter` because we want
- ``FLAT_PARAM`` to always be an attribute but dynamically change whether
- it is visible to ``nn.Module`` methods.
- """
- handles = _module_handles(state, module)
- if _has_fsdp_params(state, module):
- # TODO: figure out the case for the composable APIs.
- cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handles[0].flat_param
- @contextlib.contextmanager
- def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator:
- """
- Assumes that the flattened parameter is unsharded. When in the context,
- de-registers the flattened parameter and unflattens the original
- parameters as ``nn.Parameter`` views into the flattened parameter.
- After the context, re-registers the flattened parameter and restores
- the original parameters as ``Tensor`` views into the flattened
- parameter.
- """
- handles = _module_handles(state, module)
- if not handles:
- yield
- else:
- _deregister_flat_param(state, module)
- try:
- with handles[0].unflatten_as_params():
- yield
- finally:
- if not handles[0]._use_orig_params:
- _register_flat_param(state, module)
- def _validate_unshard_params_args(
- state: _FSDPState,
- writeback: bool,
- rank0_only: bool,
- offload_to_cpu: bool,
- with_grads: bool,
- ) -> None:
- if with_grads and (offload_to_cpu or not state._use_orig_params):
- raise NotImplementedError(
- f"with_grads={with_grads}, "
- f"use_orig_params={state._use_orig_params}, "
- f"offload_to_cpu={offload_to_cpu} "
- f"is not supported yet"
- )
- if offload_to_cpu and any(
- not handle.uses_sharded_strategy for handle in state._handles
- ):
- raise NotImplementedError(
- "offload_to_cpu=True and NO_SHARD is not supported yet"
- )
- if writeback and rank0_only:
- # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to
- # persist the changes.
- raise NotImplementedError(
- "writeback=True and rank0_only=True is not supported yet"
- )
- if offload_to_cpu and not rank0_only:
- warnings.warn(
- "offload_to_cpu=True and rank0_only=False may result in the"
- "unsharded parameters being redundantly copied to CPU memory for "
- "GPUs sharing the same CPU memory, which risks CPU OOM. We "
- "recommend using offload_to_cpu=True with rank0_only=True."
- )
- @contextlib.contextmanager
- def _unshard_fsdp_state_params(
- module: nn.Module,
- state: _FSDPState,
- writeback: bool,
- rank0_only: bool,
- offload_to_cpu: bool,
- with_grads: bool,
- ):
- """
- This unshards the parameters for a single FSDP state ``state`` that
- corresponds to ``module``.
- """
- _validate_unshard_params_args(
- state, writeback, rank0_only, offload_to_cpu, with_grads
- )
- torch.cuda.synchronize()
- # If handles are shared by other module(s), the handle may be already unsharded.
- handles = [
- handle
- for handle in _module_handles(state, module)
- if handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS
- ]
- if not handles:
- yield
- return
- for handle in handles:
- assert (
- handle._training_state == HandleTrainingState.IDLE
- ), f"Expects the handle training to be IDLE but got {handle._training_state}"
- for handle in handles:
- handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS
- _clear_grads_if_needed(handles)
- free_unsharded_flat_params = [handle.needs_unshard() for handle in handles]
- # No need to call `wait_stream()` since we unshard in the computation
- # stream directly
- computation_stream = torch.cuda.current_stream()
- _unshard(state, handles, computation_stream, computation_stream)
- if with_grads:
- _unshard_grads(handles)
- if rank0_only and state.rank != 0:
- # Free the unsharded flattened parameter early
- _reshard(state, handles, free_unsharded_flat_params)
- if with_grads:
- _reshard_grads(handles)
- try:
- yield
- finally:
- for handle in handles:
- handle._training_state = HandleTrainingState.IDLE
- else:
- # Unflatten the unsharded flattened parameters
- with contextlib.ExitStack() as stack:
- # Invariant: rank == 0 or !rank0_only
- for handle in handles:
- if offload_to_cpu and handle.uses_sharded_strategy:
- stack.enter_context(handle.to_cpu())
- # NOTE: Since PyTorch enforces that a parameter and its
- # gradients need to match metadata (e.g. device), we must
- # move gradients to CPU *after* we move parameters.
- # NOTE: This assumes 1 `FlatParameter`
- if not state._use_orig_params:
- stack.enter_context(_unflatten_as_params(state, module))
- try:
- yield
- finally:
- stack.close()
- if writeback:
- _writeback_to_local_shard(handles, with_grads)
- _reshard(state, handles, free_unsharded_flat_params)
- if with_grads:
- _reshard_grads(handles)
- for handle in handles:
- handle._training_state = HandleTrainingState.IDLE
- @contextlib.contextmanager
- def _unshard_params_recurse(
- module: nn.Module,
- state: _FSDPState,
- recurse: bool,
- writeback: bool,
- rank0_only: bool,
- offload_to_cpu: bool,
- with_grads: bool,
- ):
- """
- This is a helper for :func:`_unshard_params` that recursively calls
- :func:`_unshard_fsdp_state_params` on FSDP states if ``recurse=True``.
- NOTE: This runs lazy initialization.
- """
- _validate_unshard_params_args(
- state, writeback, rank0_only, offload_to_cpu, with_grads
- )
- if recurse:
- with contextlib.ExitStack() as stack:
- # TODO (awgu): The traversal function does not traverse through
- # incompatible composable APIs. Verify if this is the desired
- # behavior for this function.
- for state, fsdp_module in zip(
- *traversal_utils._get_fsdp_states_with_modules(module)
- ):
- stack.enter_context(
- _unshard_params_recurse(
- module=fsdp_module,
- state=state,
- recurse=False,
- writeback=writeback,
- rank0_only=rank0_only,
- offload_to_cpu=offload_to_cpu,
- with_grads=with_grads,
- )
- )
- yield
- return
- _lazy_init(state, module)
- if state.training_state == TrainingState.FORWARD_BACKWARD:
- raise AssertionError(
- "Cannot manually unshard parameters during forward/backward"
- )
- elif state.training_state == TrainingState.SUMMON_FULL_PARAMS:
- raise AssertionError(
- "Cannot manually unshard parameters when already unsharding parameters"
- )
- with _unshard_fsdp_state_params(
- module=module,
- state=state,
- writeback=writeback,
- rank0_only=rank0_only,
- offload_to_cpu=offload_to_cpu,
- with_grads=with_grads,
- ):
- try:
- state.training_state = TrainingState.SUMMON_FULL_PARAMS
- yield
- finally:
- state.training_state = TrainingState.IDLE
- @contextlib.contextmanager
- def _unshard_params(
- module: nn.Module,
- recurse: bool,
- writeback: bool,
- rank0_only: bool,
- offload_to_cpu: bool,
- with_grads: bool,
- ):
- """
- This unshards FSDP-managed parameters for all modules with FSDP applied in
- the module tree rooted at ``module``.
- """
- root_fsdp_states, root_fsdp_modules = _get_fsdp_root_states_with_modules(module)
- with contextlib.ExitStack() as stack:
- for root_fsdp_state, root_fsdp_module in zip(
- root_fsdp_states, root_fsdp_modules
- ):
- stack.enter_context(
- _unshard_params_recurse(
- module=root_fsdp_module,
- state=root_fsdp_state,
- recurse=recurse,
- writeback=writeback,
- rank0_only=rank0_only,
- offload_to_cpu=offload_to_cpu,
- with_grads=with_grads,
- )
- )
- yield
- return
- def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None:
- """
- Deregisters the original parameters; registers the ``FlatParameter``.
- """
- handles = _module_handles(state, module)
- p_assert(
- len(handles) <= 1,
- "Expects <=1 handle per FSDP instance; needs to be refactored "
- "for >1 handle (e.g. non-recursive wrapping)",
- )
- if not handles:
- return
- handle = handles[0]
- p_assert(
- handle._use_orig_params,
- f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} "
- f"handle: {handle._use_orig_params}",
- )
- handle._deregister_orig_params()
- _register_flat_param(state, module)
- def _register_orig_params(state: _FSDPState, module: nn.Module) -> None:
- """
- Deregisters the ``FlatParameter``; registers the original parameters.
- """
- handles = _module_handles(state, module)
- if not handles:
- return
- handle = handles[0]
- _deregister_flat_param(state, module)
- if handle.is_sharded(handle.flat_param):
- handle._use_sharded_views()
- handle._use_sharded_grad_views()
- else:
- handle._use_unsharded_views(as_params=True)
|