# Copyright (c) Meta Platforms, Inc. and affiliates from typing import cast, Dict, List, Sequence, Tuple import torch import torch.distributed._tensor.api as dtensor from torch.distributed._tensor.device_mesh import DeviceMesh from torch.distributed._tensor.placement_types import ( _Partial, Placement, Replicate, Shard, ) _PlacementItem = Tuple[int, Tuple[Placement, Placement]] def _replicate_then_shard(val: _PlacementItem) -> int: """ Replicate from inner to outer dimension. Shard from outer to inner dimension. """ i, (current, target) = val if (target.is_replicate() or target.is_partial()) and current.is_shard(): return -i elif (current.is_replicate() or current.is_partial()) and target.is_shard(): return i else: return 0 def _decompose_reshard(val: List[_PlacementItem]) -> List[_PlacementItem]: """ Decompose Si -> Sj into Si -> R -> Sj There's 2 ways a shardings can differ within a mesh dimension: 1) sharding on different tensor dimensions, e.g. Shard(0) -> Shard(1) 2) different sub-shards of a repeated shard ("mis-aligned sharding") (Shard(0), Shard(0)) -> (Replicate(), Shard(0)) Here the Shard(0) -> Shard(0) for mesh dimension 2 is actually a reshard, because in the first case it's a sub-sharding of an already tensor dimension 0, and in the second case, it's the first sharding on tensor dimesnion 0. """ # detect mis-aligned repeated shardings from collections import defaultdict repeat_dim_current: Dict[int, int] = defaultdict(int) repeat_dim_target: Dict[int, int] = defaultdict(int) output: List[_PlacementItem] = [] for i, (current, target) in val: # detect mis-aligned sharding if current.is_shard(): repeat_dim_current[cast(Shard, current).dim] += 1 if target.is_shard(): repeat_dim_target[cast(Shard, target).dim] += 1 if ( isinstance(current, Shard) and isinstance(target, Shard) and ( current.dim != target.dim or repeat_dim_current[current.dim] != repeat_dim_target[target.dim] ) ): # decompose Shard(i) -> Shard(j) into Shard(i) -> Replicate() -> Shard(j) output.append((i, (current, Replicate()))) output.append((i, (Replicate(), target))) else: output.append((i, (current, target))) return output # Intentionally expose this API to trace ops on local tensors def _redistribute_with_local_tensor( local_tensor: torch.Tensor, size: torch.Size, device_mesh: DeviceMesh, current_placements: Sequence[Placement], target_placements: Sequence[Placement], ) -> torch.Tensor: new_local_tensor = None sorted_placements = list(enumerate(zip(current_placements, target_placements))) sorted_placements = _decompose_reshard(sorted_placements) sorted_placements.sort(key=_replicate_then_shard) for i, (current, target) in sorted_placements: my_coordinate = device_mesh.get_coordinate_on_dim(i) num_chunks = device_mesh.size(dim=i) # TODO: what should happen if rank is not in the mesh? # see issue https://github.com/pytorch/tau/pull/492 assert ( my_coordinate is not None ), "Rank if not part of mesh" # TODO: figure out behavior here if current == target: # short cut, just use the original local tensor new_local_tensor = local_tensor continue if target.is_replicate(): # Case 1: target is Replicate if current.is_partial(): partial_spec = cast(_Partial, current) new_local_tensor = partial_spec._to_replicate( local_tensor, device_mesh, i ) elif current.is_shard(): current_placement = cast(Shard, current) new_local_tensor = current_placement._to_replicate_tensor( local_tensor, size, device_mesh, i ) else: raise RuntimeError( f"redistribute from {current_placements} to {target_placements} not supported yet" ) elif target.is_shard(): # Case 2: target is Shard target_placement = cast(Shard, target) if current.is_partial(): partial_spec = cast(_Partial, current) new_local_tensor = partial_spec._to_shard( local_tensor, device_mesh, i, target_placement ) elif current.is_replicate(): # split the tensor and return the corresponding cloned local shard shards, _ = target_placement._split_tensor( local_tensor, num_chunks, with_padding=False, contiguous=False, ) new_local_tensor = shards[my_coordinate].clone() else: # NOTE: this case shouldn't hit _decompose_sharding, decompose sharding should # decompose Shard(0) -> Shard(1) into Shard(0) -> Replicate -> Shard(1) assert ( current.is_shard() ), f"Current placement should be shard but found {current}" shard_spec = cast(Shard, current) if shard_spec.dim != target_placement.dim: # TODO: enable this with all_to_all raise NotImplementedError( "Changing sharding dim is not supported yet!" ) elif target.is_partial(): if current.is_replicate(): # For replicate -> partial, we zero out all other ranks of the current mesh dim # and leave only 1 rank have the data, to perform a "zero cost" reshard. if my_coordinate is not None and my_coordinate != 0: new_local_tensor = local_tensor.zero_() else: new_local_tensor = local_tensor else: raise RuntimeError( f"redistribute from {current_placements} to {target_placements} not supported yet" ) assert new_local_tensor is not None local_tensor = new_local_tensor assert new_local_tensor is not None, "redistribute failed!" return new_local_tensor def redistribute_dtensor( input: "dtensor.DTensor", device_mesh: DeviceMesh, placements: Sequence[Placement], ) -> "dtensor.DTensor": if input.device_mesh != device_mesh: # TODO: alltoall reshuffling to change device_mesh if they are not the same raise NotImplementedError("Cross device mesh comm not supported yet!") local_tensor = input._local_tensor new_local_tensor = _redistribute_with_local_tensor( local_tensor, input.size(), device_mesh, input.placements, placements, ) return dtensor.DTensor( new_local_tensor, device_mesh, placements, size=input.size(), requires_grad=local_tensor.requires_grad, ) class Redistribute(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] # pyre-fixme[2]: Parameter must be annotated. ctx, input: "dtensor.DTensor", device_mesh: DeviceMesh, placements: List[Placement], ): ctx.previous_placement = input.placements ctx.previous_device_mesh = input.device_mesh return redistribute_dtensor(input, device_mesh, placements) @staticmethod def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] previous_placement = ctx.previous_placement previous_device_mesh = ctx.previous_device_mesh # When we run backward pass of redistribute (i.e. manual redistribute from # user code instead of torch_dispatch), we scan first and see if we need # to change the target placement for one special case: # replicate -> partial. # In this case we keep the grad as replicate, this is because we don't # want to convert the replicated gradients back to partial, although # that's logically conform with the same layout, converting the gradients # back to partial is acutally useless as you would have to do reduce later # which would be more expensive than keeping it replicate! For this reason, # we keep the replicate grad here. # TODO: see if this make sense for all cases. target_placements: List[Placement] = [] for current, target in zip(grad_output.placements, previous_placement): if current.is_replicate() and target.is_partial(): # keep target placement to replicate instead of partial in this case target_placements.append(current) else: target_placements.append(target) return ( redistribute_dtensor(grad_output, previous_device_mesh, target_placements), None, None, )