_common_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. """
  2. This file includes private common utilities for FSDP.
  3. """
  4. import traceback
  5. import warnings
  6. from enum import auto, Enum
  7. from typing import (
  8. Callable,
  9. Dict,
  10. Generator,
  11. Iterable,
  12. List,
  13. no_type_check,
  14. Optional,
  15. Set,
  16. )
  17. import torch
  18. import torch.distributed as dist
  19. import torch.distributed.fsdp.flat_param as flat_param_file
  20. import torch.nn as nn
  21. from torch.distributed._composable_state import _get_module_state, _State
  22. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  23. _CHECKPOINT_PREFIX,
  24. )
  25. from .api import (
  26. FullOptimStateDictConfig,
  27. FullStateDictConfig,
  28. OptimStateDictConfig,
  29. ShardingStrategy,
  30. StateDictConfig,
  31. StateDictType,
  32. )
  33. FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
  34. FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
  35. FSDP_FLATTENED = "_fsdp_flattened"
  36. class _FSDPState(_State):
  37. def __init__(self) -> None:
  38. # TODO: Move all the attributes to this class to enable typing for
  39. # FSDP/fully_shard.
  40. self._ignored_modules: Set[nn.Module] = set()
  41. self._ignored_params: Set[nn.Parameter] = set()
  42. self.process_group: Optional[dist.ProcessGroup] = None
  43. self.rank: int = -1
  44. self.world_size: int = -1
  45. self.sharding_strategy = ShardingStrategy.FULL_SHARD
  46. self._use_orig_params: bool = False
  47. self.training_state = TrainingState.IDLE
  48. self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
  49. self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
  50. self._state_dict_config: StateDictConfig = FullStateDictConfig()
  51. self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
  52. self._is_root: Optional[bool] = None
  53. self._handles: List[flat_param_file.FlatParamHandle] = []
  54. self._fully_sharded_module_to_handles: Dict[
  55. nn.Module, flat_param_file.FlatParamHandle
  56. ] = {}
  57. self.compute_device = torch.device("cuda", torch.cuda.current_device())
  58. def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]:
  59. state = _get_module_state(module)
  60. if state is None or not isinstance(state, _FSDPState):
  61. return None
  62. return state
  63. def _get_module_fsdp_state_if_fully_sharded_module(
  64. module: nn.Module,
  65. ) -> Optional[_FSDPState]:
  66. state = _get_module_fsdp_state(module)
  67. if state is None:
  68. return None
  69. if state == module: # FullyShardedDataParallel module case.
  70. return state
  71. if module in state._fully_sharded_module_to_handles: # fully_shard case.
  72. return state
  73. return None
  74. class TrainingState(Enum):
  75. """
  76. An enum that indicates the state of a ``FullyShardedDataParallel` instance.
  77. """
  78. IDLE = auto()
  79. FORWARD_BACKWARD = auto()
  80. SUMMON_FULL_PARAMS = auto()
  81. class HandleTrainingState(Enum):
  82. """
  83. An enum that indicates the state of a ``FlatParamHandle`.
  84. """
  85. IDLE = auto()
  86. FORWARD = auto()
  87. BACKWARD_PRE = auto()
  88. BACKWARD_POST = auto()
  89. SUMMON_FULL_PARAMS = auto()
  90. def _is_composable(state: _FSDPState):
  91. # TODO: This is a temporary hack for differentiate between code paths.
  92. return not isinstance(state, nn.Module)
  93. @no_type_check
  94. def _module_handles(state: _FSDPState, module: nn.Module) -> List:
  95. """
  96. Returns the ``FlatParamHandle`` s corresponding to ``module``. These are
  97. the handles that contain some parameter in ``module``.
  98. """
  99. if _is_composable(state):
  100. assert (
  101. module in state._fully_sharded_module_to_handles
  102. ), f"Expects a `comm_module` but got {module} on rank {state.rank}"
  103. return state._fully_sharded_module_to_handles[module][:]
  104. else:
  105. # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance.
  106. return module._handles[:]
  107. @no_type_check
  108. def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool:
  109. """Returns if ``module`` has parameters managed by FSDP."""
  110. return len(_module_handles(state, module)) > 0
  111. def _get_sharding_strategy(handles: Iterable):
  112. """
  113. Returns the sharding strategy of the group of handles given by ``handles``
  114. or ``None`` if ``handles`` is empty. The input should be the handles
  115. corresponding to one module, so we enforce that they all share the same
  116. sharding strategy.
  117. """
  118. sharding_strategy = None
  119. for handle in handles:
  120. if sharding_strategy is None:
  121. sharding_strategy = handle._sharding_strategy
  122. elif (
  123. sharding_strategy is not None
  124. and sharding_strategy != handle._sharding_strategy
  125. ):
  126. raise AssertionError(
  127. "Expects each group of handles to have the same sharding "
  128. f"strategy but got {sharding_strategy} and {handle._sharding_strategy}"
  129. )
  130. return sharding_strategy
  131. def clean_tensor_name(tensor_name: str) -> str:
  132. """
  133. Cleans the parameter or buffer name by removing any module wrapper
  134. prefixes.
  135. """
  136. tensor_name = tensor_name.replace(FSDP_PREFIX, "")
  137. # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as
  138. # it couples `CheckpointWrapper` and FSDP and also does not scale for more
  139. # module wrappers.
  140. tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "")
  141. return tensor_name
  142. def _set_fsdp_flattened(tensor: torch.Tensor) -> None:
  143. """
  144. Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to
  145. avoid re-flattening it during nested construction.
  146. """
  147. setattr(tensor, FSDP_FLATTENED, True)
  148. def _is_fsdp_flattened(tensor: torch.Tensor) -> bool:
  149. """Returns if ``tensor`` has been marked as flattened by FSDP."""
  150. return getattr(tensor, FSDP_FLATTENED, False)
  151. def _get_param_to_fqns(
  152. model: torch.nn.Module,
  153. dedup_shared_params: bool = True,
  154. ) -> Dict[nn.Parameter, List[str]]:
  155. """
  156. Constructs a mapping from parameter to a list of its FQNs. Each normal
  157. parameter maps to a singleton list containing its FQN, while each
  158. ``FlatParameter`` maps to a list of its original parameter FQNs, which may
  159. have length greater than one. All FQNs are prefixed starting from
  160. ``model``.
  161. Args:
  162. model (torch.nn.Module): Root module (which may or may not be a
  163. :class:`FullyShardedDataParallel` instance).
  164. dedup_shared_params (bool): For shared parameters, if ``True``, only
  165. includes the FQNs corresponding to the first encounter of the
  166. shared parameter in the module traversal; if ``False``, then
  167. includes the FQNs across all encounters. (Default: ``True``)
  168. """
  169. def module_fn(module, prefix, param_to_fqns):
  170. for param_name, param in module.named_parameters(recurse=False):
  171. local_fqns = (
  172. param._fqns
  173. if type(param) is flat_param_file.FlatParameter
  174. else [param_name]
  175. ) # prefixed from `module`
  176. global_fqns = [
  177. clean_tensor_name(prefix + name) for name in local_fqns
  178. ] # prefixed from the top level `model` (i.e. including `prefix`)
  179. is_shared_param = param in param_to_fqns
  180. if not is_shared_param:
  181. param_to_fqns[param] = global_fqns
  182. else:
  183. if type(param) is flat_param_file.FlatParameter:
  184. # DMP overwrites `named_parameters` and skip (advance to
  185. # the next child module) the wrapped_module (e.g.,
  186. # _dmp_wrapped_module and _fsdp_wrapped_module). When a user
  187. # calls `named_child` to traverse the module recursively and
  188. # calls `named_parameters` with `recurse=False`, parameters
  189. # will be traversed more than once.
  190. # This hack is specificed designed for DMP + FSDP. We
  191. # overwite the flat_parameters traversal result to only obtain
  192. # the last one, which happens to be the correct one.
  193. #
  194. # TODO: Remove this hack once DMP + FSDP is not supported.
  195. warnings.warn(
  196. "FlatParameter is being traversed more than once. "
  197. "This case should only happen when using "
  198. "DistributedModelParallel with FullyShardedDataParallel."
  199. )
  200. param_to_fqns[param] = global_fqns
  201. elif not dedup_shared_params:
  202. param_to_fqns[param].extend(global_fqns)
  203. def return_fn(param_to_fqns):
  204. return param_to_fqns
  205. param_to_unflat_param_names: Dict[torch.nn.Parameter, List[str]] = {}
  206. return _apply_to_modules(
  207. model,
  208. module_fn,
  209. return_fn,
  210. [key for key, _ in model.named_parameters()],
  211. param_to_unflat_param_names,
  212. )
  213. def _apply_to_modules(
  214. root_module: torch.nn.Module,
  215. module_fn: Callable,
  216. return_fn: Callable,
  217. filter_fqns: Optional[List[str]] = None,
  218. *args,
  219. **kwargs,
  220. ):
  221. """
  222. Performs a pre-order traversal of the modules in the hierarchy rooted at
  223. ``root_module``, applying ``module_fn`` at each module and finally
  224. returning a value using ``return_fn``. The traversal constructs the full
  225. module prefix name (e.g. "module.submodule." just like in model state dict)
  226. and makes that available to ``module_fn``.
  227. ``filter_fqns`` is used because some module may have its own prefix similar
  228. to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten
  229. to remove the prefix.
  230. """
  231. def f(module: torch.nn.Module, prefix: str, *args, **kwargs):
  232. # Call the module function before recursing over children (pre-order)
  233. module_fn(module, prefix, *args, **kwargs)
  234. for submodule_name, submodule in module.named_children():
  235. if submodule is None:
  236. continue
  237. new_prefix = prefix + submodule_name + "."
  238. if filter_fqns is not None:
  239. for fqn in filter_fqns:
  240. if fqn.startswith(new_prefix):
  241. break
  242. else:
  243. # DMP's named_parameter() will mess up the traversal with
  244. # ``named_children`` + `named_parameter(recurse=False)``.
  245. # This hack is a must to make the travsersal work.
  246. # TODO: Remove this hack once DMP + FSDP is not supported.
  247. if (
  248. submodule_name == "_fsdp_wrapped_module"
  249. or submodule_name == "_dmp_wrapped_module"
  250. ):
  251. warnings.warn(
  252. "An unexpected prefix is detected. This case "
  253. " should only happen when using DMP with FSDP. "
  254. f"prefix = {prefix}, "
  255. f"submodule_name = {submodule_name}"
  256. )
  257. new_prefix = prefix
  258. f(submodule, new_prefix, *args, **kwargs)
  259. f(root_module, "", *args, **kwargs)
  260. return return_fn(*args, **kwargs)
  261. @no_type_check
  262. def _assert_in_training_states(
  263. state: _FSDPState,
  264. training_states: List[TrainingState],
  265. ) -> None:
  266. """Asserts that FSDP is in the states ``_training_states``."""
  267. # Raise a `ValueError` instead of using `assert` to ensure that these
  268. # logical assertions run even if `assert`s are disabled
  269. if state.training_state not in training_states:
  270. msg = (
  271. f"expected to be in states {training_states} but current state is "
  272. f"{state.training_state}"
  273. )
  274. # Print the error on rank 0 in case this is called in the backward pass
  275. if state.rank == 0:
  276. if isinstance(state, nn.Module):
  277. print(f"Asserting FSDP instance is: {state}")
  278. print(f"ERROR: {msg}")
  279. traceback.print_stack()
  280. raise ValueError(msg)
  281. def _get_root_modules(modules: Set[nn.Module]) -> Set[nn.Module]:
  282. """
  283. Returns:
  284. Set[nn.Module]: The subset of ``modules`` that are root modules (i.e.
  285. parent-less) with respect to the modules in the set itself. In other
  286. words, these are the modules in ``modules`` that are not the child of
  287. any other module in ``modules``.
  288. """
  289. root_modules: Set[nn.Module] = set()
  290. module_to_submodules = {module: set(module.modules()) for module in modules}
  291. for candidate_module in modules:
  292. is_root_module = True
  293. for module, submodules in module_to_submodules.items():
  294. is_child_module = (
  295. candidate_module is not module and candidate_module in submodules
  296. )
  297. if is_child_module:
  298. is_root_module = False
  299. break
  300. if is_root_module:
  301. root_modules.add(candidate_module)
  302. return root_modules