function_dispatcher.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """Dispatcher for AtenLib functions from onnx-script."""
  2. from __future__ import annotations
  3. from typing import Callable, Dict, Union
  4. import onnxscript # type: ignore[import]
  5. from onnxscript import opset18 # type: ignore[import]
  6. from onnxscript.function_libs.torch_aten import ops # type: ignore[import]
  7. import torch
  8. from torch.onnx._internal import _beartype
  9. TORCH_ONNX_OPSET = onnxscript.values.Opset(domain="torch.onnx", version=1)
  10. @onnxscript.script(opset=TORCH_ONNX_OPSET)
  11. def prims_convert_element_type(tensor, dtype: int):
  12. return opset18.Cast(tensor, to=dtype)
  13. @onnxscript.script(opset=TORCH_ONNX_OPSET)
  14. def aten_getitem(self, i):
  15. # TODO(justinchuby): Support
  16. # i = opset18.Unsqueeze(i, opset18.Constant(value_ints=[0]))
  17. # return opset18.Gather(self, i, axis=0)
  18. return opset18.SequenceAt(self, i)
  19. # A simple lookup table for atenlib functions
  20. _ATENLIB_FUNCTIONS = {
  21. "aten::abs": ops.core.aten_abs,
  22. "aten::acos": ops.core.aten_acos,
  23. "aten::acosh": ops.core.aten_acosh,
  24. "aten::adaptive_avg_pool1d": ops.nn.aten_adaptive_avg_pool1d,
  25. "aten::adaptive_avg_pool2d": ops.nn.aten_adaptive_avg_pool2d,
  26. "aten::adaptive_avg_pool3d": ops.nn.aten_adaptive_avg_pool3d,
  27. "aten::add": ops.core.aten_add,
  28. "aten::addmm": ops.core.aten_addmm,
  29. "aten::amax": ops.core.aten_amax,
  30. "aten::amin": ops.core.aten_amin,
  31. "aten::arange": ops.core.aten_arange_start,
  32. "aten::argmax": ops.core.aten_argmax,
  33. "aten::argmin": ops.core.aten_argmin,
  34. "aten::asin": ops.core.aten_asin,
  35. "aten::asinh": ops.core.aten_asinh,
  36. "aten::atan": ops.core.aten_atan,
  37. "aten::atanh": ops.core.aten_atanh,
  38. "aten::bmm": ops.core.aten_bmm,
  39. "aten::ceil": ops.core.aten_ceil,
  40. "aten::celu": ops.nn.aten_celu,
  41. "aten::clamp_max": ops.core.aten_clamp_max,
  42. "aten::clamp_min": ops.core.aten_clamp_min,
  43. "aten::clamp": ops.core.aten_clamp,
  44. "aten::clone": ops.core.aten_clone,
  45. "aten::convolution": ops.core.aten_convolution,
  46. "aten::cos": ops.core.aten_cos,
  47. "aten::cosh": ops.core.aten_cosh,
  48. "aten::detach": ops.core.aten_detach,
  49. "aten::div": ops.core.aten_div,
  50. "aten::dot": ops.core.aten_dot,
  51. "aten::elu": ops.nn.aten_elu,
  52. "aten::embedding": ops.core.aten_embedding,
  53. "aten::empty_like": ops.core.aten_empty_like,
  54. "aten::empty": ops.core.aten_empty,
  55. "aten::eq": ops.core.aten_eq,
  56. "aten::equal": ops.core.aten_equal,
  57. "aten::erf": ops.core.aten_erf,
  58. "aten::exp": ops.core.aten_exp,
  59. "aten::exp2": ops.core.aten_exp2,
  60. "aten::expand": ops.core.aten_expand,
  61. "aten::fmod": ops.core.aten_fmod,
  62. "aten::full_like": ops.core.aten_full_like,
  63. "aten::full": ops.core.aten_full,
  64. "aten::ge": ops.core.aten_ge,
  65. "aten::gelu": ops.nn.aten_gelu,
  66. "aten::gt": ops.core.aten_gt,
  67. "aten::isinf": ops.core.aten_isinf,
  68. "aten::le": ops.core.aten_le,
  69. "aten::leaky_relu": ops.nn.aten_leaky_relu,
  70. "aten::linear": ops.nn.aten_linear,
  71. "aten::log_softmax": ops.special.aten_special_log_softmax,
  72. "aten::log": ops.core.aten_log,
  73. "aten::log10": ops.core.aten_log10,
  74. "aten::log1p": ops.core.aten_log1p,
  75. "aten::log2": ops.core.aten_log2,
  76. "aten::logaddexp": ops.core.aten_logaddexp,
  77. "aten::logaddexp2": ops.core.aten_logaddexp2,
  78. "aten::logcumsumexp": ops.core.aten_logcumsumexp,
  79. "aten::logdet": ops.core.aten_logdet,
  80. "aten::logsigmoid": ops.nn.aten_log_sigmoid,
  81. "aten::logsumexp": ops.core.aten_logsumexp,
  82. "aten::lt": ops.core.aten_lt,
  83. "aten::matmul": ops.core.aten_matmul,
  84. "aten::maximum": ops.core.aten_maximum,
  85. "aten::minimum": ops.core.aten_minimum,
  86. "aten::mm": ops.core.aten_mm,
  87. "aten::mul": ops.core.aten_mul,
  88. "aten::native_layer_norm": ops.core.aten_native_layer_norm,
  89. "aten::ne": ops.core.aten_ne,
  90. "aten::neg": ops.core.aten_neg,
  91. "aten::new_full": ops.core.aten_new_full,
  92. "aten::nonzero": ops.core.aten_nonzero,
  93. "aten::ones_like": ops.core.aten_ones_like,
  94. "aten::ones": ops.core.aten_ones,
  95. "aten::permute": ops.core.aten_permute,
  96. "aten::pow": ops.core.aten_pow,
  97. "aten::reciprocal": ops.core.aten_reciprocal,
  98. "aten::relu": ops.nn.aten_relu,
  99. "aten::relu6": ops.nn.aten_relu6,
  100. "aten::remainder": ops.core.aten_remainder,
  101. "aten::repeat": ops.core.aten_repeat,
  102. "aten::reshape": ops.core.aten_reshape,
  103. "aten::round": ops.core.aten_round,
  104. "aten::rsqrt": ops.core.aten_rsqrt,
  105. "aten::rsub": ops.core.aten_rsub,
  106. "aten::selu": ops.core.aten_selu,
  107. "aten::sigmoid": ops.core.aten_sigmoid,
  108. "aten::sign": ops.core.aten_sign,
  109. "aten::sin": ops.core.aten_sin,
  110. "aten::sinh": ops.core.aten_sinh,
  111. "aten::slice": ops.core.aten_slice,
  112. "aten::softmax": ops.special.aten_special_softmax,
  113. "aten::split": ops.core.aten_split,
  114. "aten::sqrt": ops.core.aten_sqrt,
  115. "aten::sub": ops.core.aten_sub,
  116. "aten::sum": ops.core.aten_sum_dim_IntList,
  117. "aten::t": ops.core.aten_t,
  118. "aten::tan": ops.core.aten_tan,
  119. "aten::tanh": ops.core.aten_tanh,
  120. "aten::topk": ops.core.aten_topk,
  121. "aten::transpose": ops.core.aten_transpose,
  122. "aten::unsqueeze": ops.core.aten_unsqueeze,
  123. "aten::upsample_nearest2d": ops.nn.aten_upsample_nearest2d,
  124. "aten::view": ops.core.aten_view,
  125. "aten::where": ops.core.aten_where,
  126. "aten::xlogy": ops.special.aten_special_xlogy,
  127. "aten::zeros_like": ops.core.aten_zeros_like,
  128. "aten::zeros": ops.core.aten_zeros,
  129. "getitem": aten_getitem,
  130. "prims::convert_element_type": prims_convert_element_type,
  131. }
  132. def _create_op_overload_to_exporter_key_table() -> Dict[
  133. Union[torch._ops.OpOverload, Callable], str
  134. ]:
  135. # TODO(justinchuby): Improve how the table is constructed.
  136. table: Dict[Union[torch._ops.OpOverload, Callable], str] = {}
  137. for op_namespace in (torch.ops.aten, torch.ops.prims):
  138. for attr_name in dir(op_namespace):
  139. op_overload_packet = getattr(op_namespace, attr_name)
  140. if not isinstance(op_overload_packet, torch._ops.OpOverloadPacket):
  141. continue
  142. exporter_look_up_key = op_overload_packet._qualified_op_name
  143. if _ATENLIB_FUNCTIONS.get(exporter_look_up_key) is None:
  144. # This aten op doesn't have ONNX exporter.
  145. continue
  146. for overload_name in op_overload_packet.overloads():
  147. op_overload = getattr(op_overload_packet, overload_name)
  148. # This line maps torch.ops.aten.add.Tensor, torch.ops.aten.add.Scalar, torch.ops.aten.add.out, etc
  149. # to "aten::add". This means the exporter for "aten::add" is used for all overloads of "aten::add".
  150. # This is applied to all ops under torch.ops.aten.
  151. #
  152. # TODO(wechi): in the future, we might want to write individual exporter for each overload, if,
  153. # for example, they have different type promotion rules. If so, just map different overloads to
  154. # different exporter keys.
  155. table[op_overload] = op_overload_packet._qualified_op_name
  156. # TODO(justinchuby): is baddbmm different?
  157. table[torch.ops.aten.baddbmm.default] = "aten::baddbmm"
  158. return table
  159. # Dictionary that maps torch.ops.aten.* to exporter look up key; e.g.,
  160. # _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE[torch.add.Tensor] is "aten::add".
  161. _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE = _create_op_overload_to_exporter_key_table()
  162. @_beartype.beartype
  163. def _create_onnx_friendly_decomposition_table() -> Dict[
  164. torch._ops.OpOverload, Callable
  165. ]:
  166. decomposition_table: Dict[torch._ops.OpOverload, Callable] = {}
  167. for op_overload, decomp_fn in torch._decomp.decomposition_table.items():
  168. # Skip decomposition into "prim::*" ops, because they are not generally supported by ONNX.
  169. # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX exporter.
  170. if (
  171. "torch._refs" in decomp_fn.__module__
  172. or op_overload in _OP_OVERLOAD_TO_EXPORTER_KEY_TABLE
  173. ):
  174. continue
  175. decomposition_table[op_overload] = decomp_fn
  176. return decomposition_table
  177. # This is a subset of PyTorch's built-in aten-to-aten decomposition. If an aten
  178. # op (e.g., torch.ops.aten.add.Tensor) has exporter, we exclude the op's decomposition
  179. # function in the _ONNX_FRIENDLY_DECOMPOSITION_TABLE.
  180. _ONNX_FRIENDLY_DECOMPOSITION_TABLE = _create_onnx_friendly_decomposition_table()