_traversal_utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """
  2. NOTE: This file must be imported like
  3. ``import torch.distributed.fsdp._traversal_utils`` and not like
  4. ``from torch.distirbuted.fsdp._traversal_utils import ...`` to avoid circular
  5. imports. For brevity, we may import the file as ``traversal_utils``.
  6. """
  7. import collections
  8. from typing import Deque, List, Set, Tuple
  9. import torch.nn as nn
  10. from torch.distributed._composable.contract import _get_registry
  11. from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state
  12. """
  13. [Note: FSDP State Traversal]
  14. For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel``
  15. module wrapping a fully sharded module, and for the non-wrapper code path,
  16. ``_FSDPState`` is an object that gets embedded on a fully sharded module.
  17. See [Note: Fully Sharded Module] for the definition.
  18. There are three common traversal idioms: Given a root module,
  19. - ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree.
  20. - ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the
  21. tree (i.e. those with ``_is_root == True``).
  22. - ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree.
  23. All of these methods must take in the root module (i.e. an ``nn.Module``) and
  24. not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph
  25. traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal.
  26. """
  27. def _composable(module: nn.Module) -> bool:
  28. """
  29. Returns if ``module`` can compose with ``fully_shard``.
  30. """
  31. # TODO: Add any other composable APIs that are mutually exclusive.
  32. return "replicate" not in _get_registry(module)
  33. # TODO (awgu): We may be able to remove this function if we retired the
  34. # `use_orig_params=False` code path since so far we only need the module for
  35. # `FlatParameter` registration, which is not needed for `use_orig_params=True`.
  36. def _get_fsdp_states_with_modules(
  37. module: nn.Module,
  38. ) -> Tuple[List[_FSDPState], List[nn.Module]]:
  39. """
  40. Returns a tuple containing:
  41. 1. A list of the ``_FSDPState`` instances in the module tree rooted at
  42. ``module`` without any duplicates and following the ``module.modules()``
  43. traversal order (which is assumed to be depth-first).
  44. 2. A corresponding list of the modules owning the states in the first list.
  45. For the wrapper code path, both returned lists are the same, each
  46. containing all ``FullyShardedDataParallel`` instances. For the composable
  47. code path, this returns a list of all composable state instances and a list
  48. of the corresponding fully sharded modules. See [Note: Fully Sharded
  49. Module].
  50. NOTE: The traversal does not proceed into any module annotated by an
  51. incompatible API (e.g. ``replicate``).
  52. """
  53. fsdp_states: List[_FSDPState] = []
  54. fsdp_modules: List[nn.Module] = []
  55. # Track the visited FSDP states since multiple modules may share the same
  56. # one and we want to return a de-duplicated list
  57. visited_fsdp_states: Set[_FSDPState] = set()
  58. # Track the visited modules in case of shared modules, which implies the
  59. # module graph is no longer a tree
  60. visited_modules: Set[nn.Module] = set()
  61. # Perform depth-first search from `module` to ensure that we do not
  62. # traverse into an incompatible API's subtree (use DFS instead of BFS to
  63. # match `.modules()` order)
  64. deque: Deque[nn.Module] = collections.deque()
  65. deque.append(module)
  66. while deque:
  67. submodule = deque.popleft()
  68. visited_modules.add(submodule)
  69. if not _composable(submodule):
  70. continue
  71. for child_module in reversed(list(submodule.children())):
  72. if child_module not in visited_modules:
  73. deque.appendleft(child_module)
  74. optional_state = _get_module_fsdp_state(submodule)
  75. if optional_state is not None and optional_state not in visited_fsdp_states:
  76. visited_fsdp_states.add(optional_state)
  77. fsdp_states.append(optional_state)
  78. fsdp_modules.append(submodule)
  79. return fsdp_states, fsdp_modules
  80. def _get_fsdp_states(module: nn.Module) -> List[_FSDPState]:
  81. """See :func:`_get_fsdp_states_with_modules`."""
  82. fsdp_states, _ = _get_fsdp_states_with_modules(module)
  83. return fsdp_states
  84. def _get_fsdp_handles(module: nn.Module) -> List:
  85. """
  86. Returns all ``FlatParamHandle`` s in the module tree rooted at ``module``
  87. following the rules in :func:`_get_fsdp_state`.
  88. """
  89. return [
  90. handle
  91. for fsdp_state in _get_fsdp_states(module)
  92. for handle in fsdp_state._handles
  93. ]