contract.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import uuid
  2. from collections import OrderedDict
  3. from functools import wraps
  4. from typing import Callable, Dict, List, Optional, Type
  5. import torch.nn as nn
  6. from torch.distributed._composable_state import _State
  7. # use state_slot as key for module.__dict__ to avoid coliding with other
  8. # properties.
  9. # TODO: since all composable distributed features can share the same slot.
  10. class _StateKey(str):
  11. # Make _StateKey as str to satify the assumption that object.__dict__.keys()
  12. # are strings.
  13. def __new__(cls, string="__composable_api_state_key"):
  14. return super().__new__(cls, f"{string}_{str(uuid.uuid4())}")
  15. STATE_KEY = _StateKey()
  16. REGISTRY_KEY = _StateKey()
  17. # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
  18. # we can add args and kwargs here, and then we can detect whether fully_shard
  19. # is combined with reentrant activation checkpointing and error out with a clear
  20. # message.
  21. class RegistryItem:
  22. pass
  23. def contract(state_cls: Type[_State] = _State):
  24. r"""
  25. Decorate a function as a composable distributed API, where the first
  26. argument of the function must be an :class:`nn.Module` instance. The
  27. decorator verifies that the wrapped function does not modify parameter,
  28. buffer or sub-module fully-qualified names (FQN).
  29. When a function ``func`` is decorated by ``@contract()``, a
  30. ``.state(module: nn.Module)`` method will be installed to the decorated
  31. function. Then you can retrieve and modify the state on a module by calling
  32. ``func.state(module)``.
  33. Example::
  34. >>> # xdoctest: +SKIP
  35. >>> import torch.nn as nn
  36. >>>
  37. >>> class MyModel(nn.Module):
  38. >>> def __init__(self):
  39. >>> super().__init__()
  40. >>> self.l1 = nn.Linear(10, 10)
  41. >>> self.l2 = nn.Linear(10, 10)
  42. >>>
  43. >>> def forward(self, x):
  44. >>> return self.l2(self.l1(x))
  45. >>>
  46. >>> @contract()
  47. >>> def my_feature(module: nn.Module) -> nn.Module:
  48. >>> my_feature.state(module).some_state = "any value"
  49. >>> return module
  50. >>>
  51. >>> model = MyModel()
  52. >>> my_feature(model.l1)
  53. >>> assert my_feature.state(model.l1).some_state == "any value"
  54. >>> my_feature(model.l2)
  55. >>> model(torch.randn(2, 10)).sum().backward()
  56. """
  57. # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
  58. @wraps(state_cls)
  59. def inner(func):
  60. @wraps(func)
  61. def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
  62. # get existing global states
  63. default_all_state: Dict[Callable, _State] = OrderedDict()
  64. all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
  65. STATE_KEY, default_all_state
  66. )
  67. assert isinstance(
  68. all_state, dict
  69. ), "Distributed composable API states corrupted"
  70. # get global registry
  71. default_registry: Dict[str, RegistryItem] = OrderedDict()
  72. registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
  73. REGISTRY_KEY, default_registry
  74. )
  75. assert isinstance(
  76. registry, dict
  77. ), "Distributed composable API registry corrupted"
  78. # make sure the API func has not been applied to the input module yet.
  79. assert func not in all_state and func.__name__ not in registry, (
  80. "Each distinct composable distributed API can only be applied to a "
  81. f"module once. {func.__name__} has already been applied to the "
  82. f"following module.\n{module}"
  83. )
  84. # install states specific to the wrapped ``func``
  85. all_state.setdefault(func, state_cls())
  86. # register ``func`` in the global registry by name
  87. registry.setdefault(func.__name__, RegistryItem())
  88. orig_named_params = OrderedDict(module.named_parameters())
  89. orig_named_buffers = OrderedDict(
  90. module.named_buffers(remove_duplicate=False)
  91. )
  92. orig_named_modules = OrderedDict(
  93. module.named_modules(remove_duplicate=False)
  94. )
  95. updated = func(module, *args, **kwargs)
  96. if updated is None:
  97. updated = module
  98. new_named_params = OrderedDict(updated.named_parameters())
  99. new_named_buffers = OrderedDict(
  100. updated.named_buffers(remove_duplicate=False)
  101. )
  102. new_named_modules = OrderedDict(
  103. updated.named_modules(remove_duplicate=False)
  104. )
  105. assert isinstance(updated, nn.Module), (
  106. "Output of composable distributed APIs must be either None or "
  107. f"nn.Module, but got {type(updated)}"
  108. )
  109. def check_fqn(orig_fqns: List[str], new_fqns: List[str]):
  110. if orig_fqns == new_fqns:
  111. return
  112. orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
  113. orig_only = orig_fqn_set - new_fqn_set
  114. new_only = new_fqn_set - orig_fqn_set
  115. if len(orig_only) or len(new_only):
  116. raise RuntimeError(
  117. "Composable distributed API implementations cannot modify "
  118. "FQNs.\n"
  119. f"Only in original FQNs: {orig_only},\n"
  120. f"Only in new FQNs: {new_only}"
  121. )
  122. else:
  123. raise RuntimeError(
  124. "Composable distributed API implementations cannot modify "
  125. "the order of FQNs.\n"
  126. f"Original FQNs: {orig_only}\n"
  127. f"New FQNs: {new_only}"
  128. )
  129. check_fqn(list(orig_named_params.keys()), list(new_named_params.keys()))
  130. check_fqn(list(orig_named_buffers.keys()), list(new_named_buffers.keys()))
  131. check_fqn(list(orig_named_modules.keys()), list(new_named_modules.keys()))
  132. # TODO: a stricter verification should also reject changing module
  133. # types and monkey-patching forward() method implementations.
  134. # TODO: verify that installed distributed paradigms are compatible with
  135. # each other.
  136. return updated
  137. def get_state(module: nn.Module) -> Optional[_State]:
  138. return module.__dict__.setdefault( # type: ignore[call-overload]
  139. STATE_KEY,
  140. {}, # TODO(@yhcharles): this is a temporary fix, need a better way
  141. ).get(
  142. func
  143. ) # type: ignore[call-overload]
  144. wrapper.state = get_state # type: ignore[attr-defined]
  145. return wrapper
  146. return inner
  147. def _get_registry(module: nn.Module) -> Dict[str, RegistryItem]:
  148. r"""
  149. Get an ``OrderedDict`` of composable APIs that have been applied to the
  150. ``module``, indexed by the API name.
  151. """
  152. default_registry: Dict[str, RegistryItem] = OrderedDict()
  153. return module.__dict__.setdefault(REGISTRY_KEY, default_registry) # type: ignore[call-overload]