__init__.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import inspect
  2. from collections import defaultdict
  3. from functools import wraps
  4. from itertools import chain
  5. from typing import Callable, Dict, Sequence, Union
  6. import torch
  7. import torch.library
  8. from torch._ops import OpOverload, OpOverloadPacket
  9. from torch.utils._pytree import tree_map
  10. __all__ = [
  11. "decomposition_table",
  12. "pre_autograd_decomposition_table",
  13. "meta_table",
  14. "register_decomposition",
  15. "get_decompositions",
  16. "core_aten_decompositions",
  17. ]
  18. # TODO: relax key type here; torch registrations should be possible to; but
  19. # right now this type is accurate
  20. global_decomposition_table: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
  21. decomposition_table = global_decomposition_table["post_autograd"]
  22. pre_autograd_decomposition_table = global_decomposition_table["pre_autograd"]
  23. meta_table = global_decomposition_table["meta"]
  24. def _add_op_to_registry(registry, op, fn):
  25. """
  26. This is an internal API for adding an op to the decomposition table.
  27. If op is OpOverload, it will be added to the registry directly.
  28. If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
  29. """
  30. overloads = []
  31. if isinstance(op, OpOverload):
  32. overloads.append(op)
  33. else:
  34. assert isinstance(op, OpOverloadPacket)
  35. for ol in op.overloads():
  36. overloads.append(getattr(op, ol))
  37. for op_overload in overloads:
  38. if op_overload in registry:
  39. raise RuntimeError(f"duplicate registrations for {op_overload}")
  40. # TorchScript dumps a bunch of extra nonsense overloads
  41. # which don't have corresponding dispatcher entries, we need
  42. # to filter those out, e.g aten.add.float_int
  43. if torch._C._dispatch_has_kernel(op_overload.name()):
  44. registry[op_overload] = fn
  45. def register_decomposition(aten_op, registry=None, *, type="post_autograd"):
  46. """
  47. A decorator to register a function as a decomposition to the Python
  48. decomposition table. Use it like this::
  49. @register_decomposition(torch.ops.aten.clamp_min)
  50. def clamp_min(x):
  51. return torch.clamp(self, min=min)
  52. If you are writing a new decomposition, consider contributing it
  53. directly to PyTorch in torch._decomp.decompositions.
  54. This API is experimental; we are almost certainly going to extend
  55. the API when we make decompositions eligible for use in transforms (e.g.,
  56. autograd) and not just backend tracing, where we then need to know if a
  57. decomposition can be used to simulate a transform.
  58. By default, we also will register it to the Meta key of dispatcher,
  59. and replace the c++ Meta implementation if there is already one.
  60. """
  61. assert type in {"post_autograd", "pre_autograd", "meta"}
  62. def decomposition_decorator(f: Callable) -> Callable:
  63. sig = inspect.signature(f)
  64. out_annotation = f.__annotations__.get("out")
  65. # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
  66. fn = f
  67. if out_annotation and getattr(out_annotation, "__origin__", None) is tuple:
  68. out_names = sig.return_annotation._fields
  69. # If out is a tuple, we need to register a function that unpacks all the out
  70. # elements as this is what native_functions.yaml expects
  71. @wraps(f)
  72. def _fn(*args, **kwargs):
  73. out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
  74. # Either all of the out kwargs are set or none of them
  75. is_none = out_kwargs[0] is None
  76. assert all((o is None) == is_none for o in out_kwargs)
  77. return f(*args, **kwargs, out=None if is_none else out_kwargs)
  78. out_params = [
  79. inspect.Parameter(
  80. o,
  81. kind=inspect.Parameter.KEYWORD_ONLY,
  82. default=None,
  83. annotation=t,
  84. )
  85. for o, t in zip(out_names, out_annotation.__args__)
  86. ]
  87. # Drop the out parameter and concatenate the new kwargs in the signature
  88. params = chain(
  89. (v for k, v in sig.parameters.items() if k != "out"), out_params
  90. )
  91. _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined]
  92. parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type]
  93. )
  94. # Drop the out parameter and concatenate the new kwargs in the annotations
  95. _fn.__annotations__ = {
  96. k: v for k, v in f.__annotations__.items() if k != "out"
  97. }
  98. for o in out_params:
  99. _fn.__annotations__[o.name] = o.annotation
  100. fn = _fn
  101. nonlocal registry
  102. if registry is None:
  103. registry = global_decomposition_table[type]
  104. def register(op):
  105. _add_op_to_registry(registry, op, fn)
  106. # To handle allowing multiple aten_ops at once
  107. tree_map(register, aten_op)
  108. return fn
  109. return decomposition_decorator
  110. def get_decompositions(
  111. aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
  112. type: str = "post_autograd",
  113. ) -> Dict[OpOverload, Callable]:
  114. """
  115. Retrieve a dictionary of decompositions corresponding to the list of
  116. operator overloads and overload packets passed as input. Overload
  117. packets will include all decomposed overloads in the packet. If there is
  118. no decomposition for a requested operator, it is silently ignored.
  119. This API is experimental; we are almost certainly going to give an alternate,
  120. more recommended formulation, where a user provides the set of operators
  121. they know how to implement, and we provide decompositions for everything
  122. not in this set.
  123. """
  124. assert type in {"post_autograd", "pre_autograd", "meta"}
  125. registry = global_decomposition_table[type]
  126. packets_to_overloads = defaultdict(list)
  127. for opo in registry:
  128. packets_to_overloads[opo.overloadpacket].append(opo)
  129. decompositions = {}
  130. for op in aten_ops:
  131. if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
  132. for op_overload in packets_to_overloads[op]:
  133. decompositions[op_overload] = registry[op_overload]
  134. elif isinstance(op, OpOverload) and op in registry:
  135. decompositions[op] = registry[op]
  136. return decompositions
  137. # populate the table
  138. import torch._decomp.decompositions
  139. import torch._refs
  140. # This list was copied from torch/_inductor/decomposition.py
  141. # excluding decompositions that results in prim ops
  142. # Resulting opset of decomposition is core aten ops
  143. def core_aten_decompositions() -> Dict[OpOverload, Callable]:
  144. aten = torch.ops.aten
  145. return get_decompositions(
  146. [
  147. aten._adaptive_avg_pool2d_backward,
  148. aten.addcdiv,
  149. aten.addcdiv_,
  150. aten.addcmul,
  151. aten.addcmul_,
  152. aten.addr,
  153. aten.avg_pool2d_backward,
  154. aten.binary_cross_entropy,
  155. aten.binary_cross_entropy_backward,
  156. aten.binary_cross_entropy_with_logits,
  157. aten.bucketize,
  158. aten.celu,
  159. aten.col2im,
  160. aten.cudnn_batch_norm,
  161. aten.cudnn_batch_norm_backward,
  162. aten.detach,
  163. aten.diag_embed,
  164. aten.diagonal,
  165. aten.dot,
  166. aten.elu,
  167. aten.elu_backward,
  168. aten._embedding_bag,
  169. aten.embedding_dense_backward,
  170. aten.expand_as,
  171. aten.eye,
  172. aten.fill,
  173. aten.frac,
  174. aten._fused_moving_avg_obs_fq_helper,
  175. aten.gelu,
  176. aten.gelu_backward,
  177. aten.glu_backward,
  178. aten.grid_sampler_2d,
  179. aten.hardshrink,
  180. aten.hardshrink_backward,
  181. aten.hardsigmoid,
  182. aten.hardsigmoid_backward,
  183. aten.hardswish,
  184. aten.hardswish_,
  185. aten.hardswish_backward,
  186. aten.hardtanh,
  187. aten.hardtanh_,
  188. aten.hardtanh_backward,
  189. aten.heaviside,
  190. aten.huber_loss,
  191. aten.huber_loss_backward,
  192. aten.im2col,
  193. aten.index_add,
  194. aten.index_add_,
  195. aten.index_copy,
  196. aten.index_copy_,
  197. aten.index_fill,
  198. aten.index_fill_,
  199. aten.index_select,
  200. aten.isneginf,
  201. aten.isposinf,
  202. aten.l1_loss,
  203. aten.leaky_relu,
  204. aten.leaky_relu_,
  205. aten.leaky_relu_backward,
  206. aten.lerp,
  207. aten.linspace,
  208. aten.logaddexp,
  209. aten.logit,
  210. aten.logit_backward,
  211. aten.log_sigmoid_backward,
  212. aten.log_sigmoid_forward,
  213. aten._log_softmax,
  214. aten._log_softmax_backward_data,
  215. aten.logspace,
  216. aten.logsumexp.default,
  217. aten.masked_fill,
  218. aten.masked_fill_,
  219. aten.max_pool2d_with_indices_backward,
  220. aten.mish,
  221. aten.mse_loss,
  222. aten.mse_loss_backward,
  223. aten.mv,
  224. aten.mvlgamma,
  225. aten.nan_to_num,
  226. aten.narrow,
  227. aten.native_batch_norm,
  228. aten.native_batch_norm_backward,
  229. aten._native_batch_norm_legit,
  230. aten._native_batch_norm_legit_functional,
  231. aten.native_dropout_backward,
  232. aten.native_group_norm,
  233. aten.native_group_norm_backward,
  234. aten.native_layer_norm,
  235. aten.native_layer_norm_backward,
  236. aten.new_empty,
  237. aten.new_full,
  238. aten.new_ones,
  239. aten.new_zeros,
  240. aten.nll_loss_backward,
  241. aten.nll_loss_forward,
  242. aten.norm,
  243. aten.ones,
  244. aten.ones_like,
  245. aten._prelu_kernel,
  246. aten._prelu_kernel_backward,
  247. aten._reshape_alias,
  248. aten.rot90,
  249. aten.rsub.Scalar,
  250. aten.rsub.Tensor,
  251. aten.select_backward,
  252. aten.select_scatter,
  253. aten.sgn,
  254. aten.sigmoid_backward,
  255. aten.silu,
  256. aten.silu_,
  257. aten.silu_backward,
  258. aten.sinc,
  259. aten.slice_backward,
  260. aten.soft_margin_loss,
  261. aten.soft_margin_loss_backward,
  262. aten._softmax,
  263. aten._softmax_backward_data,
  264. aten.softplus,
  265. aten.softplus_backward,
  266. aten.softshrink,
  267. aten.softshrink_backward,
  268. aten.special_entr,
  269. aten.special_log_ndtr,
  270. aten.special_xlog1py,
  271. aten.stack,
  272. aten.t,
  273. aten.tanh_backward,
  274. aten.threshold,
  275. aten.threshold_backward,
  276. aten.trace,
  277. aten.transpose.int,
  278. aten.tril.default,
  279. aten.triu.default,
  280. aten.unfold,
  281. aten.unfold_backward,
  282. aten.upsample_bilinear2d,
  283. aten.upsample_bilinear2d.vec,
  284. aten.upsample_nearest2d_backward,
  285. aten.xlogy,
  286. aten.zero,
  287. aten.zero_,
  288. aten.zeros,
  289. aten.zeros_like,
  290. ]
  291. )