common_subclass.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import torch
  2. from copy import deepcopy
  3. from torch.utils._pytree import tree_map
  4. # TODO: Move LoggingTensor here.
  5. from torch.testing._internal.logging_tensor import LoggingTensor
  6. # Base class for wrapper-style tensors.
  7. class WrapperTensor(torch.Tensor):
  8. @staticmethod
  9. def __new__(cls, *args, **kwargs):
  10. t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
  11. if "size" not in kwargs:
  12. size = t.size()
  13. else:
  14. size = kwargs["size"]
  15. del kwargs["size"]
  16. if "dtype" not in kwargs:
  17. kwargs["dtype"] = t.dtype
  18. if "layout" not in kwargs:
  19. kwargs["layout"] = t.layout
  20. if "device" not in kwargs:
  21. kwargs["device"] = t.device
  22. if "requires_grad" not in kwargs:
  23. kwargs["requires_grad"] = False
  24. # Ignore memory_format and pin memory for now as I don't know how to
  25. # safely access them on a Tensor (if possible??)
  26. wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
  27. wrapper._validate_methods()
  28. return wrapper
  29. @classmethod
  30. def get_wrapper_properties(cls, *args, **kwargs):
  31. # Should return both an example Tensor and a dictionaly of kwargs
  32. # to override any of that example Tensor's properly.
  33. # This is very similar to the `t.new_*(args)` API
  34. raise NotImplementedError("You need to implement get_wrapper_properties")
  35. def _validate_methods(self):
  36. # Skip this if not in debug mode?
  37. # Changing these on the python side is wrong as it would not be properly reflected
  38. # on the c++ side
  39. # This doesn't catch attributes set in the __init__
  40. forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
  41. for el in forbidden_overrides:
  42. if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
  43. raise RuntimeError(f"Subclass {self.__class__.__name__} is overwriting the "
  44. f"property {el} but this is not allowed as such change would "
  45. "not be reflected to c++ callers.")
  46. class DiagTensorBelow(WrapperTensor):
  47. @classmethod
  48. def get_wrapper_properties(cls, diag, requires_grad=False):
  49. assert diag.ndim == 1
  50. return diag, {"size": diag.size() + diag.size(), "requires_grad": requires_grad}
  51. def __init__(self, diag, requires_grad=False):
  52. self.diag = diag
  53. handled_ops = {}
  54. # We disable torch function here to avoid any unwanted wrapping of the output
  55. __torch_function__ = torch._C._disabled_torch_function_impl
  56. @classmethod
  57. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  58. if not all(issubclass(cls, t) for t in types):
  59. return NotImplemented
  60. # For everything else, call the handler:
  61. fn = cls.handled_ops.get(func.__name__, None)
  62. if fn:
  63. return fn(*args, **kwargs or {})
  64. else:
  65. # Note that here, because we don't need to provide the autograd formulas
  66. # we can have a default "fallback" that creates a plain Tensor based
  67. # on the diag elements and calls the func again.
  68. def unwrap(e):
  69. return e.diag.diag() if isinstance(e, DiagTensorBelow) else e
  70. def wrap(e):
  71. if isinstance(e, torch.Tensor) and e.ndim == 1:
  72. return DiagTensorBelow(e)
  73. if isinstance(e, torch.Tensor) and e.ndim == 2 and e.count_nonzero() == e.diag().count_nonzero():
  74. return DiagTensorBelow(e.diag())
  75. return e
  76. rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
  77. return rs
  78. def __repr__(self):
  79. return super().__repr__(tensor_contents=f"diag={self.diag}")
  80. class SparseTensor(WrapperTensor):
  81. @classmethod
  82. def get_wrapper_properties(cls, size, values, indices, requires_grad=False):
  83. assert values.device == indices.device
  84. return values, {"size": size, "requires_grad": requires_grad}
  85. def __init__(self, size, values, indices, requires_grad=False):
  86. self.values = values
  87. self.indices = indices
  88. def __repr__(self):
  89. return super().__repr__(tensor_contents=f"values={self.values}, indices={self.indices}")
  90. def sparse_to_dense(self):
  91. res = torch.zeros(self.size(), dtype=self.values.dtype)
  92. res[self.indices.unbind(1)] = self.values
  93. return res
  94. @staticmethod
  95. def from_dense(t):
  96. indices = t.nonzero()
  97. values = t[indices.unbind(1)]
  98. return SparseTensor(t.size(), values, indices)
  99. @classmethod
  100. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  101. func_name = f"{func.__module__}.{func.__name__}"
  102. res = cls._try_call_special_impl(func_name, args, kwargs)
  103. if res is not NotImplemented:
  104. return res
  105. # Otherwise, use a default implementation that construct dense
  106. # tensors and use that to compute values
  107. def unwrap(e):
  108. return e.sparse_to_dense() if isinstance(e, SparseTensor) else e
  109. # Wrap back all Tensors into our custom class
  110. def wrap(e):
  111. # Check for zeros and use that to get indices
  112. return SparseTensor.from_dense(e) if isinstance(e, torch.Tensor) else e
  113. rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs or {})))
  114. return rs
  115. # To show how things happen later
  116. def __rmul__(self, other):
  117. return super().__rmul__(other)
  118. _SPECIAL_IMPLS = {}
  119. @classmethod
  120. def _try_call_special_impl(cls, func, args, kwargs):
  121. if func not in cls._SPECIAL_IMPLS:
  122. return NotImplemented
  123. return cls._SPECIAL_IMPLS[func](args, kwargs)
  124. # Example non-wrapper subclass that stores extra state.
  125. class NonWrapperTensor(torch.Tensor):
  126. def __new__(cls, data):
  127. t = torch.Tensor._make_subclass(cls, data)
  128. t.extra_state = {
  129. 'last_func_called': None
  130. }
  131. return t
  132. @classmethod
  133. def __torch_function__(cls, func, types, args=(), kwargs=None):
  134. result = super().__torch_function__(func, types, args, kwargs)
  135. if isinstance(result, cls):
  136. # Do something with the extra state. For the example here, just store the name of the
  137. # last function called (skip for deepcopy so the copy has the same extra state).
  138. if func is torch.Tensor.__deepcopy__:
  139. result.extra_state = deepcopy(args[0].extra_state)
  140. else:
  141. result.extra_state = {
  142. 'last_func_called': func.__name__,
  143. }
  144. return result
  145. # new_empty() must be defined for deepcopy to work
  146. def new_empty(self, shape):
  147. return type(self)(torch.empty(shape))
  148. # Class used to store info about subclass tensors used in testing.
  149. class SubclassInfo:
  150. __slots__ = ['name', 'create_fn', 'closed_under_ops']
  151. def __init__(self, name, create_fn, closed_under_ops=True):
  152. self.name = name
  153. self.create_fn = create_fn # create_fn(shape) -> tensor instance
  154. self.closed_under_ops = closed_under_ops
  155. subclass_db = {
  156. torch.Tensor: SubclassInfo(
  157. 'base_tensor', create_fn=lambda shape: torch.randn(shape)
  158. ),
  159. NonWrapperTensor: SubclassInfo(
  160. 'non_wrapper_tensor',
  161. create_fn=lambda shape: NonWrapperTensor(torch.randn(shape))
  162. ),
  163. LoggingTensor: SubclassInfo(
  164. 'logging_tensor',
  165. create_fn=lambda shape: LoggingTensor(torch.randn(shape))
  166. ),
  167. SparseTensor: SubclassInfo(
  168. 'sparse_tensor',
  169. create_fn=lambda shape: SparseTensor.from_dense(torch.randn(shape).relu())
  170. ),
  171. DiagTensorBelow: SubclassInfo(
  172. 'diag_tensor_below',
  173. create_fn=lambda shape: DiagTensorBelow(torch.randn(shape)),
  174. closed_under_ops=False # sparse semantics
  175. ),
  176. }