123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- from contextlib import contextmanager
- import torch
- import torch.distributed as dist
- import torch.nn as nn
- from torch.distributed import distributed_c10d
- from torch.distributed._shard.sharded_tensor import (
- ShardedTensor,
- _PartialTensor
- )
- from .replicated_tensor import ReplicatedTensor
- from .sharding_spec import (
- ShardingSpec,
- ChunkShardingSpec
- )
- from .sharding_plan import (
- ShardingPlan
- )
- from .sharder import Sharder
- def _shard_tensor(
- tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
- ) -> ShardedTensor:
- """
- Given a :class:`torch.Tensor`, it shards that tensor according to the provided
- ``sharding_spec``. ``src_rank`` denotes the source rank which would be
- used as the ground truth of the data which would be scattered as shards
- across the rest of the ranks.
- Args:
- tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
- sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
- describing how to shard the Tensor.
- Keyword args:
- src_rank (int, optional): The source rank which is used as the ground truth of
- the data for the parameter that would be sharded and scattered
- across the rest of the ranks.
- Default: 0.
- process_group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- A :class:`ShardedTensor` sharded from the given tensor.
- .. warning::
- Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
- currently supported as the ``sharding_spec``.
- """
- if not tensor.is_contiguous():
- raise ValueError('input tensor is not a contiguous Tensor')
- pg = process_group if process_group is not None else distributed_c10d._get_default_group()
- world_size = dist.get_world_size(pg)
- current_rank = dist.get_rank(pg)
- # Validate src_rank and sharding_spec are same across all ranks.
- gathered_list = [None] * world_size
- dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
- for idx, entry in enumerate(gathered_list):
- if src_rank != entry[0]: # type: ignore[index]
- raise ValueError(
- f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index]
- f'match with src_rank={entry[0]} on rank: {idx}')
- if sharding_spec != entry[1]: # type: ignore[index]
- raise ValueError(
- f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index]
- f'match with sharding_spec={entry[1]} on rank: {idx}')
- st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
- return st
- def shard_parameter(
- module: torch.nn.Module,
- param_name: str,
- sharding_spec: ShardingSpec,
- src_rank=0,
- process_group=None):
- """
- Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
- module, it shards that parameter according to the provided
- ``sharding_spec``. ``src_rank`` denotes the source rank which would be
- used as the ground truth of the data which would be scattered as shards
- across the rest of the ranks.
- This method replaces ``module.param_name`` with a
- :class:`torch.distributed._sharded_tensor.ShardedTensor`
- Args:
- module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
- param_name (str): Name of the parameter of ``module`` that needs to be sharded.
- sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
- describing how to shard the Tensor.
- Keyword args:
- src_rank (int, optional): The source rank which is used as the ground truth of
- the data for the parameter that would be sharded and scattered
- across the rest of the ranks.
- Default: 0.
- process_group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- .. warning::
- Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
- currently supported as the ``sharding_spec``.
- """
- # Perform some validation first.
- if not hasattr(module, param_name):
- raise AttributeError(f'{module._get_name()} has no attribute `{param_name}`')
- tensor = getattr(module, param_name)
- if not isinstance(tensor, torch.Tensor):
- raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
- if not tensor.is_contiguous():
- raise ValueError(f'param: {param_name} is not a contiguous Tensor')
- st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
- # Replace param with ShardedTensor.
- module.register_parameter(param_name, nn.Parameter(st))
- def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
- """
- Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
- ranks have the same value.
- Args:
- tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
- Keyword args:
- process_group (ProcessGroup, optional): The process group to replicate on.
- If None, the default process group will be used.
- Returns:
- A :class:`ReplicatedTensor` from the given tensor.
- """
- return ReplicatedTensor(tensor, process_group=process_group)
- # Tracks the current process group in the load context manager.
- _CURRENT_PROCESS_GROUP = None
- @contextmanager
- def load_with_process_group(process_group):
- """
- Context manager to set the process group with which to load a ShardedTensor/ReplicatedTensor.
- """
- global _CURRENT_PROCESS_GROUP
- if _CURRENT_PROCESS_GROUP is not None:
- raise RuntimeError(
- 'ProcessGroup already set by previous "load_with_process_group" '
- 'context manager')
- _CURRENT_PROCESS_GROUP = process_group
- try:
- yield process_group
- finally:
- _CURRENT_PROCESS_GROUP = None
- def _get_current_process_group():
- """
- Retrieves the current process group set by ``load_with_process_group``.
- If not set, it just returns the default group.
- """
- global _CURRENT_PROCESS_GROUP
- if _CURRENT_PROCESS_GROUP is None:
- return distributed_c10d._get_default_group()
- else:
- return _CURRENT_PROCESS_GROUP
- def _reshard_output(
- module: torch.nn.Module,
- resharding_spec: ShardingSpec) -> torch.nn.Module:
- """
- Hook a module with output resharding in the forward pass according
- to the given ``resharding_spec``.
- Args:
- module (:class:`torch.nn.Module`): Module whose output needs to be resharded.
- resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
- The specification describing how the output of the module will be resharded.
- Returns:
- A :class:`torch.nn.Module` object with reshard API hooked.
- """
- def hook_func(_module, _input, output):
- if isinstance(output, (ShardedTensor, _PartialTensor)):
- return output.reshard(resharding_spec)
- return output
- module.register_forward_hook(hook_func)
- return module
- def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module:
- """
- Hook a module with local shards collection in the forward pass.
- This API is typically used to convert a sharded representation back to data parallel
- representation. In particular, it returns the local tensor for this Shard. If the
- size along the sharding dimension for the local tensor is 1, this dimension is removed
- from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically
- a local Tensor of size [16] across each rank and not [1, 16] across each rank.
- Args:
- module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the
- local tensor value needs to be returned.
- Returns:
- A :class:`torch.nn.Module` object with collection API hooked.
- """
- def hook_func(_module, _input, output):
- if isinstance(output, ShardedTensor):
- local_tensor = output.local_tensor()
- # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec
- sharding_spec = output._sharding_spec
- if isinstance(sharding_spec, ChunkShardingSpec) \
- and local_tensor.size(sharding_spec.dim) == 1: # type: ignore[attr-defined, arg-type]
- local_tensor = local_tensor.squeeze(
- output._sharding_spec.dim # type: ignore[attr-defined]
- )
- return local_tensor
- module.register_forward_hook(hook_func)
- return module
- def shard_module(
- module: nn.Module,
- plan: ShardingPlan,
- src_rank=0,
- process_group=None
- ):
- """
- Shards a given module according to the provided sharding `plan`. This method
- first shards all the parameters according to the given sharding `plan`. Then if
- `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it
- will tag the output of modules according `output_plan`, convert the module's
- output back to data parallel according to `return_local_tensor`.
- Needs to be called on all ranks in an SPMD fashion.
- Args:
- module (:class:`torch.nn.Module`): The module to apply sharding to
- plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`):
- The ShardingPlan which specified param name to ShardingSpec to apply to
- each parameter.
- Keyword args:
- src_rank (int, optional): The source rank which is used as the ground truth of
- the data for the module that would be sharded and scattered across the rest
- of the ranks.
- Default: 0.
- process_group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- """
- # record Sharder paths for sanity check on the plan to ensure items in the plan
- # does not conflict with the submodule tree that the Sharder is working with
- sharder_paths = []
- for name, spec in plan.plan.items():
- if isinstance(spec, Sharder):
- sharder_paths.append(name)
- # shard the parameter according to the ShardingPlan
- for name, spec in plan.plan.items():
- if isinstance(spec, ShardingSpec):
- # if found a sharding spec, try to shard the parameter
- module_path, _, param_name = name.rpartition(".")
- for sharder_path in sharder_paths:
- if module_path.startswith(sharder_path):
- raise RuntimeError(f"ShardingPlan is in-valid, trying to shard a parameter: {name},"
- f" but there's already a Sharder entry for module {sharder_path},"
- f" parameter sharding should not conflict with the submodule tree"
- f" that a Sharder is working with!")
- mod = module.get_submodule(module_path)
- shard_parameter(
- mod,
- param_name,
- spec,
- src_rank=src_rank,
- process_group=process_group
- )
- elif isinstance(spec, Sharder):
- parent_mod_path, _, mod_name = name.rpartition(".")
- if name == "":
- raise KeyError("Module path must not be empty for custom sharder!")
- mod = module.get_submodule(name)
- parent_mod = module.get_submodule(parent_mod_path)
- sharded_mod = spec.shard(mod)
- # swap this submodule with the sharded module
- parent_mod.mod_name = sharded_mod
- else:
- raise TypeError(f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'")
- # reshard output if there's an entry in `reshard_output` for this module
- if plan.output_plan is not None:
- for module_path, output_spec in plan.output_plan.items():
- if isinstance(output_spec, ShardingSpec):
- mod = module.get_submodule(module_path)
- _reshard_output(mod, output_spec)
- else:
- raise TypeError(f"Only `ShardingSpec` is supported as output_plan for '{module_path}'")
- # convert the output back to data parallel for the modules appears in
- # `return_local_tensor` of the plan, we will call `_collect_local_shard`
- # to collect the local tensor for output of modules
- if plan.return_local_tensor is not None:
- for module_path in plan.return_local_tensor:
- mod = module.get_submodule(module_path)
- _collect_local_shard(mod)
|