fully_shard.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Callable, Iterable, Optional, Union
  2. import torch
  3. import torch.distributed as dist
  4. import torch.nn as nn
  5. from torch.distributed._composable.contract import contract
  6. from torch.distributed._composable_state import _get_module_state, _insert_module_state
  7. from torch.distributed.fsdp._common_utils import _FSDPState
  8. from torch.distributed.fsdp._init_utils import (
  9. _init_buffer_state,
  10. _init_core_state,
  11. _init_ignored_module_states,
  12. _init_param_handles_from_module,
  13. _init_prefetching_state,
  14. _init_process_group_state,
  15. _init_runtime_state,
  16. _init_state_dict_state,
  17. )
  18. from torch.distributed.fsdp._runtime_utils import (
  19. _register_post_forward_hooks,
  20. _register_pre_forward_hooks,
  21. _register_root_pre_forward_hook,
  22. )
  23. from torch.distributed.fsdp._state_dict_utils import _register_all_state_dict_hooks
  24. from torch.distributed.fsdp.api import (
  25. BackwardPrefetch,
  26. CPUOffload,
  27. MixedPrecision,
  28. ShardingStrategy,
  29. )
  30. from torch.distributed.fsdp.wrap import _FSDPPolicy
  31. @contract(state_cls=_FSDPState)
  32. def fully_shard(
  33. module: nn.Module,
  34. *,
  35. process_group: Optional[dist.ProcessGroup] = None,
  36. policy: Optional[_FSDPPolicy] = None,
  37. strategy: Optional[ShardingStrategy] = None,
  38. mixed_precision: Optional[MixedPrecision] = None,
  39. cpu_offload: Optional[CPUOffload] = None,
  40. ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
  41. device_id: Optional[Union[int, torch.device]] = None,
  42. param_init_fn: Optional[Callable[[nn.Module], None]] = None,
  43. sync_module_states: bool = False,
  44. ) -> nn.Module:
  45. """
  46. Applies ``FullyShardedDataParallel` (FSDP) semantics to ``module``.
  47. """
  48. # Enforce the new auto wrap policy
  49. if policy is not None and not isinstance(policy, _FSDPPolicy):
  50. raise ValueError(f"Expects an `_FSDPPolicy` but got {policy}")
  51. state = fully_shard.state(module)
  52. state = _init_ignored_module_states(state, module, ignored_modules)
  53. state = _init_process_group_state(
  54. state, process_group, ShardingStrategy.FULL_SHARD, policy
  55. )
  56. limit_all_gathers = True
  57. use_orig_params = True
  58. backward_prefetch_limit = 1
  59. forward_prefetch_limit = 1
  60. state = _init_core_state(
  61. state,
  62. strategy or ShardingStrategy.FULL_SHARD,
  63. mixed_precision,
  64. cpu_offload,
  65. limit_all_gathers,
  66. use_orig_params,
  67. backward_prefetch_limit,
  68. forward_prefetch_limit,
  69. )
  70. state = _init_runtime_state(state)
  71. state = _init_prefetching_state(state, BackwardPrefetch.BACKWARD_PRE, False)
  72. state = _init_buffer_state(state, module)
  73. state = _init_param_handles_from_module(
  74. state,
  75. module,
  76. policy,
  77. device_id,
  78. param_init_fn,
  79. sync_module_states,
  80. )
  81. state = _init_state_dict_state(state)
  82. _register_all_state_dict_hooks(state)
  83. modules = list(module.modules())
  84. _register_pre_forward_hooks(state, modules)
  85. _register_post_forward_hooks(state, modules)
  86. _register_root_pre_forward_hook(state, module) # prepend last
  87. for submodule in module.modules():
  88. if (
  89. submodule not in state._ignored_modules
  90. and _get_module_state(submodule) is None
  91. ):
  92. _insert_module_state(submodule, state)
  93. return module