123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- import torch
- import torch.distributed as dist
- from torch.distributed._shard.sharded_tensor.api import ShardedTensor
- from torch.distributed import distributed_c10d
- from torch.overrides import get_default_nowrap_functions
- _REPLICATED_WITH_NON_TENSOR_ALLOWLIST = [
- # List of ops where if parameters are a combination of ReplicatedTensors
- # and non-tensors, we can still return a ReplicatedTensor as the result.
- torch.unsqueeze,
- torch.Tensor.unsqueeze,
- torch.Tensor.__getitem__,
- ]
- class ReplicatedTensor(torch.Tensor):
- """
- ReplicatedTensor represents a tensor which is replicated across the `world_size` and
- has the same value on each rank.
- ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together
- with ShardedTensor/Tensor together to express different types of computation. The
- inter-op rules defined as (using torch.add as an example op):
- ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
- ReplicatedTensor + torch.Tensor = torch.Tensor
- ReplicatedTensor + ShardedTensor = ShardedTensor
- ReplicatedTensor + other type (i.e. Scalar) = other type
- NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
- construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
- stays the same, there's no enforcement on it (i.e. if you manually modify content on
- some ranks, the modified value will not automatically get synced to other nodes). If
- you wish to manually validate tensors are the same across ranks, use `validate()`.
- """
- _process_group: distributed_c10d.ProcessGroup
- __slots__ = ["_process_group"]
- def __new__(cls, data=None, process_group=None):
- if data is None:
- data = torch.empty(0)
- r = torch.Tensor._make_subclass(cls, data, data.requires_grad) # type: ignore[arg-type]
- r._process_group = ( # type: ignore[attr-defined]
- process_group
- if process_group is not None
- else distributed_c10d._get_default_group()
- )
- return r
- def __deepcopy__(self, memo):
- if id(self) in memo:
- return memo[id(self)]
- else:
- result = type(self)(self.data.clone(memory_format=torch.preserve_format), self._process_group)
- memo[id(self)] = result
- return result
- def __repr__(self):
- return f"ReplicatedTensor({super().__repr__()})"
- @classmethod
- def __torch_function__(cls, func, types, args=(), kwargs=None):
- if kwargs is None:
- kwargs = {}
- # We will re-dispatch the execution to ShardedTensor __torch_function__
- # if we find there're ShardedTensor operands. We will also check if args/kwargs
- # are all replicated tensor operands, we have to do this to ensure we do not
- # converting results back to ReplicatedTensor if not all operands are replicated.
- all_replicated = True
- replicated_with_non_tensor = True
- replicated_pg = None
- def dispatch_arg(arg):
- # This function returns a tuple, first element represents whether the op been
- # executed, the second element represents the result of the execution
- nonlocal replicated_pg, all_replicated, replicated_with_non_tensor
- if isinstance(arg, ShardedTensor):
- # redispatch to ShardedTensor
- # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
- return True, arg.__torch_function__(func, types, args, kwargs)
- if isinstance(arg, ReplicatedTensor):
- if replicated_pg is None:
- replicated_pg = arg._process_group
- elif replicated_pg != arg._process_group:
- raise RuntimeError(
- f"ReplicatedTensor operands must be in the same process group "
- f"in torch function '{func.__name__}', but found at least two "
- f"ReplicatedTensor operands in different process groups! ")
- elif isinstance(arg, torch.Tensor):
- replicated_with_non_tensor = False
- all_replicated = False
- else:
- all_replicated = False
- return False, None
- for arg in args:
- redispatched, res = dispatch_arg(arg)
- if redispatched:
- return res
- if kwargs is not None:
- for k, v in kwargs.items():
- redispatched, res = dispatch_arg(v)
- if redispatched:
- return res
- # We cann't do super().__torch_function__() as it implicitly convert the result
- # back to tensor subclasses, where in our case, we need to control the output type
- # base on the inter-op rules we defined.
- with torch._C.DisableTorchFunctionSubclass():
- rs = func(*args, **kwargs)
- if func in get_default_nowrap_functions():
- return rs
- result_not_replicated = isinstance(rs, torch.Tensor) and not isinstance(rs, ReplicatedTensor)
- should_convert_to_replicated = all_replicated or (
- replicated_with_non_tensor and func in _REPLICATED_WITH_NON_TENSOR_ALLOWLIST
- )
- if result_not_replicated and should_convert_to_replicated:
- # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
- # __torch_function__, result is a torch.Tensor, then we convert and return a
- # ReplicatedTensor according to our inter-op rule
- rs = rs.as_subclass(ReplicatedTensor) # type: ignore[arg-type]
- # propagate the process_group field to result
- rs._process_group = replicated_pg # type: ignore[attr-defined]
- return rs
- def validate(self) -> bool:
- """
- Validate the ReplicatedTensor is legit by all gathering tensors on all ranks
- and check to make sure they are the same.
- If there's some ranks with different values, a ValueError will be raised.
- Keyword args:
- process_group (ProcessGroup, optional): The process group to work on. If None,
- the default process group will be used.
- Returns:
- True if validation succeed.
- """
- world_size = dist.get_world_size(self._process_group)
- current_rank = dist.get_rank(self._process_group)
- tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)]
- dist.all_gather(tensors_on_rank, self, group=self._process_group)
- # validate and check if all tensors are equal
- for rank, tensor in enumerate(tensors_on_rank):
- if not torch.allclose(self, tensor):
- raise ValueError(
- f"ReplicatedTensor have different values on rank {current_rank} and {rank}")
- return True
- def __setstate__(self, state):
- with torch._C.DisableTorchFunctionSubclass():
- self.data = state
- self.requires_grad = state.requires_grad
- from torch.distributed._shard.api import _get_current_process_group
- self._process_group = _get_current_process_group()
- def __getstate__(self):
- return self.data
|