123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- """
- This file includes private common utilities for FSDP.
- """
- import traceback
- import warnings
- from enum import auto, Enum
- from typing import (
- Callable,
- Dict,
- Generator,
- Iterable,
- List,
- no_type_check,
- Optional,
- Set,
- )
- import torch
- import torch.distributed as dist
- import torch.distributed.fsdp.flat_param as flat_param_file
- import torch.nn as nn
- from torch.distributed._composable_state import _get_module_state, _State
- from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
- _CHECKPOINT_PREFIX,
- )
- from .api import (
- FullOptimStateDictConfig,
- FullStateDictConfig,
- OptimStateDictConfig,
- ShardingStrategy,
- StateDictConfig,
- StateDictType,
- )
- FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
- FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
- FSDP_FLATTENED = "_fsdp_flattened"
- class _FSDPState(_State):
- def __init__(self) -> None:
- # TODO: Move all the attributes to this class to enable typing for
- # FSDP/fully_shard.
- self._ignored_modules: Set[nn.Module] = set()
- self._ignored_params: Set[nn.Parameter] = set()
- self.process_group: Optional[dist.ProcessGroup] = None
- self.rank: int = -1
- self.world_size: int = -1
- self.sharding_strategy = ShardingStrategy.FULL_SHARD
- self._use_orig_params: bool = False
- self.training_state = TrainingState.IDLE
- self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
- self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
- self._state_dict_config: StateDictConfig = FullStateDictConfig()
- self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
- self._is_root: Optional[bool] = None
- self._handles: List[flat_param_file.FlatParamHandle] = []
- self._fully_sharded_module_to_handles: Dict[
- nn.Module, flat_param_file.FlatParamHandle
- ] = {}
- self.compute_device = torch.device("cuda", torch.cuda.current_device())
- def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
- state = _get_module_state(module)
- if state is None or not isinstance(state, _FSDPState):
- return None
- return state
- def _get_module_fsdp_state_if_fully_sharded_module(
- module: nn.Module,
- ) -> Optional[_FSDPState]:
- state = _get_module_fsdp_state(module)
- if state is None:
- return None
- if state == module: # FullyShardedDataParallel module case.
- return state
- if module in state._fully_sharded_module_to_handles: # fully_shard case.
- return state
- return None
- class TrainingState(Enum):
- """
- An enum that indicates the state of a ``FullyShardedDataParallel` instance.
- """
- IDLE = auto()
- FORWARD_BACKWARD = auto()
- SUMMON_FULL_PARAMS = auto()
- class HandleTrainingState(Enum):
- """
- An enum that indicates the state of a ``FlatParamHandle`.
- """
- IDLE = auto()
- FORWARD = auto()
- BACKWARD_PRE = auto()
- BACKWARD_POST = auto()
- SUMMON_FULL_PARAMS = auto()
- def _is_composable(state: _FSDPState):
- # TODO: This is a temporary hack for differentiate between code paths.
- return not isinstance(state, nn.Module)
- @no_type_check
- def _module_handles(state: _FSDPState, module: nn.Module) -> List:
- """
- Returns the ``FlatParamHandle`` s corresponding to ``module``. These are
- the handles that contain some parameter in ``module``.
- """
- if _is_composable(state):
- assert (
- module in state._fully_sharded_module_to_handles
- ), f"Expects a `comm_module` but got {module} on rank {state.rank}"
- return state._fully_sharded_module_to_handles[module][:]
- else:
- # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
- return module._handles[:]
- @no_type_check
- def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
- """Returns if ``module`` has parameters managed by FSDP."""
- return len(_module_handles(state, module)) > 0
- def _get_sharding_strategy(handles: Iterable):
- """
- Returns the sharding strategy of the group of handles given by ``handles``
- or ``None`` if ``handles`` is empty. The input should be the handles
- corresponding to one module, so we enforce that they all share the same
- sharding strategy.
- """
- sharding_strategy = None
- for handle in handles:
- if sharding_strategy is None:
- sharding_strategy = handle._sharding_strategy
- elif (
- sharding_strategy is not None
- and sharding_strategy != handle._sharding_strategy
- ):
- raise AssertionError(
- "Expects each group of handles to have the same sharding "
- f"strategy but got {sharding_strategy} and {handle._sharding_strategy}"
- )
- return sharding_strategy
- def clean_tensor_name(tensor_name: str) -> str:
- """
- Cleans the parameter or buffer name by removing any module wrapper
- prefixes.
- """
- tensor_name = tensor_name.replace(FSDP_PREFIX, "")
- # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
- # it couples `CheckpointWrapper` and FSDP and also does not scale for more
- # module wrappers.
- tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
- return tensor_name
- def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
- """
- Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
- avoid re-flattening it during nested construction.
- """
- setattr(tensor, FSDP_FLATTENED, True)
- def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
- """Returns if ``tensor`` has been marked as flattened by FSDP."""
- return getattr(tensor, FSDP_FLATTENED, False)
- def _get_param_to_fqns(
- model: torch.nn.Module,
- dedup_shared_params: bool = True,
- ) -> Dict[nn.Parameter, List[str]]:
- """
- Constructs a mapping from parameter to a list of its FQNs. Each normal
- parameter maps to a singleton list containing its FQN, while each
- ``FlatParameter`` maps to a list of its original parameter FQNs, which may
- have length greater than one. All FQNs are prefixed starting from
- ``model``.
- Args:
- model (torch.nn.Module): Root module (which may or may not be a
- :class:`FullyShardedDataParallel` instance).
- dedup_shared_params (bool): For shared parameters, if ``True``, only
- includes the FQNs corresponding to the first encounter of the
- shared parameter in the module traversal; if ``False``, then
- includes the FQNs across all encounters. (Default: ``True``)
- """
- def module_fn(module, prefix, param_to_fqns):
- for param_name, param in module.named_parameters(recurse=False):
- local_fqns = (
- param._fqns
- if type(param) is flat_param_file.FlatParameter
- else [param_name]
- ) # prefixed from `module`
- global_fqns = [
- clean_tensor_name(prefix + name) for name in local_fqns
- ] # prefixed from the top level `model` (i.e. including `prefix`)
- is_shared_param = param in param_to_fqns
- if not is_shared_param:
- param_to_fqns[param] = global_fqns
- else:
- if type(param) is flat_param_file.FlatParameter:
- # DMP overwrites `named_parameters` and skip (advance to
- # the next child module) the wrapped_module (e.g.,
- # _dmp_wrapped_module and _fsdp_wrapped_module). When a user
- # calls `named_child` to traverse the module recursively and
- # calls `named_parameters` with `recurse=False`, parameters
- # will be traversed more than once.
- # This hack is specificed designed for DMP + FSDP. We
- # overwite the flat_parameters traversal result to only obtain
- # the last one, which happens to be the correct one.
- #
- # TODO: Remove this hack once DMP + FSDP is not supported.
- warnings.warn(
- "FlatParameter is being traversed more than once. "
- "This case should only happen when using "
- "DistributedModelParallel with FullyShardedDataParallel."
- )
- param_to_fqns[param] = global_fqns
- elif not dedup_shared_params:
- param_to_fqns[param].extend(global_fqns)
- def return_fn(param_to_fqns):
- return param_to_fqns
- param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
- return _apply_to_modules(
- model,
- module_fn,
- return_fn,
- [key for key, _ in model.named_parameters()],
- param_to_unflat_param_names,
- )
- def _apply_to_modules(
- root_module: torch.nn.Module,
- module_fn: Callable,
- return_fn: Callable,
- filter_fqns: Optional[List[str]] = None,
- *args,
- **kwargs,
- ):
- """
- Performs a pre-order traversal of the modules in the hierarchy rooted at
- ``root_module``, applying ``module_fn`` at each module and finally
- returning a value using ``return_fn``. The traversal constructs the full
- module prefix name (e.g. "module.submodule." just like in model state dict)
- and makes that available to ``module_fn``.
- ``filter_fqns`` is used because some module may have its own prefix similar
- to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
- to remove the prefix.
- """
- def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
- # Call the module function before recursing over children (pre-order)
- module_fn(module, prefix, *args, **kwargs)
- for submodule_name, submodule in module.named_children():
- if submodule is None:
- continue
- new_prefix = prefix + submodule_name + "."
- if filter_fqns is not None:
- for fqn in filter_fqns:
- if fqn.startswith(new_prefix):
- break
- else:
- # DMP's named_parameter() will mess up the traversal with
- # ``named_children`` + `named_parameter(recurse=False)``.
- # This hack is a must to make the travsersal work.
- # TODO: Remove this hack once DMP + FSDP is not supported.
- if (
- submodule_name == "_fsdp_wrapped_module"
- or submodule_name == "_dmp_wrapped_module"
- ):
- warnings.warn(
- "An unexpected prefix is detected. This case "
- " should only happen when using DMP with FSDP. "
- f"prefix = {prefix}, "
- f"submodule_name = {submodule_name}"
- )
- new_prefix = prefix
- f(submodule, new_prefix, *args, **kwargs)
- f(root_module, "", *args, **kwargs)
- return return_fn(*args, **kwargs)
- @no_type_check
- def _assert_in_training_states(
- state: _FSDPState,
- training_states: List[TrainingState],
- ) -> None:
- """Asserts that FSDP is in the states ``_training_states``."""
- # Raise a `ValueError` instead of using `assert` to ensure that these
- # logical assertions run even if `assert`s are disabled
- if state.training_state not in training_states:
- msg = (
- f"expected to be in states {training_states} but current state is "
- f"{state.training_state}"
- )
- # Print the error on rank 0 in case this is called in the backward pass
- if state.rank == 0:
- if isinstance(state, nn.Module):
- print(f"Asserting FSDP instance is: {state}")
- print(f"ERROR: {msg}")
- traceback.print_stack()
- raise ValueError(msg)
- def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
- """
- Returns:
- Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
- parent-less) with respect to the modules in the set itself. In other
- words, these are the modules in ``modules`` that are not the child of
- any other module in ``modules``.
- """
- root_modules: Set[nn.Module] = set()
- module_to_submodules = {module: set(module.modules()) for module in modules}
- for candidate_module in modules:
- is_root_module = True
- for module, submodules in module_to_submodules.items():
- is_child_module = (
- candidate_module is not module and candidate_module in submodules
- )
- if is_child_module:
- is_root_module = False
- break
- if is_root_module:
- root_modules.add(candidate_module)
- return root_modules
|