refs.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from torch.testing._internal.opinfo.core import (
  2. BinaryUfuncInfo,
  3. OpInfo,
  4. ReductionOpInfo,
  5. UnaryUfuncInfo,
  6. )
  7. # NOTE [Python References]
  8. # Python References emulate existing PyTorch operations, but can ultimately
  9. # be expressed in terms of "primitive" operations from torch._prims.
  10. #
  11. # These references are experimental.
  12. # See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577
  13. # for additional context.
  14. #
  15. # Python Reference OpInfos should be added to the python_ref_db list below.
  16. # Tests can opt-into running on these references by including
  17. # that list in the Sequence they pass to the @ops decorator.
  18. #
  19. # When a Python Reference OpInfo is constructed a pointer to an
  20. # existing OpInfo must be provided using the torch_opinfo_name kwarg.
  21. # The existing OpInfo with that name and no variant will be found
  22. # to inherit from.
  23. #
  24. # Instead of just inheriting the existing OpInfo's metadata, the
  25. # Python Reference OpInfos inherit the existing OpInfo's
  26. # construction arguments. These arguments can be overridden
  27. # by adding kwargs to the constructor.
  28. def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None):
  29. """
  30. Finds the OpInfo with the given name that has no variant name.
  31. """
  32. # NOTE: searching the global op_db doesn't work when OpInfos are split into
  33. # different modules, as otherwise the op_db will not be fully constructed
  34. # yet. So, instead the local op_db must be passed in explicitly.
  35. if op_db is None:
  36. from torch.testing._internal.common_methods_invocations import op_db
  37. for opinfo in op_db:
  38. if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name:
  39. return opinfo
  40. def _inherit_constructor_args(name, op, inherited, overrides):
  41. # inherits metadata
  42. common_kwargs = {
  43. "name": name,
  44. "op": op,
  45. "aliases": None, # TODO add a check for alias coverage
  46. "method_variant": None,
  47. "inplace_variant": None, # TODO: add a check for inplace coverage
  48. "supports_scripting": False,
  49. }
  50. # Acquires inherited kwargs
  51. kwargs = inherited.copy()
  52. # Fixes metadata
  53. if "kwargs" in kwargs:
  54. kwargs.update(kwargs["kwargs"])
  55. del kwargs["kwargs"]
  56. if "self" in kwargs:
  57. del kwargs["self"]
  58. if "__class__" in kwargs:
  59. del kwargs["__class__"]
  60. if "skips" in kwargs:
  61. del kwargs["skips"]
  62. if "decorators" in kwargs:
  63. del kwargs["decorators"]
  64. # Overrides metadata
  65. kwargs.update(common_kwargs)
  66. kwargs.update(overrides)
  67. # At the moment no prims support autograd, so we must not run autograd
  68. # tests e.g. when testing dtype support. Once we start writing autograd
  69. # formulas for prims this can be removed.
  70. kwargs["supports_autograd"] = False
  71. kwargs["supports_gradgrad"] = False
  72. kwargs["supports_fwgrad_bwgrad"] = False
  73. kwargs["supports_inplace_autograd"] = False
  74. kwargs["supports_forward_ad"] = False
  75. return kwargs
  76. class PythonRefInfo(OpInfo):
  77. """
  78. An OpInfo for a Python reference of an OpInfo base class operation.
  79. """
  80. def __init__(
  81. self,
  82. name, # the stringname of the callable Python reference
  83. *,
  84. op=None, # the function variant of the operation, populated as torch.<name> if None
  85. op_db=None, # The database of opinfos to search for the parent opinfo
  86. torch_opinfo_name, # the string name of the corresponding torch opinfo
  87. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  88. validate_view_consistency=True,
  89. supports_nvfuser=True,
  90. **kwargs,
  91. ): # additional kwargs override kwargs inherited from the torch opinfo
  92. self.torch_opinfo_name = torch_opinfo_name
  93. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  94. self.torch_opinfo = _find_referenced_opinfo(
  95. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  96. )
  97. self.validate_view_consistency = validate_view_consistency
  98. self.supports_nvfuser = supports_nvfuser
  99. assert isinstance(self.torch_opinfo, OpInfo)
  100. inherited = self.torch_opinfo._original_opinfo_args
  101. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  102. super().__init__(**ukwargs)
  103. class ReductionPythonRefInfo(ReductionOpInfo):
  104. """
  105. An OpInfo for a Python reference of an elementwise unary operation.
  106. """
  107. def __init__(
  108. self,
  109. name, # the stringname of the callable Python reference
  110. *,
  111. op=None, # the function variant of the operation, populated as torch.<name> if None
  112. op_db=None, # The database of opinfos to search for the parent opinfo
  113. torch_opinfo_name, # the string name of the corresponding torch opinfo
  114. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  115. supports_nvfuser=True,
  116. **kwargs,
  117. ): # additional kwargs override kwargs inherited from the torch opinfo
  118. self.torch_opinfo_name = torch_opinfo_name
  119. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  120. self.torch_opinfo = _find_referenced_opinfo(
  121. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  122. )
  123. self.supports_nvfuser = supports_nvfuser
  124. assert isinstance(self.torch_opinfo, ReductionOpInfo)
  125. inherited = self.torch_opinfo._original_reduction_args
  126. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  127. # See https://github.com/pytorch/pytorch/issues/77216
  128. self.validate_view_consistency = False
  129. super().__init__(**ukwargs)
  130. class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
  131. """
  132. An OpInfo for a Python reference of an elementwise unary operation.
  133. """
  134. def __init__(
  135. self,
  136. name, # the stringname of the callable Python reference
  137. *,
  138. op=None, # the function variant of the operation, populated as torch.<name> if None
  139. op_db=None, # The database of opinfos to search for the parent opinfo
  140. torch_opinfo_name, # the string name of the corresponding torch opinfo
  141. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  142. validate_view_consistency=True,
  143. supports_nvfuser=True,
  144. **kwargs,
  145. ): # additional kwargs override kwargs inherited from the torch opinfo
  146. self.torch_opinfo_name = torch_opinfo_name
  147. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  148. self.torch_opinfo = _find_referenced_opinfo(
  149. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  150. )
  151. self.validate_view_consistency = validate_view_consistency
  152. self.supports_nvfuser = supports_nvfuser
  153. assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
  154. inherited = self.torch_opinfo._original_unary_ufunc_args
  155. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  156. super().__init__(**ukwargs)
  157. class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
  158. """
  159. An OpInfo for a Python reference of an elementwise binary operation.
  160. """
  161. def __init__(
  162. self,
  163. name, # the stringname of the callable Python reference
  164. *,
  165. op=None, # the function variant of the operation, populated as torch.<name> if None
  166. op_db=None, # The database of opinfos to search for the parent opinfo
  167. torch_opinfo_name, # the string name of the corresponding torch opinfo
  168. torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo
  169. supports_nvfuser=True,
  170. **kwargs,
  171. ): # additional kwargs override kwargs inherited from the torch opinfo
  172. self.torch_opinfo_name = torch_opinfo_name
  173. self.torch_opinfo_variant_name = torch_opinfo_variant_name
  174. self.torch_opinfo = _find_referenced_opinfo(
  175. torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
  176. )
  177. self.supports_nvfuser = supports_nvfuser
  178. assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
  179. inherited = self.torch_opinfo._original_binary_ufunc_args
  180. ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
  181. super().__init__(**ukwargs)