_wrap_utils.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import collections
  2. import functools
  3. import warnings
  4. from typing import Any, Deque, Dict, List, NamedTuple, Set, Tuple
  5. import torch
  6. import torch.nn as nn
  7. from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
  8. from torch.distributed.fsdp._utils import (
  9. _contains_batchnorm,
  10. _override_batchnorm_mixed_precision,
  11. )
  12. from torch.distributed.fsdp.wrap import (
  13. _FSDPPolicy,
  14. _or_policy,
  15. _recursive_wrap,
  16. _wrap_batchnorm_individually,
  17. )
  18. class FullyShardedModuleState(NamedTuple):
  19. """
  20. Module state for ``_get_fully_sharded_module_to_states()``, representing
  21. a logical grouping (e.g. parameters to be flattened together).
  22. """
  23. params: List[nn.Parameter]
  24. buffers: List[torch.Tensor]
  25. def _auto_wrap(
  26. auto_wrap_kwargs: Dict[str, Any],
  27. fsdp_kwargs: Dict[str, Any],
  28. module_wrapper_cls: Any, # e.g. `FullyShardedDataParallel`
  29. ) -> None:
  30. """
  31. Recursively auto wraps the root module given by the key "module" in
  32. ``auto_wrap_kwargs`` with the arguments in ``auto_wrap_kwargs`` and
  33. ``fsdp_kwargs``.
  34. Precondition: ``auto_wrap_policy`` contains the arguments expected by
  35. ``_recursive_wrap()``, where ``auto_wrap_policy`` is not ``None``.
  36. ``fsdp_kwargs`` contains all FSDP arguments except ``module``.
  37. """
  38. auto_wrap_policy = auto_wrap_kwargs["auto_wrap_policy"]
  39. # Support new way to pass an auto wrap policy
  40. if isinstance(auto_wrap_policy, _FSDPPolicy):
  41. auto_wrap_policy = auto_wrap_policy.policy
  42. root_module = auto_wrap_kwargs["module"]
  43. assert auto_wrap_policy is not None
  44. # For auto wrapping, submodules should not already be wrapped with FSDP
  45. # since double wrapping is not supported
  46. for module_name, module in root_module.named_modules():
  47. if isinstance(module, module_wrapper_cls):
  48. raise ValueError(
  49. f"Expected {module_name} to NOT be FullyShardedDataParallel "
  50. "if using an `auto_wrap_policy`"
  51. )
  52. mixed_precision = fsdp_kwargs["mixed_precision"]
  53. if mixed_precision is not None and _contains_batchnorm(root_module):
  54. _override_batchnorm_mixed_precision(root_module)
  55. auto_wrap_policy = functools.partial(
  56. _or_policy, policies=[_wrap_batchnorm_individually, auto_wrap_policy]
  57. )
  58. warnings.warn(
  59. "Both mixed precision and an `auto_wrap_policy` were specified "
  60. "for FSDP, where the wrapped module has batch norm submodules. "
  61. "The batch norm submodules will be wrapped as separate FSDP "
  62. "instances with mixed precision disabled since some batch norm "
  63. "kernels do not support low precision."
  64. )
  65. auto_wrap_kwargs["auto_wrap_policy"] = auto_wrap_policy
  66. _recursive_wrap(**auto_wrap_kwargs, **fsdp_kwargs)
  67. def _get_fully_sharded_module_to_states(
  68. root_module: nn.Module,
  69. auto_wrap_policy: _FSDPPolicy,
  70. ignored_modules: Set[nn.Module],
  71. ignored_params: Set[nn.Parameter],
  72. ) -> Dict[nn.Module, FullyShardedModuleState]:
  73. """
  74. Returns a mapping from fully sharded module to its parameters, buffers,
  75. parameter names, and buffer names, where each entry logically represents a
  76. grouping according to the given auto wrap policy and ignored
  77. modules/parameters. However, this method does not actually perform any
  78. module wrapping.
  79. The mapped-to values are the states from the subtree rooted at the
  80. corresponding submodule key, excluding child submodules in the mapping and
  81. ignored state. Sibling submodules cannot be grouped together. The parameter
  82. and buffer names are prefixed starting from the submodule.
  83. Each non-ignored parameter and buffer appears exactly once in the returned
  84. ``dict``, and the ``dict`` is ordered by increasing tree depth. A mapped-to
  85. parameter list may be empty if the fully sharded module has no parameters
  86. or if its parameters were assigned to a parent fully sharded module
  87. instead.
  88. """
  89. # Record the modules to wrap without actually wrapping
  90. wrapped_modules_set: Set[nn.Module] = set() # these are only logically wrapped
  91. wrapper_cls = functools.partial(_record_module_wrapper_cls, wrapped_modules_set)
  92. if auto_wrap_policy is not None:
  93. _recursive_wrap(
  94. root_module,
  95. auto_wrap_policy=auto_wrap_policy.policy,
  96. wrapper_cls=wrapper_cls,
  97. ignored_modules=ignored_modules,
  98. ignored_params=ignored_params,
  99. only_wrap_children=False,
  100. )
  101. # Always include the root module even if not wrapped by the given policy
  102. wrapped_modules_set.add(root_module)
  103. fully_sharded_module_to_states = collections.OrderedDict()
  104. visited_params = set()
  105. for ignored_param in ignored_params:
  106. visited_params.add(ignored_param)
  107. visited_buffers = set()
  108. # Construct `wrapped_modules` to follow `.modules()` order to ensure that
  109. # downstream data structures (`._handles`) match those of the wrapper path.
  110. # NOTE: Since `.modules()` follows a depth-first order, which is a
  111. # topological sort, and we iterate over `wrapped_modules` following that
  112. # order, parent-child shared parameters are assigned to the parent module.
  113. wrapped_modules: List[nn.Module] = []
  114. for module in root_module.modules():
  115. if module in wrapped_modules_set:
  116. wrapped_modules.append(module)
  117. for submodule in wrapped_modules:
  118. # Perform a DFS from `submodule` and record all unvisited state that is
  119. # not already associated with another module in `wrapped_modules`. We
  120. # use DFS to follow the `.modules()` order.
  121. deque: Deque[Tuple[nn.Module, str]] = collections.deque()
  122. deque.append((submodule, ""))
  123. params: List[nn.Parameter] = []
  124. buffers: List[torch.Tensor] = []
  125. while len(deque) > 0:
  126. module, prefix = deque.popleft()
  127. # Reverse `named_children()`, use `appendleft()`, and add to the
  128. # deque before processing to perform non-recursive DFS
  129. for child_module_name, child_module in reversed(
  130. list(module.named_children())
  131. ):
  132. if child_module not in wrapped_modules_set:
  133. deque.appendleft((child_module, prefix + child_module_name + "."))
  134. for param in module.parameters(recurse=False):
  135. if param not in visited_params and not _is_fsdp_flattened(param):
  136. params.append(param)
  137. visited_params.add(param)
  138. for buffer in module.buffers(recurse=False):
  139. if buffer not in visited_buffers:
  140. buffers.append(buffer)
  141. visited_buffers.add(buffer)
  142. fully_sharded_module_to_states[submodule] = FullyShardedModuleState(
  143. params, buffers
  144. )
  145. return fully_sharded_module_to_states
  146. def _record_module_wrapper_cls(
  147. wrapped_modules_set: Set[nn.Module],
  148. module: nn.Module,
  149. **kwargs,
  150. ) -> nn.Module:
  151. """
  152. This defines a pseudo-wrapper class to be passed to ``_recursive_wrap()``
  153. that records the wrapped module to the input ``wrapped_modules_set``
  154. without actually wrapping with a class.
  155. """
  156. wrapped_modules_set.add(module)
  157. return module