symbolic_opset10.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888
  1. import functools
  2. import sys
  3. import warnings
  4. from typing import Callable
  5. import torch
  6. import torch._C._onnx as _C_onnx
  7. import torch.onnx
  8. from torch import _C
  9. # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
  10. from torch.onnx import (
  11. _constants,
  12. _type_utils,
  13. errors,
  14. symbolic_helper,
  15. symbolic_opset9 as opset9,
  16. )
  17. from torch.onnx._globals import GLOBALS
  18. from torch.onnx._internal import _beartype, jit_utils, registration
  19. # EDITING THIS FILE? READ THIS FIRST!
  20. # see Note [Edit Symbolic Files] in README.md
  21. # This file exports ONNX ops for opset 10
  22. # Opset 10 is supported by ONNX release 1.5.0
  23. # release on 04/24/19
  24. __all__ = [
  25. "dequantize",
  26. "div",
  27. "embedding_bag",
  28. "fake_quantize_per_tensor_affine",
  29. "flip",
  30. "fmod",
  31. "isfinite",
  32. "isinf",
  33. "nan_to_num",
  34. "quantize_per_tensor",
  35. "quantized_add_relu",
  36. "quantized_add",
  37. "quantized_cat",
  38. "quantized_conv1d_relu",
  39. "quantized_conv2d_relu",
  40. "quantized_conv2d",
  41. "quantized_group_norm",
  42. "quantized_hardswish",
  43. "quantized_instance_norm",
  44. "quantized_layer_norm",
  45. "quantized_leaky_relu",
  46. "quantized_linear",
  47. "quantized_mul",
  48. "quantized_sigmoid",
  49. "slice",
  50. "sort",
  51. "topk",
  52. ]
  53. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
  54. def _apply_params(*args, **kwargs):
  55. """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
  56. def _apply(fn):
  57. return fn(*args, **kwargs)
  58. return _apply
  59. @_onnx_symbolic("aten::div")
  60. @_beartype.beartype
  61. def div(g: jit_utils.GraphContext, self, other, *args):
  62. if len(args) == 0:
  63. return opset9.true_divide(g, self, other)
  64. else:
  65. return _div_rounding_mode(g, self, other, *args)
  66. @symbolic_helper.parse_args("v", "v", "s")
  67. @_beartype.beartype
  68. def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
  69. if rounding_mode == "floor":
  70. return _floor_divide(g, self, other)
  71. else:
  72. return opset9._div_rounding_mode(g, self, other, rounding_mode)
  73. @_onnx_symbolic("aten::_floor_divide")
  74. @_beartype.beartype
  75. def _floor_divide(g: jit_utils.GraphContext, self, other):
  76. if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
  77. out = opset9.true_divide(g, self, other)
  78. return g.op("Floor", out)
  79. else:
  80. # Integer division does trunction rounding
  81. div = g.op("Div", self, other)
  82. # Division is negative if: self < 0 != other < 0
  83. zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
  84. negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
  85. # For negative numbers with self % other != 0, subtract 1 to round down instead of up
  86. mod = g.op("Mod", self, other, fmod_i=0)
  87. fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
  88. one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
  89. fixup = g.op("Sub", div, one)
  90. return g.op("Where", fixup_mask, fixup, div)
  91. @_onnx_symbolic("aten::sort")
  92. @symbolic_helper.parse_args("v", "i", "i", "none")
  93. @_beartype.beartype
  94. def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
  95. return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
  96. @_onnx_symbolic("aten::topk")
  97. @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
  98. @_beartype.beartype
  99. def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
  100. return symbolic_helper._topk_helper(
  101. g, self, k, dim, largest=largest, sorted=sorted, out=out
  102. )
  103. @_onnx_symbolic(
  104. "aten::max_pool1d",
  105. decorate=[
  106. _apply_params(
  107. "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
  108. )
  109. ],
  110. )
  111. @_onnx_symbolic(
  112. "aten::max_pool2d",
  113. decorate=[
  114. _apply_params(
  115. "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
  116. )
  117. ],
  118. )
  119. @_onnx_symbolic(
  120. "aten::max_pool3d",
  121. decorate=[
  122. _apply_params(
  123. "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
  124. )
  125. ],
  126. )
  127. @_onnx_symbolic(
  128. "aten::max_pool1d_with_indices",
  129. decorate=[
  130. _apply_params(
  131. "max_pool1d_with_indices",
  132. torch.nn.modules.utils._single,
  133. 1,
  134. return_indices=True,
  135. )
  136. ],
  137. )
  138. @_onnx_symbolic(
  139. "aten::max_pool2d_with_indices",
  140. decorate=[
  141. _apply_params(
  142. "max_pool2d_with_indices",
  143. torch.nn.modules.utils._pair,
  144. 2,
  145. return_indices=True,
  146. )
  147. ],
  148. )
  149. @_onnx_symbolic(
  150. "aten::max_pool3d_with_indices",
  151. decorate=[
  152. _apply_params(
  153. "max_pool3d_with_indices",
  154. torch.nn.modules.utils._triple,
  155. 3,
  156. return_indices=True,
  157. )
  158. ],
  159. )
  160. @_beartype.beartype
  161. def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
  162. @symbolic_helper.quantized_args(True, False, False, False, False, False)
  163. @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
  164. def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
  165. if not stride:
  166. stride = kernel_size
  167. kwargs = {
  168. "kernel_shape_i": tuple_fn(kernel_size),
  169. "pads_i": tuple_fn(padding) * 2,
  170. "strides_i": tuple_fn(stride),
  171. "ceil_mode_i": ceil_mode,
  172. }
  173. if set(tuple_fn(dilation)) != {1}:
  174. kwargs["dilations_i"] = tuple_fn(dilation)
  175. # easy but hacky way to get flattened indices values
  176. # to be used to convert the indices values to non-flattened.
  177. # In ONNX the indices are computed as a flatten 1-D tensor,
  178. # so the values in indices are in [0, N x C x D1 x ... x Dn).
  179. # To convert the indices to the same format used by Pytorch,
  180. # we first execute a maxpool with a kernel and stride of 1 on the same input.
  181. # This will result in a tensor of indices in which each index will have it's own value.
  182. # Using this tensor as a reference, we extract the first index of each axis and subtract
  183. # it from each index of this axis in the indices to convert.
  184. # This step will result in a tensor were each dimension has values of indices within
  185. # the dimension it is in.
  186. # For more information :
  187. # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
  188. if return_indices:
  189. r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
  190. _, flattened_indices = g.op(
  191. "MaxPool",
  192. input,
  193. outputs=2,
  194. kernel_shape_i=[1 for _ in range(ndims)],
  195. strides_i=[1 for _ in range(ndims)],
  196. )
  197. # convert indices to have non-flattened indices values
  198. s = symbolic_helper._slice_helper(
  199. g,
  200. flattened_indices,
  201. axes=[2 + i for i in range(ndims)],
  202. starts=tuple_fn(0),
  203. ends=tuple_fn(1),
  204. )
  205. indices = opset9.sub(g, indices, s)
  206. return r, indices
  207. else:
  208. r = g.op("MaxPool", input, outputs=1, **kwargs)
  209. return r
  210. return symbolic_fn
  211. @_onnx_symbolic(
  212. "aten::avg_pool1d",
  213. decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)],
  214. )
  215. @_onnx_symbolic(
  216. "aten::avg_pool2d",
  217. decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)],
  218. )
  219. @_onnx_symbolic(
  220. "aten::avg_pool3d",
  221. decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)],
  222. )
  223. @_beartype.beartype
  224. def _avg_pool(name, tuple_fn):
  225. # Although onnx::AvgPool provides count_include_pad and ceil_mode,
  226. # The corner case of Average Pooling with ceil_mode on
  227. # PyTorch allows sliding window go off bound, which leads to
  228. # this accommodation.
  229. # More detail on https://github.com/pytorch/pytorch/issues/57178
  230. return opset9._avg_pool(name, tuple_fn)
  231. @_onnx_symbolic(
  232. "aten::upsample_nearest1d",
  233. decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
  234. )
  235. @_onnx_symbolic(
  236. "aten::upsample_nearest2d",
  237. decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
  238. )
  239. @_onnx_symbolic(
  240. "aten::upsample_nearest3d",
  241. decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
  242. )
  243. @_onnx_symbolic(
  244. "aten::upsample_linear1d",
  245. decorate=[_apply_params("upsample_linear1d", 3, "linear")],
  246. )
  247. @_onnx_symbolic(
  248. "aten::upsample_bilinear2d",
  249. decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
  250. )
  251. @_onnx_symbolic(
  252. "aten::upsample_trilinear3d",
  253. decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
  254. )
  255. @_beartype.beartype
  256. def _interpolate(name, dim, interpolate_mode):
  257. @symbolic_helper.quantized_args(True, False, False)
  258. @_beartype.beartype
  259. def symbolic_fn(g, input, output_size, *args):
  260. scales, align_corners = symbolic_helper._get_interpolate_attributes(
  261. g, interpolate_mode, args
  262. )
  263. symbolic_helper._interpolate_warning(interpolate_mode)
  264. align_corners = symbolic_helper._maybe_get_scalar(align_corners)
  265. if align_corners:
  266. return symbolic_helper._unimplemented(name, "align_corners == True", input)
  267. if scales is None:
  268. scales = symbolic_helper._interpolate_size_to_scales(
  269. g, input, output_size, dim
  270. )
  271. return g.op("Resize", input, scales, mode_s=interpolate_mode)
  272. return symbolic_fn
  273. @_onnx_symbolic("aten::__interpolate")
  274. @_beartype.beartype
  275. def __interpolate(
  276. g: jit_utils.GraphContext,
  277. input,
  278. size,
  279. scale_factor,
  280. mode,
  281. align_corners,
  282. recompute_scale_factor,
  283. antialias,
  284. ):
  285. scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
  286. g, input, size, scale_factor, mode, align_corners
  287. )
  288. return g.op("Resize", input, scales, mode_s=mode)
  289. @_beartype.beartype
  290. def _slice(
  291. g: jit_utils.GraphContext,
  292. input,
  293. axes,
  294. starts,
  295. ends,
  296. steps=None,
  297. dynamic_slice=False,
  298. ):
  299. if dynamic_slice:
  300. starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
  301. ends = symbolic_helper._unsqueeze_helper(g, ends, [0])
  302. if isinstance(axes, int):
  303. axes = g.op("Constant", value_t=torch.tensor(axes))
  304. axes = symbolic_helper._unsqueeze_helper(g, axes, [0])
  305. else:
  306. assert len(starts) == len(ends)
  307. assert len(starts) == len(axes)
  308. assert steps is None or len(starts) == len(steps)
  309. if (
  310. len(starts) == 1
  311. and starts[0] == 0
  312. and ends[0] == _constants.INT64_MAX
  313. and (steps is None or (len(steps) == 1 and steps[0] == 1))
  314. ):
  315. return input
  316. if ends[0] > _constants.INT64_MAX:
  317. ends[0] = _constants.INT64_MAX
  318. axes = g.op("Constant", value_t=torch.tensor(axes))
  319. starts = g.op("Constant", value_t=torch.tensor(starts))
  320. ends = g.op("Constant", value_t=torch.tensor(ends))
  321. if steps is None:
  322. return g.op("Slice", input, starts, ends, axes)
  323. steps = g.op("Constant", value_t=torch.tensor(steps))
  324. return g.op("Slice", input, starts, ends, axes, steps)
  325. @_onnx_symbolic("aten::slice")
  326. @_beartype.beartype
  327. def slice(g: jit_utils.GraphContext, self, *args):
  328. if len(args) == 4:
  329. # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
  330. dim, start, end, step = args
  331. elif len(args) == 3:
  332. # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
  333. start, end, step = args
  334. dim = 0
  335. else:
  336. raise errors.SymbolicValueError("Unknown aten::slice signature", self)
  337. is_start_none = start.node().kind() == "prim::Constant" and isinstance(
  338. start.type(), _C.NoneType
  339. )
  340. is_end_none = end.node().kind() == "prim::Constant" and isinstance(
  341. end.type(), _C.NoneType
  342. )
  343. is_start_onnx_const = start.node().kind() == "onnx::Constant"
  344. is_end_onnx_const = end.node().kind() == "onnx::Constant"
  345. step = symbolic_helper._parse_arg(step, "i")
  346. if (
  347. (not is_start_none and not is_start_onnx_const)
  348. or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const)
  349. or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant")
  350. ):
  351. dynamic_slice = True
  352. if is_start_none:
  353. start = g.op("Constant", value_t=torch.tensor(0))
  354. if is_end_none:
  355. end = g.op("Constant", value_t=torch.tensor(_constants.INT64_MAX))
  356. else:
  357. start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")]
  358. end = [
  359. _constants.INT64_MAX
  360. if is_end_none
  361. else symbolic_helper._parse_arg(end, "i")
  362. ]
  363. dim = [symbolic_helper._parse_arg(dim, "i")]
  364. dynamic_slice = False
  365. return symbolic_helper._slice_helper(
  366. g,
  367. self,
  368. axes=dim,
  369. starts=start,
  370. ends=end,
  371. steps=[step],
  372. dynamic_slice=dynamic_slice,
  373. )
  374. @_onnx_symbolic("aten::flip")
  375. @symbolic_helper.parse_args("v", "is")
  376. @_beartype.beartype
  377. def flip(g: jit_utils.GraphContext, input, dims):
  378. return symbolic_helper._slice_helper(
  379. g,
  380. input,
  381. axes=dims,
  382. starts=[-1] * len(dims),
  383. ends=[-_constants.INT64_MAX] * len(dims),
  384. steps=[-1] * len(dims),
  385. )
  386. @_onnx_symbolic("aten::fmod")
  387. @_beartype.beartype
  388. def fmod(g: jit_utils.GraphContext, input, other):
  389. return g.op("Mod", input, other, fmod_i=1)
  390. @_onnx_symbolic("aten::embedding_bag")
  391. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
  392. @_beartype.beartype
  393. def embedding_bag(
  394. g: jit_utils.GraphContext,
  395. embedding_matrix,
  396. indices,
  397. offsets,
  398. scale_grad_by_freq,
  399. mode,
  400. sparse,
  401. per_sample_weights,
  402. include_last_offset,
  403. padding_idx,
  404. ):
  405. if scale_grad_by_freq and GLOBALS.export_training:
  406. return symbolic_helper._onnx_unsupported(
  407. "embedding_bag with scale_grad_by_freq for training mode"
  408. )
  409. if padding_idx is not None and padding_idx >= 0:
  410. raise RuntimeError("embedding_bag with padding_idx")
  411. warnings.warn(
  412. "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
  413. "Please use opset 11 or higher to export model for dynamic input shape.'"
  414. )
  415. offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
  416. if offsets_dim_0 is not None:
  417. if include_last_offset:
  418. offset_len = offsets_dim_0 - 1
  419. offsets_extended = offsets
  420. else:
  421. offset_len = offsets_dim_0
  422. offsets_extended = [
  423. offsets,
  424. g.op("Constant", value_t=torch.tensor([sys.maxsize])),
  425. ]
  426. offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
  427. list_ = []
  428. for i in range(offset_len):
  429. start_ = symbolic_helper._unsqueeze_helper(
  430. g,
  431. opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
  432. [0],
  433. )
  434. end_ = symbolic_helper._unsqueeze_helper(
  435. g,
  436. opset9.select(
  437. g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
  438. ),
  439. [0],
  440. )
  441. axes_ = g.op("Constant", value_t=torch.tensor([0]))
  442. indices_row = g.op("Slice", indices, start_, end_, axes_)
  443. embeddings = g.op("Gather", embedding_matrix, indices_row)
  444. if not symbolic_helper._is_none(per_sample_weights):
  445. per_sample_weights_row = g.op(
  446. "Slice", per_sample_weights, start_, end_, axes_
  447. )
  448. per_sample_weights_row = symbolic_helper._unsqueeze_helper(
  449. g, per_sample_weights_row, [1]
  450. )
  451. embeddings = g.op("Mul", embeddings, per_sample_weights_row)
  452. if mode == 0:
  453. embeddings = symbolic_helper._reducesum_helper(
  454. g, embeddings, axes_i=[0], keepdims_i=0
  455. )
  456. elif mode == 1:
  457. embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
  458. else:
  459. embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
  460. embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
  461. list_.append(embeddings)
  462. output = g.op("Concat", *list_, axis_i=0)
  463. # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
  464. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
  465. return output, None, None, None
  466. else:
  467. return symbolic_helper._onnx_unsupported(
  468. "embedding_bag with unknown shape of offsets for opset 10 is not supported. "
  469. "please use opset 11 or higher."
  470. )
  471. @_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
  472. @symbolic_helper.parse_args("v", "v", "v", "i", "i")
  473. @_beartype.beartype
  474. def fake_quantize_per_tensor_affine(
  475. g: jit_utils.GraphContext,
  476. inputs,
  477. scale,
  478. zero_point,
  479. quant_min=-128,
  480. quant_max=127,
  481. ):
  482. # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
  483. # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
  484. if (quant_min, quant_max) == (0, 127):
  485. symbolic_helper._onnx_opset_unsupported_detailed(
  486. "fake_quantize_per_tensor_affine",
  487. 10,
  488. 13,
  489. "Quantize range (0, 127) not supported, requires opset 13 Clip",
  490. inputs,
  491. )
  492. if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
  493. raise errors.SymbolicValueError(
  494. f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
  495. f"Got ({quant_min}, {quant_max})",
  496. inputs,
  497. )
  498. scale = symbolic_helper._maybe_get_scalar(scale)
  499. if scale is None:
  500. symbolic_helper._onnx_opset_unsupported_detailed(
  501. "fake_quantize_per_tensor_affine",
  502. 10,
  503. 13,
  504. "Non-constant scale not supported",
  505. inputs,
  506. )
  507. scale = scale.float().data # Avoid exporter generating double type
  508. if quant_min == 0:
  509. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  510. else:
  511. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
  512. return g.op(
  513. "DequantizeLinear",
  514. g.op("QuantizeLinear", inputs, scale, zero_point),
  515. scale,
  516. zero_point,
  517. )
  518. @_onnx_symbolic("aten::isinf")
  519. @_beartype.beartype
  520. def isinf(g: jit_utils.GraphContext, input):
  521. return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
  522. @_onnx_symbolic("aten::isfinite")
  523. @_beartype.beartype
  524. def isfinite(g: jit_utils.GraphContext, input):
  525. inf_node = isinf(g, input)
  526. nan_node = opset9.isnan(g, input)
  527. return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
  528. @_onnx_symbolic("aten::quantize_per_tensor")
  529. @_beartype.beartype
  530. def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
  531. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  532. # TODO(justinchuby): Extract all the cast ops into a helper function.
  533. zero_point = g.op(
  534. "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
  535. )
  536. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  537. return symbolic_helper.quantize_helper(g, input, scale, zero_point)
  538. @_onnx_symbolic("aten::dequantize")
  539. @_beartype.beartype
  540. def dequantize(g: jit_utils.GraphContext, input):
  541. return symbolic_helper.dequantize_helper(g, input)[0]
  542. @_onnx_symbolic("aten::nan_to_num")
  543. @symbolic_helper.parse_args("v", "f", "f", "f")
  544. @_beartype.beartype
  545. def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
  546. # Cannot create a int type tensor with inf/nan values, so we simply
  547. # return the original tensor
  548. if not symbolic_helper._is_fp(input):
  549. return input
  550. input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
  551. if nan is None:
  552. nan = 0.0
  553. nan_cond = opset9.isnan(g, input)
  554. nan_result = g.op(
  555. "Where",
  556. nan_cond,
  557. g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
  558. input,
  559. )
  560. # For None values of posinf, neginf we use the greatest/lowest finite
  561. # value representable by input’s dtype.
  562. finfo = torch.finfo(input_dtype)
  563. if posinf is None:
  564. posinf = finfo.max
  565. posinf_cond = opset9.logical_and(
  566. g,
  567. isinf(g, nan_result),
  568. opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
  569. )
  570. nan_posinf_result = g.op(
  571. "Where",
  572. posinf_cond,
  573. g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
  574. nan_result,
  575. )
  576. if neginf is None:
  577. neginf = finfo.min
  578. neginf_cond = opset9.logical_and(
  579. g,
  580. isinf(g, nan_posinf_result),
  581. opset9.lt(
  582. g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
  583. ),
  584. )
  585. return g.op(
  586. "Where",
  587. neginf_cond,
  588. g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
  589. nan_posinf_result,
  590. )
  591. # Quantized symbolics ---------------------------------------------------------
  592. # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
  593. # Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
  594. # introduced in opset version 10.
  595. @_onnx_symbolic("quantized::linear")
  596. @_beartype.beartype
  597. def quantized_linear(
  598. g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
  599. ):
  600. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  601. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  602. q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
  603. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  604. output = opset9.linear(g, input, weight, bias)
  605. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  606. @_onnx_symbolic("quantized::add")
  607. @_beartype.beartype
  608. def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
  609. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  610. y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
  611. output = opset9.add(g, x, y)
  612. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  613. @_onnx_symbolic("quantized::add_relu")
  614. @_beartype.beartype
  615. def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
  616. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  617. y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
  618. output = opset9.add(g, x, y)
  619. output = opset9.relu(g, output)
  620. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  621. @_onnx_symbolic("quantized::mul")
  622. @_beartype.beartype
  623. def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
  624. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  625. y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
  626. output = opset9.mul(g, x, y)
  627. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  628. @_onnx_symbolic("quantized::hardswish")
  629. @_beartype.beartype
  630. def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
  631. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  632. output = opset9.hardswish(g, x)
  633. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  634. @_onnx_symbolic("quantized::sigmoid")
  635. @_beartype.beartype
  636. def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
  637. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  638. output = opset9.sigmoid(g, x)
  639. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  640. @_onnx_symbolic("quantized::leaky_relu")
  641. @_beartype.beartype
  642. def quantized_leaky_relu(
  643. g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
  644. ):
  645. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  646. output = opset9.leaky_relu(g, x, negative_slope, inplace)
  647. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  648. @_onnx_symbolic("quantized::layer_norm")
  649. @_beartype.beartype
  650. def quantized_layer_norm(
  651. g: jit_utils.GraphContext,
  652. x,
  653. normalized_shape,
  654. weight,
  655. bias,
  656. eps,
  657. op_scale,
  658. op_zero_point,
  659. ):
  660. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  661. output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
  662. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  663. @_onnx_symbolic("quantized::group_norm")
  664. @_beartype.beartype
  665. def quantized_group_norm(
  666. g: jit_utils.GraphContext,
  667. x,
  668. num_groups,
  669. weight,
  670. bias,
  671. eps,
  672. op_scale,
  673. op_zero_point,
  674. ):
  675. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  676. output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
  677. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  678. @_onnx_symbolic("quantized::instance_norm")
  679. @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
  680. @_beartype.beartype
  681. def quantized_instance_norm(
  682. g: jit_utils.GraphContext,
  683. q_input,
  684. weight,
  685. bias,
  686. eps,
  687. op_scale,
  688. op_zero_point,
  689. ):
  690. input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  691. output = opset9.instance_norm(
  692. g, input, weight, bias, None, None, False, 0.0, eps, False
  693. )
  694. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  695. @_onnx_symbolic("quantized::conv1d_relu")
  696. @_beartype.beartype
  697. def quantized_conv1d_relu(
  698. g: jit_utils.GraphContext,
  699. q_input,
  700. q_weight,
  701. bias,
  702. stride,
  703. padding,
  704. dilation,
  705. groups,
  706. op_scale,
  707. op_zero_point,
  708. ):
  709. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  710. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  711. q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
  712. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  713. output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
  714. output = opset9.relu(g, output)
  715. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  716. @_onnx_symbolic("quantized::conv2d_relu")
  717. @_beartype.beartype
  718. def quantized_conv2d_relu(
  719. g: jit_utils.GraphContext,
  720. q_input,
  721. q_weight,
  722. bias,
  723. stride,
  724. padding,
  725. dilation,
  726. groups,
  727. op_scale,
  728. op_zero_point,
  729. ):
  730. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  731. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  732. q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
  733. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  734. output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
  735. output = opset9.relu(g, output)
  736. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  737. @_onnx_symbolic("quantized::conv2d")
  738. @_beartype.beartype
  739. def quantized_conv2d(
  740. g: jit_utils.GraphContext,
  741. q_input,
  742. q_weight,
  743. bias,
  744. stride,
  745. padding,
  746. dilation,
  747. groups,
  748. op_scale,
  749. op_zero_point,
  750. ):
  751. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  752. weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
  753. q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
  754. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  755. output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
  756. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  757. @_onnx_symbolic("quantized::cat")
  758. @symbolic_helper.parse_args("v", "i", "v", "v")
  759. @_beartype.beartype
  760. def quantized_cat(
  761. g: jit_utils.GraphContext,
  762. q_inputs: _C.Value,
  763. dim: int,
  764. op_scale: _C.Value,
  765. op_zero_point: _C.Value,
  766. ) -> _C.Value:
  767. unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
  768. dequantized = [
  769. symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
  770. ]
  771. concatenated = g.op("Concat", *dequantized, axis_i=dim)
  772. return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)