123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- import contextlib
- import functools
- from abc import ABC, abstractmethod
- from typing import Any, Callable, cast, Dict, Generator, Optional, Set, Tuple, Type
- import torch.nn as nn
- from torch.nn.modules.batchnorm import _BatchNorm
- __all__ = [
- "always_wrap_policy",
- "lambda_auto_wrap_policy",
- "transformer_auto_wrap_policy",
- "size_based_auto_wrap_policy",
- "enable_wrap",
- "wrap",
- "ModuleWrapPolicy",
- ]
- def always_wrap_policy(*args, **kwargs) -> bool:
- """
- A simple recursive wrap policy that always returns ``True``. This means
- that every submodule is wrapped by the wrapper class in
- :func:`_recursive_wrap`.
- """
- return True
- class _FSDPPolicy(ABC):
- """
- This defines an abstract base class that represents an FSDP policy for
- constructing ``FlatParameter`` s.
- """
-
-
- def __init__(self):
- ...
- @property
- @abstractmethod
- def policy(self) -> Callable:
- ...
- def _module_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- module_classes: Set[Type[nn.Module]],
- ) -> bool:
- """
- This auto wrap policy wraps every module that is an instance of any type in
- ``module_classes`` as its own FSDP instance. The root module given by
- ``module`` is always wrapped as an FSDP instance regardless. Since the
- wrapping proceeds bottom up, each FSDP instance manages the parameters in
- its subtree excluding any already managed by a child FSDP instance.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- module_classes (Set[Type[nn.Module]]): Set of module classes that are
- wrapped as FSDP instances.
- Returns:
- ``True`` if ``recurse=True``, and whether ``module`` should be wrapped
- if ``recurse=False``.
- """
- if recurse:
- return True
- return isinstance(module, tuple(module_classes))
- class ModuleWrapPolicy(_FSDPPolicy):
- """This is a wrapper around :func:`_module_wrap_policy`."""
- def __init__(self, module_classes: Set[Type[nn.Module]]):
- self._policy: Callable = functools.partial(
- _module_wrap_policy,
- module_classes=module_classes,
- )
- self._module_classes_str = str(module_classes)
- @property
- def policy(self):
- return self._policy
- def __repr__(self) -> str:
- return super().__repr__() + f"({self._module_classes_str})"
- def lambda_auto_wrap_policy(
- module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable
- ) -> bool:
- """
- A convenient auto wrap policy to wrap submodules based on an arbitrary user
- function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as
- a `wrapper_cls` unit.
- Return if a module should be wrapped during auto wrapping.
- The first three parameters are required by :func:`_recursive_wrap`.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then
- this module will be wrapped.
- """
- if recurse:
- return True
- return lambda_fn(module)
- def transformer_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- transformer_layer_cls: Set[Type[nn.Module]],
- ) -> bool:
- """
- See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the
- same as ``module_classes``. Note that shared parameters must be wrapped in
- the same FSDP instance, so this auto wrap policy can help wrap shared
- embeddings into the same FSDP instance for transformer models.
- """
- return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls)
- def _wrap_batchnorm_individually(
- module: nn.Module,
- recurse: bool,
- *args,
- **kwargs,
- ) -> bool:
- """
- A policy that wraps ``BatchNorm`` instances in their own FSDP instance.
- """
- if recurse:
-
- return True
- else:
-
-
- return isinstance(module, _BatchNorm)
- def _or_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
- policies,
- ) -> bool:
- """
- A policy that wraps ``module`` if any policy in the passed in iterable of
- ``policies`` returns ``True``.
- """
- return any(policy(module, recurse, nonwrapped_numel) for policy in policies)
- def size_based_auto_wrap_policy(
- module: nn.Module,
- recurse: bool,
- nonwrapped_numel: int,
-
- min_num_params: int = int(1e8),
- force_leaf_modules: Optional[Set[Type[nn.Module]]] = None,
- exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None,
- ) -> bool:
- """
- A size-based auto wrap policy.
- Args:
- module (nn.Module): Current module being considered.
- recurse (bool): If ``False``, then this function must decide whether
- ``module`` should be wrapped as an FSDP instance or not. If
- ``True``, then the function is still recursing down the module
- tree as a part of the DFS.
- nonwrapped_numel (int): Parameter numel not yet wrapped.
- min_num_params (int): Customizable policy input that controls the size
- threshold over which a module is ready to be wrapped. This is in
- units of numel.
- force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep
- as leaves, i.e. their children will never be wrapped.
- exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be
- excluded in wrapping.
- Returns:
- Whether ``module`` should be wrapped.
- """
- force_leaf_modules = (
- size_based_auto_wrap_policy.FORCE_LEAF_MODULES
- if force_leaf_modules is None
- else force_leaf_modules
- )
- exclude_wrap_modules = (
- size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES
- if exclude_wrap_modules is None
- else exclude_wrap_modules
- )
-
-
- min_nonwrapped_numel = min_num_params
- is_large = nonwrapped_numel >= min_nonwrapped_numel
- if recurse:
-
- return is_large and not isinstance(module, tuple(force_leaf_modules))
- else:
-
- return is_large and not isinstance(module, tuple(exclude_wrap_modules))
- size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict}
- size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention}
- @contextlib.contextmanager
- def enable_wrap(
- *, wrapper_cls: Any, **wrapper_kwargs: Any
- ) -> Generator[None, None, None]:
- """
- Context manager to wrap modules using a wrapper.
- Useful for when you'd like to apply the same configuration arguments to all
- child modules that you wrap. A particularly important use case is wrapping
- large layers so that they get sharded (in-place) during initialization, to
- avoid running out of system memory. Large layers can indicate that they
- should be sharded via the ``wrap`` annotation and this context manager can
- provide the exact configuration for these nested instances.
- Usage::
- with enable_wrap(wrapper_cls, **params):
- # Wraps layer in FSDP by default if within context
- self.l1 = wrap(torch.nn.Linear(5, 5))
- Args:
- wrapper_cls:
- Class that `wrap` annotation will `wrap` modules with, such as
- `FullyShardedDataParallel`.
- **wrapper_kwargs:
- Configuration settings that will be passed to all ``wrap``
- instances inside the context
- """
- kwargs = {
- **{"wrapper_cls": wrapper_cls},
- **wrapper_kwargs,
- }
- with _ConfigAutoWrap(**kwargs):
- yield
- def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module:
- """
- Annotate that a module should be wrapped. Annotated modules will only be
- wrapped if inside of an :func:`enable_wrap` context manager. This allows
- a module to be initialized both with and without a wrapper without code
- change.
- The class that this function wraps the passed in ``nn.Module`` with is the
- passed in ``wrapper_cls`` argument into ``enable_wrap``. Both
- ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct
- the ``wrapper_cls`` instance. In the case of duplicate kwargs in
- ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be
- respected.
- Usage::
- with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
- # Wraps layer in FSDP by default if within context
- self.l1 = wrap(torch.nn.Linear(5, 5))
- Args:
- module (nn.Module): module to wrap (if in :func:`enable_wrap` context)
- **wrap_overrides: configuration overrides that will take priority over
- the values provided by the :func:`enable_wrap` context
- """
- if _ConfigAutoWrap.in_autowrap_context:
- assert _ConfigAutoWrap.wrapper_cls is not None
- wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides}
- return _wrap(
- module,
- _ConfigAutoWrap.wrapper_cls,
- **wrap_overrides,
- )
- return module
- def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module:
- assert wrapper_cls is not None
- if hasattr(module, "_wrap_overrides"):
-
-
-
-
- overrides = {**kwargs, **module._wrap_overrides}
- return wrapper_cls(module, **overrides)
- return wrapper_cls(module, **kwargs)
- def _recursive_wrap(
- module: nn.Module,
- auto_wrap_policy: Callable,
- wrapper_cls: Callable,
- ignored_modules: Set[nn.Module],
- ignored_params: Set[nn.Parameter],
- only_wrap_children: bool = False,
- **kwargs: Any,
- ) -> Tuple[nn.Module, int]:
- """
- Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns
- ``True`` with ``wrapper_cls``.
- Args:
- module (nn.Module): Module to recursively wrap.
- auto_wrap_policy (Callable): A callable representing a policy that
- determines which modules to recursively wrap with ``wrapper_cls``.
- ignored_modules (Set[torch.nn.Module]): Modules to ignore when
- wrapping.
- ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when
- wrapping; these should be the parameters contained in the modules
- in ``ignored_modules``.
- Returns:
- (nn.Module, int):
- ``module`` after wrapping and the numel recursively wrapped.
- """
- assert auto_wrap_policy is not None, "Must specify auto_wrap_policy."
- assert wrapper_cls is not None, "Must specify wrapper_cls"
-
- for _, child in module.named_modules():
- if child in ignored_modules:
- continue
- try:
- assert not isinstance(child, cast(type, wrapper_cls))
- except TypeError:
-
- pass
-
- nonwrapped_numel = sum(
- p.numel() for p in module.parameters() if p not in ignored_params
- )
- assert auto_wrap_policy is not None
- if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel):
- total_wrapped_numel = 0
-
- for name, child in module.named_children():
- if child in ignored_modules:
- continue
- wrapped_child, num_wrapped_params = _recursive_wrap(
- module=child,
- auto_wrap_policy=auto_wrap_policy,
- wrapper_cls=wrapper_cls,
- ignored_modules=ignored_modules,
- ignored_params=ignored_params,
- **kwargs,
- )
- setattr(module, name, wrapped_child)
-
- total_wrapped_numel += num_wrapped_params
-
-
- remainder = nonwrapped_numel - total_wrapped_numel
- if not only_wrap_children and auto_wrap_policy(
- module=module, recurse=False, nonwrapped_numel=remainder
- ):
-
- return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel
- else:
- return module, total_wrapped_numel
- return module, 0
- class _ConfigAutoWrap:
- """
- Helper class to wrap modules based on default config args via a context manager.
- See :func:`enable_wrap` for more information.
- """
- in_autowrap_context: bool = False
- wrapper_cls: Optional[Callable] = None
- kwargs: Dict[str, Any] = {}
- def __init__(self, **kwargs: Dict[str, Any]):
- self.kwargs = kwargs
- @staticmethod
- def enable_autowrap_context(kwargs: Any) -> None:
- if _ConfigAutoWrap.in_autowrap_context:
- raise NotImplementedError(
- "You are already within an autowrap context and we currently do not supported nested autowrap."
- )
- _ConfigAutoWrap.in_autowrap_context = True
-
- assert (
- "wrapper_cls" in kwargs.keys()
- ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap."
- _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"])
- del kwargs["wrapper_cls"]
-
- _ConfigAutoWrap.kwargs = kwargs
- @staticmethod
- def disable_autowrap_context() -> None:
- _ConfigAutoWrap.in_autowrap_context = False
- _ConfigAutoWrap.wrapper_cls = None
- _ConfigAutoWrap.kwargs = {}
- def __enter__(self) -> None:
- self.enable_autowrap_context(self.kwargs)
- def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
- self.disable_autowrap_context()
|