123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- import functools
- from typing import Callable, Optional, Union
- import torch
- from torch.distributed._tensor import DeviceMesh, DTensor
- _PrepareInputType = Callable[
- [Union[torch.Tensor, DTensor], Optional[DeviceMesh], Optional[int]], DTensor
- ]
- _PrepareOutputType = Callable[
- [DTensor, Optional[DeviceMesh], Optional[int]], Union[torch.Tensor, DTensor]
- ]
- def _prepare_input_validate(
- _prepare_input_func: _PrepareInputType,
- ) -> _PrepareInputType:
- """
- Inject common validation logics for `_prepare_input` funcs via this
- decorator, including verifying that input needs to be either
- a :class:`Tensor` or :class:`DTensor` and only 1D :class:`DeviceMesh`
- is passed in.
- Args:
- _prepare_input_func (Callable): The func we want to inject the
- validation into.
- Returns:
- func (Callable): Same input function with validation logic added.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> @_prepare_input_validate
- >>> def make_input_shard_1d(args, kwargs):
- >>> ...
- >>>
- >>> # xdoctest: +SKIP(failing)
- >>> input = torch.rand(...)
- >>> dtensor = make_input_shard_1d(input, device_mesh, 1)
- >>> # This will call '_prepare_input_validate' first
- """
- @functools.wraps(_prepare_input_func)
- def wrapper(*args, **kwargs): # pyre-ignore[2, 3]
- assert len(args) >= 1, "_prepare_input needs at least one arg."
- input = args[0]
- if isinstance(input, (list, tuple)):
- input = input[0]
- args = (input, *args[1:])
- device_mesh = None if len(args) < 2 else args[1]
- if device_mesh is None:
- if isinstance(input, DTensor):
- device_mesh = input.device_mesh
- args = (*args[:1], device_mesh, *args[2:]) # pyre-ignore[60]
- else:
- raise RuntimeError("device_mesh is not passed nor can be inferred")
- if device_mesh.ndim != 1:
- raise RuntimeError(
- f"device_mesh has dims {device_mesh.ndim} but expcted to be 1"
- " for input."
- )
- return _prepare_input_func(*args, **kwargs)
- return wrapper
- def _prepare_output_validate(
- _prepare_output_func: _PrepareOutputType,
- ) -> _PrepareOutputType:
- """
- Inject common validation logics for _prepare_output funcs via this
- decorator, including verifying that output needs to be a DTensor
- and only 1D Device Mesh is passed in.
- Example::
- >>> # xdoctest: +SKIP(failing)
- >>> @_prepare_output_validate
- >>> def make_output_shard_1d(args, kwargs):
- >>> ...
- >>>
- >>> # xdoctest: +SKIP(failing)
- >>> dt = distribute(tensor, device_mesh, [Shard(0)])
- >>> make_output_shard_1d(dt, device_mesh, 1)
- >>> # This will call '_prepare_output_validate' first
- Args:
- _prepare_output_func (Callable): The func we want to inject the
- validation into.
- Return:
- func (Callable): Same input func with validation logic added.
- """
- @functools.wraps(_prepare_output_func)
- def wrapper(*args, **kwargs): # pyre-ignore[2, 3]
- assert len(args) >= 1, "_prepare_output needs at least one arg."
- output = args[0]
- assert isinstance(output, DTensor), (
- "Expect output of Tensor Parallel to be a DTensor, but found"
- f" {type(output)}."
- )
- if len(args) < 2 or args[1] is None:
- device_mesh = output.device_mesh
- args = (*args[:1], device_mesh, *args[2:]) # pyre-ignore[60]
- else:
- device_mesh = args[1]
- assert device_mesh.ndim == 1, (
- f"device_mesh has dims {device_mesh.ndim} but expcted to be 1 for"
- " output."
- )
- return _prepare_output_func(*args, **kwargs)
- return wrapper
- def _create_1d_device_mesh(device_mesh: DeviceMesh, tp_mesh_dim: int = 0) -> DeviceMesh:
- """
- This function converts a N-D ``device_mesh`` into a 1D ``device_mesh``
- for 1D Tensor Parallelism.
- Args:
- device_mesh (DeviceMesh):
- :class:``DeviceMesh`` object which describes the mesh topology
- of devices for the DTensor.
- tp_mesh_dim (int):
- the dimension of ``device_mesh`` where we perform
- Tensor Parallelism on.
- Return:
- device_mesh (DeviceMesh): 1-D :class:``DeviceMesh`` object that
- Tensor Parallelism operates on.
- """
- assert tp_mesh_dim < device_mesh.ndim and tp_mesh_dim >= -device_mesh.ndim, (
- f"Expect tp_mesh_dim within range [{-device_mesh.ndim},"
- f" {device_mesh.ndim}), but found {tp_mesh_dim}."
- )
- if device_mesh.ndim == 1:
- return device_mesh
- # swap the current dim to the last dim then reshape to flatten out other
- # dims, so we can just extract the list of ranks which contains cur_rank.
- cur_rank = device_mesh.get_rank()
- pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, tp_mesh_dim).reshape(
- -1, device_mesh.mesh.size(tp_mesh_dim)
- )
- dim_mesh_1d = pg_ranks_by_dim[torch.any(pg_ranks_by_dim == cur_rank, 1), :]
- sub_pg = device_mesh.get_dim_groups()[tp_mesh_dim]
- return DeviceMesh(device_mesh.device_type, dim_mesh_1d.squeeze(), [sub_pg])
|