1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000 |
- import collections
- import warnings
- from typing import (
- Any,
- Callable,
- Dict,
- Generator,
- Iterable,
- Iterator,
- List,
- no_type_check,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
- )
- import torch
- import torch.distributed as dist
- import torch.distributed.fsdp._exec_order_utils as exec_order_utils
- import torch.distributed.fsdp._traversal_utils as traversal_utils
- import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
- import torch.nn as nn
- from torch.distributed.algorithms._comm_hooks import default_hooks
- from torch.distributed.distributed_c10d import _get_default_group
- from torch.distributed.fsdp._common_utils import (
- _FSDPState,
- _get_module_fsdp_state,
- _is_fsdp_flattened,
- clean_tensor_name,
- TrainingState,
- )
- from torch.distributed.fsdp._limiter_utils import _FreeEventQueue
- from torch.distributed.fsdp._wrap_utils import _get_fully_sharded_module_to_states
- from torch.distributed.fsdp.api import (
- BackwardPrefetch,
- CPUOffload,
- FullOptimStateDictConfig,
- FullStateDictConfig,
- MixedPrecision,
- ShardingStrategy,
- StateDictConfig,
- StateDictType,
- )
- from torch.distributed.fsdp.flat_param import (
- _HandlesKey,
- FlatParameter,
- FlatParamHandle,
- HandleShardingStrategy,
- )
- from torch.distributed.fsdp.wrap import _FSDPPolicy
- from torch.distributed.utils import _sync_params_and_buffers
- from torch.utils.hooks import RemovableHandle
- _TORCHDISTX_AVAIL = True
- try:
- from torchdistx import deferred_init, fake
- except ImportError:
- _TORCHDISTX_AVAIL = False
- PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024)
- FSDP_SYNCED = "_fsdp_synced"
- HybridShardProcessGroupType = Tuple[dist.ProcessGroup, dist.ProcessGroup]
- ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]]
- SHARDING_STRATEGY_MAP = {
- ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD,
- ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD,
- ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP,
- ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
- ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
- }
- HYBRID_SHARDING_STRATEGIES = {
- ShardingStrategy.HYBRID_SHARD,
- ShardingStrategy._HYBRID_SHARD_ZERO2,
- }
- @no_type_check
- def _init_process_group_state(
- state: _FSDPState,
- process_group: ProcessGroupType,
- sharding_strategy: ShardingStrategy,
- policy: Optional[_FSDPPolicy],
- ) -> _FSDPState:
- if sharding_strategy in HYBRID_SHARDING_STRATEGIES:
- if process_group is None and policy is None:
-
-
-
- raise ValueError(
- f"Manual wrapping with {sharding_strategy} requires explicit specification of process group."
- )
- else:
- state = _init_process_group_state_for_hybrid_shard(state, process_group)
- assert (
- state.process_group is not None
- ), "Expected to populate state.process_group for hybrid shard"
- assert (
- state._inter_node_pg is not None
- ), "Expected to populate state._inter_node_pg for hybrid shard"
- assert (
- state._inter_node_state is not None
- ), "Expected to populate state._inter_node_state for hybrid shad."
- else:
- state.process_group = (
- process_group if process_group is not None else _get_default_group()
- )
- state.rank = state.process_group.rank()
- state.world_size = state.process_group.size()
- return state
- @no_type_check
- def _init_process_group_state_for_hybrid_shard(
- state: _FSDPState, process_group
- ) -> _FSDPState:
- if process_group is None:
- default_group = _get_default_group()
- intra_node_group, inter_node_group = _init_intra_and_inter_node_groups(
- default_group
- )
-
- state.process_group = intra_node_group
-
- state._inter_node_pg = inter_node_group
- else:
-
- if _is_valid_hybrid_shard_pg_type(process_group):
-
-
- state.process_group, state._inter_node_pg = process_group
- else:
- raise ValueError(
- "Expected process_group to be passed in as either None or "
- f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}"
- )
-
- state._inter_node_state = _get_default_comm_hook_state(
- process_group=state._inter_node_pg,
- )
- return state
- @no_type_check
- def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool:
- return (
- isinstance(process_group, tuple)
- and len(process_group) == 2
- and all(isinstance(pg, dist.ProcessGroup) for pg in process_group)
- )
- @no_type_check
- def _init_intra_node_process_group() -> dist.ProcessGroup:
- """
- Returns a process group across the current node.
- For example, given each row is a distinct node:
- 0 1 2 3 4 5 6 7 8
- 9 10 11 12 13 14 15
- This API would return an intra-node subgroup across
- [0, 7] or [8, 15] depending on the process's rank.
- For example, rank 3 would get [0, 7].
- """
- intra_node_subgroup, _ = dist.new_subgroups()
- return intra_node_subgroup
- @no_type_check
- def _init_inter_node_process_group(
- global_process_group: dist.ProcessGroup,
- ) -> dist.ProcessGroup:
- """
- Returns an inter-node process group where each contained rank has
- the same local rank. For example, given each column is a distinct node:
- 0 1 2 3 4 5 6 7 8
- 9 10 11 12 13 14 15
- This API would return inter-node process group {0, 8}, {1, 9}, {2, 10}, and so forth
- depending on the process's rank. For example, rank 1 would get {1, 9}, rank 5
- would get {5, 13}.
- """
-
- inter_node_pg = None
- sharding_backend = dist.get_backend(global_process_group)
- world_size = dist.get_world_size(global_process_group)
-
- num_devices = torch.cuda.device_count()
- num_nodes = world_size // num_devices
- my_local_rank = dist.get_rank(global_process_group) % num_devices
- for local_rank in range(num_devices):
- ranks_for_inter_group = [
- local_rank + (i * num_devices) for i in range(num_nodes)
- ]
-
- grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend)
- if local_rank == my_local_rank:
- print(f"{local_rank} created process group for {ranks_for_inter_group}")
- inter_node_pg = grp
- assert (
- inter_node_pg is not None
- ), f"{my_local_rank} expected to assign inter-node pg, but did not"
- return inter_node_pg
- def _init_intra_and_inter_node_groups(
- global_process_group: dist.ProcessGroup,
- ) -> Tuple[dist.ProcessGroup, dist.ProcessGroup]:
- """
- Initializes intra and inter-node process groups and returns the ones corresponding
- to this process's rank.
- This function can be used to initialize process groups for ``HYBRID_SHARD`` or
- ``_HYBRID_SHARD_ZERO2`` in FSDP.
- This function assumes each node has an equal number of CUDA-enabled devices.
- Returns:
- Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group.
- """
- return (
- _init_intra_node_process_group(),
- _init_inter_node_process_group(global_process_group),
- )
- @no_type_check
- def _init_ignored_module_states(
- state: _FSDPState,
- module: nn.Module,
- ignored_modules: Optional[Iterable[torch.nn.Module]],
- ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
- ) -> _FSDPState:
- assert (
- ignored_modules is None or ignored_parameters is None
- ), "Can not pass `ignored_modules` and `ignored_parameters` at the same time. \
- Please either pass `ignored_modules` or `ignored_parameters`."
- state._ignored_modules = _get_ignored_modules(module, ignored_modules)
- state._ignored_params = _get_ignored_params(
- module,
- state._ignored_modules,
- ignored_parameters,
- )
-
-
-
-
-
- return state
- @no_type_check
- def _init_buffer_state(
- state: _FSDPState,
- module: nn.Module,
- ) -> _FSDPState:
- state._buffer_names = _get_buffer_names(module)
-
-
-
-
- _buffer_name_to_orig_dtype: Dict[str, torch.dtype] = {}
- for buffer_name, buffer in module.named_buffers():
- buffer_name = clean_tensor_name(buffer_name)
- _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype
- state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype
- return state
- @no_type_check
- def _init_core_state(
- state: _FSDPState,
- sharding_strategy: Optional[ShardingStrategy],
- mixed_precision: Optional[MixedPrecision],
- cpu_offload: Optional[CPUOffload],
- limit_all_gathers: bool,
- use_orig_params: bool,
- backward_prefetch_limit: int,
- forward_prefetch_limit: int,
- ) -> _FSDPState:
-
-
-
- if state.world_size == 1:
- if sharding_strategy != ShardingStrategy.NO_SHARD:
- warnings.warn(
- "FSDP is switching to use `NO_SHARD` instead of "
- f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since "
- "the world size is 1."
- )
- sharding_strategy = ShardingStrategy.NO_SHARD
- state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD
- state.mixed_precision = mixed_precision or MixedPrecision()
- state.cpu_offload = cpu_offload or CPUOffload()
- state.limit_all_gathers = limit_all_gathers
- state._use_orig_params = use_orig_params
- state.training_state = TrainingState.IDLE
- state._is_root = None
- _streams: Dict[str, torch.cuda.Stream] = {}
- state._streams = _streams
- _stream_to_name: Dict[torch.cuda.Stream, str] = {}
- state._stream_to_name = _stream_to_name
- state._free_event_queue = _FreeEventQueue()
- state._debug_level = dist.get_debug_level()
- state._exec_order_data = exec_order_utils._ExecOrderData(
- state._debug_level,
- backward_prefetch_limit,
- forward_prefetch_limit,
- )
-
-
- _fully_sharded_module_to_handles: Dict[
- nn.Module, List[FlatParamHandle]
- ] = collections.defaultdict(list)
- state._fully_sharded_module_to_handles = _fully_sharded_module_to_handles
-
-
- _handles: List[FlatParamHandle] = []
- state._handles = _handles
- params: List[FlatParameter] = []
- state.params = params
- return state
- @no_type_check
- def _init_runtime_state(
- state: _FSDPState,
- ) -> _FSDPState:
- _root_pre_forward_handles: List[RemovableHandle] = []
- state._root_pre_forward_handles = _root_pre_forward_handles
- _pre_forward_handles: List[RemovableHandle] = []
- state._pre_forward_handles = _pre_forward_handles
- _post_forward_handles: List[RemovableHandle] = []
- state._post_forward_handles = _post_forward_handles
- state._sync_gradients = True
- state._communication_hook = _get_default_comm_hook(state.sharding_strategy)
- state._communication_hook_state = _get_default_comm_hook_state(state.process_group)
- state._hook_registered = False
-
- _ran_pre_backward_hook: Dict[_HandlesKey, bool] = {}
- state._ran_pre_backward_hook = _ran_pre_backward_hook
- return state
- @no_type_check
- def _init_prefetching_state(
- state: _FSDPState,
- backward_prefetch: BackwardPrefetch,
- forward_prefetch: bool,
- ) -> _FSDPState:
- state.backward_prefetch = backward_prefetch
- state.forward_prefetch = forward_prefetch
- _handles_prefetched: Dict[_HandlesKey, bool] = {}
- state._handles_prefetched = _handles_prefetched
-
- _needs_pre_backward_unshard: Dict[_HandlesKey, bool] = {}
- state._needs_pre_backward_unshard = _needs_pre_backward_unshard
-
- _needs_pre_forward_unshard: Dict[_HandlesKey, bool] = {}
- state._needs_pre_forward_unshard = _needs_pre_forward_unshard
-
-
- return state
- def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
- state._state_dict_type = StateDictType.FULL_STATE_DICT
- state_dict_config: StateDictConfig = FullStateDictConfig()
- state._optim_state_dict_config = FullOptimStateDictConfig()
- state._state_dict_config = state_dict_config
- unshard_params_ctx: Dict[nn.Module, Generator] = {}
- state._unshard_params_ctx = unshard_params_ctx
- return state
- @no_type_check
- def _init_param_handle_from_module(
- state: _FSDPState,
- fully_sharded_module: nn.Module,
- device_id: Optional[Union[int, torch.device]],
- param_init_fn: Optional[Callable[[nn.Module], None]],
- sync_module_states: bool,
- module_wrapper_cls: Type,
- ) -> _FSDPState:
- """
- Initializes a ``FlatParamHandle`` from a module ``fully_sharded_module``.
- This is the module wrapper code path.
- """
- _check_single_device_module(fully_sharded_module, state._ignored_params)
- device_from_device_id = _get_device_from_device_id(device_id, state.rank)
- is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
- fully_sharded_module, state._ignored_params
- )
-
- if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None:
- _materialize_with_param_init_fn(fully_sharded_module, param_init_fn)
- elif is_meta_module:
- _materialize_meta_module(fully_sharded_module, device_id)
- elif is_torchdistX_deferred_init:
- deferred_init.materialize_module(
- fully_sharded_module,
- check_fn=lambda k: not isinstance(k, module_wrapper_cls),
- )
-
-
- _move_module_to_device(
- fully_sharded_module, state._ignored_params, device_from_device_id
- )
- state.compute_device = _get_compute_device(
- fully_sharded_module,
- state._ignored_params,
- device_from_device_id,
- state.rank,
- )
- managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params))
- if sync_module_states:
- _sync_module_params_and_buffers(
- fully_sharded_module, managed_params, state.process_group
- )
- _init_param_handle_from_params(state, managed_params, fully_sharded_module)
- return state
- @no_type_check
- def _init_param_handles_from_module(
- state: _FSDPState,
- root_module: nn.Module,
- policy: _FSDPPolicy,
- device_id: Optional[Union[int, torch.device]],
- param_init_fn: Optional[Callable[[nn.Module], None]],
- sync_module_states: bool,
- ) -> _FSDPState:
- """
- Initializes all ``FlatParamHandle`` s from a module ``root_module``. This
- is the non-module-wrapper code path. ``root_module`` is guaranteed to be
- a fully sharded module, and some of its submodules may be as well,
- depending on ``policy``. See [Note: Fully Sharded Module].
- """
- fully_sharded_module_to_states = _get_fully_sharded_module_to_states(
- root_module,
- policy,
- state._ignored_modules,
- state._ignored_params,
- )
- _check_single_device_module(root_module, state._ignored_params)
- device_from_device_id = _get_device_from_device_id(device_id, state.rank)
-
-
-
-
-
-
-
- materialized_module = False
- for fully_sharded_module, (params, buffers) in reversed(
- fully_sharded_module_to_states.items()
- ):
-
- is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module(
- fully_sharded_module, state._ignored_params
- )
- if is_meta_module or is_torchdistX_deferred_init:
- materialized_module = True
-
-
- param_names, buffer_names = _get_state_names_for_states(
- fully_sharded_module, params, buffers
- )
- if (
- is_meta_module or is_torchdistX_deferred_init
- ) and param_init_fn is not None:
- _materialize_with_param_init_fn(fully_sharded_module, param_init_fn)
- elif is_meta_module:
- _materialize_meta_module(fully_sharded_module, device_id)
- elif is_torchdistX_deferred_init:
- deferred_init.materialize_module(
- root_module,
- check_fn=lambda _: True,
- )
- if materialized_module:
-
- params = [
- fully_sharded_module.get_parameter(param_name)
- for param_name in param_names
- ]
- buffers = [
- fully_sharded_module.get_buffer(buffer_name)
- for buffer_name in buffer_names
- ]
- _move_states_to_device(params, buffers, device_from_device_id)
- if not hasattr(state, "compute_device"):
- state.compute_device = _get_compute_device(
- fully_sharded_module,
- state._ignored_params,
- device_from_device_id,
- state.rank,
- )
- if sync_module_states:
- _sync_module_states(params, buffers, state.process_group)
- _init_param_handle_from_params(state, params, fully_sharded_module)
-
-
-
- state._handles.reverse()
- return state
- @no_type_check
- def _init_param_handle_from_params(
- state: _FSDPState,
- params: List[nn.Parameter],
- fully_sharded_module: nn.Module,
- ):
- if len(params) == 0:
- return
- handle = FlatParamHandle(
- params,
- fully_sharded_module,
- state.compute_device,
- SHARDING_STRATEGY_MAP[state.sharding_strategy],
- state.cpu_offload.offload_params,
- state.mixed_precision.param_dtype,
- state.mixed_precision.reduce_dtype,
- state.mixed_precision.keep_low_precision_grads,
- state.process_group,
- state._use_orig_params,
- )
-
- handle.shard()
- assert handle not in state._handles
- state.params.append(handle.flat_param)
- state._handles.append(handle)
- state._fully_sharded_module_to_handles[handle._fully_sharded_module].append(handle)
- num_fully_sharded_module_handles = len(
- state._fully_sharded_module_to_handles[handle._fully_sharded_module]
- )
- assert num_fully_sharded_module_handles == 1, (
- "The current design assumes a module manages at most one "
- f"`FlatParamHandle` but got {num_fully_sharded_module_handles}"
- )
- cpu_device = torch.device("cpu")
- if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device:
- handle.flat_param_to(cpu_device)
- def _get_state_names_for_states(
- module: nn.Module,
- params: List[nn.Parameter],
- buffers: List[torch.Tensor],
- ) -> Tuple[List[str], List[str]]:
- """
- Returns the parameter and buffer names of the given ``params`` and
- ``buffers``, where the names are prefixed starting from ``module``. This
- function assumes that the parameters and buffers are in the module tree.
- """
- param_names: List[str] = []
- buffer_names: List[str] = []
- param_to_param_name = {
- param: param_name for param_name, param in module.named_parameters()
- }
- buffer_to_buffer_name = {
- buffer: buffer_name for buffer_name, buffer in module.named_buffers()
- }
- for param in params:
- assert (
- param in param_to_param_name
- ), f"Parameter not in the module tree:\n{module}\n{param}"
- param_names.append(param_to_param_name[param])
- for buffer in buffers:
- assert (
- buffer in buffer_to_buffer_name
- ), f"Buffer not in the module tree:\n{module}\n{buffer}"
- buffer_names.append(buffer_to_buffer_name[buffer])
- return param_names, buffer_names
- def _get_ignored_modules(
- root_module: nn.Module,
- _ignored_modules: Optional[Iterable[torch.nn.Module]],
- ) -> Set[nn.Module]:
- """
- Checks that ``_ignored_modules`` is an iterable of ``nn.Module`` s without
- any FSDP instances, and returns the modules contained in their module
- subtrees as a :class:`set`. Nested FSDP instances are excluded, but their
- already-computed ignored modules are included.
- ``_ignored_modules`` represents the argument passed by the user to FSDP.
- """
- msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s "
- try:
- ignored_root_modules = (
- set(_ignored_modules) if _ignored_modules is not None else set()
- )
- except TypeError as e:
- raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e
- for module in ignored_root_modules:
- if not isinstance(module, torch.nn.Module):
- raise TypeError(msg_prefix + f"but got an iterable with {type(module)}")
- if isinstance(module, fsdp_file.FullyShardedDataParallel):
-
-
- raise ValueError("`ignored_modules` should not include FSDP modules")
-
-
- for module in root_module.modules():
- if not traversal_utils._composable(module):
- ignored_root_modules.add(module)
-
-
-
- ignored_modules = {
- child
- for module in ignored_root_modules
- for child in module.modules()
- if not isinstance(child, fsdp_file.FullyShardedDataParallel)
- }
- if root_module in ignored_modules:
- warnings.warn(
- "Trying to ignore the top-level module passed into the FSDP "
- "constructor itself will result in all parameters being "
- f"ignored and is not well-supported: {module}"
- )
-
- for submodule in root_module.modules():
- optional_fsdp_state = _get_module_fsdp_state(submodule)
- if optional_fsdp_state is not None:
- assert hasattr(optional_fsdp_state, "_ignored_modules")
- ignored_modules.update(optional_fsdp_state._ignored_modules)
- return ignored_modules
- def _get_ignored_params(
- root_module: torch.nn.Module,
- ignored_modules: Set[torch.nn.Module],
- ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None,
- ) -> Set[torch.nn.Parameter]:
- """
- Returns the parameters of the modules in ``ignored_modules`` and
- the parameters in ``ignored_parameters``, excluding any :class:`FlatParameter` s.
- """
- all_ignored_params: Set[torch.nn.Parameter] = set()
- params_in_ignored_modules = {
- p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p)
- }
- all_ignored_params.update(params_in_ignored_modules)
- if ignored_parameters is not None:
- params_in_ignored_parameters = {
- p for p in ignored_parameters if not _is_fsdp_flattened(p)
- }
- all_ignored_params.update(params_in_ignored_parameters)
-
- for submodule in root_module.modules():
- optional_fsdp_state = _get_module_fsdp_state(submodule)
- if optional_fsdp_state is not None:
- assert hasattr(optional_fsdp_state, "_ignored_params")
- all_ignored_params.update(optional_fsdp_state._ignored_params)
- return all_ignored_params
- def _get_buffer_names(root_module: nn.Module) -> Set[str]:
- """
- Returns the fully prefixed names of all buffers in the module hierarchy
- rooted at ``root_module`` as a class:`set`.
- """
- return {
- clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers()
- }
- def _check_single_device_module(
- module: nn.Module,
- ignored_params: Set[nn.Parameter],
- ) -> None:
- """
- Raises an error if ``module`` has original parameters on multiple devices,
- ignoring the parameters in ``ignored_params``. Thus, after this method, the
- module must be either fully on the CPU or fully on a non-CPU device.
- """
- devices = {param.device for param in _get_orig_params(module, ignored_params)}
- if len(devices) > 1:
- raise RuntimeError(
- f"FSDP only supports single device modules but got params on {devices}"
- )
- def _get_device_from_device_id(
- device_id: Optional[Union[int, torch.device]],
- rank: int,
- ) -> Optional[torch.device]:
- """
- Processes ``device_id`` and returns either the corresponding device or
- ``None`` if ``device_id`` is ``None``.
- """
- if device_id is None:
- return None
- device = (
- device_id if isinstance(device_id, torch.device) else torch.device(device_id)
- )
- if device == torch.device("cuda"):
- warnings.warn(
- f"FSDP got the argument `device_id` {device_id} on rank "
- f"{rank}, which does not have an explicit index. "
- f"FSDP will use the current device {torch.cuda.current_device()}. "
- "If this is incorrect, please explicitly call `torch.cuda.set_device()` "
- "before FSDP initialization or pass in the explicit device "
- "index as the `device_id` argument."
- )
- device = torch.device("cuda", torch.cuda.current_device())
- return device
- def _need_to_materialize_module(
- module: nn.Module,
- ignored_params: Set[nn.Parameter],
- ) -> Tuple[bool, bool]:
- """
- Returns if ``module`` has parameters on meta device and if ``module`` is
- using torchdistX deferred initialization. At most of the returned bools can
- be ``True``. If either is ``True``, then ``module`` needs to be
- materialized.
- """
- managed_params = _get_orig_params(module, ignored_params)
- is_meta_module = any(param.is_meta for param in managed_params)
- is_torchdistX_deferred_init = (
- not is_meta_module
- and _TORCHDISTX_AVAIL
- and any(fake.is_fake(param) for param in managed_params)
- )
- return is_meta_module, is_torchdistX_deferred_init
- def _materialize_with_param_init_fn(
- module: nn.Module,
- param_init_fn,
- ) -> None:
- if not callable(param_init_fn):
- raise ValueError(
- f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}"
- )
- param_init_fn(module)
- def _materialize_meta_module(
- module: nn.Module,
- device_from_device_id: Optional[torch.device],
- ):
-
- materialization_device = device_from_device_id or torch.device(
- torch.cuda.current_device()
- )
- module.to_empty(device=materialization_device)
- try:
- with torch.no_grad():
- module.reset_parameters()
- except BaseException as e:
- warnings.warn(
- "Unable to call `reset_parameters()` for module on meta "
- f"device with error {str(e)}. Please ensure your "
- "module implements a `reset_parameters()` method."
- )
- raise e
- def _move_module_to_device(
- module: nn.Module,
- ignored_params: Set[nn.Parameter],
- device_from_device_id: Optional[torch.device],
- ) -> None:
- """
- Moves ``module`` depending on ``device_from_device_id`` and its current
- device. This includes moving ignored modules' parameters.
- - If ``device_from_device_id`` is not ``None``, then this moves
- ``module`` to the device.
- - If ``device_from_device_id`` is ``None``, then this does not move
- ``module`` but warns the user if it is on CPU.
- Precondition: ``_check_single_device_module()``.
- """
- param = next(_get_orig_params(module, ignored_params), None)
- if param is None:
- return
- cpu_device = torch.device("cpu")
- if device_from_device_id is not None:
- if param.device == cpu_device:
-
- module = module.to(device_from_device_id)
-
-
-
- for submodule in module.modules():
- if (
- isinstance(submodule, fsdp_file.FullyShardedDataParallel)
- and submodule.cpu_offload.offload_params
- ):
- for handle in submodule._handles:
- handle.flat_param_to(torch.device("cpu"))
- elif param.device == cpu_device:
- _warn_cpu_init()
- def _move_states_to_device(
- params: List[nn.Parameter],
- buffers: List[torch.Tensor],
- device_from_device_id: Optional[torch.device],
- ) -> None:
- """
- Precondition: ``_check_single_device_module()``.
- """
- if len(params) == 0 and len(buffers) == 0:
- return
- if len(params) > 0:
- current_device = params[0].device
- elif len(buffers) > 0:
- current_device = buffers[0].device
- cpu_device = torch.device("cpu")
- if device_from_device_id is not None:
-
-
- for param in params:
- with torch.no_grad():
- param.data = param.to(device_from_device_id)
- if param.grad is not None:
- param.grad.data = param.grad.to(device_from_device_id)
- for buffer in buffers:
- buffer.data = buffer.to(device_from_device_id)
- elif current_device == cpu_device:
- _warn_cpu_init()
- def _warn_cpu_init():
- warnings.warn(
- "The passed-in `module` is on CPU and will thus have FSDP's sharding "
- "initialization run on CPU, which may be slower than on GPU. We "
- "recommend passing in the `device_id` argument for FSDP to move "
- "`module` to GPU for the sharding initialization. `module` must also "
- "be on GPU device to work with the `sync_module_states=True` flag "
- "since that requires GPU communication."
- )
- def _get_compute_device(
- module: nn.Module,
- ignored_params: Set[nn.Parameter],
- device_from_device_id: Optional[torch.device],
- rank: int,
- ) -> torch.device:
- """
- Determines and returns this FSDP instance's compute device. If the module
- is already on a non-CPU device, then the compute device is that non-CPU
- device. If the module is on CPU, then the compute device is the current
- device.
- Since this method should be called after materializing the module, any
- non-CPU device should not be meta device. For now, the compute device is
- always a CUDA GPU device with its explicit index.
- Precondition: ``_check_single_device_module()`` and
- ``_move_module_to_device()``.
- """
-
-
- param = next(_get_orig_params(module, ignored_params), None)
- if param is not None and param.device.type == "cuda":
- compute_device = param.device
- else:
- compute_device = torch.device("cuda", torch.cuda.current_device())
- if device_from_device_id is not None and compute_device != device_from_device_id:
- raise ValueError(
- f"Inconsistent compute device and `device_id` on rank {rank}: "
- f"{compute_device} vs {device_from_device_id}"
- )
- return compute_device
- def _sync_module_params_and_buffers(
- module: nn.Module,
- params: List[nn.Parameter],
- process_group: dist.ProcessGroup,
- ) -> None:
- """
- Synchronizes module states (i.e. parameters ``params`` and all
- not-yet-synced buffers) by broadcasting from rank 0 to all ranks.
- Precondition: ``sync_module_states == True`` and ``self.process_group`` has
- been set.
- """
- _check_params_for_sync_module_states(params)
- module_states: List[torch.Tensor] = []
- for buffer in module.buffers():
-
- if not getattr(buffer, FSDP_SYNCED, False):
- setattr(buffer, FSDP_SYNCED, True)
- module_states.append(buffer.detach())
- module_states.extend(param.detach() for param in params)
- _sync_params_and_buffers(
- process_group,
- module_states,
- PARAM_BROADCAST_BUCKET_SIZE,
- src=0,
- )
- def _sync_module_states(
- params: List[nn.Parameter],
- buffers: List[torch.Tensor],
- process_group: dist.ProcessGroup,
- ) -> None:
- _check_params_for_sync_module_states(params)
-
-
- params_and_buffers = [param.detach() for param in params] + [
- buffer.detach() for buffer in buffers
- ]
- _sync_params_and_buffers(
- process_group,
- params_and_buffers,
- PARAM_BROADCAST_BUCKET_SIZE,
- src=0,
- )
- def _check_params_for_sync_module_states(
- params: List[nn.Parameter],
- ) -> None:
- if params and any(param.device == torch.device("cpu") for param in params):
- raise ValueError(
- "The module has CPU parameters when `sync_module_states=True`, "
- "which only works when all parameters are on GPU. Please specify "
- "the `device_id` argument or move the module to GPU before passing "
- "into FSDP."
- )
- def _get_orig_params(
- module: nn.Module,
- ignored_params: Set[nn.Parameter],
- ) -> Iterator[nn.Parameter]:
- """
- Returns an iterator over the original parameters in ``module``, ignoring
- the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be
- present due to nested FSDP wrapping), and any original parameters already
- flattened (only relevant when ``use_orig_params=True``).
- """
- param_gen = module.parameters()
- try:
- while True:
- param = next(param_gen)
- if param not in ignored_params and not _is_fsdp_flattened(param):
- yield param
- except StopIteration:
- pass
- def _check_orig_params_flattened(
- fsdp_module,
- ignored_params: Set[nn.Parameter],
- ) -> None:
- """
- Checks that all original parameters have been flattened and hence made
- invisible to ``named_parameters()`` for the module hierarchy rooted at
- ``fsdp_module``. This should be called as a sanity check after flattening
- the wrapped module's parameters.
- """
- for param_name, param in fsdp_module.named_parameters():
- if param not in ignored_params and not _is_fsdp_flattened(param):
- raise RuntimeError(
- f"Found an unflattened parameter: {param_name}; "
- f"{param.size()} {param.__class__}"
- )
- def _get_default_comm_hook(sharding_strategy: ShardingStrategy):
- return (
- default_hooks.allreduce_hook
- if sharding_strategy == ShardingStrategy.NO_SHARD
- else default_hooks.reduce_scatter_hook
- )
- def _get_default_comm_hook_state(
- process_group: dist.ProcessGroup,
- ) -> default_hooks.DefaultState:
- return default_hooks.DefaultState(process_group=process_group)
|