symbolic_helper.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804
  1. from __future__ import annotations
  2. import functools
  3. import inspect
  4. import sys
  5. import typing
  6. import warnings
  7. from typing import (
  8. Any,
  9. Callable,
  10. List,
  11. Literal,
  12. NoReturn,
  13. Optional,
  14. Sequence,
  15. Set,
  16. Tuple,
  17. Union,
  18. )
  19. import torch
  20. import torch._C._onnx as _C_onnx
  21. from torch import _C
  22. # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
  23. from torch.onnx import _constants, _deprecation, _type_utils, errors
  24. from torch.onnx._globals import GLOBALS
  25. from torch.onnx._internal import _beartype, jit_utils
  26. from torch.types import Number
  27. __all__ = [
  28. "args_have_same_dtype",
  29. "cast_pytorch_to_onnx",
  30. "check_training_mode",
  31. "dequantize_helper",
  32. "is_caffe2_aten_fallback",
  33. "is_complex_value",
  34. "parse_args",
  35. "pytorch_name_to_type",
  36. "quantize_helper",
  37. "quantized_args",
  38. "requantize_bias_helper",
  39. "scalar_name_to_pytorch",
  40. "scalar_type_to_onnx",
  41. "scalar_type_to_pytorch_type",
  42. ]
  43. # ---------------------------------------------------------------------------------
  44. # Helper functions
  45. # ---------------------------------------------------------------------------------
  46. _ValueDescriptor = Literal[
  47. "v",
  48. "i",
  49. "is",
  50. "f",
  51. "fs",
  52. "b",
  53. "s",
  54. "t",
  55. "none",
  56. ]
  57. @_beartype.beartype
  58. def _parse_arg(
  59. value,
  60. desc: _ValueDescriptor,
  61. arg_name: Optional[str] = None,
  62. node_name: Optional[str] = None,
  63. ):
  64. if desc == "none":
  65. return value
  66. if desc == "v" or not _is_value(value):
  67. return value
  68. node = value.node()
  69. if node.mustBeNone():
  70. return None
  71. if node.kind() == "onnx::Constant":
  72. node_val = _node_get(node, "value")
  73. if desc == "i":
  74. return int(node_val)
  75. elif desc == "f":
  76. return float(node_val)
  77. elif desc == "b":
  78. return bool(node_val)
  79. elif desc == "s":
  80. return str(node_val)
  81. elif desc == "t":
  82. return node_val
  83. elif desc == "is":
  84. return [int(v) for v in node_val]
  85. elif desc == "fs":
  86. return [float(v) for v in node_val]
  87. else:
  88. raise errors.SymbolicValueError(
  89. f"ONNX symbolic does not understand the Constant node '{node}' "
  90. f"specified with descriptor '{desc}'.",
  91. value,
  92. )
  93. elif node.kind() == "prim::ListConstruct":
  94. if desc == "is":
  95. for v in node.inputs():
  96. element_node = v.node()
  97. if element_node.kind() != "onnx::Constant":
  98. raise errors.SymbolicValueError(
  99. f"Failed to export a node '{element_node}' "
  100. f"(in list node {node}) "
  101. f"because it is not constant. "
  102. f"Please try to make things (e.g. kernel sizes) static if possible.",
  103. value,
  104. )
  105. return [int(_node_get(v.node(), "value")) for v in value.node().inputs()]
  106. else:
  107. raise errors.SymbolicValueError(
  108. f"ONNX symbolic does not know how to unpack the ListConstruct node that "
  109. f"is not a list of integers: '{node}'",
  110. value,
  111. )
  112. if arg_name is None or node_name is None:
  113. raise errors.SymbolicValueError(
  114. f"Expected node type 'onnx::Constant', got '{node.kind()}'.",
  115. value,
  116. )
  117. raise errors.SymbolicValueError(
  118. "Expected node type 'onnx::Constant' "
  119. f"for argument '{arg_name}' of node '{node_name}', got '{node.kind()}'.",
  120. value,
  121. )
  122. @_beartype.beartype
  123. def _node_get(node: _C.Node, key: str):
  124. """Gets attributes of a node which is polymorphic over return type."""
  125. assert isinstance(node, _C.Node)
  126. sel = node.kindOf(key)
  127. return getattr(node, sel)(key)
  128. @_beartype.beartype
  129. def _is_onnx_constant(value: _C.Value):
  130. """Whether a Value is an ONNX constant."""
  131. return value.node().kind() == "onnx::Constant"
  132. @_beartype.beartype
  133. def _maybe_get_const(
  134. value: Optional[Union[_C.Value, torch.Tensor, Number, Sequence]],
  135. descriptor: _ValueDescriptor,
  136. ):
  137. # NOTE: prim::Constant at this stage usually means something not compatible in ONNX,
  138. # otherwise it'd be converted to onnx::Constant
  139. # TODO(justinchuby): Replace insinstance with _is_value once we figure out mypy
  140. if isinstance(value, _C.Value) and _is_onnx_constant(value):
  141. return _parse_arg(value, descriptor)
  142. return value
  143. @_beartype.beartype
  144. def _maybe_get_scalar(value):
  145. value_t = _maybe_get_const(value, "t")
  146. if isinstance(value_t, torch.Tensor) and value_t.shape == ():
  147. return value_t
  148. return value
  149. @_beartype.beartype
  150. def _get_const(value, desc, arg_name):
  151. if not _is_constant(value):
  152. raise errors.SymbolicValueError(
  153. f"ONNX symbolic expected a constant value of the '{arg_name}' argument, "
  154. f"got '{value}'",
  155. value,
  156. )
  157. return _parse_arg(value, desc)
  158. @_beartype.beartype
  159. def _unpack_list(list_value: _C.Value) -> List[_C.Value]:
  160. list_node = list_value.node()
  161. if list_node.kind() != "prim::ListConstruct":
  162. raise errors.SymbolicValueError(
  163. f"ONNX symbolic expected node type prim::ListConstruct, "
  164. f"got '{list_node}'.",
  165. list_value,
  166. )
  167. return list(list_node.inputs())
  168. @_beartype.beartype
  169. def _unpack_tuple(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
  170. tuple_node = tuple_value.node()
  171. if not _is_tuple_construct(tuple_value):
  172. raise errors.SymbolicValueError(
  173. f"ONNX symbolic expected node type 'prim::TupleConstruct', "
  174. f"got '{tuple_node.kind()}'.",
  175. tuple_value,
  176. )
  177. return tuple(tuple_node.inputs())
  178. @_beartype.beartype
  179. def _unpack_quantized_tensor(tuple_value: _C.Value) -> Tuple[_C.Value, ...]:
  180. """Unpacks a quantized tensor into a tuple of tensor and scale/zero_point.
  181. Args:
  182. tuple_value: A tuple of tensor, scale, zero_point, and optionally axis.
  183. Returns:
  184. A tuple of tensor, scale, zero_point, and optionally axis.
  185. """
  186. tuple_node = tuple_value.node()
  187. # A quantized tensor is represented as tuple of the form (tensor, scale, zero_point, <axis>)
  188. if not _is_tuple_construct(tuple_value):
  189. raise errors.SymbolicValueError(
  190. f"ONNX symbolic expected the output of `{tuple_node}` to be a quantized "
  191. f"tensor. Is this likely due to missing support for quantized "
  192. f"`{tuple_node.kind()}`. Please create an issue on {_constants.PYTORCH_GITHUB_ISSUES_URL}",
  193. tuple_value,
  194. )
  195. unpacked = tuple(tuple_node.inputs())
  196. assert len(unpacked) == 3 or len(unpacked) == 4
  197. return unpacked
  198. # Check if list_value is output from prim::ListConstruct
  199. # This is usually called before _unpack_list to ensure the list can be unpacked.
  200. @_beartype.beartype
  201. def _is_packed_list(list_value: Any) -> bool:
  202. return _is_value(list_value) and list_value.node().kind() == "prim::ListConstruct"
  203. @_beartype.beartype
  204. def parse_args(*arg_descriptors: _ValueDescriptor):
  205. """A decorator which converts args from torch._C.Value to built-in types.
  206. For example:
  207. ```
  208. @parse_args('v', 'i', 'fs')
  209. foo(g, a, b, c):
  210. assert isinstance(a, torch._C.Value)
  211. assert isinstance(b, int)
  212. assert isinstance(c, list)
  213. assert isinstance(c[0], float)
  214. ```
  215. Args:
  216. arg_descriptors: list of str, where each element is
  217. a string that specifies the type to convert to. Valid descriptors:
  218. "v": no conversion, keep torch._C.Value.
  219. "i": int
  220. "is": list of int
  221. "f": float
  222. "fs": list of float
  223. "b": bool
  224. "s": str
  225. "t": torch.Tensor
  226. "none": the variable is unused
  227. """
  228. def decorator(fn):
  229. fn._arg_descriptors = arg_descriptors
  230. @functools.wraps(fn)
  231. def wrapper(g, *args, **kwargs):
  232. # some args may be optional, so the length may be smaller
  233. FILE_BUG_MSG = (
  234. "If you believe this is not due to custom symbolic implementation within your code or "
  235. "an external library, please file an issue at "
  236. "https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug."
  237. )
  238. assert len(arg_descriptors) >= len(args), (
  239. f"A mismatch between the number of arguments ({len(args)}) and "
  240. f"their descriptors ({len(arg_descriptors)}) was found at symbolic function '{fn.__name__}'. "
  241. f"{FILE_BUG_MSG}"
  242. )
  243. try:
  244. sig = inspect.signature(fn)
  245. arg_names = list(sig.parameters.keys())[1:]
  246. fn_name = fn.__name__
  247. except Exception:
  248. # FIXME(justinchuby): Avoid catching Exception.
  249. # Catch a more specific exception instead.
  250. arg_names = [None] * len(args) # type: ignore[list-item]
  251. fn_name = None
  252. args = [
  253. _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[assignment]
  254. for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)
  255. ]
  256. # only support _outputs in kwargs
  257. assert len(kwargs) <= 1, (
  258. f"Symbolic function {fn.__name__}'s '**kwargs' can contain a single "
  259. f"key/value entry. "
  260. f"{FILE_BUG_MSG}"
  261. )
  262. if len(kwargs) == 1:
  263. assert "_outputs" in kwargs, (
  264. f"Symbolic function {fn.__name__}'s '**kwargs' can only contain "
  265. f"'_outputs' key at '**kwargs'. "
  266. f"{FILE_BUG_MSG}"
  267. )
  268. return fn(g, *args, **kwargs)
  269. return wrapper
  270. return decorator
  271. @_beartype.beartype
  272. def quantized_args(
  273. *arg_q_descriptors: bool,
  274. scale: Optional[float] = None,
  275. zero_point: Optional[int] = None,
  276. ):
  277. """A decorator which extends support for quantized version of the base operator.
  278. Quantization is detected by examining the arguments that are annotated by
  279. `arg_q_descriptors`.
  280. If quantization is detected, the base operator symbolic function will be wrapped with
  281. argument de-quantization and output quantization.
  282. Otherwise, only the base symbolic function will be invoked.
  283. For example:
  284. ```
  285. @quantized_args(True, False)
  286. def foo(g, x, y):
  287. return x + y
  288. ```
  289. is equivalent to
  290. ```
  291. def q_foo(g, x, y):
  292. if is_quantized_tensor(x):
  293. x = dequantize(x)
  294. out = foo(g, x, y)
  295. return quantize(out)
  296. else:
  297. return foo(g, x, y)
  298. ```
  299. Args:
  300. arg_q_descriptors: A sequence of bool, where each element represents if the
  301. argument is QTensor for quantized version of this operator. It defaults
  302. to False for unspecified (variable length) arguments.
  303. scale: Quantized output scale. If None, derive from
  304. the first quantized input scale.
  305. zero_point: Quantized output zero point. If None,
  306. derive from the first quantized input zero point.
  307. """
  308. def decorator(fn):
  309. @functools.wraps(fn)
  310. def wrapper(g, *args, **kwargs):
  311. nonlocal scale
  312. nonlocal zero_point
  313. if scale is not None:
  314. _scale = g.op("Constant", value_t=torch.tensor(scale))
  315. else:
  316. _scale = None
  317. if zero_point is not None:
  318. _zero_point = g.op("Constant", value_t=torch.tensor(zero_point))
  319. else:
  320. _zero_point = None
  321. # Support variable length arguments by marking unspecified ones as non-quantized
  322. arg_q_descriptors_extended = arg_q_descriptors + (False,) * (
  323. len(args) - len(arg_q_descriptors)
  324. )
  325. descriptor_args = tuple(zip(arg_q_descriptors_extended, args))
  326. def _is_arg_quantized(descriptor, arg):
  327. return descriptor and _is_value(arg) and _is_tuple_construct(arg)
  328. # Run regular symbolic function if none of the argument is QTensor.
  329. is_quantized = list()
  330. for descriptor, arg in descriptor_args:
  331. # ListConstruct
  332. if _is_packed_list(arg):
  333. for arg_input in arg.node().inputs():
  334. is_quantized.append(_is_arg_quantized(descriptor, arg_input))
  335. else:
  336. is_quantized.append(_is_arg_quantized(descriptor, arg))
  337. if not any(is_quantized):
  338. return fn(g, *args, **kwargs)
  339. # Dequantize arguments that are quantized
  340. non_quantized_args = []
  341. for descriptor, arg in descriptor_args:
  342. if _is_arg_quantized(descriptor, arg):
  343. # Quantized arg is a tuple of (value, scale, zero_point)
  344. dequantized_arg, arg_scale, arg_zero_point, _ = dequantize_helper(
  345. g, arg
  346. )
  347. non_quantized_args.append(dequantized_arg)
  348. # Set scale and zero_point to the first quantized input if not already set
  349. if _scale is None:
  350. _scale = arg_scale
  351. if _zero_point is None:
  352. _zero_point = arg_zero_point
  353. # ListConstruct
  354. elif _is_packed_list(arg):
  355. for arg_input in arg.node().inputs():
  356. if _is_arg_quantized(descriptor, arg_input):
  357. # Quantized arg is a tuple of (value, scale, zero_point)
  358. (
  359. dequantized_arg,
  360. arg_scale,
  361. arg_zero_point,
  362. _,
  363. ) = dequantize_helper(g, arg_input)
  364. # Set scale and zero_point to the first quantized input if not already set
  365. if _scale is None:
  366. _scale = arg_scale
  367. if _zero_point is None:
  368. _zero_point = arg_zero_point
  369. arg_input.replaceAllUsesWith(dequantized_arg)
  370. non_quantized_args.append(arg)
  371. else:
  372. # Non-quantized arg
  373. non_quantized_args.append(arg)
  374. # TODO(justinchuby): Only single output is supported for now. We may want to
  375. # support multiple outputs in the future.
  376. output = fn(g, *non_quantized_args, **kwargs)
  377. assert _scale is not None, "Bug: Scale must be set for quantized operator"
  378. assert (
  379. _zero_point is not None
  380. ), "Bug: Zero point must be set for quantized operator"
  381. return quantize_helper(g, output, _scale, _zero_point)
  382. return wrapper
  383. return decorator
  384. @_beartype.beartype
  385. def _scalar(x: Any) -> Optional[Number]:
  386. """Convert a scalar tensor into a Python value."""
  387. if isinstance(x, torch.Tensor) and x.shape == ():
  388. return x.item()
  389. return None
  390. @_beartype.beartype
  391. def _if_scalar_type_as(self, tensor):
  392. """
  393. Convert self into the same type of tensor, as necessary.
  394. We only support implicit casting for scalars, so we never
  395. actually need to insert an ONNX cast operator here; just
  396. fix up the scalar.
  397. """
  398. if isinstance(self, _C.Value):
  399. return self
  400. scalar_type = _type_utils.JitScalarType.from_value(
  401. tensor, _type_utils.JitScalarType.UNDEFINED
  402. )
  403. if scalar_type != _type_utils.JitScalarType.UNDEFINED:
  404. ty = scalar_type.scalar_name().lower()
  405. return getattr(self, ty)()
  406. return self
  407. @_beartype.beartype
  408. def _is_none(x: _C.Value) -> bool:
  409. return x.node().mustBeNone()
  410. @_beartype.beartype
  411. def _is_value(x: Any) -> bool:
  412. return isinstance(x, _C.Value)
  413. @_beartype.beartype
  414. def _is_constant(value: Any) -> bool:
  415. return not _is_value(value) or value.node().kind() in {
  416. "onnx::Constant",
  417. "prim::Constant",
  418. }
  419. @_beartype.beartype
  420. def _is_tensor(x: _C.Value) -> bool:
  421. return x.type().isSubtypeOf(_C.TensorType.get())
  422. # Note: _C.JitType is not exposed to Python and cannot be checked in runtime.
  423. def _as_list_type(jit_type: _C.JitType) -> Optional[_C.ListType]:
  424. if isinstance(jit_type, _C.ListType):
  425. return jit_type
  426. return None
  427. @_beartype.beartype
  428. def _is_list(x: _C.Value) -> bool:
  429. return _as_list_type(x.type()) is not None
  430. @_beartype.beartype
  431. def _is_tensor_list(x: _C.Value) -> bool:
  432. x_type = _as_list_type(x.type())
  433. if x_type is None:
  434. return False
  435. return isinstance(x_type.getElementType(), _C.TensorType)
  436. @_beartype.beartype
  437. def _is_scalar_list(x: _C.Value) -> bool:
  438. """Checks if x is a scalar list, for example: List[float], List[int].
  439. Besides checking the type is ListType, we also check if the data type is
  440. a valid ONNX data type.
  441. """
  442. x_type = _as_list_type(x.type())
  443. if x_type is None:
  444. return False
  445. scalar_type = _type_utils.JitScalarType.from_value(x)
  446. return scalar_type.onnx_compatible()
  447. @_beartype.beartype
  448. def _is_tuple_construct(x: _C.Value) -> bool:
  449. return x.node().kind() == "prim::TupleConstruct"
  450. @_beartype.beartype
  451. def is_complex_value(x: _C.Value) -> bool:
  452. assert _is_value(x)
  453. return _type_utils.JitScalarType.from_value(
  454. x, _type_utils.JitScalarType.UNDEFINED
  455. ) in {
  456. _type_utils.JitScalarType.COMPLEX32,
  457. _type_utils.JitScalarType.COMPLEX64,
  458. _type_utils.JitScalarType.COMPLEX128,
  459. }
  460. @_beartype.beartype
  461. def is_caffe2_aten_fallback() -> bool:
  462. return (
  463. GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
  464. and _C_onnx._CAFFE2_ATEN_FALLBACK
  465. )
  466. @_beartype.beartype
  467. def _get_tensor_rank(x: _C.Value) -> Optional[int]:
  468. if not _is_tensor(x) or x.type() is None:
  469. return None
  470. x_type = x.type()
  471. x_type = typing.cast(_C.TensorType, x_type)
  472. return x_type.dim()
  473. @_beartype.beartype
  474. def _get_tensor_sizes(x: _C.Value, allow_nonstatic: bool = True):
  475. if not _is_tensor(x) or x.type() is None:
  476. return None
  477. x_type = x.type()
  478. x_type = typing.cast(_C.TensorType, x_type)
  479. if allow_nonstatic:
  480. # Each individual symbol is returned as None.
  481. # e.g. [1, "a", "b"] -> [1, None, None]
  482. return x_type.varyingSizes()
  483. # returns None, if exists any symbol in sizes.
  484. # e.g. [1, "a", "b"] -> None
  485. return x_type.sizes()
  486. @_beartype.beartype
  487. def _get_tensor_dim_size(x: _C.Value, dim: int) -> Optional[int]:
  488. sizes = _get_tensor_sizes(x)
  489. return sizes[dim] if sizes else None
  490. @_beartype.beartype
  491. def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
  492. if dim == -1:
  493. tensor_rank = _get_tensor_rank(x)
  494. assert tensor_rank is not None
  495. return dim + tensor_rank
  496. # If dim is not given, it defaults to the first dimension found with the size 3
  497. if dim is None:
  498. sizes = _get_tensor_sizes(x)
  499. assert sizes is not None
  500. for index, size in enumerate(sizes):
  501. if size is not None and size == 3:
  502. return index
  503. return dim
  504. @_beartype.beartype
  505. def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
  506. # For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
  507. if _C_onnx._CAFFE2_ATEN_FALLBACK:
  508. warnings.warn(f"ONNX export failed on {op} because {msg} not supported")
  509. elif GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
  510. _onnx_unsupported(f"{op}, {msg}", value)
  511. @_beartype.beartype
  512. def _onnx_unsupported(op_name: str, value: Optional[_C.Value] = None) -> NoReturn:
  513. message = (
  514. f"Unsupported: ONNX export of operator {op_name}. "
  515. f"Please feel free to request support or submit a pull request "
  516. f"on PyTorch GitHub: {_constants.PYTORCH_GITHUB_ISSUES_URL}"
  517. )
  518. if isinstance(value, _C.Value):
  519. raise errors.SymbolicValueError(
  520. message,
  521. value,
  522. )
  523. raise errors.OnnxExporterError(message)
  524. @_beartype.beartype
  525. def _onnx_opset_unsupported(
  526. op_name: str,
  527. current_opset: int,
  528. supported_opset: int,
  529. value: Optional[_C.Value] = None,
  530. ) -> NoReturn:
  531. message = (
  532. f"Unsupported: ONNX export of {op_name} in opset {current_opset}. "
  533. f"Please try opset version {supported_opset}."
  534. )
  535. if isinstance(value, _C.Value):
  536. raise errors.SymbolicValueError(
  537. message,
  538. value,
  539. )
  540. raise errors.OnnxExporterError(message)
  541. @_beartype.beartype
  542. def _onnx_opset_unsupported_detailed(
  543. op_name: str,
  544. current_opset: int,
  545. supported_opset: int,
  546. reason: str,
  547. value: Optional[_C.Value] = None,
  548. ) -> NoReturn:
  549. message = (
  550. f"Unsupported: ONNX export of {op_name} in "
  551. f"opset {current_opset}. {reason}. Please try opset version {supported_opset}."
  552. )
  553. if isinstance(value, _C.Value):
  554. raise errors.SymbolicValueError(
  555. message,
  556. value,
  557. )
  558. raise errors.OnnxExporterError(message)
  559. @_beartype.beartype
  560. def _block_list_in_opset(name: str):
  561. def symbolic_fn(*args, **kwargs):
  562. raise errors.OnnxExporterError(
  563. f"ONNX export failed on {name}, which is not implemented for opset "
  564. f"{GLOBALS.export_onnx_opset_version}. "
  565. "Try exporting with other opset versions."
  566. )
  567. return symbolic_fn
  568. @_beartype.beartype
  569. def _try_get_scalar_type(*args) -> Optional[_type_utils.JitScalarType]:
  570. for arg in args:
  571. scalar_type = _type_utils.JitScalarType.from_value(
  572. arg, _type_utils.JitScalarType.UNDEFINED
  573. )
  574. if scalar_type != _type_utils.JitScalarType.UNDEFINED:
  575. return scalar_type
  576. return None
  577. @_beartype.beartype
  578. def _select_helper(g: jit_utils.GraphContext, self, dim, index, apply_reshape=True):
  579. index_const = _maybe_get_scalar(index)
  580. index_dim = _get_tensor_rank(index)
  581. if not _is_value(index_const):
  582. # Index is a constant scalar. Make it a size 1 constant tensor.
  583. index = g.op("Constant", value_t=torch.LongTensor([index_const]))
  584. elif index_dim is not None and apply_reshape:
  585. if index_dim == 0:
  586. # Index is a scalar. Reshape it to a size 1 tensor.
  587. index = _reshape_helper(
  588. g, index, g.op("Constant", value_t=torch.LongTensor([1]))
  589. )
  590. index_scalar_type = _type_utils.JitScalarType.from_value(
  591. index, _type_utils.JitScalarType.UNDEFINED
  592. )
  593. if index_scalar_type not in {
  594. _type_utils.JitScalarType.INT64,
  595. _type_utils.JitScalarType.INT,
  596. }:
  597. index = g.op("Cast", index, to_i=_C_onnx.TensorProtoDataType.INT64)
  598. return g.op("Gather", self, index, axis_i=dim)
  599. @_beartype.beartype
  600. def _slice_helper(
  601. g: jit_utils.GraphContext,
  602. input,
  603. axes,
  604. starts,
  605. ends,
  606. steps=None,
  607. dynamic_slice=False,
  608. ):
  609. if g.opset <= 9:
  610. from torch.onnx.symbolic_opset9 import _slice as _slice9
  611. return _slice9(g, input, axes, starts, ends)
  612. else:
  613. from torch.onnx.symbolic_opset10 import _slice as _slice10
  614. return _slice10(g, input, axes, starts, ends, steps, dynamic_slice)
  615. @_beartype.beartype
  616. def _is_fp(value) -> bool:
  617. return _type_utils.JitScalarType.from_value(
  618. value, _type_utils.JitScalarType.UNDEFINED
  619. ) in {
  620. _type_utils.JitScalarType.FLOAT,
  621. _type_utils.JitScalarType.DOUBLE,
  622. _type_utils.JitScalarType.HALF,
  623. _type_utils.JitScalarType.BFLOAT16,
  624. }
  625. @_beartype.beartype
  626. def _is_bool(value) -> bool:
  627. return _type_utils.JitScalarType.from_value(
  628. value, _type_utils.JitScalarType.UNDEFINED
  629. ) in {_type_utils.JitScalarType.BOOL}
  630. @_beartype.beartype
  631. def _generate_wrapped_number(g: jit_utils.GraphContext, scalar):
  632. """Creates a wrapped number based on https://github.com/pytorch/pytorch/issues/9515.
  633. A Tensor is a considered a "wrapped number" if it is
  634. auto-wrapped from a C++ or Python number type. Integer types are
  635. wrapped as 0-dim int64 tensors and floating-point types are
  636. wrapped as 0-dim double tensors.
  637. The input to this function is constant value. If the data type
  638. is a floating point type, it is converted to a 0-dim double
  639. tensor, else it is converted to a 0-dim tensor of its original type
  640. """
  641. assert not isinstance(scalar, torch.Tensor)
  642. if isinstance(scalar, float):
  643. return g.op("Constant", value_t=torch.tensor(scalar, dtype=torch.double))
  644. return g.op("Constant", value_t=torch.tensor(scalar))
  645. @_beartype.beartype
  646. def _sort_helper(g: jit_utils.GraphContext, input, dim, decending=True, out=None):
  647. if out is not None:
  648. _unimplemented("Sort", "Out parameter is not supported")
  649. shape_ = g.op("Shape", input)
  650. dim_size_ = g.op(
  651. "Gather",
  652. shape_,
  653. g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)),
  654. )
  655. if g.opset <= 10:
  656. if not decending:
  657. _unimplemented("Sort", "Ascending is not supported")
  658. return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
  659. else:
  660. return g.op(
  661. "TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2
  662. )
  663. @_beartype.beartype
  664. def _topk_helper(
  665. g: jit_utils.GraphContext, input, k, dim, largest=True, sorted=False, out=None
  666. ):
  667. if out is not None:
  668. _unimplemented("TopK", "Out parameter is not supported")
  669. if not _is_value(k):
  670. k = g.op("Constant", value_t=torch.tensor([k], dtype=torch.int64))
  671. else:
  672. k = _reshape_helper(g, k, g.op("Constant", value_t=torch.tensor([1])))
  673. if _try_get_scalar_type(k) != _type_utils.JitScalarType.INT64:
  674. k = g.op("Cast", k, to_i=_C_onnx.TensorProtoDataType.INT64)
  675. if g.opset <= 10:
  676. if not largest:
  677. _unimplemented("TopK", "Ascending is not supported")
  678. return g.op("TopK", input, k, axis_i=dim, outputs=2)
  679. else:
  680. return g.op(
  681. "TopK", input, k, axis_i=dim, largest_i=largest, sorted_i=sorted, outputs=2
  682. )
  683. @_beartype.beartype
  684. def _lt_helper(g: jit_utils.GraphContext, input, other):
  685. if g.opset <= 8:
  686. from torch.onnx.symbolic_opset8 import lt as _lt8
  687. return _lt8(g, input, other)
  688. else:
  689. from torch.onnx.symbolic_opset9 import lt as _lt9
  690. return _lt9(g, input, other)
  691. @_beartype.beartype
  692. def _interpolate_warning(interpolate_mode):
  693. onnx_op = (
  694. "onnx:Resize" if GLOBALS.export_onnx_opset_version >= 10 else "onnx:Upsample"
  695. )
  696. warnings.warn(
  697. "You are trying to export the model with "
  698. + onnx_op
  699. + " for ONNX opset version "
  700. "" + str(GLOBALS.export_onnx_opset_version) + ". "
  701. "This operator might cause results to not match the expected results by PyTorch.\n"
  702. "ONNX's Upsample/Resize operator did not match Pytorch's Interpolation until opset 11. "
  703. "Attributes to determine how to transform the input were added in onnx:Resize in opset 11 "
  704. "to support Pytorch's behavior (like coordinate_transformation_mode and nearest_mode).\n"
  705. "We recommend using opset 11 and above for models using this operator."
  706. )
  707. @_beartype.beartype
  708. def _unsqueeze_helper(g: jit_utils.GraphContext, input, axes_i):
  709. if _is_constant(axes_i[0]):
  710. if g.opset >= 13:
  711. axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
  712. return g.op("Unsqueeze", input, axes)
  713. return g.op("Unsqueeze", input, axes_i=axes_i)
  714. # Tensor type
  715. if g.opset < 13:
  716. raise errors.SymbolicValueError(
  717. "Opset version must be >= 13 for Unsqueeze with dynamic axes.", input
  718. )
  719. return g.op("Unsqueeze", input, axes_i[0])
  720. @_beartype.beartype
  721. def _squeeze_helper(g: jit_utils.GraphContext, input, axes_i):
  722. if _is_constant(axes_i[0]):
  723. if g.opset >= 13:
  724. axes = g.op("Constant", value_t=torch.tensor(axes_i, dtype=torch.long))
  725. return g.op("Squeeze", input, axes)
  726. return g.op("Squeeze", input, axes_i=axes_i)
  727. # Tensor type
  728. if g.opset < 13:
  729. raise errors.SymbolicValueError(
  730. "Opset version must be >= 13 for Squeeze with dynamic axes.", input
  731. )
  732. axes_t = axes_i[0]
  733. axes_rank = _get_tensor_rank(axes_t)
  734. assert axes_rank is not None
  735. if axes_rank > 1:
  736. raise errors.SymbolicValueError(
  737. "For Squeeze axses as input, the axes rank must be one in ONNX spec.", input
  738. )
  739. elif axes_rank == 0:
  740. # The axes is a scalar. Unsqueeze it to a rank 1 tensor.
  741. axes_t = _unsqueeze_helper(g, axes_t, [0])
  742. return g.op("Squeeze", input, axes_t)
  743. return g.op("Squeeze", input, axes_t)
  744. @_beartype.beartype
  745. def _reducesum_helper(
  746. g: jit_utils.GraphContext,
  747. input,
  748. axes_i=None,
  749. keepdims_i=1,
  750. noop_with_empty_axes_i=0,
  751. ):
  752. keepdims_i = _maybe_get_const(keepdims_i, "i")
  753. if g.opset >= 13:
  754. if axes_i:
  755. if not _is_value(axes_i):
  756. axes_i = g.op(
  757. "Constant", value_t=torch.tensor(axes_i, dtype=torch.long)
  758. )
  759. return g.op(
  760. "ReduceSum",
  761. input,
  762. axes_i,
  763. keepdims_i=keepdims_i,
  764. noop_with_empty_axes_i=noop_with_empty_axes_i,
  765. )
  766. return g.op(
  767. "ReduceSum",
  768. input,
  769. keepdims_i=keepdims_i,
  770. noop_with_empty_axes_i=noop_with_empty_axes_i,
  771. )
  772. else:
  773. return g.op("ReduceSum", input, axes_i=axes_i, keepdims_i=keepdims_i)
  774. @_beartype.beartype
  775. def _interpolate_size_to_scales(g: jit_utils.GraphContext, input, output_size, dim):
  776. output_size = _maybe_get_const(output_size, "is")
  777. if _is_value(output_size):
  778. offset = 2
  779. offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.float32))
  780. dividend = g.op("Cast", output_size, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  781. divisor = _slice_helper(
  782. g, g.op("Shape", input), axes=[0], ends=[sys.maxsize], starts=[offset]
  783. )
  784. divisor = g.op("Cast", divisor, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  785. scale_dims = g.op("Div", dividend, divisor)
  786. scales = g.op("Concat", offsets, scale_dims, axis_i=0)
  787. else:
  788. scales_constant = [
  789. 1.0
  790. if i < 2
  791. else float(output_size[-(dim - i)])
  792. / float(input.type().sizes()[-(dim - i)])
  793. for i in range(0, dim)
  794. ]
  795. scales = g.op(
  796. "Constant", value_t=torch.tensor(scales_constant, dtype=torch.float32)
  797. )
  798. return scales
  799. @_beartype.beartype
  800. def _interpolate_get_scales_if_available(g: jit_utils.GraphContext, scales):
  801. available_scales = _maybe_get_const(scales[0], "fs") != -1 and not _is_none(
  802. scales[0]
  803. )
  804. if not available_scales:
  805. return None
  806. offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
  807. scales_list = g.op(
  808. "Constant", value_t=torch.tensor(_maybe_get_const(scales[0], "fs"))
  809. )
  810. scales = g.op("Concat", offsets, scales_list, axis_i=0)
  811. return scales
  812. @_beartype.beartype
  813. def _get_interpolate_attributes(g: jit_utils.GraphContext, mode, args):
  814. if mode == "nearest":
  815. align_corners = None
  816. scales = args[0:]
  817. else:
  818. align_corners = args[0]
  819. scales = args[1:]
  820. scales = _interpolate_get_scales_if_available(g, scales)
  821. return scales, align_corners
  822. @_beartype.beartype
  823. def _interpolate_get_scales(g: jit_utils.GraphContext, scale_factor, dim):
  824. offsets = g.op("Constant", value_t=torch.ones(2, dtype=torch.float32))
  825. scale_factor_rank = _get_tensor_rank(scale_factor)
  826. if isinstance(scale_factor.type(), _C.ListType) or (
  827. scale_factor_rank is not None and scale_factor_rank > 0
  828. ):
  829. return g.op("Concat", offsets, scale_factor, axis_i=0)
  830. else:
  831. scale_factor = _unsqueeze_helper(g, scale_factor, [0])
  832. scale_factor = g.op(
  833. "Cast", scale_factor, to_i=_C_onnx.TensorProtoDataType.FLOAT
  834. )
  835. scales = [scale_factor for i in range(dim - 2)]
  836. scale_factor = g.op("Concat", offsets, *scales, axis_i=0)
  837. return scale_factor
  838. @_beartype.beartype
  839. def _interpolate_get_scales_and_mode(
  840. g: jit_utils.GraphContext, input, size, scale_factor, mode, align_corners
  841. ):
  842. mode = _maybe_get_const(mode, "s")
  843. if "linear" in mode:
  844. mode = "linear"
  845. if "cubic" in mode:
  846. mode = "cubic"
  847. _interpolate_warning(mode)
  848. align_corners = _maybe_get_const(align_corners, "b")
  849. if isinstance(align_corners, bool) and align_corners:
  850. return _unimplemented("interpolate", "align_corners == True")
  851. if not input.type().dim():
  852. return _unimplemented("interpolate", "missing input shape")
  853. dim = input.type().dim()
  854. if not _is_none(scale_factor):
  855. scale_factor = _interpolate_get_scales(g, scale_factor, dim)
  856. elif not _is_none(size):
  857. if not _is_packed_list(size):
  858. is_scalar = _maybe_get_const(size, "t").dim() == 0
  859. if is_scalar:
  860. size = _unsqueeze_helper(g, size, [0])
  861. size = [size for i in range(dim - 2)]
  862. size = g.op("Concat", *size, axis_i=0)
  863. scale_factor = _interpolate_size_to_scales(g, input, size, dim)
  864. else:
  865. return _unimplemented(
  866. "interpolate", "Both size and scales are None in __interpolate"
  867. )
  868. return scale_factor, mode
  869. @_beartype.beartype
  870. def _argmin_argmax_helper(
  871. g: jit_utils.GraphContext,
  872. input: torch._C.Value,
  873. dim: torch._C.Value,
  874. keepdim: bool,
  875. op_name: str,
  876. ):
  877. def op_wrapper(input, axis_i, keepdims_i):
  878. if g.opset >= 12:
  879. return g.op(
  880. op_name,
  881. input,
  882. axis_i=axis_i,
  883. keepdims_i=keepdims_i,
  884. select_last_index_i=False,
  885. )
  886. return g.op(op_name, input, axis_i=axis_i, keepdims_i=keepdims_i)
  887. if _is_none(dim):
  888. flattened = _reshape_helper(
  889. g, input, g.op("Constant", value_t=torch.tensor([-1]))
  890. )
  891. output = op_wrapper(flattened, axis_i=0, keepdims_i=False)
  892. if keepdim:
  893. input_shape = g.op("Shape", input)
  894. input_shape_shape = g.op("Shape", input_shape)
  895. new_shape = g.op(
  896. "ConstantOfShape",
  897. input_shape_shape,
  898. value_t=torch.tensor([1], dtype=torch.int64),
  899. )
  900. output = g.op("Reshape", output, new_shape)
  901. return output
  902. dim = _parse_arg(dim, "i")
  903. return op_wrapper(input, axis_i=dim, keepdims_i=keepdim)
  904. @_beartype.beartype
  905. def _interpolate_helper(name, dim, interpolate_mode):
  906. @quantized_args(True, False, False)
  907. def symbolic_fn(g, input, output_size, *args):
  908. scales, align_corners = _get_interpolate_attributes(g, interpolate_mode, args)
  909. align_corners = _maybe_get_scalar(align_corners)
  910. coordinate_transformation_mode = (
  911. "asymmetric"
  912. if interpolate_mode == "nearest"
  913. else "align_corners"
  914. if align_corners
  915. else "half_pixel"
  916. )
  917. if scales is None:
  918. input_size = g.op("Shape", input)
  919. input_size_beg = _slice_helper(
  920. g, input_size, axes=[0], ends=[2], starts=[0]
  921. )
  922. output_size = g.op(
  923. "Cast", output_size, to_i=_C_onnx.TensorProtoDataType.INT64
  924. )
  925. output_size = g.op("Concat", input_size_beg, output_size, axis_i=0)
  926. if g.opset >= 13:
  927. empty_roi = _optional_input_placeholder_tensor(g)
  928. empty_scales = _optional_input_placeholder_tensor(g)
  929. else:
  930. empty_roi = g.op(
  931. "Constant", value_t=torch.tensor([], dtype=torch.float32)
  932. )
  933. empty_scales = g.op(
  934. "Constant", value_t=torch.tensor([], dtype=torch.float32)
  935. )
  936. return g.op(
  937. "Resize",
  938. input,
  939. empty_roi,
  940. empty_scales,
  941. output_size,
  942. coordinate_transformation_mode_s=coordinate_transformation_mode,
  943. cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
  944. mode_s=interpolate_mode, # nearest, linear, or cubic
  945. nearest_mode_s="floor",
  946. ) # only valid when mode="nearest"
  947. else:
  948. if g.opset >= 13:
  949. empty_roi = _optional_input_placeholder_tensor(g)
  950. else:
  951. empty_roi = g.op(
  952. "Constant", value_t=torch.tensor([], dtype=torch.float32)
  953. )
  954. return g.op(
  955. "Resize",
  956. input,
  957. empty_roi,
  958. scales,
  959. coordinate_transformation_mode_s=coordinate_transformation_mode,
  960. cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
  961. mode_s=interpolate_mode, # nearest, linear, or cubic
  962. nearest_mode_s="floor",
  963. ) # only valid when mode="nearest"
  964. return symbolic_fn
  965. @_beartype.beartype
  966. def __interpolate_helper(
  967. g: jit_utils.GraphContext,
  968. input,
  969. size,
  970. scale_factor,
  971. mode,
  972. align_corners,
  973. recompute_scale_factor,
  974. ):
  975. mode = _maybe_get_const(mode, "s")
  976. if "linear" in mode:
  977. mode = "linear"
  978. if "cubic" in mode:
  979. mode = "cubic"
  980. align_corners = _maybe_get_const(align_corners, "b")
  981. align_corners = False if not isinstance(align_corners, bool) else align_corners
  982. coordinate_transformation_mode = (
  983. "asymmetric"
  984. if mode == "nearest"
  985. else "align_corners"
  986. if align_corners
  987. else "half_pixel"
  988. )
  989. if not _is_none(size):
  990. input_size = g.op("Shape", input)
  991. input_size = _slice_helper(g, input_size, axes=[0], ends=[2], starts=[0])
  992. # in some cases size is not a packed list but size is a scalar
  993. # We need to also verify that (_maybe_get_const(size, "t").dim() == 0)
  994. # but this information is not always available. Try to get the dim,
  995. # and if not assume that it is not a scalar.
  996. try:
  997. is_scalar = not _is_packed_list(size) and (
  998. _maybe_get_const(size, "t").dim() == 0
  999. )
  1000. except AttributeError:
  1001. is_scalar = not _is_packed_list(size)
  1002. if not is_scalar:
  1003. warnings.warn(
  1004. "Cannot verify if the output_size is a scalar "
  1005. "while exporting interpolate. Assuming that it is not a scalar."
  1006. )
  1007. if is_scalar:
  1008. rank = _get_tensor_rank(input)
  1009. if rank is None:
  1010. return _unimplemented(
  1011. "interpolate (with a scalar output_size)",
  1012. "missing input shape (try giving an array of output_size values)",
  1013. )
  1014. size = _unsqueeze_helper(g, size, [0])
  1015. size = [size for i in range(rank - 2)]
  1016. size = g.op("Concat", *size, axis_i=0)
  1017. size = g.op("Cast", size, to_i=_C_onnx.TensorProtoDataType.INT64)
  1018. size = g.op("Concat", input_size, size, axis_i=0)
  1019. if g.opset >= 13:
  1020. empty_roi = _optional_input_placeholder_tensor(g)
  1021. empty_scales = _optional_input_placeholder_tensor(g)
  1022. else:
  1023. empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
  1024. empty_scales = g.op(
  1025. "Constant", value_t=torch.tensor([], dtype=torch.float32)
  1026. )
  1027. return g.op(
  1028. "Resize",
  1029. input,
  1030. empty_roi,
  1031. empty_scales,
  1032. size,
  1033. coordinate_transformation_mode_s=coordinate_transformation_mode,
  1034. cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
  1035. mode_s=mode, # nearest, linear, or cubic
  1036. nearest_mode_s="floor",
  1037. )
  1038. else: # if not _is_none(scales)
  1039. rank = _get_tensor_rank(input)
  1040. if rank is None:
  1041. return _unimplemented("interpolate (with scales)", "missing input shape")
  1042. if g.opset >= 13:
  1043. empty_roi = _optional_input_placeholder_tensor(g)
  1044. else:
  1045. empty_roi = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
  1046. scales = _interpolate_get_scales(g, scale_factor, rank)
  1047. return g.op(
  1048. "Resize",
  1049. input,
  1050. empty_roi,
  1051. scales,
  1052. coordinate_transformation_mode_s=coordinate_transformation_mode,
  1053. cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
  1054. mode_s=mode, # nearest, linear, or cubic
  1055. nearest_mode_s="floor",
  1056. ) # only valid when mode="nearest"
  1057. @_beartype.beartype
  1058. def _unbind_helper(g: jit_utils.GraphContext, self, dim, _outputs):
  1059. if g.opset < 11:
  1060. from torch.onnx.symbolic_opset9 import unbind
  1061. elif g.opset <= 12:
  1062. from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
  1063. else:
  1064. from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef]
  1065. return unbind(g, self, dim, _outputs)
  1066. @_beartype.beartype
  1067. def _scatter_helper(g: jit_utils.GraphContext, self, dim, index, src):
  1068. if g.opset <= 10:
  1069. from torch.onnx.symbolic_opset9 import scatter
  1070. else:
  1071. # for mypy, scatter was imported two lines above
  1072. from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef]
  1073. return scatter(g, self, dim, index, src)
  1074. @_beartype.beartype
  1075. def _repeat_interleave_split_helper(g: jit_utils.GraphContext, self, reps, dim):
  1076. if g.opset <= 12:
  1077. split_out = g.op("Split", self, split_i=[1] * reps, axis_i=dim, outputs=reps)
  1078. else:
  1079. from torch.onnx.symbolic_opset13 import split
  1080. repeats = g.op("Constant", value_t=torch.tensor([1] * reps))
  1081. split_out = split(g, self, repeats, dim, _outputs=reps)
  1082. return split_out if reps > 1 else [split_out]
  1083. @_beartype.beartype
  1084. def _arange_cast_helper(
  1085. g: jit_utils.GraphContext, end, start=None, step=None, dtype=None
  1086. ) -> Tuple[
  1087. _type_utils.JitScalarType,
  1088. Optional[_C.Value],
  1089. Optional[_C.Value],
  1090. Optional[_C.Value],
  1091. ]:
  1092. def _is_all_integral(scalars):
  1093. for scalar in scalars:
  1094. scalar_type = _type_utils.JitScalarType.from_value(
  1095. scalar, _type_utils.JitScalarType.UNDEFINED
  1096. )
  1097. if (
  1098. scalar_type != _type_utils.JitScalarType.INT64
  1099. and scalar_type != _type_utils.JitScalarType.UNDEFINED
  1100. ):
  1101. return False
  1102. return True
  1103. # This logic is based on torch.arange docs. If "dtype" is provided,
  1104. # infer input types from dtype. If not, then check if any of start, stop,
  1105. # or step are floating point, and infer the type from get_default.
  1106. # Otherwise, the dtype is inferred to be torch.int64.
  1107. if dtype is None or (_is_value(dtype) and _is_none(dtype)):
  1108. if _is_all_integral([start, end, step]):
  1109. scalar_type = _type_utils.JitScalarType.INT64
  1110. else:
  1111. scalar_type = _type_utils.JitScalarType.from_dtype(
  1112. torch.get_default_dtype()
  1113. )
  1114. else:
  1115. assert isinstance(dtype, int)
  1116. # TODO(justinchuby): Check if dtype is indeed a int.
  1117. scalar_type = _type_utils.JitScalarType(dtype)
  1118. start = g.op("Cast", start, to_i=scalar_type.onnx_type()) if start else None
  1119. end = g.op("Cast", end, to_i=scalar_type.onnx_type()) if end else None
  1120. step = g.op("Cast", step, to_i=scalar_type.onnx_type()) if step else None
  1121. return scalar_type, end, start, step
  1122. @_beartype.beartype
  1123. def _arange_helper(g: jit_utils.GraphContext, *args):
  1124. if g.opset <= 10:
  1125. from torch.onnx.symbolic_opset9 import arange
  1126. else:
  1127. from torch.onnx.symbolic_opset11 import arange # type: ignore[no-redef]
  1128. return arange(g, *args)
  1129. @_beartype.beartype
  1130. def _size_helper(g: jit_utils.GraphContext, self, dim):
  1131. full_shape = g.op("Shape", self)
  1132. from torch.onnx.symbolic_opset9 import select
  1133. return select(g, full_shape, g.op("Constant", value_t=torch.tensor([0])), dim)
  1134. @_beartype.beartype
  1135. def _index_fill_reshape_helper(g: jit_utils.GraphContext, self, dim, index):
  1136. # 1. reshape index => [1, ..., 1, dim, 1, ..., 1]
  1137. # 2. expand index => [..., dim, ...], same shape as self except for dim.
  1138. # 3. expand value as well.
  1139. # 4. apply onnx::scatter.
  1140. from torch.onnx.symbolic_opset9 import expand
  1141. if g.opset <= 10:
  1142. from torch.onnx.symbolic_opset9 import scatter
  1143. else:
  1144. # for mypy, scatter was imported two lines above
  1145. from torch.onnx.symbolic_opset11 import scatter # type: ignore[no-redef]
  1146. if self.type().dim() is None:
  1147. return _unimplemented("index_fill", "input rank not accessible")
  1148. self_dim = self.type().dim()
  1149. dim_value = _parse_arg(dim, "i")
  1150. unsqueezed_index = _unsqueeze_helper(
  1151. g, index, [i for i in range(self_dim) if i != dim_value]
  1152. )
  1153. expanded_index_shape = scatter(
  1154. g, g.op("Shape", self), 0, _unsqueeze_helper(g, dim, [0]), g.op("Shape", index)
  1155. )
  1156. expanded_index = expand(g, unsqueezed_index, expanded_index_shape, None)
  1157. return expanded_index_shape, expanded_index
  1158. # By default, when any value in the 'shape' input is equal to zero
  1159. # the corresponding dimension value is copied from the input tensor dynamically.
  1160. # allowzero=1 indicates that if any value in the 'shape' input is set to zero,
  1161. # the zero value is honored, similar to NumPy.
  1162. # allowzero=1 is only supported for opset version >= 14.
  1163. @_beartype.beartype
  1164. def _reshape_helper(g: jit_utils.GraphContext, input, shape, allowzero=0):
  1165. shape = _maybe_get_const(shape, "is")
  1166. if not _is_value(shape):
  1167. shape = g.op("Constant", value_t=torch.LongTensor(shape))
  1168. if g.opset <= 13:
  1169. if allowzero == 1:
  1170. _onnx_opset_unsupported(
  1171. "Reshape with allowzero=1", GLOBALS.export_onnx_opset_version, 14, input
  1172. )
  1173. return g.op("Reshape", input, shape)
  1174. else:
  1175. return g.op("Reshape", input, shape, allowzero_i=allowzero)
  1176. @_beartype.beartype
  1177. def _batchnorm_helper(
  1178. g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var
  1179. ):
  1180. from torch.onnx.symbolic_opset9 import _var_mean
  1181. batch_size = _get_tensor_dim_size(input, 0)
  1182. channel_size = _get_tensor_dim_size(input, 1)
  1183. if weight is None or _is_none(weight):
  1184. if channel_size is None:
  1185. raise errors.SymbolicValueError(
  1186. "Unsupported: ONNX export of batch_norm for unknown channel size.",
  1187. input,
  1188. )
  1189. weight_value = torch.tensor(
  1190. [1.0] * channel_size,
  1191. dtype=_type_utils.JitScalarType.from_value(input).dtype(),
  1192. )
  1193. weight = g.op("Constant", value_t=weight_value)
  1194. if bias is None or _is_none(bias):
  1195. if channel_size is None:
  1196. raise errors.SymbolicValueError(
  1197. "Unsupported: ONNX export of batch_norm for unknown channel size.",
  1198. input,
  1199. )
  1200. bias_value = torch.tensor(
  1201. [0.0] * channel_size,
  1202. dtype=_type_utils.JitScalarType.from_value(input).dtype(),
  1203. )
  1204. bias = g.op("Constant", value_t=bias_value)
  1205. # If track_running_stats is set to False batch statistics are instead used during evaluation time
  1206. if (
  1207. running_mean is None
  1208. or _is_none(running_mean)
  1209. or running_var is None
  1210. or _is_none(running_var)
  1211. ):
  1212. assert batch_size is not None and channel_size is not None
  1213. reshape_in = _reshape_helper(
  1214. g,
  1215. input,
  1216. g.op(
  1217. "Constant",
  1218. value_t=torch.tensor([batch_size, channel_size, -1], dtype=torch.int64),
  1219. ),
  1220. )
  1221. trans_in = g.op("Transpose", reshape_in, perm_i=[0, 2, 1])
  1222. running_var, running_mean = _var_mean(
  1223. g,
  1224. trans_in,
  1225. g.op("Constant", value_t=torch.tensor([0, 1], dtype=torch.int64)),
  1226. False,
  1227. False,
  1228. )
  1229. return weight, bias, running_mean, running_var
  1230. @_beartype.beartype
  1231. def _avgpool_helper(
  1232. tuple_fn: Callable[[Any], Sequence[int]],
  1233. padding: Union[int, Sequence[int]],
  1234. kernel_size,
  1235. stride,
  1236. divisor_override,
  1237. name,
  1238. ) -> Tuple[int, ...]:
  1239. if divisor_override and divisor_override.node().kind() != "prim::Constant":
  1240. _unimplemented(name, "divisor_override")
  1241. return tuple(tuple_fn(padding))
  1242. @_beartype.beartype
  1243. def check_training_mode(op_train_mode: int, op_name: str) -> None:
  1244. """Warns the user if the model's training mode and the export mode do not agree."""
  1245. if GLOBALS.training_mode == _C_onnx.TrainingMode.PRESERVE:
  1246. return
  1247. if op_train_mode:
  1248. op_mode_enum = _C_onnx.TrainingMode.TRAINING
  1249. else:
  1250. op_mode_enum = _C_onnx.TrainingMode.EVAL
  1251. if op_mode_enum == GLOBALS.training_mode:
  1252. # The modes agree. Do nothing
  1253. return
  1254. op_mode_text = f"train={bool(op_train_mode)}"
  1255. # Setting the model mode could result in op_mode != GLOBALS.training_mode
  1256. # if the model is a FuncModule. In this case we warn the user of
  1257. # the state and export depending on op_mode
  1258. # This is to support use-cases of fixing certain layer weights
  1259. # in training.
  1260. warnings.warn(
  1261. f"ONNX export mode is set to {GLOBALS.training_mode}, but operator '{op_name}' "
  1262. f"is set to {op_mode_text}. Exporting with {op_mode_text}."
  1263. )
  1264. @_beartype.beartype
  1265. def _flatten_helper(g: jit_utils.GraphContext, input, start_dim, end_dim, dim):
  1266. input_size = g.op("Shape", input)
  1267. slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
  1268. slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
  1269. if end_dim < dim - 1:
  1270. slice3 = _slice_helper(
  1271. g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim]
  1272. )
  1273. slices = [
  1274. slice1,
  1275. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  1276. slice3,
  1277. ]
  1278. final_shape = g.op("Concat", *slices, axis_i=0)
  1279. from torch.onnx.symbolic_opset9 import _reshape_from_tensor
  1280. return _reshape_from_tensor(g, input, final_shape)
  1281. @_beartype.beartype
  1282. def _is_split_static(split_size_or_sizes, _outputs):
  1283. if _outputs is None:
  1284. return False
  1285. if (
  1286. _is_value(split_size_or_sizes)
  1287. and split_size_or_sizes.node().kind() != "onnx::Constant"
  1288. ):
  1289. return False
  1290. return True
  1291. @_beartype.beartype
  1292. def _optional_input_placeholder_tensor(g):
  1293. n = g.op("prim::Constant")
  1294. n.setType(_C.OptionalType.ofTensor())
  1295. return n
  1296. @_beartype.beartype
  1297. def _handle_reduce_dim_none(g: jit_utils.GraphContext, self, op_name):
  1298. rank = _get_tensor_rank(self)
  1299. if rank is not None and any(
  1300. [_get_tensor_dim_size(self, i) == 0 for i in range(rank)]
  1301. ):
  1302. # If input tensor is empty, according to ONNX ReduceSum definition,
  1303. # set keepdims=1 so that the resulted tensor has the same rank as the input.
  1304. return g.op(op_name, self, keepdims_i=1)
  1305. return g.op(op_name, self, keepdims_i=0)
  1306. @_beartype.beartype
  1307. def dequantize_helper(
  1308. g: jit_utils.GraphContext,
  1309. qtensor: _C.Value,
  1310. qdtype: Optional[_C_onnx.TensorProtoDataType] = None,
  1311. ) -> Tuple[_C.Value, _C.Value, _C.Value, Optional[_C.Value]]:
  1312. """Appends to graph `g` ONNX nodes that dequantizes `qtensor` into `tensor`.
  1313. Args:
  1314. g: Graph, the ONNX IR graph that is under construction.
  1315. qtensor: torch._C.Value, either a tuple of (quantized_tensor, scale, zero_point)
  1316. for per tensor quantization, or
  1317. (quantized_tensor, scale, zero_point, axis) for per channel quantization,
  1318. representing the quantized tensor.
  1319. qdtype: torch.onnx.TensorProtoDataType default None, if not None, represents the
  1320. data type of quantized tensor. It must be either
  1321. torch.onnx.TensorProtoDataType.UINT8 or torch.onnx.TensorProtoDataType.INT8.
  1322. """
  1323. unpacked_qtensors = _unpack_quantized_tensor(qtensor)
  1324. tensor, scale, zero_point = unpacked_qtensors[:3]
  1325. axis = unpacked_qtensors[3] if len(unpacked_qtensors) >= 4 else None
  1326. axis_i = _get_const(axis, "i", "axis")
  1327. input_qdtype = _type_utils.JitScalarType.from_value(tensor)
  1328. if qdtype is None:
  1329. if input_qdtype is not None:
  1330. qdtype = input_qdtype.onnx_type()
  1331. else:
  1332. qdtype = _C_onnx.TensorProtoDataType.UINT8
  1333. value = g.op("Cast", tensor, to_i=qdtype)
  1334. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  1335. zero_point = g.op("Cast", zero_point, to_i=qdtype)
  1336. if axis_i is not None and GLOBALS.export_onnx_opset_version < 13:
  1337. _onnx_opset_unsupported_detailed(
  1338. "DequantizeLinear",
  1339. GLOBALS.export_onnx_opset_version,
  1340. 13,
  1341. "Attribute axis is not supported.",
  1342. qtensor,
  1343. )
  1344. return (
  1345. g.op("DequantizeLinear", value, scale, zero_point, axis_i=axis_i),
  1346. scale,
  1347. zero_point,
  1348. axis,
  1349. )
  1350. @_beartype.beartype
  1351. def quantize_helper(
  1352. g: jit_utils.GraphContext,
  1353. tensor: _C.Value,
  1354. scale: _C.Value,
  1355. zero_point: _C.Value,
  1356. axis: Optional[_C.Value] = None,
  1357. ) -> _C.Value:
  1358. """Appends to graph `g` ONNX nodes that quantizes `tensor` based on `scale`, `zero_point` and `axis`.
  1359. Args:
  1360. g: Graph, the ONNX IR graph that is under construction.
  1361. tensor: torch._C.Value, representing the tensor to be quantized.
  1362. scale: torch._C.Value, quantized scale.
  1363. zero_point: torch._C.Value, quantized zero point.
  1364. axis: Optional[torch._C.Value] default None, if None, represents per tensor quantization.
  1365. Otherwise, represents per channel quantization, along given axis.
  1366. Returns:
  1367. A TupleConstruct storing information of the quantized tensor.
  1368. """
  1369. if (
  1370. axis is not None
  1371. and not _is_none(axis)
  1372. and GLOBALS.export_onnx_opset_version < 13
  1373. ):
  1374. _onnx_opset_unsupported_detailed(
  1375. "QuantizeLinear",
  1376. GLOBALS.export_onnx_opset_version,
  1377. 13,
  1378. "Attribute axis is not supported.",
  1379. tensor,
  1380. )
  1381. assert scale is not None
  1382. if (
  1383. _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
  1384. != _type_utils.JitScalarType.FLOAT
  1385. ):
  1386. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  1387. assert zero_point is not None
  1388. if _type_utils.JitScalarType.from_value(
  1389. zero_point, _type_utils.JitScalarType.UNDEFINED
  1390. ) not in {
  1391. _type_utils.JitScalarType.UINT8,
  1392. _type_utils.JitScalarType.INT8,
  1393. }:
  1394. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  1395. output = g.op(
  1396. "QuantizeLinear",
  1397. tensor,
  1398. scale,
  1399. zero_point,
  1400. axis_i=_get_const(axis, "i", "axis"),
  1401. )
  1402. args = [output, scale, zero_point]
  1403. if axis is not None and not _is_none(axis):
  1404. args.append(axis)
  1405. return g.op("prim::TupleConstruct", *args)
  1406. @_beartype.beartype
  1407. def requantize_bias_helper(
  1408. g: jit_utils.GraphContext, bias, input_scale, weight_scale, axis=None
  1409. ):
  1410. """In PyTorch, bias is float and is quantized to int32 implicitly inside the quantized ATen op kernel.
  1411. In ONNX we need to make the quantization explicit because operators expect all of their inputs to be quantized.
  1412. Since int32 is not a supported output type by ONNX operator `QuantizeLinear`, quantization is exported using
  1413. regular operators.
  1414. """
  1415. bias_scale = g.op("Mul", weight_scale, input_scale)
  1416. bias_scale_shape = g.op("Shape", bias_scale)
  1417. bias_zero_point = g.op(
  1418. "ConstantOfShape", bias_scale_shape, value_t=torch.tensor([0], dtype=torch.int)
  1419. )
  1420. q_bias = g.op(
  1421. "Cast", g.op("Div", bias, bias_scale), to_i=_C_onnx.TensorProtoDataType.INT32
  1422. )
  1423. axis_args = []
  1424. if axis is not None and not _is_none(axis):
  1425. axis_args.append(axis)
  1426. return g.op("prim::TupleConstruct", q_bias, bias_scale, bias_zero_point, *axis_args)
  1427. @_beartype.beartype
  1428. def args_have_same_dtype(args):
  1429. assert args
  1430. base_dtype = _type_utils.JitScalarType.from_value(args[0])
  1431. has_same_dtype = all(
  1432. _type_utils.JitScalarType.from_value(elem) == base_dtype for elem in args
  1433. )
  1434. return has_same_dtype
  1435. # TODO(justinchuby): Delete these setters, users should set the vars directly.
  1436. @_deprecation.deprecated(
  1437. "1.13",
  1438. "2.0",
  1439. "remove its usage and avoid setting internal variables directly",
  1440. )
  1441. def _set_opset_version(opset_version: int):
  1442. GLOBALS.export_onnx_opset_version = opset_version
  1443. @_deprecation.deprecated(
  1444. "1.13",
  1445. "2.0",
  1446. "remove its usage and avoid setting internal variables directly",
  1447. )
  1448. def _set_operator_export_type(operator_export_type):
  1449. GLOBALS.operator_export_type = operator_export_type
  1450. # This function is for debug use only.
  1451. # onnx_shape_inference = True by default.
  1452. @_deprecation.deprecated(
  1453. "1.13",
  1454. "2.0",
  1455. "remove its usage and avoid setting internal variables directly",
  1456. )
  1457. def _set_onnx_shape_inference(onnx_shape_inference: bool):
  1458. GLOBALS.onnx_shape_inference = onnx_shape_inference
  1459. # Deprecated. Internally use _type_utils.ScalarType
  1460. # TODO: remove these once we support Type's in the JIT IR and we can once again
  1461. # use the unified toType operator
  1462. cast_pytorch_to_onnx = {
  1463. "Byte": _C_onnx.TensorProtoDataType.UINT8,
  1464. "Char": _C_onnx.TensorProtoDataType.INT8,
  1465. "Double": _C_onnx.TensorProtoDataType.DOUBLE,
  1466. "Float": _C_onnx.TensorProtoDataType.FLOAT,
  1467. "Half": _C_onnx.TensorProtoDataType.FLOAT16,
  1468. "Int": _C_onnx.TensorProtoDataType.INT32,
  1469. "Long": _C_onnx.TensorProtoDataType.INT64,
  1470. "Short": _C_onnx.TensorProtoDataType.INT16,
  1471. "Bool": _C_onnx.TensorProtoDataType.BOOL,
  1472. "ComplexFloat": _C_onnx.TensorProtoDataType.COMPLEX64,
  1473. "ComplexDouble": _C_onnx.TensorProtoDataType.COMPLEX128,
  1474. "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16,
  1475. "Undefined": _C_onnx.TensorProtoDataType.UNDEFINED,
  1476. }
  1477. # Deprecated. Internally use _type_utils.ScalarType
  1478. scalar_name_to_pytorch = {
  1479. "uint8_t": "Byte",
  1480. "int8_t": "Char",
  1481. "double": "Double",
  1482. "float": "Float",
  1483. "half": "Half",
  1484. "int": "Int",
  1485. "int64_t": "Long",
  1486. "int16_t": "Short",
  1487. "bool": "Bool",
  1488. "complex64": "ComplexFloat",
  1489. "complex128": "ComplexDouble",
  1490. "qint8": "QInt8",
  1491. "quint8": "QUInt8",
  1492. "qint32": "QInt32",
  1493. "bfloat16": "BFloat16",
  1494. }
  1495. # Deprecated. Internally use _type_utils.ScalarType
  1496. # This indicates each scalar type's corresponding
  1497. # torch type. Related source:
  1498. # https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h
  1499. scalar_type_to_pytorch_type = [
  1500. torch.uint8, # 0
  1501. torch.int8, # 1
  1502. torch.short, # 2
  1503. torch.int, # 3
  1504. torch.int64, # 4
  1505. torch.half, # 5
  1506. torch.float, # 6
  1507. torch.double, # 7
  1508. torch.complex32, # 8
  1509. torch.complex64, # 9
  1510. torch.complex128, # 10
  1511. torch.bool, # 11
  1512. torch.qint8, # 12
  1513. torch.quint8, # 13
  1514. torch.qint32, # 14
  1515. torch.bfloat16, # 15
  1516. ]
  1517. # Deprecated. Internally use _type_utils.ScalarType
  1518. # source of truth is
  1519. # https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp
  1520. pytorch_name_to_type = {
  1521. "Byte": torch.uint8,
  1522. "Char": torch.int8,
  1523. "Double": torch.double,
  1524. "Float": torch.float,
  1525. "Half": torch.half,
  1526. "Int": torch.int,
  1527. "Long": torch.int64,
  1528. "Short": torch.short,
  1529. "Bool": torch.bool,
  1530. "ComplexFloat": torch.complex64,
  1531. "ComplexDouble": torch.complex128,
  1532. "QInt8": torch.qint8,
  1533. "QUInt8": torch.quint8,
  1534. "QInt32": torch.qint32,
  1535. "BFloat16": torch.bfloat16,
  1536. }
  1537. # Deprecated. Internally use _type_utils.ScalarType
  1538. scalar_type_to_onnx = [
  1539. cast_pytorch_to_onnx["Byte"], # 0
  1540. cast_pytorch_to_onnx["Char"], # 1
  1541. cast_pytorch_to_onnx["Short"], # 2
  1542. cast_pytorch_to_onnx["Int"], # 3
  1543. cast_pytorch_to_onnx["Long"], # 4
  1544. cast_pytorch_to_onnx["Half"], # 5
  1545. cast_pytorch_to_onnx["Float"], # 6
  1546. cast_pytorch_to_onnx["Double"], # 7
  1547. cast_pytorch_to_onnx["Undefined"], # 8
  1548. cast_pytorch_to_onnx["ComplexFloat"], # 9
  1549. cast_pytorch_to_onnx["ComplexDouble"], # 10
  1550. cast_pytorch_to_onnx["Bool"], # 11
  1551. cast_pytorch_to_onnx["Char"], # 12
  1552. cast_pytorch_to_onnx["Byte"], # 13
  1553. cast_pytorch_to_onnx["Int"], # 14
  1554. cast_pytorch_to_onnx["BFloat16"], # 15
  1555. ]
  1556. # Global set to store the list of quantized operators in the network.
  1557. # This is currently only used in the conversion of quantized ops from PT -> C2 via ONNX.
  1558. _quantized_ops: Set[int] = set()