123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- from typing import Dict, Union
- import torch
- import torch.nn as nn
- from torch.distributed._tensor import (
- DeviceMesh,
- DTensor,
- distribute_module,
- distribute_tensor,
- Replicate,
- Shard,
- )
- from torch.distributed._tensor.sharding_prop import _CachingPropagator
- from torch.distributed.tensor.parallel._utils import _create_1d_device_mesh
- from torch.distributed.tensor.parallel.multihead_attention_tp import (
- TensorParallelMultiheadAttention,
- )
- from torch.distributed.tensor.parallel.style import (
- ColwiseParallel,
- PairwiseParallel,
- ParallelStyle,
- RowwiseParallel,
- )
- __all__ = [
- "parallelize_module",
- ]
- # switch the DTensor propagator to use the caching propagator to speed up
- # the TP eager execution time.
- DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
- def parallelize_module( # type: ignore[return]
- module: nn.Module,
- device_mesh: DeviceMesh,
- parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]],
- tp_mesh_dim: int = 0,
- ) -> nn.Module:
- """
- The API to apply Tensor Parallelism (TP) in PyTorch. We parallelize module
- or sub_modules based on a parallelize_plan. The parallelize_plan contains
- :class:`ParallelStyle`, which indicates how user wants the module or sub_module
- to be parallelized.
- User can also specify different parallel style per module fully qualifed name (FQN).
- The API supports 2D parallelism natively by accepting an n-dimension device_mesh
- and users just need to specify the dimension where we perform tensor parallelism on.
- Args:
- module (:class:`nn.Module`):
- Module to be parallelized.
- device_mesh (:class:`DeviceMesh`):
- Object which describes the mesh topology
- of devices for the DTensor.
- parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]):
- The plan used to parallelize the module. It can be either a
- :class:`ParallelStyle` object which contains how
- we prepare input/output for Tensor Parallelism or it can be a
- dict of module FQN and its corresponding :class:`ParallelStyle` object.
- tp_mesh_dim (int):
- The dimension of ``device_mesh`` where we perform
- Tensor Parallelism on.
- Return:
- A :class:`nn.Module` object parallelized.
- Example::
- >>> # xdoctest: +SKIP("distributed")
- >>> from torch.distributed._tensor.parallel import parallelize_module, PairwiseParallel
- >>>
- >>> # Define the module.
- >>> m = Model(...)
- >>> m = parallelize_module(m, PairwiseParallel())
- >>>
- .. warning::
- ``PairwiseParallel`` comes with constraints for now. If you need finer
- granularity, you need to pass in a dict of module FQN and parallel style instead.
- """
- if device_mesh.ndim > 1:
- device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
- if isinstance(parallelize_plan, ParallelStyle):
- # RowwiseParallel or ColwiseParallel
- if isinstance(parallelize_plan, (ColwiseParallel, RowwiseParallel)):
- return _parallelize_linear(module, device_mesh, parallelize_plan)
- # PairwiseParallel
- if _is_mha_for_pairwise_parallel(module):
- return _parallelize_multihead_attn(module, device_mesh)
- elif _is_mlp_for_pairwise_parallel(module):
- return _parallelize_mlp(module, device_mesh)
- else:
- for n, m in module.named_children():
- module.register_module(
- n, parallelize_module(m, device_mesh, parallelize_plan)
- )
- return module
- elif isinstance(parallelize_plan, dict):
- for module_path, parallelize_style in parallelize_plan.items():
- sub_module = module.get_submodule(module_path)
- parent_module = module
- if "." in module_path:
- parent_module_path = ".".join(module_path.split(".")[:-1])
- parent_module = module.get_submodule(parent_module_path)
- module_path = module_path.split(".")[-1]
- parent_module.register_module( # type: ignore[call-arg] # pyre-ignore[20]
- module_path,
- parallelize_module( # type: ignore[arg-type]
- sub_module, device_mesh, parallelize_style # type: ignore[arg-type] # pyre-ignore[6]
- ),
- )
- return module
- else:
- raise RuntimeError( # pyre-ignore[7]
- "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for"
- f" parallelize_plan, {type(parallelize_plan)} found!"
- )
- def _is_mha_for_pairwise_parallel(module: nn.Module) -> bool:
- """
- Check whether the mha module is the one can be handled for Pairwise parallel.
- Args:
- module (:class:`nn.Module`):
- Module to be checked.
- Return:
- A boolean object which specifies whether the module is MHA supported by Pairwise parallel or not.
- """
- return isinstance(module, (TensorParallelMultiheadAttention, nn.MultiheadAttention))
- def _is_mlp_for_pairwise_parallel(module: nn.Module) -> bool:
- """
- Traverse through all the immediate children of the given module and count the
- number of Linear module. If the number is more than one, we return True.
- Args:
- module (:class:`nn.Module`):
- Module to be traversed and counted.
- Return:
- A bool which specifies whether the module is MLP supported or not.
- .. warning::
- The traversal is not recursive for now.
- """
- linear_submodules = list(
- filter(lambda x: isinstance(x, nn.Linear), module.children())
- )
- return len(linear_submodules) > 1
- def _rowwise_parallelize_linear_fn(
- name: str,
- module: nn.Module,
- device_mesh: DeviceMesh,
- ) -> None:
- """
- This function parallelizes the input :class:`nn.Linear` module in
- :class:`RowwiseParallel` style.
- Args:
- name (str):
- Name of the input module.
- module (:class:`nn.Module`):
- The :class:`nn.Linear` module to be parallelized.
- device_mesh (:class:`DeviceMesh`):
- Object which describes the mesh topology of devices.
- Returns:
- None
- """
- for name, param in module.named_parameters():
- dist_spec = (
- [Shard(1)] if name == "weight" else [Replicate()] # type: ignore[list-item]
- )
- dist_param = torch.nn.Parameter(
- distribute_tensor(param, device_mesh, dist_spec)
- )
- module.register_parameter(name, dist_param)
- def _colwise_parallelize_linear_fn(
- name: str,
- module: nn.Module,
- device_mesh: DeviceMesh,
- ) -> None:
- """
- This function parallelizes the input :class:`nn.Linear` module in
- :class:`ColwiseParallel` style.
- Args:
- name (str):
- Name of the input module.
- module (:class:`nn.Module`):
- The :class:`nn.Linear` module to be parallelized.
- device_mesh (:class:`DeviceMesh`):
- Object which describes the mesh topology of devices.
- Returns:
- None
- """
- for name, param in module.named_parameters():
- dist_param = torch.nn.Parameter(
- distribute_tensor(param, device_mesh, [Shard(0)])
- )
- module.register_parameter(name, dist_param)
- def _parallelize_linear(
- module: nn.Module,
- device_mesh: DeviceMesh,
- parallel_style: ParallelStyle = ColwiseParallel(),
- tp_mesh_dim: int = 0,
- ) -> nn.Module:
- """
- This function requires that the input module be an object
- of :class:`nn.Linear`.
- The module will be parallelized over a 1-d :class:`DeviceMesh`
- based on the :class:`ParallelStyle`.
- Args:
- module (:class:`nn.Module`):
- The module to be parallelized.
- device_mesh (:class:`DeviceMesh`):
- Object which describes the mesh topology of devices for the :class:`DTensor`.
- If the mesh is more than 1-dimensional, we will use the mesh dim of
- `device_mesh` specified by `tp_mesh_dim`.
- parallel_style (:class:`ParallelStyle`, optional):
- The object which describes how the :class:`nn.Linear` module
- should be distributed over :class:`DeviceMesh` and how the input
- and output should be prepared for Tensor Parallelism.
- :class:`RowwiseStyle`: weight is sharded on dim 1 and bias is
- replicate.
- :class:`ColwiseStyle`: weight and bias are both sharded on dim 0.
- Default: :class:`ColwiseParallel`
- tp_mesh_dim (int):
- The dimension of :class:`DeviceMesh` on which we
- perform Tensor Parallelism.
- Default: 0
- Return:
- A :class:`nn.Module` object parallelized.
- """
- if not isinstance(module, nn.Linear):
- raise RuntimeError(
- f"Expect a torch.nn.Linear module but received {type(module)}!"
- )
- if not isinstance(parallel_style, ParallelStyle):
- raise RuntimeError(
- "Expect a ParallelStyle object but received" f" {type(parallel_style)}!"
- )
- if device_mesh.ndim > 1:
- device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
- if isinstance(parallel_style, RowwiseParallel):
- distribute_module(
- module,
- device_mesh,
- _rowwise_parallelize_linear_fn,
- input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
- output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
- )
- elif isinstance(parallel_style, ColwiseParallel):
- distribute_module(
- module,
- device_mesh,
- _colwise_parallelize_linear_fn,
- input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
- output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
- )
- else:
- raise RuntimeError(f"{type(parallel_style)} is not supported!")
- return module
- def _parallelize_multihead_attn(
- module: nn.Module,
- device_mesh: DeviceMesh,
- parallel_style: ParallelStyle = PairwiseParallel(),
- tp_mesh_dim: int = 0,
- ) -> nn.Module:
- """
- This function assumes the input module is a sequence of nn.Linear
- and we parallelize the module based on the given parallel style.
- We don't change the FQN of each sub-module and replace each parameter
- in place.
- Args:
- module (:class:`nn.Module`):
- Module to be parallelized.
- device_mesh (:class:`DeviceMesh`):
- Object which describes the mesh topology of devices.
- parallel_style (:class:`ParallelStyle`):
- Object which contains how we prepare input/output
- for Tensor Parallelism.
- tp_mesh_dim (int):
- The dimension of `device_mesh` where we perform
- Tensor Parallelism on.
- Return:
- A :class:`nn.Module` object parallelized.
- .. warning::
- We only support ``PairwiseParallel`` right now.
- """
- if not isinstance(parallel_style, PairwiseParallel):
- raise NotImplementedError(
- "Only support PairwiseParallel for Multihead Attention" " parallelization."
- )
- if device_mesh.ndim > 1:
- device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
- if isinstance(module, nn.MultiheadAttention):
- tp_multi_head_attention = TensorParallelMultiheadAttention(
- module.embed_dim,
- module.num_heads,
- device=torch.device(device_mesh.device_type),
- tp_size=device_mesh.size(tp_mesh_dim),
- add_bias_kv=module.bias_k is not None,
- )
- tp_multi_head_attention.copy(module)
- module = tp_multi_head_attention
- if isinstance(module, TensorParallelMultiheadAttention): # shard TPMA
- for n, m in module.named_children():
- if n == "qkv":
- # Col-wise Parallelize the qkv layer.
- distribute_module(
- m,
- device_mesh,
- _colwise_parallelize_linear_fn,
- input_fn=parallel_style._prepare_input, # type: ignore[arg-type, misc] # pyre-ignore[6]
- )
- elif n == "proj":
- # Row-wise Parallelize the proj layer
- distribute_module(
- m,
- device_mesh,
- _rowwise_parallelize_linear_fn,
- output_fn=parallel_style._prepare_output, # type: ignore[arg-type, misc] # pyre-ignore[6]
- )
- return module
- def _parallelize_mlp(
- module: nn.Module,
- device_mesh: DeviceMesh,
- parallel_style: ParallelStyle = PairwiseParallel(),
- tp_mesh_dim: int = 0,
- ) -> nn.Module:
- """
- This function assumes the input module is a sequence of nn.Linear
- and we parallelize the module based on the given parallel style.
- We don't change the FQN of each sub-module and replace each parameter
- in place.
- Args:
- module (:class:`nn.Module`):
- Module to be parallelized.
- device_mesh (:class:`DeviceMesh`):
- Object which describes the mesh topology of devices.
- parallel_style (:class:`ParallelStyle`):
- Object which contains how we prepare input/output
- for Tensor Parallelism.
- tp_mesh_dim (int):
- The dimension of `device_mesh` where we perform
- Tensor Parallelism on.
- Return:
- A :class:`nn.Module` object parallelized.
- .. warning::
- We only support ``PairwiseParallel`` right now.
- """
- if not isinstance(parallel_style, PairwiseParallel):
- raise NotImplementedError(
- "Only support PairwiseParallel for MLP parallelization."
- )
- if not _is_mlp_for_pairwise_parallel(module):
- raise RuntimeError("More than one nn.Linear needed for a MLP.")
- if device_mesh.ndim > 1:
- device_mesh = _create_1d_device_mesh(device_mesh, tp_mesh_dim)
- linear_submodules = list(
- filter(lambda x: isinstance(x, nn.Linear), module.children())
- )
- mlp_last_even_layer = (len(linear_submodules) // 2) * 2
- for i in range(mlp_last_even_layer):
- m = linear_submodules[i]
- if i % 2 == 0:
- # Col-wise Parallelize the linear layer
- distribute_module(
- m,
- device_mesh,
- _colwise_parallelize_linear_fn,
- input_fn=parallel_style._prepare_input # type: ignore[arg-type, misc] # pyre-ignore[6]
- if i == 0
- else None,
- )
- else:
- # Row-wise Parallelize the linear layer
- distribute_module(
- m,
- device_mesh,
- _rowwise_parallelize_linear_fn,
- output_fn=parallel_style._prepare_output # type: ignore[arg-type, misc] # pyre-ignore[6]
- if i == (mlp_last_even_layer - 1)
- else None,
- )
- return module
|