import torch import torch.distributed as dist from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed import distributed_c10d from torch.overrides import get_default_nowrap_functions _REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [ # List of ops where if parameters are a combination of ReplicatedTensors # and non-tensors, we can still return a ReplicatedTensor as the result. torch.unsqueeze, torch.Tensor.unsqueeze, torch.Tensor.__getitem__, ] class ReplicatedTensor(torch.Tensor): """ ReplicatedTensor represents a tensor which is replicated across the `world_size` and has the same value on each rank. ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op): ReplicatedTensor + ReplicatedTensor = ReplicatedTensor ReplicatedTensor + torch.Tensor = torch.Tensor ReplicatedTensor + ShardedTensor = ShardedTensor ReplicatedTensor + other type (i.e. Scalar) = other type NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its construction. Although we defined proper inter-op rules to make sure ReplicatedTensor stays the same, there's no enforcement on it (i.e. if you manually modify content on some ranks, the modified value will not automatically get synced to other nodes). If you wish to manually validate tensors are the same across ranks, use `validate()`. """ _process_group: distributed_c10d.ProcessGroup __slots__ = ["_process_group"] def __new__(cls, data=None, process_group=None): if data is None: data = torch.empty(0) r = torch.Tensor._make_subclass(cls, data, data.requires_grad) # type: ignore[arg-type] r._process_group = ( # type: ignore[attr-defined] process_group if process_group is not None else distributed_c10d._get_default_group() ) return r def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] else: result = type(self)(self.data.clone(memory_format=torch.preserve_format), self._process_group) memo[id(self)] = result return result def __repr__(self): return f"ReplicatedTensor({super().__repr__()})" @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} # We will re-dispatch the execution to ShardedTensor __torch_function__ # if we find there're ShardedTensor operands. We will also check if args/kwargs # are all replicated tensor operands, we have to do this to ensure we do not # converting results back to ReplicatedTensor if not all operands are replicated. all_replicated = True replicated_with_non_tensor = True replicated_pg = None def dispatch_arg(arg): # This function returns a tuple, first element represents whether the op been # executed, the second element represents the result of the execution nonlocal replicated_pg, all_replicated, replicated_with_non_tensor if isinstance(arg, ShardedTensor): # redispatch to ShardedTensor # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor return True, arg.__torch_function__(func, types, args, kwargs) if isinstance(arg, ReplicatedTensor): if replicated_pg is None: replicated_pg = arg._process_group elif replicated_pg != arg._process_group: raise RuntimeError( f"ReplicatedTensor operands must be in the same process group " f"in torch function '{func.__name__}', but found at least two " f"ReplicatedTensor operands in different process groups! ") elif isinstance(arg, torch.Tensor): replicated_with_non_tensor = False all_replicated = False else: all_replicated = False return False, None for arg in args: redispatched, res = dispatch_arg(arg) if redispatched: return res if kwargs is not None: for k, v in kwargs.items(): redispatched, res = dispatch_arg(v) if redispatched: return res # We cann't do super().__torch_function__() as it implicitly convert the result # back to tensor subclasses, where in our case, we need to control the output type # base on the inter-op rules we defined. with torch._C.DisableTorchFunctionSubclass(): rs = func(*args, **kwargs) if func in get_default_nowrap_functions(): return rs result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor) should_convert_to_replicated = all_replicated or ( replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST ) if result_not_replicated and should_convert_to_replicated: # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor # __torch_function__, result is a torch.Tensor, then we convert and return a # ReplicatedTensor according to our inter-op rule rs = rs.as_subclass(ReplicatedTensor) # type: ignore[arg-type] # propagate the process_group field to result rs._process_group = replicated_pg # type: ignore[attr-defined] return rs def validate(self) -> bool: """ Validate the ReplicatedTensor is legit by all gathering tensors on all ranks and check to make sure they are the same. If there's some ranks with different values, a ValueError will be raised. Keyword args: process_group (ProcessGroup, optional): The process group to work on. If None, the default process group will be used. Returns: True if validation succeed. """ world_size = dist.get_world_size(self._process_group) current_rank = dist.get_rank(self._process_group) tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)] dist.all_gather(tensors_on_rank, self, group=self._process_group) # validate and check if all tensors are equal for rank, tensor in enumerate(tensors_on_rank): if not torch.allclose(self, tensor): raise ValueError( f"ReplicatedTensor have different values on rank {current_rank} and {rank}") return True def __setstate__(self, state): with torch._C.DisableTorchFunctionSubclass(): self.data = state self.requires_grad = state.requires_grad from torch.distributed._shard.api import _get_current_process_group self._process_group = _get_current_process_group() def __getstate__(self): return self.data