123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761 |
- import functools
- import math
- import warnings
- from typing import Any, Callable, cast, Dict, Iterator, no_type_check, Tuple
- import torch
- import torch.distributed as dist
- import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper
- import torch.distributed.fsdp._traversal_utils as traversal_utils
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.distributed._shard.sharded_tensor import (
- init_from_local_shards,
- Shard,
- ShardedTensor,
- )
- from torch.distributed.fsdp._common_utils import (
- _FSDPState,
- _has_fsdp_params,
- _is_composable,
- _module_handles,
- clean_tensor_name,
- FSDP_PREFIX,
- FSDP_WRAPPED_MODULE,
- )
- from torch.distributed.fsdp._runtime_utils import (
- _cast_buffers_to_dtype_and_device,
- _clear_grads_if_needed,
- _get_buffer_dtypes,
- _lazy_init,
- )
- from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
- from torch.distributed.utils import _replace_by_prefix
- from ._fsdp_extensions import (
- _ext_chunk_tensor,
- _ext_pre_load_state_dict_transform,
- _extensions as _user_extensions,
- )
- from ._unshard_param_utils import (
- _deregister_orig_params,
- _register_orig_params,
- _unshard_fsdp_state_params,
- FLAT_PARAM,
- )
- from .flat_param import FlatParamHandle
- def _convert_to_wrapped_module_name(module_name: str) -> str:
- module_name = module_name.replace(f"{FSDP_PREFIX}", "")
- module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "")
- if module_name:
- module_name = f"{module_name}."
- # `CheckpointWrapper` adds a prefix that has to be removed as well.
- module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "")
- return module_name
- def _param_fqns(
- module: nn.Module, fsdp_state: _FSDPState
- ) -> Iterator[Tuple[str, str, str]]:
- if not _has_fsdp_params(fsdp_state, module):
- return
- for param_name, module_name in _module_handles(fsdp_state, module)[
- 0
- ].parameter_module_names():
- module_name = _convert_to_wrapped_module_name(module_name)
- fqn = f"{module_name}{param_name}"
- yield fqn, param_name, module_name
- def _shared_param_fqns(module: nn.Module, fsdp_state) -> Iterator[Tuple[str, str, str]]:
- for param_name, module_name in _module_handles(fsdp_state, module)[
- 0
- ].shared_parameter_module_names():
- module_name = _convert_to_wrapped_module_name(module_name)
- fqn = f"{module_name}{param_name}"
- yield fqn, param_name, module_name
- @no_type_check
- def _enter_unshard_params_ctx(
- module: nn.Module,
- fsdp_state: _FSDPState,
- writeback: bool = False,
- rank0_only: bool = False,
- offload_to_cpu: bool = False,
- with_grads: bool = False,
- ) -> None:
- """
- state_dict hooks cannot use the pure context call as the checkpoint flow
- requires to enter the context in the pre-hook but leave the context in the
- post-hook. This API enters the context of ``_unshard_fsdp_state_params``.
- """
- assert module not in fsdp_state._unshard_params_ctx, (
- "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] "
- "is not None."
- )
- fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params(
- module,
- fsdp_state,
- writeback=writeback,
- rank0_only=rank0_only,
- offload_to_cpu=offload_to_cpu,
- with_grads=with_grads,
- )
- fsdp_state._unshard_params_ctx[module].__enter__()
- @no_type_check
- def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None:
- """A helper function to exit ``_unshard_fsdp_state_params`` context."""
- fsdp_state._unshard_params_ctx[module].__exit__(None, None, None)
- fsdp_state._unshard_params_ctx.pop(module)
- def _common_pre_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- ) -> None:
- """Performs the pre-state_dict tasks shared by all state_dict types."""
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- # TODO: need to check if this is always correct for composable FSDP.
- _lazy_init(fsdp_state, module)
- # TODO: change to this call after pre_state_dict_hook is in `nn.Module`.
- if fsdp_state._is_root:
- _clear_grads_if_needed(traversal_utils._get_fsdp_handles(module))
- def _common_unshard_pre_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- offload_to_cpu: bool,
- rank0_only: bool,
- ) -> None:
- """
- Performs the pre-state_dict tasks shared by all state_dict types that require
- ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook.
- """
- _enter_unshard_params_ctx(
- module,
- fsdp_state,
- writeback=False,
- offload_to_cpu=offload_to_cpu,
- rank0_only=rank0_only,
- )
- # TODO: change to the decorator style. See ``_full_pre_state_dict_hook``.
- @no_type_check
- def _common_unshard_post_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- param_hook: Callable,
- ) -> Dict[str, Any]:
- """
- The post-state_dict flow that shared by all state_dict types that require
- ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this
- hook.
- """
- _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix)
- # Return early for trivial cases
- if not state_dict or not _has_fsdp_params(fsdp_state, module):
- _exit_unshard_params_ctx(module, fsdp_state)
- return state_dict
- # If a rank does not have unsharded parameters(when `rank0_only=True`
- # and `rank != 0`), then the rank only needed to participate in the
- # all-gather and does not need to save the # state dict. We simply check
- # rank0_only to ensure this issue.
- rank0_only = (
- fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT
- and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only
- )
- # no_fsdp_return means the state_dict returned by this rank should contain
- # only non-FSDP controlled parameters and buffers.
- no_fsdp_return = rank0_only and fsdp_state.rank != 0
- if no_fsdp_return and not fsdp_state._use_orig_params:
- for clean_key in fsdp_state._buffer_names:
- # This is a hack to support activation checkpoint.
- clean_key = clean_key.replace(
- f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", ""
- )
- state_dict.pop(f"{prefix}{clean_key}", None)
- # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is
- # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param
- # to appear in state_dict.
- state_dict.pop(f"{prefix}{FLAT_PARAM}")
- _exit_unshard_params_ctx(module, fsdp_state)
- return state_dict
- # Loop only the parameters saved in this instance's wrapped module to
- # avoid processing buffers.
- for fqn, param_name, module_name in _param_fqns(module, fsdp_state):
- fqn = f"{prefix}{fqn}"
- if no_fsdp_return:
- state_dict.pop(fqn)
- continue
- assert fqn in state_dict, (
- f"FSDP assumes {fqn} is in the state_dict but the state_dict only "
- f"has {state_dict.keys()}. "
- f"prefix={prefix}, module_name={module_name}, "
- f"param_name={param_name} rank={fsdp_state.rank}."
- )
- param_hook(state_dict, prefix, fqn)
- _exit_unshard_params_ctx(module, fsdp_state)
- cpu_device = torch.device("cpu")
- buffer_clean_fqns = []
- buffers = []
- for clean_key in fsdp_state._buffer_names:
- # This is a hack to support activation checkpoint.
- clean_key = clean_tensor_name(clean_key)
- fqn = f"{prefix}{clean_key}"
- if fqn not in state_dict:
- # A buffer can be registered as non-persistent.
- continue
- if no_fsdp_return:
- state_dict.pop(fqn)
- else:
- buffer = state_dict[fqn]
- if (
- fsdp_state._state_dict_config.offload_to_cpu
- and buffer.device != cpu_device
- ):
- state_dict[fqn] = buffer.to(cpu_device)
- # TODO: for composable FSDP, this should be clean_tensor_name(clean_key),
- buffer_clean_fqns.append(clean_key)
- buffers.append(state_dict[fqn])
- if buffers:
- mixed_precision_enabled_for_buffers = (
- fsdp_state._mixed_precision_enabled_for_buffers()
- if not _is_composable(fsdp_state)
- else (fsdp_state.mixed_precision.buffer_dtype is not None)
- )
- if mixed_precision_enabled_for_buffers:
- buffer_dtypes = _get_buffer_dtypes(fsdp_state, buffer_clean_fqns)
- _cast_buffers_to_dtype_and_device(
- buffers, buffer_dtypes, fsdp_state.compute_device
- )
- for buffer, clean_fqn in zip(buffers, buffer_clean_fqns):
- fqn = f"{prefix}{clean_fqn}"
- state_dict[fqn] = buffer.clone()
- return state_dict
- @no_type_check
- def _full_pre_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- *args,
- **kwargs,
- ) -> None:
- """
- Hook that runs before model.state_dict() is called. pre-state_dict hook is
- not actually supported by ``nn.Module``. As a result, this API is called
- from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict
- is supported in ``nn.Module``, this hook will be registered as a hook in
- ``nn.Module``.
- TODO: clean the callsites and hacks after ``pre_state_dict_hook` ` is supported
- in ``nn.Module``.
- """
- _common_pre_state_dict_hook(module, fsdp_state)
- _common_unshard_pre_state_dict_hook(
- module,
- fsdp_state,
- offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu,
- rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only,
- )
- @no_type_check
- def _full_post_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- ) -> Dict[str, Any]:
- """
- Hook that runs after model.state_dict() is called before returning result to
- user. For FSDP, we may have to clone the tensors in state_dict as params go
- back to sharded version after _unshard_fsdp_state_params ends, and also remove
- the ``FSDP_WRAPPED_MODULE`` prefix.
- """
- def param_hook(
- state_dict: Dict[str, Any],
- prefix: str,
- fqn: str,
- ) -> None:
- clean_key = fqn
- clean_prefix = clean_tensor_name(prefix)
- # Strip prefix out of key if needed as buffer names and param names
- # do not have prefix considered as they are not computed in `state_dict`
- # call.
- if clean_key.startswith(clean_prefix):
- clean_key = clean_key[len(clean_prefix) :]
- # Clone parameters before exiting the `_unshard_fsdp_state_params()` context.
- if not getattr(state_dict[fqn], "_has_been_cloned", False):
- try:
- state_dict[fqn] = state_dict[fqn].clone().detach()
- state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined]
- except BaseException as e:
- warnings.warn(
- f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. "
- "This may mean that this state_dict entry could point to invalid "
- "memory regions after returning from state_dict() call if this "
- "parameter is managed by FSDP. Please check clone "
- f"implementation of {fqn}. Error: {str(e)}"
- )
- return _common_unshard_post_state_dict_hook(
- module, fsdp_state, state_dict, prefix, param_hook
- )
- def _full_pre_load_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- ) -> None:
- _lazy_init(fsdp_state, module)
- _enter_unshard_params_ctx(module, fsdp_state, writeback=True)
- # Add FSDP_PREFIX only for wrapper-based FSDP.
- if not _is_composable(fsdp_state):
- _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
- def _full_post_load_state_dict_hook(
- module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
- ) -> None:
- _exit_unshard_params_ctx(module, fsdp_state)
- def _local_pre_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- *args,
- **kwargs,
- ) -> None:
- """
- Hook that runs before model.state_dict() is called. Right now, pre-state_dict
- hook is not supported by the PyTorch core. So this API is called from
- `_local_post_state_dict_hook()` to simulate the case.
- """
- if (
- _has_fsdp_params(fsdp_state, module)
- and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy
- ):
- raise RuntimeError(
- "``local_state_dict`` can only be used when parameters are flatten "
- "and sharded."
- )
- _common_pre_state_dict_hook(module, fsdp_state)
- @no_type_check
- def _local_post_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- ) -> Dict[str, Any]:
- """
- This hook create a ShardedTensor from the local flat_param and replace
- the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy
- will happen. The underlying storage is the same.
- """
- _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix)
- if not _has_fsdp_params(fsdp_state, module):
- return state_dict
- # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor
- # value as the flat_param but it is a pure Tensor because
- # nn.Module.state_dict() will detach the parameter. Therefore, we need
- # to get flat_param to get the metadata.
- assert _module_handles(fsdp_state, module), "Should have returned early"
- flat_param = _module_handles(fsdp_state, module)[0].flat_param
- # Constructs a ShardedTensor from the flat_param "without" padding.
- # Removing the padding allows users to change the number of ranks
- # when loading the local_state_dict.
- full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined]
- shard_offset = flat_param.numel() * fsdp_state.rank
- valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
- if valid_data_size > 0:
- # If FlatParameter is returned, FlatParameter._local_shard cause a
- # pickling issue (can be torch.save but not torch.load). Since there
- # is no benefit for state_dict to return the actual FlatParameter class,
- # a view (which is a tensor) of the FlatParameter will be returned.
- flat_param = flat_param[:valid_data_size].view(valid_data_size)
- local_shards = [
- Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank)
- ]
- else:
- local_shards = []
- sharded_tensor = init_from_local_shards(
- local_shards, full_numel, process_group=fsdp_state.process_group
- ) # type: ignore[assignment]
- if fsdp_state._state_dict_config.offload_to_cpu:
- sharded_tensor = sharded_tensor.cpu()
- state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
- return state_dict
- def _local_post_load_state_dict_hook(
- module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
- ) -> None:
- pass
- def _local_pre_load_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- ) -> None:
- """
- This hook finds the local flat_param for this FSDP module from the
- state_dict. The flat_param should be a ShardedTensor. This hook converts
- the ShardedTensor to a tensor. No copy happen unless padding is required.
- """
- _lazy_init(fsdp_state, module)
- _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}")
- fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"
- if fqn not in state_dict:
- assert not _has_fsdp_params(fsdp_state, module), (
- "No `FlatParameter` in `state_dict` for this FSDP instance "
- "but it has parameters"
- )
- return
- load_tensor = state_dict[fqn]
- assert isinstance(
- load_tensor, ShardedTensor
- ), "Tensors in local_state_dict should be ShardedTensor."
- # Convert the ShardedTensor to a Tensor.
- flat_param = _module_handles(fsdp_state, module)[0].flat_param
- assert flat_param is not None
- valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
- shards = load_tensor.local_shards()
- if valid_data_size > 0:
- assert len(shards), "load_local_state_dict assume one shard per ShardedTensor."
- load_tensor = shards[0].tensor
- # Get the metadata of the flat_param to decide whether to pad the loaded
- # tensor.
- if flat_param._shard_numel_padded > 0:
- assert load_tensor.numel() < flat_param.numel(), (
- f"Local shard size = {flat_param.numel()} and the tensor in "
- f"the state_dict is {load_tensor.numel()}."
- )
- load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded])
- else:
- load_tensor = flat_param
- state_dict[fqn] = load_tensor
- def _sharded_pre_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- *args,
- **kwargs,
- ) -> None:
- """
- Hook that runs before model.state_dict() is called. Check
- ``_full_pre_load_state_dict_hook`` for the detail.
- """
- if (
- _has_fsdp_params(fsdp_state, module)
- and not _module_handles(fsdp_state, module)[0].uses_sharded_strategy
- ):
- raise RuntimeError(
- "``sharded_state_dict`` can only be used when parameters are flatten "
- "and sharded."
- )
- _common_pre_state_dict_hook(module, fsdp_state)
- # Setting offload_to_cpu here does not work even if offload_to_cpu is True.
- # We have to create ShardedTensor first then move it to CPU.
- _common_unshard_pre_state_dict_hook(
- module,
- fsdp_state,
- offload_to_cpu=False,
- rank0_only=False,
- )
- @no_type_check
- def _sharded_post_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- ) -> Dict[str, Any]:
- """
- The hook replaces the unflattened, unsharded parameter in the state_dict
- with a unflattened, sharded parameter (a ShardedTensor).
- """
- def param_hook(state_dict: Dict[str, Any], prefix: str, fqn: str):
- param = state_dict[fqn]
- sharded_tensor = _ext_chunk_tensor(
- tensor=param,
- rank=fsdp_state.rank,
- world_size=fsdp_state.world_size,
- num_devices_per_node=torch.cuda.device_count(),
- pg=fsdp_state.process_group,
- )
- if fsdp_state._state_dict_config.offload_to_cpu:
- sharded_tensor = sharded_tensor.cpu()
- state_dict[fqn] = sharded_tensor
- return _common_unshard_post_state_dict_hook(
- module, fsdp_state, state_dict, prefix, param_hook
- )
- @no_type_check
- def _sharded_post_load_state_dict_hook(
- module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs
- ) -> None:
- if fsdp_state._use_orig_params:
- _register_orig_params(module, fsdp_state)
- @no_type_check
- def _sharded_pre_load_state_dict_hook(
- module: nn.Module,
- fsdp_state: _FSDPState,
- state_dict: Dict[str, Any],
- prefix: str,
- ) -> None:
- """
- The hook combines the unflattened, sharded parameters (ShardedTensor) to
- a new FlatParameter and shards the new FlatParameter to the local chunk.
- """
- _lazy_init(fsdp_state, module)
- _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}")
- if not _has_fsdp_params(fsdp_state, module):
- return
- if not _module_handles(fsdp_state, module)[0].uses_sharded_strategy:
- raise RuntimeError(
- "load_sharded_state_dict can only be called when parameters "
- "are flatten and sharded."
- )
- nonsharded_tensors = []
- shared_fqns = [fqn for fqn, _, _ in _shared_param_fqns(module, fsdp_state)]
- loaded_shapes = []
- for fqn, _, _ in _param_fqns(module, fsdp_state):
- full_fqn = f"{prefix}{FSDP_PREFIX}{fqn}"
- param = state_dict.pop(full_fqn)
- if fqn in shared_fqns:
- continue
- # All-gather the param (ShardedTensor)
- param, shards = _ext_pre_load_state_dict_transform(param)
- loaded_shapes.append(param.size())
- assert len(shards) < 2, (
- "Expects 0 or 1 shard per rank "
- f"but got {len(shards)} shards on rank {fsdp_state.rank}."
- )
- param_numel = param.size().numel()
- dim_0_size = param.size()[0]
- chunk_size = (
- math.ceil(dim_0_size / fsdp_state.world_size) * param_numel // dim_0_size
- )
- if len(shards) == 1:
- local_tensor = shards[0].tensor.flatten()
- if not local_tensor.is_cuda:
- local_tensor = local_tensor.cuda()
- num_padding = chunk_size - local_tensor.numel()
- if num_padding > 0:
- local_tensor = F.pad(local_tensor, [0, num_padding])
- else:
- local_tensor = torch.zeros(chunk_size, dtype=param.dtype).cuda()
- tensor = torch.empty(
- chunk_size * fsdp_state.world_size, dtype=local_tensor.dtype
- ).cuda()
- dist.all_gather_into_tensor(
- tensor, local_tensor, group=fsdp_state.process_group
- )
- tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
- nonsharded_tensors.append(tensor)
- # Create a new flat_param from the loaded, non-sharded tensors.
- flat_param = _module_handles(fsdp_state, module)[0].flat_param
- loaded_flat_param = FlatParamHandle.flatten_params(
- nonsharded_tensors, requires_grad=False
- )
- # Get the chunk from the loaded flat_param for the local rank.
- loaded_flat_tensor, num_to_pad = FlatParamHandle._get_shard(
- loaded_flat_param,
- fsdp_state.rank,
- fsdp_state.world_size,
- )
- loaded_flat_tensor.to(flat_param.device)
- assert all(s1 == s2 for s1, s2 in zip(loaded_shapes, flat_param._shapes)), (
- f"The original shapes in FSDP are {flat_param._shapes}. "
- f"The loaded shapes are {loaded_shapes}. "
- f"FSDP extension is {'NOT' if _user_extensions is not None else ''} None."
- )
- assert flat_param.numel() == loaded_flat_tensor.numel(), (
- f"The loaded local chunk has different numel({loaded_flat_tensor.numel()}) "
- f"from the local chunk {flat_param.numel()}."
- )
- assert flat_param._shard_numel_padded == num_to_pad, (
- f"The loaded local chunk has different padding({num_to_pad}) "
- f"from the local chunk {flat_param._shard_numel_padded}."
- )
- state_dict[f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}"] = loaded_flat_tensor
- if fsdp_state._use_orig_params:
- _deregister_orig_params(module, fsdp_state)
- @no_type_check
- @torch.no_grad()
- def _post_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- state_dict: Dict[str, Any],
- prefix: str,
- *args: Any,
- ) -> Dict[str, Any]:
- """
- _post_state_dict_hook() is called after the state_dict() of this
- FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide
- what postprocessing will be done.
- """
- _post_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
- }
- processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
- module, fsdp_state, state_dict, prefix
- )
- return processed_state_dict
- @no_type_check
- @torch.no_grad()
- def _pre_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- *args,
- **kwargs,
- ) -> None:
- """
- This is called before the core state dict saving logic of ``module``.
- ``fsdp_state._state_dict_type`` is used to decide what postprocessing will
- be done.
- """
- _pre_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
- }
- _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
- fsdp_state,
- module,
- *args,
- **kwargs,
- )
- @no_type_check
- @torch.no_grad()
- def _pre_load_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- state_dict: Dict[str, Any],
- prefix: str,
- *args: Any,
- ) -> None:
- """
- This is called before ``module._load_from_state_dict()``.
- ``fsdp_state._state_dict_type`` is used to decide what preprocessing will
- be done.
- """
- _pre_load_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
- }
- # Code that is common for all state_dict impls
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- # Dispatch into state_dict specific implementation of pre-hook.
- _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
- module, fsdp_state, state_dict, prefix
- )
- @no_type_check
- @torch.no_grad()
- def _post_load_state_dict_hook(
- fsdp_state: _FSDPState,
- module: nn.Module,
- *args: Any,
- ) -> None:
- _post_load_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
- }
- # Code that is common for all state_dict impls
- # Dispatch into state_dict type specific implementation of post-hook for
- # loading state_dict.
- _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
- def _register_all_state_dict_hooks(state: _FSDPState):
- """
- Registers pre-save, post-save, pre-load, and post-load state dict hooks.
- """
- for hook_registration_fn_str, hook, hook_registration_fn_kwargs in (
- ("register_state_dict_pre_hook", _pre_state_dict_hook, {}),
- ("_register_state_dict_hook", _post_state_dict_hook, {}),
- (
- "_register_load_state_dict_pre_hook",
- _pre_load_state_dict_hook,
- {"with_module": True},
- ),
- ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}),
- ):
- _register_state_dict_hooks_base(
- state, hook_registration_fn_str, hook, hook_registration_fn_kwargs
- )
- @no_type_check
- def _register_state_dict_hooks_base(
- state: _FSDPState,
- hook_registration_fn_name: str,
- hook: Callable,
- hook_registration_fn_kwargs: Dict[str, Any],
- ) -> None:
- """Registers ``hook`` using ``hook_registration_fn``."""
- # TODO: Use `_get_submodule_state(module)` in each hook instead of
- # `partial`: https://github.com/pytorch/pytorch/issues/90788
- hook_with_state = functools.partial(hook, state)
- if not _is_composable(state):
- getattr(state, hook_registration_fn_name)(
- hook_with_state, **hook_registration_fn_kwargs
- )
- else:
- for handle in state._handles:
- getattr(handle._fully_sharded_module, hook_registration_fn_name)(
- hook_with_state, **hook_registration_fn_kwargs
- )
|