common_op_utils.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import torch
  2. from torch.utils._pytree import tree_map
  3. from typing import Optional
  4. def _basic_validation(op, args=(), kwargs=None):
  5. """
  6. Common validation across all ops go in here.
  7. """
  8. from torch.distributed._shard.partial_tensor import _PartialTensor
  9. from torch.distributed._shard.replicated_tensor import ReplicatedTensor
  10. from torch.distributed._shard.sharded_tensor import ShardedTensor
  11. if len(args) == 0 and (kwargs is None or len(kwargs) == 0):
  12. raise ValueError(f" No input for '{op.__name__}'!")
  13. # Validate types
  14. has_distributed_tensor = False
  15. def is_distributed_tensor(e):
  16. nonlocal has_distributed_tensor
  17. if isinstance(e, (ReplicatedTensor, _PartialTensor, ShardedTensor)):
  18. has_distributed_tensor = True
  19. tree_map(is_distributed_tensor, args)
  20. tree_map(is_distributed_tensor, kwargs)
  21. if not has_distributed_tensor:
  22. raise TypeError(
  23. f"torch function '{op.__name__}', with args: {args} and "
  24. f"kwargs: {kwargs} are called without any distributed tensor!"
  25. )
  26. # Validate all distributed tensors use the same PG.
  27. cur_pg: Optional[torch.distributed.ProcessGroup] = None
  28. def validate_pg(e):
  29. nonlocal cur_pg
  30. if isinstance(e, (ReplicatedTensor, _PartialTensor, ShardedTensor)):
  31. if cur_pg is not None and e._process_group is not cur_pg:
  32. raise RuntimeError(
  33. 'All distributed tensors should use the '
  34. 'same ProcessGroup if used together in an op.'
  35. )
  36. cur_pg = e._process_group
  37. tree_map(validate_pg, args)
  38. tree_map(validate_pg, kwargs)
  39. def _register_default_op(op, decorator):
  40. @decorator(op)
  41. def tensor_default_op(types, args=(), kwargs=None, pg=None):
  42. """
  43. Handles ``__torch_function__`` dispatch for the default tensor ops that
  44. behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or
  45. ``torch.Tensor.dtype``. We simply lower to the real op call with
  46. DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__``
  47. to avoid recursions.
  48. """
  49. if kwargs is None:
  50. kwargs = {}
  51. with torch._C.DisableTorchFunctionSubclass():
  52. return op(*args, **kwargs)