api.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. """
  2. This file includes public APIs for FSDP such as the classes used for the
  3. constructor arguments.
  4. """
  5. from dataclasses import dataclass
  6. from enum import auto, Enum
  7. from typing import Optional
  8. import torch
  9. __all__ = [
  10. "ShardingStrategy",
  11. "BackwardPrefetch",
  12. "MixedPrecision",
  13. "CPUOffload",
  14. "StateDictType",
  15. "StateDictConfig",
  16. "FullStateDictConfig",
  17. "LocalStateDictConfig",
  18. "ShardedStateDictConfig",
  19. "OptimStateDictConfig",
  20. "FullOptimStateDictConfig",
  21. "LocalOptimStateDictConfig",
  22. "ShardedOptimStateDictConfig",
  23. "StateDictSettings",
  24. ]
  25. class ShardingStrategy(Enum):
  26. """
  27. This specifies the sharding strategy to be used for distributed training by
  28. :class:`FullyShardedDataParallel`.
  29. - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded.
  30. For the parameters, this strategy unshards (via all-gather) before the
  31. forward, reshards after the forward, unshards before the backward
  32. computation, and reshards after the backward computation. For gradients,
  33. it synchronizes and shards them (via reduce-scatter) after the backward
  34. computation. The sharded optimizer states are updated locally per rank.
  35. - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during
  36. computation, and additionally, parameters are sharded outside
  37. computation. For the parameters, this strategy unshards before the
  38. forward, does not reshard them after the forward, and only reshards them
  39. after the backward computation. The sharded optimizer states are updated
  40. locally per rank. Inside ``no_sync()``, the parameters are not resharded
  41. after the backward computation.
  42. - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded
  43. but instead replicated across ranks similar to PyTorch's
  44. :class:`DistributedDataParallel` API. For gradients, this strategy
  45. synchronizes them (via all-reduce) after the backward computation. The
  46. unsharded optimizer states are updated locally per rank.
  47. - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across
  48. nodes. This results in reduced communication volume as expensive all-gathers and
  49. reduce-scatters are only done within a node, which can be more performant for medium
  50. -sized models.
  51. - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
  52. nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput
  53. since the unsharded parameters are not freed after the forward pass, saving the
  54. all-gathers in the pre-backward.
  55. """
  56. FULL_SHARD = auto()
  57. SHARD_GRAD_OP = auto()
  58. NO_SHARD = auto()
  59. HYBRID_SHARD = auto()
  60. _HYBRID_SHARD_ZERO2 = auto()
  61. class BackwardPrefetch(Enum):
  62. """
  63. This configures explicit backward prefetching, which can improve throughput
  64. but may slightly increase peak memory usage.
  65. For NCCL backend, any collectives, even if issued in different streams,
  66. contend for the same per-device NCCL stream, which is why the relative
  67. order in which the collectives are issued matters for overlapping. The
  68. different backward prefetching settings correspond to different orderings.
  69. - ``BACKWARD_PRE``: This prefetches the next set of parameters before the
  70. current set of parameter's gradient computation. This improves backward
  71. pass throughput by overlapping communication (next all-gather) and
  72. computation (current gradient computation).
  73. - ``BACKWARD_POST``: This prefetches the next set of parameters after the
  74. current set of parameter's gradient computation. This may improve
  75. backward pass throughput by overlapping communication (current
  76. reduce-scatter) and computation (next gradient computation).
  77. Specifically, the next all-gather is reordered to be before the current
  78. reduce-scatter.
  79. .. note:: If the increase in peak memory usage from prefetching is an
  80. issue, you may consider passing ``limit_all_gathers=True`` to the FSDP
  81. constructor, which may help reduce peak memory usage in some cases.
  82. """
  83. # NOTE: For both modes, the ordering that defines "current" and "next" is
  84. # not always correct in the current implementation, so this may cause some
  85. # performance regression for some models.
  86. BACKWARD_PRE = auto()
  87. BACKWARD_POST = auto()
  88. @dataclass
  89. class MixedPrecision:
  90. """
  91. This configures FSDP-native mixed precision training.
  92. Attributes:
  93. param_dtype (torch.dtype): This specifies the dtype for model
  94. parameters, inputs (when ``cast_forward_inputs`` or
  95. ``cast_root_forward_inputs``is set to
  96. ``True``), and therefore the dtype for computation.
  97. However, outside the forward and backward passes, parameters are in
  98. full precision. Model checkpointing always happens in full
  99. precision.
  100. reduce_dtype (torch.dtype): This specifies the dtype for gradient
  101. reduction, which is permitted to differ from ``param_dtype``.
  102. buffer_dtype (torch.dtype): This specifies the dtype for buffers. FSDP
  103. does not shard buffers, casts them to ``buffer_dtype`` in the first
  104. forward pass, and keeps them in that dtype thereafter. Model
  105. checkpointing always happens in full precision.
  106. keep_low_precision_grads (bool): This specifies whether to upcast
  107. gradients back to the full parameter precision after the backward
  108. pass. This may be set to ``False`` to save memory if using custom
  109. optimizers that can perform the optimizer step in ``reduce_dtype``.
  110. (Default: ``False``)
  111. cast_forward_inputs (bool): Cast floating point tensors in the forward
  112. arguments and keyword arguments to ``param_dtype``.
  113. (Default: ``False``)
  114. cast_root_forward_inputs (bool): Cast floating point tensors in the forward
  115. arguments and keyword arguments to ``param_dtype`` for the root FSDP instance.
  116. It takes precedence over ``cast_forward_inputs`` for the root FSDP instance.
  117. (Default: ``True``)
  118. .. note:: This API is experimental and subject to change.
  119. .. note:: Only floating point tensors are cast to their specified dtypes.
  120. .. note:: In ``summon_full_params``, parameters are forced to full
  121. precision, but buffers are not.
  122. .. note:: ``state_dict`` checkpoints parameters and buffers in full
  123. precision. For buffers, this is only supported for
  124. ``StateDictType.FULL_STATE_DICT``.
  125. .. note:: Each low precision dtype must be specified explicitly. For
  126. example, ``MixedPrecision(reduce_dtype=torch.float16)`` only specifies
  127. the reduction dtype to be low precision, and FSDP will not cast
  128. parameters or buffers.
  129. .. note:: If a ``reduce_dtype`` is not specified, then gradient reduction
  130. happens in ``param_dtype`` if specified or the original parameter dtype
  131. otherwise.
  132. .. note:: If the user passes a model with ``BatchNorm`` modules and an
  133. ``auto_wrap_policy`` to the FSDP constructor, then FSDP will disable
  134. mixed precision for ``BatchNorm`` modules by wrapping them separately
  135. in their own FSDP instance with mixed precision disabled. This is due
  136. to some missing low precision ``BatchNorm`` kernels. If the user does
  137. not use an ``auto_wrap_policy``, then the user must take care to not
  138. use mixed precision for FSDP instances containing ``BatchNorm``
  139. modules.
  140. .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and
  141. ``cast_forward_inputs=False`` by default. For the root FSDP instance,
  142. its ``cast_root_forward_inputs`` takes precedence over its
  143. ``cast_forward_inputs``. For non-root FSDP instances, their
  144. ``cast_root_forward_inputs`` values are ignored. The default setting is
  145. sufficient for the typical case where each FSDP instance has the same
  146. ``MixedPrecision`` configuration and only needs to cast inputs to the
  147. ``param_dtype`` at the beginning of the model's forward pass.
  148. .. note:: For nested FSDP instances with different ``MixedPrecision``
  149. configurations, we recommend setting individual ``cast_forward_inputs``
  150. values to configure casting inputs or not before each instance's
  151. forward. In such a case, since the casts happen before each FSDP
  152. instance's forward, a parent FSDP instance should have its non-FSDP
  153. submodules run before its FSDP submodules to avoid the activation dtype
  154. being changed due to a different ``MixedPrecision`` configuration.
  155. Example::
  156. >>> # xdoctest: +SKIP("undefined variables")
  157. >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
  158. >>> model[1] = FSDP(
  159. >>> model[1],
  160. >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True),
  161. >>> )
  162. >>> model = FSDP(
  163. >>> model,
  164. >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True),
  165. >>> )
  166. The above shows a working example. On the other hand, if ``model[1]``
  167. were replaced with ``model[0]``, meaning that the submodule using
  168. different ``MixedPrecision`` ran its forward first, then ``model[1]``
  169. would incorrectly see ``float16`` activations instead of ``bfloat16``
  170. ones.
  171. """
  172. param_dtype: Optional[torch.dtype] = None
  173. reduce_dtype: Optional[torch.dtype] = None
  174. buffer_dtype: Optional[torch.dtype] = None
  175. keep_low_precision_grads: bool = False
  176. cast_forward_inputs: bool = False
  177. cast_root_forward_inputs: bool = True
  178. @dataclass
  179. class CPUOffload:
  180. """
  181. This configures CPU offloading.
  182. Attributes:
  183. offload_params (bool): This specifies whether to offload parameters to
  184. CPU when not involved in computation. If enabled, this implicitly
  185. offloads gradients to CPU as well. This is to support the optimizer
  186. step, which requires parameters and gradients to be on the same
  187. device.
  188. """
  189. offload_params: bool = False
  190. class StateDictType(Enum):
  191. """
  192. This enum indicates that which type of ``state_dict`` the FSDP module is
  193. currently processing (returning or loading).
  194. The default value is FULL_STATE_DICT to comply the PyTorch convention.
  195. ..note::
  196. FSDP currently supports three types of ``state_dict``:
  197. 1. ``state_dict/load_state_dict`: this pair of APIs return and load
  198. the non-sharded, unflattened parameters. The semantics is the
  199. same as using DDP.
  200. 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return
  201. and load local sharded, flattened parameters. The values returned
  202. by ``_local_state_dict`` can be directly used by FSDP and is only
  203. meaningful to FSDP (because parameters are flattened). Note that
  204. these APIs are meant for use via the :func:`state_dict_type`
  205. context manager as follows:
  206. >>> # xdoctest: +SKIP("undefined variables")
  207. >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT):
  208. ... state = fsdp.state_dict() # loads local state dict
  209. 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs
  210. return and load sharded, unflattened parameters. The ``state_dict``
  211. return by ``sharded_state_dict`` can be used by all other parallel
  212. schemes (resharding may be required).
  213. """
  214. FULL_STATE_DICT = auto()
  215. LOCAL_STATE_DICT = auto()
  216. SHARDED_STATE_DICT = auto()
  217. @dataclass
  218. class StateDictConfig:
  219. """
  220. ``StateDictConfig`` is the base class for all state_dict configuration classes.
  221. Users should instantiate a child version (i.e. ``FullStateDictConfig``) in
  222. order to configure settings for the particular type of ``state_dict``
  223. implementation FSDP will use.
  224. """
  225. offload_to_cpu: bool = False
  226. @dataclass
  227. class FullStateDictConfig(StateDictConfig):
  228. """
  229. ``FullStateDictConfig`` is a config class meant to be used with
  230. ``StateDictType.FULL_STATE_DICT``. Currently, it accepts two parameters,
  231. ``offload_to_cpu`` and ``rank0_only`` which can be configured to offload
  232. the full ``state_dict`` to CPU and to materialize the ``state_dict`` on
  233. rank 0 only. When used, it is recommended to enable both of these flags
  234. together to optimize memory savings when taking checkpoints. Note that
  235. this config class is meant for user via the :func:`state_dict_type`
  236. context manager as follows:
  237. >>> # xdoctest: +SKIP("undefined variables")
  238. >>> fsdp = FSDP(model, auto_wrap_policy=...)
  239. >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  240. >>> with FullyShardedDataParallel.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg):
  241. >>> state = fsdp.state_dict()
  242. >>> # state will be empty on non rank 0 and contain CPU tensors on rank 0.
  243. >>> # To reload checkpoint for inference, finetuning, transfer learning, etc:
  244. >>> model = model_fn() # Initialize model on CPU in preparation for wrapping with FSDP
  245. >>> if dist.get_rank() == 0:
  246. >>> # Load checkpoint only on rank 0 to avoid memory redundancy
  247. >>> state_dict = torch.load("my_checkpoint.pt")
  248. >>> model.load_state_dict(state_dict)
  249. >>> # All ranks initialize FSDP module as usual. ``sync_module_states`` argument
  250. >>> # communicates loaded checkpoint states from rank 0 to rest of the world.
  251. >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
  252. >>> # After this point, all ranks have FSDP model with loaded checkpoint.
  253. """
  254. rank0_only: bool = False
  255. @dataclass
  256. class LocalStateDictConfig(StateDictConfig):
  257. pass
  258. @dataclass
  259. class ShardedStateDictConfig(StateDictConfig):
  260. pass
  261. @dataclass
  262. class OptimStateDictConfig:
  263. """
  264. ``OptimStateDictConfig`` is the base class for all optimizer state_dict
  265. configuration classes. Users should instantiate a child version
  266. (i.e. ``FullOptimStateDictConfig``) in order to configure settings for the
  267. particular type of ``optim_state_dict`` implementation FSDP will use.
  268. """
  269. # TODO: actually use this flag in the _optim_utils.py
  270. offload_to_cpu: bool = True
  271. @dataclass
  272. class FullOptimStateDictConfig(OptimStateDictConfig):
  273. rank0_only: bool = False
  274. @dataclass
  275. class LocalOptimStateDictConfig(OptimStateDictConfig):
  276. offload_to_cpu: bool = False
  277. @dataclass
  278. class ShardedOptimStateDictConfig(OptimStateDictConfig):
  279. pass
  280. @dataclass
  281. class StateDictSettings:
  282. state_dict_type: StateDictType
  283. state_dict_config: StateDictConfig
  284. optim_state_dict_config: OptimStateDictConfig