123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- import uuid
- from collections import OrderedDict
- from functools import wraps
- from typing import Callable, Dict, List, Optional, Type
- import torch.nn as nn
- from torch.distributed._composable_state import _State
- # use state_slot as key for module.__dict__ to avoid coliding with other
- # properties.
- # TODO: since all composable distributed features can share the same slot.
- class _StateKey(str):
- # Make _StateKey as str to satify the assumption that object.__dict__.keys()
- # are strings.
- def __new__(cls, string="__composable_api_state_key"):
- return super().__new__(cls, f"{string}_{str(uuid.uuid4())}")
- STATE_KEY = _StateKey()
- REGISTRY_KEY = _StateKey()
- # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
- # we can add args and kwargs here, and then we can detect whether fully_shard
- # is combined with reentrant activation checkpointing and error out with a clear
- # message.
- class RegistryItem:
- pass
- def contract(state_cls: Type[_State] = _State):
- r"""
- Decorate a function as a composable distributed API, where the first
- argument of the function must be an :class:`nn.Module` instance. The
- decorator verifies that the wrapped function does not modify parameter,
- buffer or sub-module fully-qualified names (FQN).
- When a function ``func`` is decorated by ``@contract()``, a
- ``.state(module: nn.Module)`` method will be installed to the decorated
- function. Then you can retrieve and modify the state on a module by calling
- ``func.state(module)``.
- Example::
- >>> # xdoctest: +SKIP
- >>> import torch.nn as nn
- >>>
- >>> class MyModel(nn.Module):
- >>> def __init__(self):
- >>> super().__init__()
- >>> self.l1 = nn.Linear(10, 10)
- >>> self.l2 = nn.Linear(10, 10)
- >>>
- >>> def forward(self, x):
- >>> return self.l2(self.l1(x))
- >>>
- >>> @contract()
- >>> def my_feature(module: nn.Module) -> nn.Module:
- >>> my_feature.state(module).some_state = "any value"
- >>> return module
- >>>
- >>> model = MyModel()
- >>> my_feature(model.l1)
- >>> assert my_feature.state(model.l1).some_state == "any value"
- >>> my_feature(model.l2)
- >>> model(torch.randn(2, 10)).sum().backward()
- """
- # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
- @wraps(state_cls)
- def inner(func):
- @wraps(func)
- def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
- # get existing global states
- default_all_state: Dict[Callable, _State] = OrderedDict()
- all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
- STATE_KEY, default_all_state
- )
- assert isinstance(
- all_state, dict
- ), "Distributed composable API states corrupted"
- # get global registry
- default_registry: Dict[str, RegistryItem] = OrderedDict()
- registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
- REGISTRY_KEY, default_registry
- )
- assert isinstance(
- registry, dict
- ), "Distributed composable API registry corrupted"
- # make sure the API func has not been applied to the input module yet.
- assert func not in all_state and func.__name__ not in registry, (
- "Each distinct composable distributed API can only be applied to a "
- f"module once. {func.__name__} has already been applied to the "
- f"following module.\n{module}"
- )
- # install states specific to the wrapped ``func``
- all_state.setdefault(func, state_cls())
- # register ``func`` in the global registry by name
- registry.setdefault(func.__name__, RegistryItem())
- orig_named_params = OrderedDict(module.named_parameters())
- orig_named_buffers = OrderedDict(
- module.named_buffers(remove_duplicate=False)
- )
- orig_named_modules = OrderedDict(
- module.named_modules(remove_duplicate=False)
- )
- updated = func(module, *args, **kwargs)
- if updated is None:
- updated = module
- new_named_params = OrderedDict(updated.named_parameters())
- new_named_buffers = OrderedDict(
- updated.named_buffers(remove_duplicate=False)
- )
- new_named_modules = OrderedDict(
- updated.named_modules(remove_duplicate=False)
- )
- assert isinstance(updated, nn.Module), (
- "Output of composable distributed APIs must be either None or "
- f"nn.Module, but got {type(updated)}"
- )
- def check_fqn(orig_fqns: List[str], new_fqns: List[str]):
- if orig_fqns == new_fqns:
- return
- orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
- orig_only = orig_fqn_set - new_fqn_set
- new_only = new_fqn_set - orig_fqn_set
- if len(orig_only) or len(new_only):
- raise RuntimeError(
- "Composable distributed API implementations cannot modify "
- "FQNs.\n"
- f"Only in original FQNs: {orig_only},\n"
- f"Only in new FQNs: {new_only}"
- )
- else:
- raise RuntimeError(
- "Composable distributed API implementations cannot modify "
- "the order of FQNs.\n"
- f"Original FQNs: {orig_only}\n"
- f"New FQNs: {new_only}"
- )
- check_fqn(list(orig_named_params.keys()), list(new_named_params.keys()))
- check_fqn(list(orig_named_buffers.keys()), list(new_named_buffers.keys()))
- check_fqn(list(orig_named_modules.keys()), list(new_named_modules.keys()))
- # TODO: a stricter verification should also reject changing module
- # types and monkey-patching forward() method implementations.
- # TODO: verify that installed distributed paradigms are compatible with
- # each other.
- return updated
- def get_state(module: nn.Module) -> Optional[_State]:
- return module.__dict__.setdefault( # type: ignore[call-overload]
- STATE_KEY,
- {}, # TODO(@yhcharles): this is a temporary fix, need a better way
- ).get(
- func
- ) # type: ignore[call-overload]
- wrapper.state = get_state # type: ignore[attr-defined]
- return wrapper
- return inner
- def _get_registry(module: nn.Module) -> Dict[str, RegistryItem]:
- r"""
- Get an ``OrderedDict`` of composable APIs that have been applied to the
- ``module``, indexed by the API name.
- """
- default_registry: Dict[str, RegistryItem] = OrderedDict()
- return module.__dict__.setdefault(REGISTRY_KEY, default_registry) # type: ignore[call-overload]
|