123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import copy
- import warnings
- from typing import Callable, cast, Dict, Optional, Sequence
- import torch
- import torch.nn as nn
- import torch.distributed._tensor.dispatch as op_dispatch
- from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh
- from torch.distributed._tensor.placement_types import (
- _Partial,
- DTensorSpec,
- Placement,
- Replicate,
- Shard,
- )
- from torch.distributed._tensor.sharding_prop import ShardingPropagator
- from torch.distributed._tensor.redistribute import Redistribute
- from torch.utils._pytree import tree_flatten
- __all__ = ["DTensor", "distribute_tensor", "distribute_module"]
- # NOTE [Autograd interaction between torch.Tensor]
- #
- # The autograd functions defined below are being used by the public
- # facing APIs (i.e. from_local, to_local) to ensure our DTensor
- # works together with torch.Tensor within autograd engine. This
- # allows DistributedTensor to exist on part of the module hierarchy
- # and still able to calculate gradients across the torch.Tensor and
- # DistributedTensor boundary.
- # As an example, we have the a module that consists of submodules
- # A, B, and C, the execution flow would be like:
- # input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
- #
- # Suppose I only want to make Module B be a sharded module with
- # DistributedTensor params, we would need to make the folloing
- # flow to work:
- #
- # input(torch.Tensor) -> Module A
- # -> DTensor input -> Sharded Module B -> DTensor output
- # -> output (torch.Tensor) -> Module C -> output (torch.Tensor)
- #
- # We need the conversion from Module A to DTensor input, which is
- # `from_local`, and conversion from DTensor output to output, which
- # is `to_local`, thus these two functions must be Autograd functions.
- #
- class _ToTorchTensor(torch.autograd.Function):
- @staticmethod
- def forward(ctx, input: "DTensor"): # type: ignore[override]
- ctx.dtensor_device_mesh = input.device_mesh
- ctx.dtensor_placements = input.placements
- ctx.dtensor_shape = input.shape
- ctx.dtensor_requires_grad = input.requires_grad
- return input._local_tensor.detach()
- @staticmethod
- def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
- device_mesh = ctx.dtensor_device_mesh
- placements = ctx.dtensor_placements
- return DTensor(
- grad_output,
- device_mesh,
- placements,
- size=ctx.dtensor_shape,
- requires_grad=grad_output.requires_grad,
- )
- class _FromTorchTensor(torch.autograd.Function):
- @staticmethod
- def forward( # type: ignore[override]
- ctx, # pyre-ignore[2]: Parameter must be annotated.
- input: torch.Tensor,
- device_mesh: DeviceMesh,
- placements: Sequence[Placement],
- run_check: bool,
- ) -> "DTensor":
- ctx.previous_placement = placements
- ctx.previous_device_mesh = device_mesh
- if run_check:
- # TODO: by default check tensor metas across rank
- # TODO: See if we need to make this run_check logic
- # have a corresponding backward.
- for idx, placement in enumerate(placements):
- if placement.is_replicate():
- # broadcast rank 0 tensor to all ranks
- # only broadcast if run_check is True
- input = input.contiguous()
- device_mesh.broadcast(input, mesh_dim=idx)
- # if it's not by default run_check, we assume user is certain that each
- # rank has the same tensor shape, and we just use that to calculate the
- # global shape
- tensor_shape = list(input.size())
- for idx, placement in enumerate(placements):
- if placement.is_shard():
- shard_dim = cast(Shard, placement).dim
- local_dim_size = tensor_shape[shard_dim]
- tensor_shape[shard_dim] = local_dim_size * device_mesh.size(idx)
- dist_tensor = DTensor(
- input,
- device_mesh,
- placements,
- size=torch.Size(tensor_shape),
- # requires_grad of the dist tensor depends on if input
- # requires_grad or not
- requires_grad=input.requires_grad,
- )
- return dist_tensor
- @staticmethod
- def backward(ctx, grad_output: "DTensor"): # type: ignore[override]
- previous_placement = ctx.previous_placement
- previous_device_mesh = ctx.previous_device_mesh
- # reshard to the placement when creating DistributedTensor
- # so that the gradient layout matches, and we could return
- # local gradients directly
- if grad_output.placements != previous_placement:
- # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
- grad_output = Redistribute.apply(
- grad_output, previous_device_mesh, previous_placement
- )
- # TODO: backward is also differentiable now, add a test
- # to test higher level gradients.
- return grad_output.to_local(), None, None, None
- class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
- _local_tensor: torch.Tensor
- _spec: DTensorSpec
- __slots__ = ["_local_tensor", "_spec"]
- # class attribute that handles operator placements propagation
- # rules, keyed by aten op name, value is propagation func
- _propagator: ShardingPropagator = ShardingPropagator()
- # class attribute that handles custom registered ops, all handled
- # custom ops should appear in this table, and overriding the default
- # operators that's been covered by _op_to_rules or fallbacks.
- # (custom operator is the highest priority when dispatching).
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
- _custom_dispatch_ops: Dict[str, Callable] = {}
- @staticmethod
- def __new__(
- cls,
- local_tensor: torch.Tensor,
- device_mesh: DeviceMesh,
- placements: Sequence[Placement],
- *,
- size: torch.Size,
- requires_grad: bool = False,
- ) -> "DTensor":
- """
- Construct a DTensor from a local tensor, device mesh, and placement and
- other tensor properties (i.e. shape, requires_grad, strides, etc).
- Note: This is not a public API and it's only supposed to be used by the
- operator implementations and internals. If you want to construct a
- DTensor from a local tensor, consider using `DTensor.from_local`, if
- you want to construct a DTensor from a "global" tensor (where you
- already have tensor initialized and want to shard this tensor),
- consider using `distribute_tensor`.
- """
- # recover tensor strides from local tensor strides and global size info
- # in the case of sharding
- # TODO: we should try to use meta tensor for shape and stride calculation
- tensor_stride = list(local_tensor.stride())
- local_size = list(local_tensor.size())
- for placement in placements:
- if isinstance(placement, Shard):
- shard_dim = placement.dim
- # recover tensor stride by modifying the stride that larger than
- # the current stride on the shard_dim
- for i in range(len(tensor_stride)):
- if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
- # rescale the stride by the shard size
- tensor_stride[i] = (
- tensor_stride[i] // local_size[shard_dim]
- ) * size[shard_dim]
- elif not isinstance(placement, (Replicate, _Partial)):
- raise RuntimeError(f"placement type {type(placement)} not supported!")
- if requires_grad != local_tensor.requires_grad:
- warnings.warn(
- "To construct DTensor from torch.Tensor, it's recommended to "
- "use local_tensor.detach() and make requires_grad consistent."
- )
- # new method instruct wrapper tensor from local_tensor and add
- # placement spec, it does not do actual distribution
- r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
- cls,
- size,
- strides=tensor_stride,
- dtype=local_tensor.dtype,
- device=local_tensor.device,
- layout=local_tensor.layout,
- requires_grad=requires_grad,
- )
- # deepcopy and set spec
- r._spec = DTensorSpec(device_mesh, copy.deepcopy(placements), shape=r.size())
- # detach local tensor from autograd graph as we initialize the
- # distributed tensor and autograd will be working on top of
- # the wrapper tensor directly instead of local torch.Tensor
- r._local_tensor = local_tensor.detach()
- return r
- # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
- # pyre-fixme[3]: Return type must be annotated.
- def __repr__(self):
- # TODO: consider all_gather the local tensors for better debugging
- return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
- @classmethod
- # pyre-fixme[3]: Return type must be annotated.
- # pyre-fixme[2]: Parameter must be annotated.
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- # check that we are not getting mixed vanilla and Distributed tensors
- arg_list, _ = tree_flatten(args)
- for arg in arg_list:
- if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
- raise RuntimeError(
- f"{func}: got mixed distributed and non-distributed tensors."
- )
- if kwargs is None:
- kwargs = {}
- return op_dispatch.operator_dispatch(
- func,
- args,
- kwargs,
- DTensor._propagator,
- DTensor._custom_dispatch_ops,
- )
- @classmethod
- def from_local(
- cls,
- local_tensor: torch.Tensor,
- device_mesh: Optional[DeviceMesh] = None,
- placements: Optional[Sequence[Placement]] = None,
- run_check: bool = True,
- ) -> "DTensor":
- """
- Create a :class:`DTensor` from a local torch.Tensor on each rank
- according to the `device_mesh` and `placements` specified.
- Args:
- local_tensor (torch.Tensor): local torch.Tensor on each rank.
- device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
- tensor, if not specified, must be called under a DeviceMesh
- context manager, default: None
- placements (List[:class:`Placement`], optional): the placements that
- describes how to place the local torch.Tensor on DeviceMesh, must
- have the same number of elements as `device_mesh.ndim`. If not
- specified, we will by default replicate the tensor across the
- `device_mesh` from the first rank of each dimension of the `device_mesh`.
- run_check (bool, optional): indicate whether to run check across ranks
- to check meta information and data. if have :class:`Replicate` in
- `placements`, the data on first rank of the device mesh dimension
- will be broadcasted to other ranks.
- Returns:
- A :class:`DTensor` object
- .. note:: `from_local` is differentiable, the `requires_grad` of the created
- `DTensor` object will depend on if `local_tensor` requires_grad or not.
- """
- # if same shape/dtype, no need to run_check, if not, must allgather
- # the metadatas to check the size/dtype across ranks
- # There should be no data communication unless there's replication
- # strategy, where we broadcast the replication from the first rank
- # in the mesh dimension
- device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
- # convert the local tensor to desired device base on device mesh's device_type
- if not local_tensor.is_meta:
- local_tensor = local_tensor.to(device_mesh.device_type)
- # set default placements to replicated if not specified
- if placements is None:
- placements = [Replicate() for _ in range(device_mesh.ndim)]
- # `from_local` is differentiable, and the gradient of the dist tensor this function
- # created should flow back the gradients to the local_tensor, so we call an autograd
- # function to construct the dist tensor instead.
- return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
- local_tensor, device_mesh, placements, run_check
- )
- def to_local(self) -> torch.Tensor:
- """
- Get the local tensor of this DTensor on its current rank. For sharding it returns
- a local shard of the logical tensor view, for replication it returns the replica on
- its current rank.
- Returns:
- A :class:`torch.Tensor` object that represents the local tensor of its current rank.
- .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned
- will depend on if the `DTensor` requires_grad or not.
- """
- return _ToTorchTensor.apply(self) # pyre-ignore[16]: autograd func
- def redistribute(
- self,
- device_mesh: Optional[DeviceMesh] = None,
- placements: Optional[Sequence[Placement]] = None,
- ) -> "DTensor":
- """
- `redistribute` performs necessary collective operations that redistribute the current
- DTensor from its current placements to a new placements, or from is current DeviceMesh
- to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
- specifying a Replicate placement for each dimension of the DeviceMesh.
- Args:
- device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
- DTensor, if not specified, must be called under a DeviceMesh
- context manager, default: None
- placements (List[:class:`Placement`], optional): the new placements that
- describes how to place the DTensor into the DeviceMesh, must
- have the same number of elements as `device_mesh.ndim`.
- Returns:
- A :class:`DTensor` object
- .. note:: `redistribute` is differentiable.
- """
- # This API perform necessary transformations and get
- # a new DTensor with the new spec. i.e. for
- # sharding it's a reshard behavior.
- # Note that redistribute currently only supports out
- # of place redistribution, i.e. it always create a new
- # DTensor object and leave the original one unchanged.
- device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
- # raise error if new placements not specified
- if placements is None:
- raise RuntimeError("placements is needed for redistribute!")
- for placement in placements:
- if placement.is_partial():
- raise RuntimeError(
- "Can not redistribute to _Partial, _Partial is for internal use only!"
- )
- # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
- return Redistribute.apply(self, device_mesh, placements)
- @property
- def device_mesh(self) -> DeviceMesh:
- """
- The :class:`DeviceMesh` attribute that associates with this DTensor object.
- .. note:: device_mesh is a read-only property, it can not be set.
- """
- return self._spec.mesh
- @property
- def placements(self) -> Sequence[Placement]:
- """
- The placements attribute of this DTensor that describes the layout of this
- DTensor on the its DeviceMesh.
- .. note:: placements is a read-only property, it can not be set.
- """
- return self._spec.placements
- def distribute_tensor(
- tensor: torch.Tensor,
- device_mesh: Optional[DeviceMesh] = None,
- placements: Optional[Sequence[Placement]] = None,
- ) -> DTensor:
- """
- Distribute a torch.Tensor to the `device_mesh` according to the `placements`
- specified. The rank of `device_mesh` and `placements` must be the same.
- Args:
- tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
- want to shard a tensor on a dimension that is not evenly divisible by
- the number of devices in that mesh dimension, we use `torch.tensor_split`
- semantic to shard the tensor and scatter the shards.
- device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
- tensor, if not specified, must be called under a DeviceMesh context
- manager, default: None
- placements (List[:class:`Placement`], optional): the placements that
- describes how to place the tensor on DeviceMesh, must have the same
- number of elements as `device_mesh.ndim`. If not specified, we will
- by default replicate the tensor across the `device_mesh` from the
- first rank of each dimension of the `device_mesh`.
- Returns:
- A :class:`DTensor` object
- """
- # get default device mesh if there's nothing specified
- device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
- # convert tensor to the correponding device type if it's not in that device type
- if not tensor.is_meta:
- tensor = tensor.to(device_mesh.device_type)
- # set default placements to replicated if not specified
- if placements is None:
- placements = [Replicate() for _ in range(device_mesh.ndim)]
- if len(placements) != device_mesh.ndim:
- raise ValueError(
- f"`placements` must have the same length as `device_mesh.ndim`! "
- f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
- )
- if isinstance(tensor, DTensor):
- # if the tensor is already a DTensor, we just need to check if the
- # device mesh and placements are the same
- if tensor.device_mesh != device_mesh:
- raise ValueError(
- f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
- f"to a different device mesh {device_mesh}."
- )
- if tensor.placements != placements:
- raise ValueError(
- f"Cannot distribute a DTensor with placements {tensor.placements} "
- f"to a different placements {placements}. do you want to call "
- f"`redistribute` instead?"
- )
- return tensor
- local_tensor = tensor
- # distribute the tensor according to the placements.
- for idx, placement in enumerate(placements):
- if placement.is_shard():
- placement = cast(Shard, placement)
- output = placement._shard_tensor(local_tensor, device_mesh, idx)
- # scatter call could not return a tensor with correct requires_grad
- # field, as ProcessGroupNCCL refuse to take a tensor with requires_grad
- # to do inplace update! So we manually set it here
- output.requires_grad_(tensor.requires_grad)
- local_tensor = output
- elif placement.is_replicate():
- local_tensor = local_tensor.contiguous()
- device_mesh.broadcast(local_tensor, mesh_dim=idx)
- else:
- raise RuntimeError(
- f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
- )
- assert local_tensor is not None, "distributing a tensor should not be None"
- return DTensor(
- local_tensor,
- device_mesh,
- placements,
- size=tensor.size(),
- requires_grad=tensor.requires_grad,
- )
- def distribute_module(
- module: nn.Module,
- device_mesh: Optional[DeviceMesh] = None,
- partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
- input_fn: Optional[Callable[..., None]] = None,
- output_fn: Optional[Callable[..., None]] = None,
- ) -> nn.Module:
- """
- This function converts all module parameters to :class:`DTensor` parameters
- according to the `partition_fn` specified. It could also control the input or
- output of the module by specifying the `input_fn` and `output_fn`. (i.e. convert
- the input to :class:`DTensor`, convert the output back to torch.Tensor)
- Args:
- module (:class:`nn.Module`): user module to be partitioned.
- device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
- partition_fn (Callable): the function to partition parameters (i.e. shard certain
- parameters across the `device_mesh`). If `partition_fn` is not specified,
- by default we replicate all module parameters of `module` across the mesh.
- input_fn (Callable): specify the input distribution, i.e. could control how the
- input of the module is sharded. `input_fn` will be installed as a module
- `forward_pre_hook` (pre forward hook).
- output_fn (Callable): specify the output distribution, i.e. could control how the
- output is sharded, or convert it back to torch.Tensor. output_fn will be
- installed as a module `forward_hook` (post forward hook).
- Returns:
- A module that contains parameters/buffers that are all `DTensor`s.
- """
- if device_mesh is None:
- device_mesh = get_global_device_mesh()
- def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
- # This function loop over the immediate module parameters and
- # buffers, replicate all non DTensor params/buffers to DTensor
- # parameters/buffers, if they have not been partitioned in the
- # partition_fn, we can't easily use `module._apply` here
- # because we don't know what happened inside partition_fn as
- # user could do anything, i.e. install hooks, and we want to
- # preserve those.
- full_replicate = [Replicate()] * mesh.ndim
- for key, param in m._parameters.items():
- if param is not None and not isinstance(param, DTensor):
- m.register_parameter(
- key,
- nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
- )
- for key, buffer in m._buffers.items():
- if buffer is not None and not isinstance(buffer, DTensor):
- m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
- if partition_fn is None:
- # if partition_fn not specified, we by default replicate
- # all module params/buffers
- for name, submod in module.named_modules():
- replicate_module_params_buffers(submod, device_mesh)
- else:
- # apply partition_fun to submodules
- for name, submod in module.named_modules():
- partition_fn(name, submod, device_mesh)
- replicate_module_params_buffers(submod, device_mesh)
- # register input_fn as module forward pre hook
- if input_fn is not None:
- module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[misc]
- # register input_fn as module forward hook
- if output_fn is not None:
- module.register_forward_hook(
- lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[misc]
- )
- return module
|