123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import torch
- from copy import deepcopy
- from torch.utils._pytree import tree_map
- # TODO: Move LoggingTensor here.
- from torch.testing._internal.logging_tensor import LoggingTensor
- # Base class for wrapper-style tensors.
- class WrapperTensor(torch.Tensor):
- @staticmethod
- def __new__(cls, *args, **kwargs):
- t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
- if "size" not in kwargs:
- size = t.size()
- else:
- size = kwargs["size"]
- del kwargs["size"]
- if "dtype" not in kwargs:
- kwargs["dtype"] = t.dtype
- if "layout" not in kwargs:
- kwargs["layout"] = t.layout
- if "device" not in kwargs:
- kwargs["device"] = t.device
- if "requires_grad" not in kwargs:
- kwargs["requires_grad"] = False
- # Ignore memory_format and pin memory for now as I don't know how to
- # safely access them on a Tensor (if possible??)
- wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
- wrapper._validate_methods()
- return wrapper
- @classmethod
- def get_wrapper_properties(cls, *args, **kwargs):
- # Should return both an example Tensor and a dictionaly of kwargs
- # to override any of that example Tensor's properly.
- # This is very similar to the `t.new_*(args)` API
- raise NotImplementedError("You need to implement get_wrapper_properties")
- def _validate_methods(self):
- # Skip this if not in debug mode?
- # Changing these on the python side is wrong as it would not be properly reflected
- # on the c++ side
- # This doesn't catch attributes set in the __init__
- forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
- for el in forbidden_overrides:
- if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
- raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
- f"property {el} but this is not allowed as such change would "
- "not be reflected to c++ callers.")
- class DiagTensorBelow(WrapperTensor):
- @classmethod
- def get_wrapper_properties(cls, diag, requires_grad=False):
- assert diag.ndim == 1
- return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
- def __init__(self, diag, requires_grad=False):
- self.diag = diag
- handled_ops = {}
- # We disable torch function here to avoid any unwanted wrapping of the output
- __torch_function__ = torch._C._disabled_torch_function_impl
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- if not all(issubclass(cls, t) for t in types):
- return NotImplemented
- # For everything else, call the handler:
- fn = cls.handled_ops.get(func.__name__, None)
- if fn:
- return fn(*args, **kwargs or {})
- else:
- # Note that here, because we don't need to provide the autograd formulas
- # we can have a default "fallback" that creates a plain Tensor based
- # on the diag elements and calls the func again.
- def unwrap(e):
- return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
- def wrap(e):
- if isinstance(e, torch.Tensor) and e.ndim == 1:
- return DiagTensorBelow(e)
- if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
- return DiagTensorBelow(e.diag())
- return e
- rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
- return rs
- def __repr__(self):
- return super().__repr__(tensor_contents=f"diag={self.diag}")
- class SparseTensor(WrapperTensor):
- @classmethod
- def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
- assert values.device == indices.device
- return values, {"size": size, "requires_grad": requires_grad}
- def __init__(self, size, values, indices, requires_grad=False):
- self.values = values
- self.indices = indices
- def __repr__(self):
- return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
- def sparse_to_dense(self):
- res = torch.zeros(self.size(), dtype=self.values.dtype)
- res[self.indices.unbind(1)] = self.values
- return res
- @staticmethod
- def from_dense(t):
- indices = t.nonzero()
- values = t[indices.unbind(1)]
- return SparseTensor(t.size(), values, indices)
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- func_name = f"{func.__module__}.{func.__name__}"
- res = cls._try_call_special_impl(func_name, args, kwargs)
- if res is not NotImplemented:
- return res
- # Otherwise, use a default implementation that construct dense
- # tensors and use that to compute values
- def unwrap(e):
- return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
- # Wrap back all Tensors into our custom class
- def wrap(e):
- # Check for zeros and use that to get indices
- return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
- rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
- return rs
- # To show how things happen later
- def __rmul__(self, other):
- return super().__rmul__(other)
- _SPECIAL_IMPLS = {}
- @classmethod
- def _try_call_special_impl(cls, func, args, kwargs):
- if func not in cls._SPECIAL_IMPLS:
- return NotImplemented
- return cls._SPECIAL_IMPLS[func](args, kwargs)
- # Example non-wrapper subclass that stores extra state.
- class NonWrapperTensor(torch.Tensor):
- def __new__(cls, data):
- t = torch.Tensor._make_subclass(cls, data)
- t.extra_state = {
- 'last_func_called': None
- }
- return t
- @classmethod
- def __torch_function__(cls, func, types, args=(), kwargs=None):
- result = super().__torch_function__(func, types, args, kwargs)
- if isinstance(result, cls):
- # Do something with the extra state. For the example here, just store the name of the
- # last function called (skip for deepcopy so the copy has the same extra state).
- if func is torch.Tensor.__deepcopy__:
- result.extra_state = deepcopy(args[0].extra_state)
- else:
- result.extra_state = {
- 'last_func_called': func.__name__,
- }
- return result
- # new_empty() must be defined for deepcopy to work
- def new_empty(self, shape):
- return type(self)(torch.empty(shape))
- # Class used to store info about subclass tensors used in testing.
- class SubclassInfo:
- __slots__ = ['name', 'create_fn', 'closed_under_ops']
- def __init__(self, name, create_fn, closed_under_ops=True):
- self.name = name
- self.create_fn = create_fn # create_fn(shape) -> tensor instance
- self.closed_under_ops = closed_under_ops
- subclass_db = {
- torch.Tensor: SubclassInfo(
- 'base_tensor', create_fn=lambda shape: torch.randn(shape)
- ),
- NonWrapperTensor: SubclassInfo(
- 'non_wrapper_tensor',
- create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
- ),
- LoggingTensor: SubclassInfo(
- 'logging_tensor',
- create_fn=lambda shape: LoggingTensor(torch.randn(shape))
- ),
- SparseTensor: SubclassInfo(
- 'sparse_tensor',
- create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
- ),
- DiagTensorBelow: SubclassInfo(
- 'diag_tensor_below',
- create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
- closed_under_ops=False # sparse semantics
- ),
- }
|