123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- from typing import cast, Dict, List, Sequence, Tuple
- import torch
- import torch.distributed._tensor.api as dtensor
- from torch.distributed._tensor.device_mesh import DeviceMesh
- from torch.distributed._tensor.placement_types import (
- _Partial,
- Placement,
- Replicate,
- Shard,
- )
- _PlacementItem = Tuple[int, Tuple[Placement, Placement]]
- def _replicate_then_shard(val: _PlacementItem) -> int:
- """
- Replicate from inner to outer dimension.
- Shard from outer to inner dimension.
- """
- i, (current, target) = val
- if (target.is_replicate() or target.is_partial()) and current.is_shard():
- return -i
- elif (current.is_replicate() or current.is_partial()) and target.is_shard():
- return i
- else:
- return 0
- def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]:
- """
- Decompose Si -> Sj into Si -> R -> Sj
- There's 2 ways a shardings can differ within a mesh dimension:
- 1) sharding on different tensor dimensions, e.g. Shard(0) -> Shard(1)
- 2) different sub-shards of a repeated shard ("mis-aligned sharding")
- (Shard(0), Shard(0)) -> (Replicate(), Shard(0))
- Here the Shard(0) -> Shard(0) for mesh dimension 2 is actually
- a reshard, because in the first case it's a sub-sharding of an already tensor dimension 0,
- and in the second case, it's the first sharding on tensor dimesnion 0.
- """
- # detect mis-aligned repeated shardings
- from collections import defaultdict
- repeat_dim_current: Dict[int, int] = defaultdict(int)
- repeat_dim_target: Dict[int, int] = defaultdict(int)
- output: List[_PlacementItem] = []
- for i, (current, target) in val:
- # detect mis-aligned sharding
- if current.is_shard():
- repeat_dim_current[cast(Shard, current).dim] += 1
- if target.is_shard():
- repeat_dim_target[cast(Shard, target).dim] += 1
- if (
- isinstance(current, Shard)
- and isinstance(target, Shard)
- and (
- current.dim != target.dim
- or repeat_dim_current[current.dim] != repeat_dim_target[target.dim]
- )
- ):
- # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j)
- output.append((i, (current, Replicate())))
- output.append((i, (Replicate(), target)))
- else:
- output.append((i, (current, target)))
- return output
- # Intentionally expose this API to trace ops on local tensors
- def _redistribute_with_local_tensor(
- local_tensor: torch.Tensor,
- size: torch.Size,
- device_mesh: DeviceMesh,
- current_placements: Sequence[Placement],
- target_placements: Sequence[Placement],
- ) -> torch.Tensor:
- new_local_tensor = None
- sorted_placements = list(enumerate(zip(current_placements, target_placements)))
- sorted_placements = _decompose_reshard(sorted_placements)
- sorted_placements.sort(key=_replicate_then_shard)
- for i, (current, target) in sorted_placements:
- my_coordinate = device_mesh.get_coordinate_on_dim(i)
- num_chunks = device_mesh.size(dim=i)
- # TODO: what should happen if rank is not in the mesh?
- # see issue https://github.com/pytorch/tau/pull/492
- assert (
- my_coordinate is not None
- ), "Rank if not part of mesh" # TODO: figure out behavior here
- if current == target:
- # short cut, just use the original local tensor
- new_local_tensor = local_tensor
- continue
- if target.is_replicate():
- # Case 1: target is Replicate
- if current.is_partial():
- partial_spec = cast(_Partial, current)
- new_local_tensor = partial_spec._to_replicate(
- local_tensor, device_mesh, i
- )
- elif current.is_shard():
- current_placement = cast(Shard, current)
- new_local_tensor = current_placement._to_replicate_tensor(
- local_tensor, size, device_mesh, i
- )
- else:
- raise RuntimeError(
- f"redistribute from {current_placements} to {target_placements} not supported yet"
- )
- elif target.is_shard():
- # Case 2: target is Shard
- target_placement = cast(Shard, target)
- if current.is_partial():
- partial_spec = cast(_Partial, current)
- new_local_tensor = partial_spec._to_shard(
- local_tensor, device_mesh, i, target_placement
- )
- elif current.is_replicate():
- # split the tensor and return the corresponding cloned local shard
- shards, _ = target_placement._split_tensor(
- local_tensor,
- num_chunks,
- with_padding=False,
- contiguous=False,
- )
- new_local_tensor = shards[my_coordinate].clone()
- else:
- # NOTE: this case shouldn't hit _decompose_sharding, decompose sharding should
- # decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1)
- assert (
- current.is_shard()
- ), f"Current placement should be shard but found {current}"
- shard_spec = cast(Shard, current)
- if shard_spec.dim != target_placement.dim:
- # TODO: enable this with all_to_all
- raise NotImplementedError(
- "Changing sharding dim is not supported yet!"
- )
- elif target.is_partial():
- if current.is_replicate():
- # For replicate -> partial, we zero out all other ranks of the current mesh dim
- # and leave only 1 rank have the data, to perform a "zero cost" reshard.
- if my_coordinate is not None and my_coordinate != 0:
- new_local_tensor = local_tensor.zero_()
- else:
- new_local_tensor = local_tensor
- else:
- raise RuntimeError(
- f"redistribute from {current_placements} to {target_placements} not supported yet"
- )
- assert new_local_tensor is not None
- local_tensor = new_local_tensor
- assert new_local_tensor is not None, "redistribute failed!"
- return new_local_tensor
- def redistribute_dtensor(
- input: "dtensor.DTensor",
- device_mesh: DeviceMesh,
- placements: Sequence[Placement],
- ) -> "dtensor.DTensor":
- if input.device_mesh != device_mesh:
- # TODO: alltoall reshuffling to change device_mesh if they are not the same
- raise NotImplementedError("Cross device mesh comm not supported yet!")
- local_tensor = input._local_tensor
- new_local_tensor = _redistribute_with_local_tensor(
- local_tensor,
- input.size(),
- device_mesh,
- input.placements,
- placements,
- )
- return dtensor.DTensor(
- new_local_tensor,
- device_mesh,
- placements,
- size=input.size(),
- requires_grad=local_tensor.requires_grad,
- )
- class Redistribute(torch.autograd.Function):
- @staticmethod
- def forward( # type: ignore[override]
- # pyre-fixme[2]: Parameter must be annotated.
- ctx,
- input: "dtensor.DTensor",
- device_mesh: DeviceMesh,
- placements: List[Placement],
- ):
- ctx.previous_placement = input.placements
- ctx.previous_device_mesh = input.device_mesh
- return redistribute_dtensor(input, device_mesh, placements)
- @staticmethod
- def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override]
- previous_placement = ctx.previous_placement
- previous_device_mesh = ctx.previous_device_mesh
- # When we run backward pass of redistribute (i.e. manual redistribute from
- # user code instead of torch_dispatch), we scan first and see if we need
- # to change the target placement for one special case:
- # replicate -> partial.
- # In this case we keep the grad as replicate, this is because we don't
- # want to convert the replicated gradients back to partial, although
- # that's logically conform with the same layout, converting the gradients
- # back to partial is acutally useless as you would have to do reduce later
- # which would be more expensive than keeping it replicate! For this reason,
- # we keep the replicate grad here.
- # TODO: see if this make sense for all cases.
- target_placements: List[Placement] = []
- for current, target in zip(grad_output.placements, previous_placement):
- if current.is_replicate() and target.is_partial():
- # keep target placement to replicate instead of partial in this case
- target_placements.append(current)
- else:
- target_placements.append(target)
- return (
- redistribute_dtensor(grad_output, previous_device_mesh, target_placements),
- None,
- None,
- )
|