123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the BSD license found in the
- # LICENSE file in the root directory of this source tree.
- import contextlib
- import functools
- from abc import ABC, abstractmethod
- from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type
- import torch.nn as nn
- from torch.nn.modules.batchnorm import _BatchNorm
- __all__ = [
- "always_wrap_policy",
- "lambda_auto_wrap_policy",
- "transformer_auto_wrap_policy",
- "size_based_auto_wrap_policy",
- "enable_wrap",
- "wrap",
- "ModuleWrapPolicy",
- ]
- def always_wrap_policy(*args, **kwargs) -> bool:
- """
- A simple recursive wrap policy that always returns ``True``. This means
- that every submodule is wrapped by the wrapper class in
- :func:`_recursive_wrap`.
- """
- return True
- class _FSDPPolicy(ABC):
- """
- This defines an abstract base class that represents an FSDP policy for
- constructing ``FlatParameter`` s.
- """
- # The motivation for this abstract base class is to hide the interface
- # expected by `_recursive_wrap()` from users (i.e. the `recurse` argument).
- def __init__(self):
- ...
- @property
- @abstractmethod
- def policy(self) -> Callable:
- ...
- def _module_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- module_classes: Set[Type[nn.Module]],
- ) -> bool:
- """
- This auto wrap policy wraps every module that is an instance of any type in
- ``module_classes`` as its own FSDP instance. The root module given by
- ``module`` is always wrapped as an FSDP instance regardless. Since the
- wrapping proceeds bottom up, each FSDP instance manages the parameters in
- its subtree excluding any already managed by a child FSDP instance.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- module_classes (Set[Type[nn.Module]]): Set of module classes that are
- wrapped as FSDP instances.
- Returns:
- ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
- if ``recurse=False``.
- """
- if recurse:
- return True # always recurse
- return isinstance(module, tuple(module_classes))
- class ModuleWrapPolicy(_FSDPPolicy):
- """This is a wrapper around :func:`_module_wrap_policy`."""
- def __init__(self, module_classes: Set[Type[nn.Module]]):
- self._policy: Callable = functools.partial(
- _module_wrap_policy,
- module_classes=module_classes,
- )
- self._module_classes_str = str(module_classes)
- @property
- def policy(self):
- return self._policy
- def __repr__(self) -> str:
- return super().__repr__() + f"({self._module_classes_str})"
- def lambda_auto_wrap_policy(
- module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
- ) -> bool:
- """
- A convenient auto wrap policy to wrap submodules based on an arbitrary user
- function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
- a `wrapper_cls` unit.
- Return if a module should be wrapped during auto wrapping.
- The first three parameters are required by :func:`_recursive_wrap`.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
- this module will be wrapped.
- """
- if recurse:
- return True # always recurse
- return lambda_fn(module)
- def transformer_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- transformer_layer_cls: Set[Type[nn.Module]],
- ) -> bool:
- """
- See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
- same as ``module_classes``. Note that shared parameters must be wrapped in
- the same FSDP instance, so this auto wrap policy can help wrap shared
- embeddings into the same FSDP instance for transformer models.
- """
- return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
- def _wrap_batchnorm_individually(
- module: nn.Module,
- recurse: bool,
- *args,
- **kwargs,
- ) -> bool:
- """
- A policy that wraps ``BatchNorm`` instances in their own FSDP instance.
- """
- if recurse:
- # always recurse
- return True
- else:
- # if not recursing, decide whether we should wrap based on whether it is a
- # BN layer or not.
- return isinstance(module, _BatchNorm)
- def _or_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- policies,
- ) -> bool:
- """
- A policy that wraps ``module`` if any policy in the passed in iterable of
- ``policies`` returns ``True``.
- """
- return any(policy(module, recurse, nonwrapped_numel) for policy in policies)
- def size_based_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- # Additional custom arguments
- min_num_params: int = int(1e8),
- force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
- exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
- ) -> bool:
- """
- A size-based auto wrap policy.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- min_num_params (int): Customizable policy input that controls the size
- threshold over which a module is ready to be wrapped. This is in
- units of numel.
- force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
- as leaves, i.e. their children will never be wrapped.
- exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
- excluded in wrapping.
- Returns:
- Whether ``module`` should be wrapped.
- """
- force_leaf_modules = (
- size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined]
- if force_leaf_modules is None
- else force_leaf_modules
- )
- exclude_wrap_modules = (
- size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined]
- if exclude_wrap_modules is None
- else exclude_wrap_modules
- )
- # Keep the argument `min_num_params` for BC for now, but it represents the
- # minimum non-wrapped *numel* before triggering a wrapping
- min_nonwrapped_numel = min_num_params
- is_large = nonwrapped_numel >= min_nonwrapped_numel
- if recurse:
- # We should recurse if the module is big enough but not in force_leaf_modules list.
- return is_large and not isinstance(module, tuple(force_leaf_modules))
- else:
- # If we are not recursing, determine if we should wrap.
- return is_large and not isinstance(module, tuple(exclude_wrap_modules))
- # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported.
- size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined]
- size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined]
- @contextlib.contextmanager
- def enable_wrap(
- *, wrapper_cls: Any, **wrapper_kwargs: Any
- ) -> Generator[None, None, None]:
- """
- Context manager to wrap modules using a wrapper.
- Useful for when you'd like to apply the same configuration arguments to all
- child modules that you wrap. A particularly important use case is wrapping
- large layers so that they get sharded (in-place) during initialization, to
- avoid running out of system memory. Large layers can indicate that they
- should be sharded via the ``wrap`` annotation and this context manager can
- provide the exact configuration for these nested instances.
- Usage::
- with enable_wrap(wrapper_cls, **params):
- # Wraps layer in FSDP by default if within context
- self.l1 = wrap(torch.nn.Linear(5, 5))
- Args:
- wrapper_cls:
- Class that `wrap` annotation will `wrap` modules with, such as
- `FullyShardedDataParallel`.
- **wrapper_kwargs:
- Configuration settings that will be passed to all ``wrap``
- instances inside the context
- """
- kwargs = {
- **{"wrapper_cls": wrapper_cls},
- **wrapper_kwargs,
- }
- with _ConfigAutoWrap(**kwargs):
- yield
- def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
- """
- Annotate that a module should be wrapped. Annotated modules will only be
- wrapped if inside of an :func:`enable_wrap` context manager. This allows
- a module to be initialized both with and without a wrapper without code
- change.
- The class that this function wraps the passed in ``nn.Module`` with is the
- passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
- ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
- the ``wrapper_cls`` instance. In the case of duplicate kwargs in
- ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
- respected.
- Usage::
- with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
- # Wraps layer in FSDP by default if within context
- self.l1 = wrap(torch.nn.Linear(5, 5))
- Args:
- module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
- **wrap_overrides: configuration overrides that will take priority over
- the values provided by the :func:`enable_wrap` context
- """
- if _ConfigAutoWrap.in_autowrap_context:
- assert _ConfigAutoWrap.wrapper_cls is not None
- wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
- return _wrap(
- module,
- _ConfigAutoWrap.wrapper_cls,
- **wrap_overrides,
- )
- return module
- def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
- assert wrapper_cls is not None
- if hasattr(module, "_wrap_overrides"):
- # If module has a _wrap_overrides attribute, we force overriding the
- # FSDP config with these attributes for this module. Currently this
- # is only used to disable mixed precision for BatchNorm when
- # auto_wrapping.
- overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type]
- return wrapper_cls(module, **overrides)
- return wrapper_cls(module, **kwargs)
- def _recursive_wrap(
- module: nn.Module,
- auto_wrap_policy: Callable,
- wrapper_cls: Callable,
- ignored_modules: Set[nn.Module],
- ignored_params: Set[nn.Parameter],
- only_wrap_children: bool = False,
- **kwargs: Any,
- ) -> Tuple[nn.Module, int]:
- """
- Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
- ``True`` with ``wrapper_cls``.
- Args:
- module (nn.Module): Module to recursively wrap.
- auto_wrap_policy (Callable): A callable representing a policy that
- determines which modules to recursively wrap with ``wrapper_cls``.
- ignored_modules (Set[torch.nn.Module]): Modules to ignore when
- wrapping.
- ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
- wrapping; these should be the parameters contained in the modules
- in ``ignored_modules``.
- Returns:
- (nn.Module, int):
- ``module`` after wrapping and the numel recursively wrapped.
- """
- assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
- assert wrapper_cls is not None, "Must specify wrapper_cls"
- # Make sure no child is already wrapped.
- for _, child in module.named_modules():
- if child in ignored_modules:
- continue
- try:
- assert not isinstance(child, cast(type, wrapper_cls))
- except TypeError:
- # wrapper_cls is a function as opposed to a class type, just bypass above check.
- pass
- # We count all params, assuming none of them are already wrapped.
- nonwrapped_numel = sum(
- p.numel() for p in module.parameters() if p not in ignored_params
- )
- assert auto_wrap_policy is not None
- if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
- total_wrapped_numel = 0
- # Iterate through the children, recursively wrap if necessary
- for name, child in module.named_children():
- if child in ignored_modules:
- continue
- wrapped_child, num_wrapped_params = _recursive_wrap(
- module=child,
- auto_wrap_policy=auto_wrap_policy,
- wrapper_cls=wrapper_cls,
- ignored_modules=ignored_modules,
- ignored_params=ignored_params,
- **kwargs,
- )
- setattr(module, name, wrapped_child)
- # Keep track of how many parameters have been wrapped
- total_wrapped_numel += num_wrapped_params
- # decide if we need to wrap the current module,
- # since the left over parameters exceed the number of params to wrap
- remainder = nonwrapped_numel - total_wrapped_numel
- if not only_wrap_children and auto_wrap_policy(
- module=module, recurse=False, nonwrapped_numel=remainder
- ):
- # Leaf node or final wrapping of the remainder both happen here.
- return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
- else:
- return module, total_wrapped_numel
- return module, 0
- class _ConfigAutoWrap:
- """
- Helper class to wrap modules based on default config args via a context manager.
- See :func:`enable_wrap` for more information.
- """
- in_autowrap_context: bool = False # Context flag
- wrapper_cls: Optional[Callable] = None # The wrapper class
- kwargs: Dict[str, Any] = {} # Wrapper's args
- def __init__(self, **kwargs: Dict[str, Any]):
- self.kwargs = kwargs
- @staticmethod
- def enable_autowrap_context(kwargs: Any) -> None:
- if _ConfigAutoWrap.in_autowrap_context:
- raise NotImplementedError(
- "You are already within an autowrap context and we currently do not supported nested autowrap."
- )
- _ConfigAutoWrap.in_autowrap_context = True
- # Get and save the wrapper cls for the context.
- assert (
- "wrapper_cls" in kwargs.keys()
- ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
- _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
- del kwargs["wrapper_cls"]
- # Save the rest.
- _ConfigAutoWrap.kwargs = kwargs
- @staticmethod
- def disable_autowrap_context() -> None:
- _ConfigAutoWrap.in_autowrap_context = False
- _ConfigAutoWrap.wrapper_cls = None
- _ConfigAutoWrap.kwargs = {}
- def __enter__(self) -> None:
- self.enable_autowrap_context(self.kwargs)
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
- self.disable_autowrap_context()
|