# Copyright (c) Meta Platforms, Inc. and affiliates from abc import ABC, abstractmethod from typing import Optional, Union import torch from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard from torch.distributed.tensor.parallel._utils import ( _prepare_input_validate, _prepare_output_validate, _PrepareInputType, _PrepareOutputType, ) __all__ = [ "ParallelStyle", "RowwiseParallel", "ColwiseParallel", "PairwiseParallel", "make_input_replicate_1d", "make_input_shard_1d", "make_input_shard_1d_last_dim", "make_output_replicate_1d", "make_output_tensor", "make_output_shard_1d", ] class ParallelStyle(ABC): """ The parallel style user wants the module or submodule to be parallelized. Users can extend this class to build their own parallel style with customized input/output preparations. """ _prepare_input: _PrepareInputType _prepare_output: _PrepareOutputType @abstractmethod def __init__(self, _prepare_input, _prepare_output) -> None: self._prepare_input = _prepare_input # type: ignore[assignment, misc] self._prepare_output = _prepare_output # type: ignore[assignment, misc] class PairwiseParallel(ParallelStyle): """ PairwiseParallel concatenate colwise and rowwise styles as a fixed pair like what Megatron-LM(https://arxiv.org/abs/1909.08053) is doing. We assume both input and output needs to a replicate DTensor. .. warning:: PairwiseParallel only supports ``nn.Multihead Attention``, ``nn.Transformer`` or even-number-layer MLP for now. """ def __init__(self) -> None: super().__init__(make_input_replicate_1d, make_output_tensor) class RowwiseParallel(ParallelStyle): """ Partitioning the row of a module. We assume the input to be a sharded :class:`DTensor` and output to be a replicated :class:`DTensor`. """ def __init__(self) -> None: super().__init__(make_input_shard_1d_last_dim, make_output_replicate_1d) class ColwiseParallel(ParallelStyle): """ Partitioning the column of a tensor or module. We assume the input to be a replicated :class:`DTensor` and output to be a sharded :class:`DTensor`. """ def __init__(self) -> None: super().__init__(make_input_replicate_1d, make_output_replicate_1d) @_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] def make_input_shard_1d( input: Union[torch.Tensor, DTensor], device_mesh: Optional[DeviceMesh] = None, dim: int = 0, ) -> DTensor: """ Shard input tensor on ``dim`` over an 1-D device mesh. This function will be used in ParallelStyle. Args: input (Union[:class:`torch.Tensor`, :class:`DTensor`]): Single tensor will be sharded on dimension ``dim`` over the 1-D :class:`DeviceMesh`. device_mesh (:class:`DeviceMesh`, optional): The 1-D device mesh where ``input`` will be sharded. If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`, `input.device_mesh` will be used. If :class:`DeviceMesh` is not 1-D, an exception will be thrown. Default: ``None`` dim (int, optional): The sharding dimension of ``input`` tensor. Default: 0 Returns: A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``. """ shard_spec = [Shard(dim)] if isinstance(input, DTensor): return input.redistribute(device_mesh, shard_spec) elif isinstance(input, torch.Tensor): return DTensor.from_local(input, device_mesh, shard_spec, run_check=False) else: raise RuntimeError( "Tensor parallel module expects torch.Tensor or DTensor input but" f" received {type(input)}!" ) def make_input_shard_1d_last_dim( input: Union[torch.Tensor, DTensor], device_mesh: Optional[DeviceMesh] = None, ) -> DTensor: """ Wrapper func of ``make_input_shard_1d`` with ``dim`` = -1. Args: input (Union[:class:`torch.Tensor`, :class:`DTensor`]): This single tensor will be sharded on dimension ``dim`` over the 1-D :class:`DeviceMesh`. device_mesh (:class:`DeviceMesh`, optional): The 1-D device mesh where ``input`` will be sharded. If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`, `input.device_mesh` will be used. If :class:`DeviceMesh` is not 1-D, an exception will be thrown. Default: ``None`` Returns: A :class:`DTensor` sharded on dimension ``dim`` over ``device_mesh``. """ return make_input_shard_1d(input, device_mesh, dim=-1) # type: ignore[call-arg] @_prepare_input_validate # type: ignore[arg-type] # pyre-ignore[56] def make_input_replicate_1d( input: Union[torch.Tensor, DTensor], device_mesh: Optional[DeviceMesh] = None, ) -> DTensor: """ Replicate input tensor over an 1-D device mesh. This function will be used in ParallelStyle. Args: input (Union[:class:`torch.Tensor`, :class:`DTensor`]): This input tensor will be replicated over the 1-D :class:`DeviceMesh`. device_mesh (:class:`DeviceMesh`, optional): The 1-D device mesh where ``input`` will be replicated. If no :class:`DeviceMesh` is passed and ``input`` is a :class:`DTensor`, ``input.device_mesh`` will be used. If :class:`DeviceMesh` is not 1-D, an exception will be thrown. Default: ``None`` Returns: A :class:`DTensor` replicated over ``device_mesh``. """ replicate = [Replicate()] if isinstance(input, DTensor): return input.redistribute(device_mesh, replicate) elif isinstance(input, torch.Tensor): return DTensor.from_local(input, device_mesh, replicate, run_check=False) else: raise RuntimeError( "Tensor parallel module expects torch.Tensor or DTensor input but" f" received {type(input)}!" ) @_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_shard_1d( output: DTensor, device_mesh: Optional[DeviceMesh] = None, dim: int = 0 ) -> DTensor: """ Convert Output DTensor to a sharded DTensor. This will be used in ParallelStyle. Args: output (:class:`DTensor`): Output of module to be converted. device_mesh (:class:`DeviceMesh`, optional): Object needed to shard the output and it needs to be a 1D ``device_mesh`` and we will throw exceptions if a non-1D ``device_mesh`` is passed in. If no ``device_mesh`` is passed in, we will reuse the one from output. Default: ``None`` dim (int): Sharding dim for output. Default: 0 Return: A :class:`DTensor` object sharded on the given dim. """ return output.redistribute(device_mesh, [Shard(dim)]) @_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_replicate_1d( output: DTensor, device_mesh: Optional[DeviceMesh] = None ) -> DTensor: """ Convert Output DTensor to a replicated DTensor. This will be used in ParallelStyle. Args: output (:class:`DTensor`): Output of module to be converted. device_mesh (:class:`DeviceMesh`, optional): Object needed to replicate the output and it needs to be a 1D ``device_mesh`` and we will throw exceptions if a non-1D ``device_mesh`` is passed in. If no ``device_mesh`` is passed in, we will reuse the one from output. Default: ``None`` Return: A :class:`DTensor` object made replicate. """ return output.redistribute(device_mesh, [Replicate()]) @_prepare_output_validate # type: ignore[arg-type] # pyre-ignore[56] def make_output_tensor( output: DTensor, device_mesh: Optional[DeviceMesh] = None ) -> torch.Tensor: """ Convert Output DTensor to a replicated DTensor first and then convert it to Tensor. Args: output (:class:`DTensor`): Output of module to be converted. device_mesh (:class:`DeviceMesh`, optional): Object which is needed to replicate the output and it needs to be a 1D ``device_mesh`` and we will throw exceptions if a non-1D ``device_mesh`` is passed in. If no ``device_mesh`` is passed in, we will reuse the one from output. Default: ``None`` Return: A :class:`torch.Tensor` object converted from output DTensor. """ return make_output_replicate_1d( # type: ignore[attr-defined] output, device_mesh ).to_local() # type: ignore[call-arg]