12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060 |
- import contextlib
- import warnings
- from enum import auto, Enum
- from itertools import accumulate, chain
- from typing import (
- Any,
- Dict,
- Generator,
- Iterator,
- List,
- NamedTuple,
- no_type_check,
- Optional,
- Sequence,
- Set,
- Tuple,
- Union,
- )
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- import torch.nn.functional as F
- from torch import Tensor
- from torch.distributed._tensor import DTensor
- from torch.distributed.fsdp._common_utils import (
- _set_fsdp_flattened,
- HandleTrainingState,
- )
- from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
- from ._utils import (
- _alloc_storage,
- _free_storage,
- _no_dispatch_record_stream,
- _same_storage,
- p_assert,
- )
- __all__ = [
- "FlatParameter",
- "FlatParamHandle",
- "FlatParamShardMetadata",
- "ParamInfo",
- "SharedParamInfo",
- "HandleShardingStrategy",
- ]
- """
- [Note: Fully Sharded Module]
- We define the "fully sharded module" to be the original ``nn.Module`` that owns
- a ``FlatParamHandle``. It is the *single* module logically responsible for the
- *single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
- forward or backward pass. The fully sharded module should be passed to the
- ``FlatParamHandle`` constructor.
- For the wrapper code path:
- - The ``FullyShardedDataParallel`` module wrapping the fully sharded module
- runs the unshard/reshard on behalf of the fully sharded module by overriding
- ``nn.Module.forward``.
- - The fully sharded module is exactly the module passed to the
- ``FullyShardedDataParallel`` constructor's ``module`` argument.
- For the non-wrapper code path:
- - Hooks registered on the fully sharded module run the unshard/reshard.
- - The fully sharded module may either be the direct argument to ``fully_shard``
- or a submodule chosen by the provided wrapping policy.
- """
- class ParamInfo(NamedTuple):
- """Information for an original module parameter."""
- param_name: str # unprefixed
- module: nn.Module
- module_name: str
- class SharedParamInfo(NamedTuple):
- """
- Additional information for a shared parameter.
- For each shared parameter, we designate one module and its parameter
- variable to be the primary owner, determined as the first one encountered
- in the parameter walk. These are prefixed with "prim". The primary module
- and parameter do not have their own :class:`SharedParamInfo` instance.
- """
- param_name: str # unprefixed
- module: nn.Module
- module_name: str
- prim_param_name: str # unprefixed
- prim_module: nn.Module
- prim_module_name: str
- class FlatParamShardMetadata(NamedTuple):
- """
- This holds metadata specific to this rank's shard of the flattened
- parameter.
- Attributes:
- param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
- shard of the parameters; see :class:`FlatParameter`.
- param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
- shard of the parameters; see :class:`FlatParameter`.
- param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
- of the parameters; see :class:`FlatParameter`.
- param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
- units of numels) giving this rank's part of each flattened
- original module parameter.
- """
- param_names: Tuple[str, ...]
- param_shapes: Tuple[torch.Size, ...]
- param_numels: Tuple[int, ...]
- param_offsets: Tuple[Tuple[int, int], ...]
- # TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and
- # inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py
- # later
- class HandleShardingStrategy(Enum):
- FULL_SHARD = auto()
- SHARD_GRAD_OP = auto()
- NO_SHARD = auto()
- HYBRID_SHARD = auto()
- _HYBRID_SHARD_ZERO2 = auto()
- class FlatParameter(nn.Parameter):
- """
- This is the flattened parameter used by :class:`FullyShardedDataParallel`.
- It is comprised of one or more original parameters, which are flattened
- and concatenated to construct the flattened parameter.
- Under the current design, this parameter logically represents both the
- unsharded and sharded flattened parameter, and its data changes storages
- dynamically.
- - In the :class:`FullyShardedDataParallel` constructor, the parameter
- is initialized as unsharded and then sharded in-place.
- - At runtime, the parameter is lazily (re)-initialized. The sharded
- parameter data is saved in ``self._local_shard``, and a new ``Tensor``
- ``self._full_param_padded`` is created, which is the all-gather
- destination and owns the unsharded parameter storage thereafter. (See
- :meth:`FlatParamHandle.init_flat_param_attributes`.)
- - Throughout runtime, the parameter data changes storages as needed,
- e.g. to the sharded flattened parameter, reduced-precision sharded
- flattened parameter, or the unsharded flattened parameter.
- Attributes:
- _unpadded_unsharded_size (torch.Size): Unsharded flattened parameter's
- size without padding.
- _padded_unsharded_size (torch.Size): Unsharded flattened parameter's
- size with padding. This is only set for sharded strategies since
- they require padding for the all-gather.
- _sharded_size (torch.Size): Sharded flattened parameter's size with
- padding. This is also set for ``NO_SHARD``, in which case it is the
- same as the unsharded sizes. (We omit "padded" because there is no
- analogous unpadded one.)
- _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
- entry; see :class:`ParamInfo`.
- _numels (Tuple[int, ...]): Each parameter's numel.
- _shapes (Tuple[torch.Size, ...]): Each parameter's shape.
- _fqns (Tuple[str, ...]): The original parameters' FQNs prefixed from
- the owning handle's ``_fully_sharded_module``. The names are
- guaranteed to be unique within the subtree rooted at that module.
- _num_params (int): Number of original parameters flattened into this
- flattened parameter; this is the length of ``_param_infos``,
- ``_numels``, ``_shapes``, and ``_fqns``.
- _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
- info entries; see :class:`SharedParamInfo`.
- _param_extensions (Tuple[Optional[Any], ...]): Parameter extensions
- (i.e. some per-parameter state) used to customize pre-flatten and
- post-unflatten behavior. This is experimental, and users should not
- depend on its existence in the future.
- _modules (Set[nn.Module]): Modules that contain some original parameter
- that is flattened into the ``FlatParameter``.
- _shard_param_offsets (List[Tuple[int, int])): [start, end] offsets (in
- units of numel) giving this rank's part of each flattened original
- module parameter; for any parameter ``p`` that is not sharded
- across ranks, this will be [0, ``p.numel()``-1].
- _shard_indices (Tuple[int, int]): [start, end] indices (in units of
- parameters) for this rank's shard of the original model parameters,
- where the parameters follow the order in which they were originally
- flattened; this indexes appropriately into any data structure that
- follows the flattening order (e.g. ``_param_infos``, ``_numels``,
- etc.).
- _shard_numel_padded (int): Numel padded for this rank's sharded
- flattened parameter.
- _local_shard (Tensor): Sharded flattened parameter with padding if
- using a sharded strategy. If using ``NO_SHARD``, then this is the
- unpadded unsharded flattened parameter, and there is no notion of a
- sharded flattened parameter or padded unsharded flattened
- parameter.
- _full_param_padded (Tensor): Unsharded flattened parameter with
- padding. This is not defined for ``NO_SHARD``. When using mixed
- precision for parameters, this has the low precision.
- _full_prec_full_param_padded (Tensor): Full precision unsharded
- flattened parameter with padding. This is used for unsharding
- outside of computation when using mixed precision for parameters.
- This is never defined for ``NO_SHARD``.
- _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
- Flattened parameter's :class:`AccumulateGrad` object and
- post-backward hook handle.
- _mp_shard (Tensor): Low precision sharded flattened parameter with
- padding. This is only defined when parameter mixed precision is
- enabled. For ``NO_SHARD``, this is used for computation.
- _cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
- This is only defined when offloading parameters is enabled.
- _saved_grad_shard (Tensor): Sharded gradient with padding from previous
- iterations for gradient accumulation without :meth:`no_sync`.
- _params (Optional[List[nn.Parameter]]): The original parameter
- variables if ``use_orig_params=True`` and ``None`` otherwise.
- _shared_params (Optional[List[nn.Parameter]]): The original shared
- parameter variables if ``use_orig_params=True`` and ``None``
- otherwise.
- _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
- views created in the forward and tracked by autograd when
- ``use_orig_params=True`` and is ``None`` otherwise. This is to
- preserve those ``Tensor`` variables for the backward to ensure that
- the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
- in which case the post-backward hook does not run. This is relevant
- for cases like reentrant activation checkpointing.
- _is_grad_none (Optional[List[bool]]): A mask over the original
- parameters' gradients indicating if it is logically ``None`` or not
- if ``use_orig_params=True`` and ``None`` otherwise. This is needed
- because only some of the parameters may have ``None`` gradient, in
- which case the ``FlatParameter`` gradient must be non-``None`` and
- must use zeros to approximate those original ``None`` gradients.
- This mask informs FSDP to set the original parameter gradients to
- ``None`` (instead of zeros) as needed.
- """
- def _init_metadata(
- self,
- param_infos: List[ParamInfo],
- numels: List[int],
- shapes: List[torch.Size],
- fqns: List[str],
- shared_param_infos: List[SharedParamInfo],
- param_extensions: List[Any],
- params: Optional[List[nn.Parameter]],
- shared_params: Optional[List[nn.Parameter]],
- ) -> None:
- """
- Initializes attributes holding metadata about the original parameters
- comprising the flattened parameter.
- We expose this method separate from the constructor to keep the
- constructor only responsible for the flattened parameter's tensor data.
- This method should only be called once per model, while the constructor
- may be called multiple times, e.g. when reloading from a checkpoint, in
- which case only the tensor data needs to be passed to the constructor.
- Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
- metadata is correctly assumed to be unchanged.
- Args:
- See the Attributes in the class docstring.
- """
- assert len(param_infos) == len(numels)
- assert len(param_infos) == len(shapes)
- assert len(param_infos) == len(fqns)
- assert len(param_infos) == len(param_extensions)
- self._num_params = len(param_infos)
- self._param_infos = tuple(param_infos)
- self._numels = tuple(numels)
- self._shapes = tuple(shapes)
- self._fqns = tuple(fqns)
- self._shared_param_infos = tuple(shared_param_infos)
- self._param_extensions = tuple(param_extensions)
- self._modules = {pi.module for pi in self._param_infos}.union(
- {spi.module for spi in self._shared_param_infos}
- )
- assert (params is None) == (shared_params is None)
- if params is not None:
- assert shared_params is not None and len(shared_params) == len(
- shared_param_infos
- )
- self._params: Optional[List[nn.Parameter]] = params
- self._shared_params: Optional[List[nn.Parameter]] = shared_params
- # Mark the original parameters to avoid flattening them into
- # another `FlatParameter` during recursive construction
- for param in chain(self._params, self._shared_params):
- _set_fsdp_flattened(param)
- self._is_grad_none: Optional[List[bool]] = [
- False for _ in range(len(params))
- ]
- self._tensors: Optional[List[Optional[Tensor]]] = [
- None for _ in range(len(self._params))
- ]
- else:
- self._params = None
- self._shared_params = None
- self._is_grad_none = None
- self._tensors = None
- self._unpadded_unsharded_size = self.size()
- _set_fsdp_flattened(self)
- # Tracks whether the `FlatParameter`'s post-backward hook has been
- # called to modify the behavior of the post-backward callback
- self._post_backward_called = False
- class FlatParamHandle:
- """
- This handle manages a flattened parameter (:class:`FlatParameter`). This
- includes sharding and view management.
- Args:
- params (Sequence[nn.Parameter]): The parameters to use for the
- flattened parameter.
- fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
- device (torch.device): The compute and communication device, which
- should be a non-CPU device. We refer to it as the compute device.
- sharding_strategy (ShardingStrategy): Sharding strategy to apply to
- this handle's ``FlatParameter``.
- offload_params (bool): Whether to offload the handle's
- ``FlatParameter`` to CPU.
- mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
- setting passed to the FSDP constructor.
- mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
- precision setting passed to the FSDP constructor.
- keep_low_precision_grads (bool): Whether to keep gradients in low
- precision.
- use_orig_params (bool): If ``True``, then FSDP preserves the original
- parameter variables and returns them from ``named_parameters()``
- (e.g. to support different optimizer hyperparameters within one
- :class:`FlatParameter`). If ``False``, then FSDP reconstructs the
- parameter every iteration and returns the :class:`FlatParameter` s
- from ``named_parameters()``.
- """
- ##################
- # INITIALIZATION #
- ##################
- def __init__(
- self,
- params: Sequence[nn.Parameter],
- fully_sharded_module: nn.Module,
- device: torch.device,
- sharding_strategy: HandleShardingStrategy,
- offload_params: bool,
- mp_param_dtype: Optional[torch.dtype],
- mp_reduce_dtype: Optional[torch.dtype],
- keep_low_precision_grads: bool,
- process_group: dist.ProcessGroup,
- use_orig_params: bool,
- ):
- super().__init__()
- self.device = device
- self.process_group = process_group
- self.rank = process_group.rank()
- self.world_size = process_group.size()
- self._sharding_strategy = sharding_strategy
- self._offload_params = offload_params
- self._use_orig_params = use_orig_params
- self._keep_low_precision_grads = keep_low_precision_grads
- self._training_state = HandleTrainingState.IDLE
- self._debug_level = dist.get_debug_level()
- self._fully_sharded_module = fully_sharded_module
- self._init_flat_param(params, fully_sharded_module, use_orig_params)
- self._orig_param_dtype = self.flat_param.dtype
- self._use_unsharded_views(as_params=False)
- self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
- def _init_flat_param(
- self,
- params: Sequence[Optional[nn.Parameter]],
- module: nn.Module,
- use_orig_params: bool,
- ) -> None:
- """
- Initializes the flattened parameter ``self.flat_param`` by flattening
- the parameters in ``params`` into a single :class:`FlatParameter` and
- saves relevant metadata. Shared parameters are only included in the
- flattened parameter once.
- This checks that all comprising parameters have the same dtype and
- ``requires_grad`` and does not support nested construction of
- :class:`FlatParameter` s.
- Args:
- See the Args in the class docstring.
- """
- params_set = set(params)
- params_set.discard(None)
- if len(params_set) == 0:
- raise ValueError(
- "Cannot initialize a `FlatParameter` from an empty parameter list"
- )
- param_infos: List[ParamInfo] = []
- numels: List[int] = []
- shapes: List[torch.Size] = []
- fqns: List[str] = []
- shared_param_infos: List[SharedParamInfo] = []
- shared_param_memo: Dict[nn.Parameter, Tuple[nn.Module, str, str]] = {}
- params_to_flatten: List[Union[torch.Tensor, nn.Parameter]] = []
- shared_params: List[Union[torch.Tensor, nn.Parameter]] = []
- param_extensions: List[Any] = []
- dtype: Optional[torch.dtype] = None
- requires_grad: Optional[bool] = None
- for submodule_name, submodule in module.named_modules():
- for param_name, param in submodule.named_parameters(recurse=False):
- if param not in params_set:
- continue
- if param in shared_param_memo: # shared reference
- prim_module, prim_module_name, prim_param_name = shared_param_memo[
- param
- ]
- shared_params.append(param)
- shared_param_infos.append(
- SharedParamInfo(
- param_name,
- submodule,
- submodule_name,
- prim_param_name,
- prim_module,
- prim_module_name,
- )
- )
- else:
- if type(param) is FlatParameter:
- raise ValueError("`FlatParameter` does not support nesting")
- if dtype is not None and param.dtype != dtype:
- raise ValueError(
- "`FlatParameter` requires uniform dtype but got "
- f"{dtype} and {param.dtype}"
- )
- if dtype is None and not param.is_floating_point():
- raise ValueError("Integer parameters are unsupported")
- if (
- requires_grad is not None
- and param.requires_grad != requires_grad
- ):
- raise ValueError(
- "`FlatParameter` requires uniform `requires_grad`"
- )
- param, extension = _ext_pre_flatten_transform(param)
- param_extensions.append(extension)
- dtype = param.dtype
- requires_grad = param.requires_grad
- shared_param_memo[param] = (submodule, submodule_name, param_name)
- params_to_flatten.append(param)
- param_infos.append(ParamInfo(param_name, submodule, submodule_name))
- numels.append(param.numel())
- shapes.append(param.shape)
- fqn = (
- submodule_name + "." + param_name
- if submodule_name
- else param_name
- )
- fqns.append(fqn)
- assert requires_grad is not None, (
- "Passed-in `params` were not found in the module tree\n"
- f"params: {params}\nmodule: {module}"
- )
- self.flat_param = FlatParamHandle.flatten_params(
- params_to_flatten, requires_grad
- )
- # For `use_orig_params=True`, ensure that the logical parameters are
- # `nn.Parameter`s (and not plain `torch.Tensor`)
- def convert_to_params(
- tensors: List[Union[torch.Tensor, nn.Parameter]]
- ) -> List[nn.Parameter]:
- return [
- t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors
- ]
- self.flat_param._init_metadata(
- param_infos,
- numels,
- shapes,
- fqns,
- shared_param_infos,
- param_extensions,
- convert_to_params(params_to_flatten) if use_orig_params else None,
- convert_to_params(shared_params) if use_orig_params else None,
- )
- @staticmethod
- def flatten_params(
- params: Sequence[torch.Tensor],
- requires_grad: bool,
- ) -> FlatParameter:
- """
- Flattens the parameters in ``params`` into a single
- :class:`FlatParameter`. This should be the only way used to construct
- :class:`FlatParameter` s.
- We expose this factory method for checkpointing (e.g. sharded state
- dict). The flattened parameter's metadata should only be initialized
- once (see :meth:`_init_metadata`), but its tensor data may be reloaded.
- """
- with torch.no_grad():
- flat_params = [
- p.detach().reshape(-1) if isinstance(p, nn.Parameter) else p.reshape(-1)
- for p in params
- ]
- flat_param_data = torch.cat(flat_params, dim=0)
- flat_param = FlatParameter(flat_param_data, requires_grad=requires_grad)
- return flat_param
- def _init_param_reduce_dtypes(
- self,
- mp_param_dtype: Optional[torch.dtype],
- mp_reduce_dtype: Optional[torch.dtype],
- ) -> None:
- """
- Precondition: ``self.flat_param`` is set via :meth:`_init_flat_param`.
- This ensures that this handle's parameters have a single dtype.
- Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
- ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
- is ``None``, then we assume the original parameter dtype. One special
- case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
- is ``None``, in which case we assume the gradient reduction dtype
- matches the forward/backward parameter dtype.
- """
- # Save whether these dtypes were specified so that we permit the
- # parameter dtype to change up until the lazy initialization
- self._low_prec_param_dtype_specified = mp_param_dtype is not None
- self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
- if (
- self._low_prec_param_dtype_specified
- and not self._low_prec_reduce_dtype_specified
- ):
- # Special case: infer gradient reduction mixed precision
- self._fwd_bwd_param_dtype = mp_param_dtype
- self._reduce_dtype = self._fwd_bwd_param_dtype
- else:
- self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
- self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
- assert self._fwd_bwd_param_dtype is not None
- assert self._reduce_dtype is not None
- ###################################
- # SHARD INITIALIZATION & METADATA #
- ###################################
- @torch.no_grad()
- def shard(self):
- """
- Shards the handle's ``FlatParameter``. In terms of memory, this
- allocates new memory for the sharded flattened parameter and frees the
- unsharded flattened parameter's storage.
- Postcondition: ``self.flat_param`` is the sharded flattened parameter.
- Shard metadata attributes are set for all sharding strategies.
- ``process_group``, ``rank``, and ``world_size`` attributes are set if
- using a sharded strategy.
- """
- flat_param = self.flat_param
- if not self.uses_sharded_strategy:
- self._init_shard_metadata(0, 0, flat_param.numel() - 1)
- else:
- p_assert(
- flat_param.storage_offset() == 0,
- "The `FlatParameter` is not the sole occupant of its storage",
- )
- orig_storage = flat_param._typed_storage()
- sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
- flat_param, self.rank, self.world_size
- )
- flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
- start = sharded_flat_param.numel() * self.rank
- end = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
- self._init_shard_metadata(numel_padded, start, end)
- if orig_storage._size() > 0:
- orig_storage._resize_(0)
- if self._use_orig_params:
- self._use_sharded_views()
- def _init_shard_metadata(
- self,
- numel_padded: int,
- start: int,
- end: int,
- ) -> None:
- """
- Initializes shard-related metadata for this rank's shard of the
- flattened parameter: ``_sharded_size``, ``_shard_param_offsets``,
- ``_shard_indices``, and ``_shard_numel_padded``.
- Args:
- numel_padded (int): Numel padded for this rank's sharded flattened
- parameter.
- start (int): Start index in the sharded flattened parameter
- assigned to this rank.
- end (int): End index (inclusive) in the sharded flattened parameter
- assigned to this rank. If this exceeds the sharded flattened
- parameter's numel, then it is truncated.
- Precondition: ``self.flat_param`` 's data is the sharded flattened
- parameter.
- """
- self.flat_param._sharded_size = self.flat_param.size() # type: ignore[attr-defined]
- sharded_flat_param_numel = self.flat_param.numel() # includes `numel_padded`
- p_assert(start >= 0 and start <= end, f"start: {start} end: {end}")
- p_assert(
- numel_padded <= sharded_flat_param_numel,
- f"numel_padded: {numel_padded} "
- f"sharded_flat_param_numel: {sharded_flat_param_numel}",
- )
- (
- self.flat_param._shard_param_offsets, # type: ignore[attr-defined]
- self.flat_param._shard_indices, # type: ignore[attr-defined]
- ) = self._get_shard_metadata(start, end)
- self.flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
- def _get_shard_metadata(
- self,
- start: int,
- end: int,
- ) -> Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]:
- """
- Computes the shard metadata based on ``start`` and ``end``, which give
- the closed interval of the unsharded flattened parameter specifying the
- shard.
- Args:
- start (int): Start index (in units of numel) of this rank's shard
- of the flattened parameter.
- end (int): End index (in units of numel and inclusive) of this
- rank's shard of the flattened parameter.
- Return:
- Tuple[Tuple[Tuple[int, int], ...], Tuple[int, int]]: See
- ``_shard_param_offsets`` and ``_shard_indices`` in
- :class:`FlatParameter` 's docstring.
- """
- flat_param_offsets = self._get_flat_param_offsets()
- # Indices of the original parameters in this rank's sharded flattened
- # parameter
- shard_param_indices_range = [] # elements will be consecutive
- # [start, end] offsets giving this rank's part of the flattened
- # original module parameter (which will be [0, `p.numel()`-1] for any
- # parameter that is not sharded across ranks)
- shard_param_offsets = []
- for i, (param_start, param_end) in enumerate(flat_param_offsets):
- if start > param_end or end < param_start:
- continue
- if start <= param_start:
- intra_param_start = 0
- else:
- intra_param_start = start - param_start
- intra_param_end = min(param_end, end) - param_start
- shard_param_indices_range.append(i)
- shard_param_offsets.append(
- (intra_param_start, intra_param_end)
- ) # both inclusive
- if len(shard_param_indices_range) == 0:
- shard_param_indices = (0, 0)
- assert len(shard_param_offsets) == 0
- else:
- shard_param_indices = (
- shard_param_indices_range[0],
- shard_param_indices_range[-1],
- )
- assert (
- len(shard_param_offsets)
- == shard_param_indices[-1] - shard_param_indices[0] + 1
- )
- return tuple(shard_param_offsets), shard_param_indices
- @staticmethod
- def _get_unpadded_shard(
- tensor: Tensor,
- rank: int,
- world_size: int,
- ) -> Tuple[Tensor, int]:
- """
- Returns the shard of ``tensor`` without any padding for the given
- ``rank`` and ``world_size`` and the numel to pad for that shard.
- If ``tensor`` is already flattened or may be viewed in the flattened
- shape (which is true in the expected usage), then this method does not
- allocate any new tensor memory.
- """
- chunks = torch.flatten(tensor).chunk(world_size)
- if len(chunks) < (rank + 1):
- # This rank gets an empty chunk fully padded with zeros since there
- # are not enough chunks across ranks
- chunk = chunks[0].new_empty(0)
- else:
- chunk = chunks[rank]
- numel_to_pad = chunks[0].numel() - chunk.numel()
- assert (
- numel_to_pad >= 0
- ), "Chunk's size should be at most the first chunk's size"
- return chunk, numel_to_pad
- @staticmethod
- def _get_shard(
- tensor: Tensor,
- rank: int,
- world_size: int,
- ) -> Tuple[Tensor, int]:
- """
- Returns the shard of ``tensor`` with padding for the given ``rank`` and
- ``world_size`` and the numel padded for that shard.
- This method allocates new memory (via :meth:`clone`) since the
- unsharded ``tensor`` may be deallocated after this method returns.
- """
- chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
- tensor, rank, world_size
- )
- shard = chunk.clone()
- if numel_to_pad > 0:
- shard = F.pad(shard, [0, numel_to_pad])
- return shard, numel_to_pad
- @staticmethod
- def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
- """
- Returns the shape of ``tensor`` after sharding including padding. This
- requires ``tensor`` to have 1D shape and ensures that the returned
- shape is 1D.
- """
- assert len(tensor.shape) == 1, f"{tensor.shape}"
- unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
- tensor, rank, world_size
- )
- unpadded_sharded_size = unpadded_sharded_tensor.size()
- assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
- return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
- def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
- """Returns [start, end] offsets of each original parameter's flattened
- data in the unsharded flattened parameter (without padding)."""
- cumulative_sum = list(accumulate(self.flat_param._numels))
- starts = [0] + cumulative_sum[:-1]
- ends = [end - 1 for end in cumulative_sum] # inclusive
- param_offsets = list(zip(starts, ends))
- return param_offsets
- def shard_metadata(
- self,
- ) -> FlatParamShardMetadata:
- """Returns shard-related metadata specific to this rank's shard of the
- flattened parameter."""
- assert hasattr(self.flat_param, "_shard_indices") and hasattr(
- self.flat_param, "_shard_param_offsets"
- ), "Shard metadata has not been initialized"
- shard_param_start_index = self.flat_param._shard_indices[0] # type: ignore[attr-defined]
- shard_param_end_index = self.flat_param._shard_indices[1] # type: ignore[attr-defined]
- sl = (
- slice(shard_param_start_index, shard_param_end_index + 1)
- if shard_param_start_index <= shard_param_end_index
- else slice(0, 0)
- )
- return FlatParamShardMetadata(
- self.flat_param._fqns[sl],
- self.flat_param._shapes[sl],
- self.flat_param._numels[sl],
- self.flat_param._shard_param_offsets[:], # type: ignore[attr-defined]
- )
- @no_type_check
- @torch.no_grad()
- def init_flat_param_attributes(self) -> None:
- """
- This initializes some attributes on the handle's ``FlatParameter``.
- This should be called during lazy initialization since it requires the
- parameter to be on the compute device if not offloading to CPU and we
- want to give users the chance to move the parameter appropriately after
- the FSDP constructor.
- For each tensor attribute on the ``FlatParameter``, see the unshard and
- reshard methods in this class for the allocation and free pattern.
- """
- flat_param = self.flat_param
- if flat_param.dtype != self._orig_param_dtype:
- # Entering this branch means that the user changed the parameter
- # dtype after FSDP initialization, in which case we may need to
- # refresh some saved dtype attributes (dtypes specified as a part
- # of mixed precision take precedence).
- if not self._low_prec_param_dtype_specified:
- self._fwd_bwd_param_dtype = flat_param.dtype
- # For `reduce_dtype`, require `param_dtype` was not specified since
- # then we infer the `reduce_dtype` from the specified `param_dtype`
- if (
- not self._low_prec_reduce_dtype_specified
- and not self._low_prec_param_dtype_specified
- ):
- self._reduce_dtype = flat_param.dtype
- self._orig_param_dtype = flat_param.dtype
- cpu_device = torch.device("cpu")
- if self._offload_params:
- p_assert(
- flat_param.device == cpu_device,
- f"Expects the `FlatParameter` to be on CPU when parameter CPU "
- f"offloading is enabled, not {flat_param.device}",
- )
- else:
- self._check_on_compute_device(self.flat_param)
- flat_param._local_shard = flat_param.data
- if self._offload_params:
- # Pin the memory for faster H2D transfer
- flat_param._local_shard = flat_param._local_shard.pin_memory()
- # Pre-allocate the sharded gradient on CPU to enable non-blocking
- # D2H transfer during the backward pass
- flat_param._cpu_grad = torch.zeros_like(
- flat_param._local_shard, device=cpu_device
- ).pin_memory()
- if self._uses_param_mixed_precision:
- # For parameter mixed precision, we maintain a low precision
- # sharded tensor on the compute device to be all-gathered (for
- # sharded strategies) or directly used (for `NO_SHARD`) for
- # computation.
- flat_param._mp_shard = torch.zeros_like(
- flat_param._local_shard,
- device=self.device,
- dtype=self._fwd_bwd_param_dtype,
- )
- _free_storage(flat_param._mp_shard)
- if self.uses_sharded_strategy:
- # We maintain a padded unsharded tensor that serves as the
- # all-gather destination and owns the original parameter storages.
- unsharded_param_dtype = (
- self._fwd_bwd_param_dtype
- if self._uses_param_mixed_precision
- else flat_param.dtype
- ) # use low precision if parameter mixed precision is enabled
- padded_unsharded_numel = flat_param.numel() * self.world_size
- flat_param._full_param_padded = torch.zeros(
- padded_unsharded_numel,
- device=self.device,
- dtype=unsharded_param_dtype,
- )
- flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
- _free_storage(flat_param._full_param_padded)
- if self._uses_param_mixed_precision:
- # For parameter mixed precision, we maintain a full precision
- # padded unsharded tensor for when we force full precision.
- flat_param._full_prec_full_param_padded = torch.zeros(
- padded_unsharded_numel,
- device=self.device,
- dtype=flat_param.dtype, # full precision
- )
- _free_storage(flat_param._full_prec_full_param_padded)
- ###################
- # UNSHARD/RESHARD #
- ###################
- def pre_unshard(self) -> bool:
- """
- Returns: ``False`` if this is a no-op and ``True`` otherwise.
- Postcondition: ``self.flat_param`` 's data is on the device for
- communication and is what should be all-gathered. This means that it
- matches the dtype of the expected unsharded parameter.
- """
- ret = False
- if self._use_orig_params:
- ret = self._writeback_orig_params()
- if (
- self.uses_sharded_strategy
- and not self._offload_params
- and not self.needs_unshard()
- ):
- pass # no-op
- elif self._uses_param_mixed_precision and not self._force_full_precision:
- self._use_low_precision_shard()
- ret = True
- elif self._offload_params and self.flat_param.device != self.device:
- # NOTE: This creates a new tensor distinct from any attributes.
- self.flat_param_to(self.device, non_blocking=True)
- ret = True
- self._check_on_compute_device(self.flat_param)
- return ret
- def _use_low_precision_shard(self):
- """
- Allocates the low precision shard directly on the compute device and
- switches to using the low precision sharded flattened parameter.
- """
- self._check_low_precision_shard()
- flat_param = self.flat_param
- _alloc_storage(
- flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
- )
- # `copy_()` implicitly casts to the low precision
- flat_param._mp_shard.copy_( # type: ignore[attr-defined]
- flat_param._local_shard.to( # type: ignore[attr-defined]
- self.device, non_blocking=True
- )
- )
- # Invariant: `_mp_shard` is always on the compute device.
- flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
- def unshard(self):
- """
- Runs the unshard logic. This includes all-gathering the flattened
- parameter and switching to using the unsharded flattened parameter. If
- the handle does not need unsharding, then this only switches to using
- the unsharded flattened parameter. For ``NO_SHARD``, this is a no-op.
- If FSDP is in :meth:`summon_full_params` and the handle uses parameter
- mixed precision, then the parameter is forced to full precision.
- """
- if not self.needs_unshard():
- # Even when not needing an unshard, we should switch to using
- # the unsharded flattened parameter
- unsharded_flat_param = (
- self._get_padded_unsharded_flat_param()
- if self.uses_sharded_strategy
- else self.flat_param
- )
- self._use_unsharded_flat_param(unsharded_flat_param)
- return
- unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
- padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
- self._use_unsharded_flat_param(padded_unsharded_flat_param)
- def needs_unshard(self) -> bool:
- """Returns if the handle's flattened parameter needs to be unsharded."""
- if not self.uses_sharded_strategy:
- return False
- unsharded_flat_param = self._get_padded_unsharded_flat_param()
- already_unsharded = (
- unsharded_flat_param._typed_storage()._size()
- == unsharded_flat_param.numel()
- )
- return not already_unsharded
- def _alloc_padded_unsharded_flat_param(self):
- """
- Allocates the *padded* unsharded flattened parameter. The unpadded
- unsharded flattened parameter is always a view into the padded one.
- This padded parameter is saved to a different attribute on the
- ``FlatParameter`` depending on if we force full precision.
- """
- self._check_sharded_strategy()
- flat_param = self.flat_param
- unsharded_flat_param = self._get_padded_unsharded_flat_param()
- self._check_storage_freed(unsharded_flat_param)
- _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
- return unsharded_flat_param
- def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
- """
- Returns a reference to the padded unsharded flattened parameter
- depending on the calling context. This should only be called if using a
- sharded strategy.
- """
- self._check_sharded_strategy()
- flat_param = self.flat_param
- if self._force_full_precision:
- # When parameter mixed precision is enabled, we use a different
- # tensor as the all-gather destination to preserve the invariant
- # that `_full_param_padded` is in the low precision
- unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
- p_assert(
- unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
- f"Expects full precision but got {self._fwd_bwd_param_dtype}",
- )
- else:
- unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
- return unsharded_flat_param
- def _all_gather_flat_param(
- self,
- padded_unsharded_flat_param: Tensor,
- ) -> Tensor:
- """
- All-gathers the handle's flattened parameter to the destination
- ``padded_unsharded_flat_param``, and switches to using the all-gathered
- tensor.
- """
- p_assert(
- hasattr(self, "process_group") and hasattr(self, "world_size"),
- "Expects a process group and world size to have been set via `shard()`",
- )
- sharded_flat_param = self.flat_param.data
- expected_numel = sharded_flat_param.numel() * self.world_size
- p_assert(
- padded_unsharded_flat_param.numel() == expected_numel,
- f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
- )
- dist.all_gather_into_tensor(
- padded_unsharded_flat_param,
- sharded_flat_param,
- self.process_group,
- )
- return padded_unsharded_flat_param
- def _use_unsharded_flat_param(
- self,
- padded_unsharded_flat_param: torch.Tensor,
- ) -> None:
- """
- Switches to using the *unpadded* unsharded flattened parameter, which
- is a view into the *padded* unsharded flattened parameter.
- """
- unsharded_size = self.flat_param._unpadded_unsharded_size
- self.flat_param.data = padded_unsharded_flat_param[
- : unsharded_size.numel()
- ].view(
- unsharded_size
- ) # this `.view()` is not autograd visible
- in_forward = self._training_state == HandleTrainingState.FORWARD
- in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
- if self._use_orig_params:
- # We use `Tensor` views in the forward so that they are tracked by
- # autograd. We use them in the pre-backward as well to support
- # reentrant activation checkpointing, which needs the views to be
- # tracked by autograd in the backward pass's recomputed forward.
- self._use_unsharded_views(
- as_params=(not in_forward and not in_pre_backward)
- )
- elif in_forward:
- self._use_unsharded_views(as_params=False)
- def post_unshard(self):
- """
- Runs the post-unshard logic. This includes freeing the low precision
- shard if needed.
- """
- if self._uses_param_mixed_precision and self.uses_sharded_strategy:
- self._free_low_precision_sharded_param()
- self._check_on_compute_device(self.flat_param)
- def _free_low_precision_sharded_param(self):
- """Frees the low precision sharded flattened parameter."""
- self._check_low_precision_shard()
- # `_mp_shard` is allocated in the pre-unshard stream, consumed in the
- # unshard stream for sharded strategies, and consumed in both the
- # unshard and default streams for `NO_SHARD`. For sharded strategies,
- # the current stream here is the unshard stream, and for `NO_SHARD`,
- # it is the default stream. For `NO_SHARD`, only recording for the
- # default stream suffices since the default stream waits for the
- # unshard stream.
- _no_dispatch_record_stream(
- self.flat_param._mp_shard, torch.cuda.current_stream() # type: ignore[attr-defined]
- )
- _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
- @torch.no_grad()
- def unshard_grad(self):
- """
- Unshards the handle's ``FlatParameter`` 's gradient. If all ranks have
- ``None`` gradient, then all original parameters will as well. This
- method performs an all-reduce and an all-gather. The additional
- all-reduce is tolerable since this method is not meant to be used on
- the computation critical path.
- Postcondition: ``_saved_grad_shard`` is defined and contains the value
- to set ``flat_param.grad`` after gradients are resharded.
- """
- if not self.uses_sharded_strategy:
- self._use_unsharded_grad_views()
- return
- flat_param = self.flat_param
- self._check_unsharded(flat_param)
- # Check if all ranks have a `None` gradient
- num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
- num_grad_none[0] = flat_param.grad is None
- dist.all_reduce(num_grad_none, group=self.process_group)
- if num_grad_none[0] == self.world_size:
- flat_param._saved_grad_shard = None # type: ignore[attr-defined]
- self._use_unsharded_grad_views()
- return
- padded_unsharded_grad = torch.empty(
- flat_param._padded_unsharded_size, # type: ignore[attr-defined]
- device=self.device,
- )
- if flat_param.grad is None:
- # In the case that only some ranks have `None` gradient, we use
- # zeros to approximate as a best effort attempt
- if self._debug_level == dist.DebugLevel.DETAIL:
- warnings.warn(
- f"[Rank {self.rank}] Only some but not all ranks have a "
- "`None` `FlatParameter` gradient, so FSDP is using zeros to "
- "approximate those ranks' sharded gradients being `None`"
- )
- flat_param._saved_grad_shard = None # type: ignore[attr-defined]
- sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
- else:
- self._check_sharded(flat_param.grad)
- flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
- sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
- dist.all_gather_into_tensor(
- padded_unsharded_grad, sharded_grad, self.process_group
- )
- unsharded_size = self.flat_param._unpadded_unsharded_size
- flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
- unsharded_size
- )
- self._use_unsharded_grad_views()
- def reshard_grad(self):
- if self._use_orig_params:
- self._use_sharded_grad_views()
- if not self.uses_sharded_strategy:
- return
- self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
- delattr(self.flat_param, "_saved_grad_shard")
- def prepare_gradient_for_backward(self):
- """
- Prepares the gradient for the backward computation by saving and
- clearing any existing sharded gradient in ``.grad`` to enable computing
- a new unsharded gradient.
- """
- p_assert(
- self._training_state
- in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
- "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
- )
- flat_param = self.flat_param
- if flat_param.grad is not None and (
- flat_param.grad.size() != flat_param._unpadded_unsharded_size
- or flat_param.grad.device != flat_param.device # grad on CPU
- ):
- self._check_on_compute_device(self.flat_param)
- grad_offloaded = flat_param.grad.device != self.device
- p_assert(
- not grad_offloaded or self._offload_params,
- f"Expects the sharded gradient to be on {self.device} "
- f"but got {flat_param.grad.device}",
- )
- prev_iter_synced_gradients = (
- flat_param.grad.size()
- == flat_param._local_shard.size() # type: ignore[attr-defined]
- )
- if prev_iter_synced_gradients:
- # TODO (awgu): Gradient accumulation outside `no_sync()`
- # does not work with CPU offloading. The issue should be
- # that, in the post-backward hook, we cannot do an addition
- # between a CPU tensor (the existing sharded gradient) and
- # a GPU tensor (the new sharded gradient).
- if not grad_offloaded:
- flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
- sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
- else:
- p_assert(
- hasattr(flat_param, "_cpu_grad"),
- "`_cpu_grad` should be defined if the gradient is on CPU",
- )
- sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
- # If user specified to keep the gradient in low precision, then
- # the gradient may still be of the low precision dtype if the
- # user did not set the gradient to `None` after the previous
- # backward, in which case FSDP should cast back to the full
- # precision dtype so that FSDP can accumulate in that dtype in
- # the post-backward hook and assign to `.grad` in that dtype in
- # the post-backward callback.
- local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
- if (
- self._keep_low_precision_grads
- and sharded_grad.dtype != local_shard_dtype
- ):
- sharded_grad.data = sharded_grad.to(local_shard_dtype)
- else:
- padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
- p_assert(
- flat_param.grad.size() == padded_unsharded_size,
- "Expects `.grad` to be the unsharded gradient in "
- f"`no_sync()` with size {padded_unsharded_size} "
- f"but got size {flat_param.grad.size()}",
- )
- flat_param.grad = None
- def prepare_gradient_for_optim(self):
- """
- Prepares the gradient for optimizer computation by moving the sharded
- gradient to the ``.grad`` attribute.
- """
- def cast_grad_to_param_dtype_if_needed(flat_param):
- if self._keep_low_precision_grads:
- assert flat_param.grad is not None # mypy
- if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
- flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
- if self._use_orig_params:
- self._use_sharded_grad_views()
- flat_param = self.flat_param
- # TODO (awgu): We should replace these conditional checks to encode
- # the logical intention more directly.
- if hasattr(flat_param, "_cpu_grad"):
- # NOTE: This branch includes `NO_SHARD`.
- self._check_sharded(flat_param)
- self._check_on_cpu(flat_param)
- flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
- cast_grad_to_param_dtype_if_needed(flat_param)
- elif hasattr(flat_param, "_saved_grad_shard"):
- self._check_sharded(flat_param)
- self._check_on_compute_device(flat_param)
- self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
- # If no sharded gradient was computed this iteration, then there is
- # no need to forward `_saved_grad_shard` to `grad`
- if flat_param._post_backward_called: # type: ignore[attr-defined]
- flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
- cast_grad_to_param_dtype_if_needed(flat_param)
- else:
- p_assert(
- not self.uses_sharded_strategy
- or not flat_param._post_backward_called, # type: ignore[attr-defined]
- "All sharded parameters that received a gradient in the "
- "post-backward should use `_saved_grad_shard`",
- )
- # Delete `_saved_grad_shard` since its existence indicates a previous
- # gradient to accumulate with in the post-backward hook
- if hasattr(flat_param, "_saved_grad_shard"):
- delattr(flat_param, "_saved_grad_shard")
- @contextlib.contextmanager
- def to_cpu(self):
- """
- Moves the unpadded unsharded flattened parameter to CPU while in the
- context and moves it back to the previous device upon exit. For now,
- this assumes the ``FlatParameter`` is the unpadded unsharded flattened
- parameter since (1) there is no reason to include the padding in the
- copy and (2) there is no use case for the sharded flattened parameter.
- Precondition: ``self.flat_param`` 's data is the unpadded unsharded
- flattened parameter on the compute device, and the handle uses a
- sharded strategy.
- Postcondition: Same as the precondition.
- """
- self._check_sharded_strategy()
- p_assert(
- self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
- f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
- )
- self._check_on_compute_device(self.flat_param)
- # Check that the unpadded unsharded flattened parameter is a view into
- # the padded unsharded flattened parameter as expected
- # NOTE: This check is not strictly needed for correctness but is a
- # useful sanity check since the tensor should only be used internally.
- unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr()
- padded_storage_ptr = (
- self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
- )
- p_assert(
- unpadded_storage_ptr == padded_storage_ptr,
- "Expects the unpadded parameter to be a view into the padded parameter",
- )
- self.flat_param_to(torch.device("cpu"))
- self._free_unsharded_flat_param()
- try:
- yield
- finally:
- p_assert(
- self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
- f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
- )
- padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
- # Copy from CPU to the compute device
- padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
- self.flat_param
- )
- self._use_unsharded_flat_param(padded_unsharded_flat_param)
- def reshard(self, free_unsharded_flat_param: bool):
- """
- Runs the reshard logic. This includes freeing the unsharded flattened
- parameter if ``free_unsharded_flat_param`` and switching to using the
- sharded flattened parameter.
- """
- # Switch to the sharded `FlatParameter` before freeing to prevent
- # "use-after-free"-type bugs with external profiling tools, where for
- # `use_orig_params=True`, the `param` does not point to valid memory
- # when setting `param.data = ...` in `_use_sharded_views()`.
- self._use_sharded_flat_param()
- if free_unsharded_flat_param:
- self._free_unsharded_flat_param()
- def post_reshard(self):
- """
- Runs the post-reshard logic. This includes freeing any memory that
- can now be freed given that the ``FlatParameter`` points to the full
- precision sharded flattened parameter.
- Precondition: ``self.flat_param`` 's data points to the full precision
- sharded flattened parameter.
- """
- # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since
- # it is also the low precision *unsharded* flattened parameter. Hence,
- # we delay the free until the reshard.
- if (
- self._uses_param_mixed_precision
- and not self.uses_sharded_strategy
- and not self._force_full_precision # did not use the low precision shard
- ):
- self._free_low_precision_sharded_param()
- def _free_unsharded_flat_param(self):
- """
- Frees the padded unsharded flattened parameter. The tensor to free
- depends on the calling context since the unshard may have forced full
- precision, in which case a different tensor is used.
- """
- self._check_sharded_strategy()
- unsharded_flat_param = self._get_padded_unsharded_flat_param()
- self._check_storage_allocated(unsharded_flat_param)
- self._check_on_compute_device(unsharded_flat_param)
- # Do not free the memory until all ops in the current stream finish
- _no_dispatch_record_stream(unsharded_flat_param, torch.cuda.current_stream())
- _free_storage(unsharded_flat_param)
- def _use_sharded_flat_param(self) -> None:
- """Switches to using the sharded flattened parameter."""
- flat_param = self.flat_param
- if self._offload_params:
- device = flat_param._local_shard.device # type: ignore[attr-defined]
- p_assert(
- device == torch.device("cpu"),
- f"Expects the local shard to be on CPU but got {device}",
- )
- flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
- if self._use_orig_params:
- self._use_sharded_views()
- # For the post-forward reshard, we may try to use sharded gradient
- # views (or unsharded gradient views if a gradient was accumulated
- # in `no_sync()`), but for the post-backward reshard, we delay the
- # call to after the reduce-scatter.
- if self._training_state == HandleTrainingState.FORWARD:
- # TODO: Change `_unpadded_unsharded_size` if we change the
- # gradient to be computed directly with padding.
- accumulated_grad_in_no_sync = (
- flat_param.grad is not None
- and self.uses_sharded_strategy
- and flat_param.grad.shape == flat_param._unpadded_unsharded_size
- )
- if accumulated_grad_in_no_sync:
- self._use_unsharded_grad_views()
- else:
- self._use_sharded_grad_views()
- #########
- # VIEWS #
- #########
- @staticmethod
- def _get_unflat_views(
- flat_param: FlatParameter,
- tensor: Optional[torch.Tensor] = None,
- ) -> Iterator[Tensor]:
- """
- Returns unflattened ``Tensor`` views into ``tensor`` if it is not
- ``None`` or ``flat_param`` otherwise, where the unflattening is based
- on ``flat_param`` 's metadata.
- In other words, to get views into the unsharded flattened parameter,
- pass ``tensor`` as ``None``, but to get views into tensor optimizer
- state, pass ``tensor`` as the optimizer state tensor.
- """
- if tensor is None:
- tensor = flat_param
- p_assert(
- tensor.numel() == flat_param._unpadded_unsharded_size.numel(),
- f"Expects {flat_param._unpadded_unsharded_size.numel()} numel but got "
- f"{tensor.numel()} numel",
- )
- views = (
- _ext_post_unflatten_transform(subtensor.view(shape), param_extension)
- for (subtensor, shape, param_extension) in zip(
- torch.split(tensor, flat_param._numels, dim=0), # type: ignore[arg-type]
- flat_param._shapes,
- flat_param._param_extensions,
- )
- )
- return views
- def _use_unsharded_views(self, as_params: bool) -> None:
- """
- Unflattens the unsharded flattened parameter by setting the original
- module parameter variables to be views into it.
- Args:
- as_params (bool): If ``True``, then registers the original
- parameters as ``nn.Parameter`` s; if ``False``, then registers
- the original parameters only as ``Tensor`` s. ``False`` should
- be used during forward/backward computation and when hiding the
- original parameters from :meth:`nn.Module.named_parameters`.
- """
- self._check_unsharded(self.flat_param)
- views = self._get_unflat_views(self.flat_param)
- for i, (view, (param_name, module, _)) in enumerate(
- zip(views, self.flat_param._param_infos)
- ):
- if hasattr(module, param_name):
- delattr(module, param_name)
- if self._use_orig_params and as_params:
- if type(view) is DTensor:
- # A `DTensor` `view` is not compatible with assigning
- # `param.data = view`, so we cannot preserve the parameter
- # variable.
- setattr(module, param_name, nn.Parameter(view))
- continue
- param = self.flat_param._params[i] # type: ignore[index]
- setattr(module, param_name, param)
- param.data = view
- elif as_params:
- module.register_parameter(param_name, nn.Parameter(view))
- else: # `as_params=False`
- param_var: Tensor = view
- if self._use_orig_params:
- if self._training_state == HandleTrainingState.FORWARD:
- assert self.flat_param._tensors is not None
- # Save the `Tensor` for the pre-backward
- self.flat_param._tensors[i] = view # save for pre-backward
- elif self._training_state == HandleTrainingState.BACKWARD_PRE:
- # Use the saved `Tensor` variable from the forward to
- # preserve the autograd graph so that the post-backward
- # hook fires (e.g. for reentrant AC)
- assert self.flat_param._tensors is not None # mypy
- tensor = self.flat_param._tensors[i]
- p_assert(
- tensor is not None,
- "Expects `Tensor` to have been saved in forward",
- )
- tensor.data = view # type: ignore[union-attr]
- assert tensor is not None # mypy
- param_var = tensor
- setattr(module, param_name, param_var)
- if (
- self._use_orig_params
- and self._training_state == HandleTrainingState.FORWARD
- ):
- module._parameters[param_name] = param_var # type: ignore[assignment]
- for i, (
- param_name,
- module,
- _,
- prim_param_name,
- prim_module,
- prim_module_name,
- ) in enumerate(self.flat_param._shared_param_infos):
- if hasattr(module, param_name):
- delattr(module, param_name)
- p_assert(
- hasattr(prim_module, prim_param_name),
- f"Module {prim_module_name} is missing parameter {prim_param_name}",
- )
- prim_param: Union[Tensor, nn.Parameter] = getattr(
- prim_module, prim_param_name
- )
- p_assert(
- not as_params or isinstance(prim_param, nn.Parameter),
- f"as_params={as_params} type(prim_param)={type(prim_param)}",
- )
- if self._use_orig_params and as_params:
- shared_param = self.flat_param._shared_params[i] # type: ignore[index]
- setattr(module, param_name, shared_param)
- shared_param.data = prim_param
- elif as_params:
- assert isinstance(prim_param, nn.Parameter)
- module.register_parameter(param_name, prim_param)
- else:
- setattr(module, param_name, prim_param)
- if (
- self._use_orig_params
- and self._training_state == HandleTrainingState.FORWARD
- ):
- module._parameters[param_name] = prim_param # type: ignore[assignment]
- def _use_unsharded_grad_views(self) -> None:
- """
- Unflattens the unsharded flattened parameter's gradient by setting the
- original module parameter variables' gradients to be views into it.
- """
- # Expects the gradient to be in `flat_param.grad`
- if self.flat_param.grad is None:
- assert self.flat_param._params is not None # mypy
- assert self.flat_param._shared_params is not None # mypy
- for param in chain(
- self.flat_param._params, # type: ignore[attr-defined]
- self.flat_param._shared_params, # type: ignore[attr-defined]
- ):
- param.grad = None
- return
- self._check_unsharded(self.flat_param.grad)
- views = self._get_unflat_views(self.flat_param, self.flat_param.grad)
- for i, (view, (param_name, module, _)) in enumerate(
- zip(views, self.flat_param._param_infos)
- ):
- p_assert(
- hasattr(module, param_name),
- f"{self.flat_param._fqns[i]} is missing",
- )
- param = getattr(module, param_name)
- if param.shape != view.shape or param.dtype != view.dtype:
- # NOTE: This is a hack using `.data` to side step the
- # check that parameter/gradient sizes and dtypes match. Here,
- # `param` can have the sharded size, and `grad` can have the
- # unsharded size. Orthgonally, `param` can have the full
- # precision dtype from `reshard()`, and `grad` can have the
- # parameter low precision dtype. Both of these mismatches
- # happen when running in `no_sync()`.
- if param.grad is None:
- param.grad = torch.empty_like(param)
- param.grad.data = view
- else:
- param.grad = view
- for i, (
- param_name,
- module,
- module_name,
- prim_param_name,
- prim_module,
- _,
- ) in enumerate(self.flat_param._shared_param_infos):
- p_assert(
- hasattr(module, param_name),
- f"{module_name + '.' + param_name if module_name else param_name} is missing",
- ) # did not save FQN info in `_shared_param_infos`
- param = getattr(module, param_name)
- prim_param = getattr(prim_module, prim_param_name)
- if (
- param.shape != prim_param.grad.shape
- or param.dtype != prim_param.grad.dtype
- ):
- # NOTE: This is the same hack to use `.data` to side step the
- # size check.
- if param.grad is None:
- param.grad = torch.empty_like(param)
- param.grad.data = prim_param.grad
- else:
- param.grad = prim_param.grad
- @contextlib.contextmanager
- def unflatten_as_params(self) -> Generator:
- """
- Assumes the flattened parameter is unsharded. When in the context,
- unflattens the original parameters as ``nn.Parameter`` views into the
- flattened parameter, and after the context, restores the original
- parameters as ``Tensor`` views into the flattened parameter.
- """
- self._use_unsharded_views(as_params=True)
- try:
- yield
- finally:
- self._use_unsharded_views(as_params=False)
- @torch.no_grad()
- def _use_sharded_views(self) -> None:
- """
- Sets the original module parameter variables' data to be flattened
- views into the sharded flattened parameter.
- The views are kept as flattened to simplify the case where a parameter
- is sharded across ranks. Parameters whose data is not present in the
- sharded flattened parameter have their data set to a size-0 empty
- tensor. We do not delete them to ensure to preserve expected behaviors
- like model printability. Parameters whose data is present must preserve
- their variables to be passable to an optimizer.
- """
- if not self.uses_sharded_strategy:
- # For `NO_SHARD`, use the *unflattened* unsharded views since we
- # have the unsharded parameter
- self._use_unsharded_views(as_params=True)
- return
- self._check_sharded(self.flat_param)
- start, end = self.flat_param._shard_indices # type: ignore[attr-defined]
- offset = 0
- assert self.flat_param._params is not None
- for i, (param, (param_name, module, _)) in enumerate(
- zip(self.flat_param._params, self.flat_param._param_infos)
- ):
- setattr(module, param_name, param)
- in_sharded_flat_param = (
- i >= start
- and i <= end
- and self.flat_param._shard_param_offsets # type: ignore[attr-defined]
- )
- if in_sharded_flat_param:
- param_start, param_end = self.flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
- numel_in_shard = param_end - param_start + 1
- param.data = self.flat_param[offset : offset + numel_in_shard]
- offset += numel_in_shard
- else:
- # Allow the original data to be freed via garbage collection
- param.data = torch.empty(
- 0,
- dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
- device=self.flat_param.device,
- requires_grad=False,
- )
- assert self.flat_param._shared_params is not None
- for i, (
- param,
- (param_name, module, _, prim_param_name, prim_module, _),
- ) in enumerate(
- zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
- ):
- setattr(module, param_name, param)
- prim_param = getattr(prim_module, prim_param_name)
- param.data = prim_param # could be both empty and non-empty
- if self._training_state == HandleTrainingState.BACKWARD_POST:
- assert self.flat_param._tensors is not None # mypy
- # Clear the saved `Tensor`s since they are unneeded now
- for i in range(len(self.flat_param._tensors)):
- self.flat_param._tensors[i] = None # type: ignore[index]
- @torch.no_grad()
- def _use_sharded_grad_views(self) -> None:
- """
- Sets the original module parameter variables' gradients to be flattened
- views into the sharded flattened parameter's gradient. This is a no-op
- if there is no gradient.
- Parameters whose data is not present in the sharded flattened parameter
- and parameters with ``requires_grad=False`` have their gradients set to
- ``None``. Since the gradient variables do not need to be preserved,
- this method does not manipulate existing ``Tensor`` data directly and
- creates new ``Tensor`` variables instead.
- """
- flat_param = self.flat_param
- self._check_sharded(flat_param)
- grad = self.sharded_grad
- if grad is None:
- assert flat_param._params is not None # mypy
- assert flat_param._shared_params is not None # mypy
- for param in chain(flat_param._params, flat_param._shared_params): # type: ignore[attr-defined]
- param.grad = None
- return
- self._check_sharded(grad)
- start, end = flat_param._shard_indices # type: ignore[attr-defined]
- offset = 0
- assert flat_param._params is not None
- for i, param in enumerate(flat_param._params):
- in_sharded_flat_param = (
- i >= start
- and i <= end
- and flat_param._shard_param_offsets # type: ignore[attr-defined]
- )
- if in_sharded_flat_param:
- param_start, param_end = flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
- numel_in_shard = param_end - param_start + 1
- assert flat_param._is_grad_none is not None # mypy
- if param.requires_grad and not flat_param._is_grad_none[i]:
- if self._keep_low_precision_grads or param.dtype != grad.dtype:
- # NOTE: This is a hack using `.data` to side step the
- # check that parameter/gradient dtypes match. Here,
- # `param` has full precision; `grad` has low precision.
- if param.grad is None:
- # `.grad` must have the same shape as `param`
- param.grad = torch.empty_like(param)
- param.grad.data = grad[
- offset : offset + numel_in_shard
- ].reshape(param.shape)
- else:
- param.grad = grad[offset : offset + numel_in_shard].reshape(
- param.shape
- )
- else:
- param.grad = None
- offset += numel_in_shard
- else:
- param.grad = None
- assert flat_param._shared_params is not None
- for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
- zip(flat_param._shared_params, flat_param._shared_param_infos)
- ):
- in_sharded_flat_param = hasattr(prim_module, prim_param_name)
- if in_sharded_flat_param and param.requires_grad:
- prim_param = getattr(prim_module, prim_param_name)
- param.grad = prim_param.grad # share the same reference
- else:
- param.grad = None
- @torch.no_grad()
- def _writeback_orig_params(self) -> bool:
- """
- Iterates over the original parameters and writes back any parameters
- that changed storages (due to a non-inplace operator) to the handle's
- ``FlatParameter``. This method preserves the ``FlatParameter` 's
- device even if an original parameter's device changes.
- Raises:
- RuntimeError: If an original parameter or gradient changes storages
- but no longer has the expected flattened shape.
- Returns: ``True`` if some writeback happened, and ``False`` otherwise.
- """
- if self.uses_sharded_strategy and not self.is_sharded(self.flat_param):
- # For `NO_SHARD`, we may still need to writeback
- return False
- flat_param = self.flat_param
- start, end = flat_param._shard_indices # type: ignore[attr-defined]
- offset = 0
- assert flat_param._params is not None
- wroteback = False
- for i, (param, (param_name, module, _)) in enumerate(
- zip(flat_param._params, flat_param._param_infos)
- ):
- if not hasattr(module, param_name):
- # Do not writeback if original parameters are deregistered
- # (e.g. during model checkpointing)
- continue
- in_sharded_flat_param = (
- i >= start
- and i <= end
- and self.flat_param._shard_param_offsets # type: ignore[attr-defined]
- )
- if not in_sharded_flat_param:
- continue
- param_start, param_end = flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
- numel_in_shard = param_end - param_start + 1
- # Check for parameter writeback
- param_changed = getattr(module, param_name) is not param
- needs_param_writeback = (
- param_changed # changed parameter variable itself
- or not _same_storage(param, flat_param) # changed `.data`
- )
- if param_changed:
- # NOTE: The gradient is not preserved after a parameter change.
- param = getattr(module, param_name)
- flat_param._params[i] = param
- if needs_param_writeback:
- expected_shape = torch.Size([numel_in_shard])
- self._writeback_tensor(
- param, flat_param, i, expected_shape, offset, True
- )
- wroteback = True
- # Check for gradient writeback
- # NOTE: Since this method is called in the pre-unshard, which is
- # only called during computation in the pre-forward or
- # pre-backward, the sharded gradient should be guaranteed to be in
- # `.grad`, not in `._saved_grad_shard`.
- if param.grad is None and flat_param.grad is not None:
- expected_shape = torch.Size([numel_in_shard])
- self._writeback_tensor(
- None, flat_param.grad, i, expected_shape, offset, False
- )
- elif param.grad is not None:
- # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
- # memory and owns the gradient storage, so it will never
- # require gradient writeback.
- flat_param_grad = (
- flat_param.grad
- if self.uses_sharded_strategy or not self._offload_params
- else flat_param._cpu_grad # type: ignore[attr-defined]
- )
- needs_grad_writeback = flat_param_grad is None or not _same_storage(
- param.grad, flat_param_grad
- )
- if needs_grad_writeback:
- if flat_param_grad is None:
- flat_param_grad = torch.zeros_like(flat_param)
- expected_shape = torch.Size([numel_in_shard])
- self._writeback_tensor(
- param.grad, flat_param_grad, i, expected_shape, offset, False
- )
- flat_param.grad = flat_param_grad
- offset += numel_in_shard
- # TODO (awgu): Handle shared parameters. We need to re-generate the
- # shared parameter data structures in case sharedness changed.
- for i, (
- param_name,
- module,
- _,
- prim_param_name,
- prim_module,
- _,
- ) in enumerate(flat_param._shared_param_infos):
- if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
- raise NotImplementedError(
- "Changing shared parameters is not supported yet"
- )
- return wroteback
- def _writeback_tensor(
- self,
- src_tensor: Optional[Tensor],
- dst_tensor: Tensor,
- tensor_index: int,
- expected_shape: torch.Size,
- offset: int,
- is_param: bool, # else gradient
- ) -> None:
- """
- Writes back ``src_tensor`` to ``dst_tensor`` at offset ``offset``,
- where ``src_tensor`` should have shape ``expected_shape``. ``is_param``
- indicates if the tensor is the parameter (if ``True``) or gradient (if
- ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
- instead of copying. ``tensor_index`` gives the index of ``src_tensor``
- in the metadata structures.
- Raises:
- RuntimeError: If the ``src_tensor`` does not have the expected
- shape.
- """
- p_assert(
- len(expected_shape) == 1,
- f"Expects a 1D expected shape but got {expected_shape}",
- )
- if self._debug_level == dist.DebugLevel.DETAIL:
- rank = self.rank if hasattr(self, "rank") else dist.get_rank()
- src_shape = src_tensor.shape if src_tensor is not None else None
- src_device = src_tensor.device if src_tensor is not None else None
- warnings.warn(
- f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
- f"writeback in {self._training_state}\n"
- f"expected shape={expected_shape} shape={src_shape} "
- f"expected device={dst_tensor.device} device={src_device}"
- )
- if src_tensor is not None and src_tensor.shape != expected_shape:
- # NOTE: Gradient shape mismatch is not possible in practice since
- # the gradient shape is enforced to match that of the parameter and
- # we already check for parameter shape mismatch.
- raise RuntimeError(
- f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
- f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
- )
- if src_tensor is not None:
- dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
- else:
- dst_tensor[offset : offset + expected_shape.numel()].zero_()
- assert self.flat_param._is_grad_none is not None
- self.flat_param._is_grad_none[tensor_index] = True
- def _clear_grads_if_needed(self):
- """
- When ``use_orig_params=True``, sets the underlying ``flat_param.grad``
- to ``None`` if *all* of the original parameters' ``.grad`` are
- ``None``. This is targeting ``optim.zero_grad(set_to_none=True)``, in
- which case we want to free the gradients as soon after the
- ``zero_grad()`` call as possible.
- """
- if not self._use_orig_params:
- return
- flat_param = self.flat_param
- assert flat_param._params is not None
- if all(param.grad is None for param in flat_param._params):
- flat_param.grad = None
- def _deregister_orig_params(self):
- for (param_name, module, _) in self.flat_param._param_infos:
- if hasattr(module, param_name):
- delattr(module, param_name)
- for (param_name, module, _, _, _, _) in self.flat_param._shared_param_infos:
- if hasattr(module, param_name):
- delattr(module, param_name)
- ###########
- # HELPERS #
- ###########
- def flat_param_to(self, *args, **kwargs):
- """Wraps an in-place call to ``.to()`` for ``self.flat_param``."""
- self.flat_param.data = self.flat_param.to(*args, **kwargs)
- if self._use_orig_params:
- # Refresh the views because their storage may have changed
- if self.is_sharded(self.flat_param):
- self._use_sharded_views()
- else:
- self._use_unsharded_views(as_params=True)
- def _get_modules(self) -> Set[nn.Module]:
- """Returns a :class:`set` of the modules whose parameters are included
- in this handle's flattened parameter."""
- return {pi.module for pi in self.flat_param._param_infos}.union(
- {spi.module for spi in self.flat_param._shared_param_infos}
- )
- def is_sharded(self, tensor: Tensor) -> bool:
- """
- Returns if ``tensor`` is *currently* sharded. For ``NO_SHARD``, we
- choose to have this always return ``False`` for clarity.
- """
- if (
- not hasattr(self.flat_param, "_sharded_size")
- or not self.uses_sharded_strategy
- ):
- # `_sharded_size` is defined iff `handle.shard()` has been called
- return False
- sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
- return tensor.size() == sharded_size
- def parameter_module_names(self) -> Iterator[Tuple[str, str]]:
- shared_param_infos = [
- ParamInfo(param_name, module, module_name)
- for (
- param_name,
- module,
- module_name,
- _,
- _,
- _,
- ) in self.flat_param._shared_param_infos
- ]
- for param_name, _, module_name in chain(
- self.flat_param._param_infos, shared_param_infos
- ):
- yield (param_name, module_name)
- def shared_parameter_module_names(self) -> Iterator[Tuple[str, str]]:
- for param_name, _, module_name in [
- ParamInfo(param_name, module, module_name)
- for (
- param_name,
- module,
- module_name,
- _,
- _,
- _,
- ) in self.flat_param._shared_param_infos
- ]:
- yield (param_name, module_name)
- @property
- def _fqns_in_shard(self) -> List[str]:
- """Returns the FQNs of the parameters present in this rank's shard."""
- fqns_in_shard: List[str] = []
- start, end = self.flat_param._shard_indices # type: ignore[attr-defined]
- for i in range(len(self.flat_param._fqns)):
- if i >= start and i <= end and self.flat_param._shard_param_offsets: # type: ignore[attr-defined]
- fqns_in_shard.append(self.flat_param._fqns[i])
- return fqns_in_shard
- @property
- def sharded_grad(self) -> Optional[Tensor]:
- """Returns the handle's sharded gradient."""
- flat_param = self.flat_param
- # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
- # - CPU offloading: `_cpu_grad`
- # - No CPU offloading + sharded strategies: `_saved_grad_shard`
- # - No CPU offloading + `NO_SHARD`: `grad`
- if hasattr(flat_param, "_cpu_grad"):
- grad = flat_param._cpu_grad # type: ignore[attr-defined]
- elif hasattr(flat_param, "_saved_grad_shard"):
- grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
- else:
- # If in the forward, then there may be an accumulated gradient,
- # which will be in `.grad`
- p_assert(
- flat_param.grad is None
- or not self.uses_sharded_strategy
- or self._training_state == HandleTrainingState.FORWARD,
- "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
- "unless in FORWARD (for the post-forward reshard)",
- )
- grad = flat_param.grad
- return grad
- def _reset_is_grad_none(self) -> None:
- """
- Resets the ``_is_grad_none`` mask as needed. This method should only be
- called in the post-backward after gradient computation, in which case
- if a parameter requires gradient, then it will surely receive a
- gradient and we may reset its mask entry to ``False``.
- """
- if not self._use_orig_params:
- return
- p_assert(
- self._training_state == HandleTrainingState.BACKWARD_POST,
- "Expects to only be called in the post-backward after gradient computation",
- )
- flat_param = self.flat_param
- assert flat_param._params is not None # mypy
- for i, param in enumerate(flat_param._params):
- # As long as the parameter requires gradient, it should receive a
- # meaningful gradient (even if the gradient happens to be zeros)
- if param.requires_grad:
- assert flat_param._is_grad_none is not None # mypy
- flat_param._is_grad_none[i] = False
- #######################
- # CHECKS & INVARIANTS #
- #######################
- def _check_sharded_strategy(self):
- p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
- def _check_on_compute_device(self, tensor: Tensor):
- p_assert(
- tensor.device == self.device,
- f"Expects tensor to be on the compute device {self.device}",
- )
- def _check_on_cpu(self, tensor: Tensor):
- p_assert(
- tensor.device == torch.device("cpu"),
- f"Expects tensor to be on CPU but got {tensor.device}",
- )
- @staticmethod
- def _check_storage_freed(tensor: Tensor):
- storage_size: int = tensor._typed_storage()._size()
- p_assert(
- storage_size == 0,
- f"Expects storage to be freed but got storage with size {storage_size}",
- )
- @staticmethod
- def _check_storage_allocated(tensor: Tensor):
- storage_size: int = tensor._typed_storage()._size()
- p_assert(storage_size > 0, "Expects storage to be allocated")
- def _check_low_precision_shard(self):
- p_assert(
- self._uses_param_mixed_precision,
- "Not using low precision for parameters",
- )
- p_assert(
- getattr(self.flat_param, "_mp_shard", None) is not None,
- "Expects `_mp_shard` to exist",
- )
- device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
- p_assert(
- device == self.device,
- f"Expects the low precision shard to be on {self.device} but got {device}",
- )
- def _check_unsharded(self, tensor: Tensor):
- msg_prefix = "Expects tensor to be unsharded "
- p_assert(tensor is not None, msg_prefix + "but got `None`")
- unsharded_size = self.flat_param._unpadded_unsharded_size
- p_assert(
- tensor.size() == unsharded_size,
- msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
- )
- def _check_sharded(self, tensor: Tensor):
- msg_prefix = "Expects tensor to be sharded "
- p_assert(tensor is not None, msg_prefix + "but got `None`")
- sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
- p_assert(
- tensor.size() == sharded_size,
- msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
- )
- ##############
- # PROPERTIES #
- ##############
- @property
- def uses_sharded_strategy(self) -> bool:
- return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
- @property
- def _uses_param_mixed_precision(self) -> bool:
- return self._fwd_bwd_param_dtype != self._orig_param_dtype
- @property
- def _uses_reduce_mixed_precision(self) -> bool:
- return self._reduce_dtype != self._orig_param_dtype
- @property
- def _force_full_precision(self) -> bool:
- return (
- self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
- and self._uses_param_mixed_precision
- )
- # A handles key represents the group of `FlatParamHandle`s involved in a given
- # module's forward. These will be all-gathered together in the pre-forward and
- # pre-backward.
- _HandlesKey = Tuple[FlatParamHandle, ...]
|