symbolic_opset13.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. # EDITING THIS FILE? READ THIS FIRST!
  2. # see Note [Edit Symbolic Files] in README.md
  3. # This file exports ONNX ops for opset 13
  4. import functools
  5. import torch
  6. import torch._C._onnx as _C_onnx
  7. from torch.onnx import (
  8. _type_utils,
  9. errors,
  10. symbolic_helper,
  11. symbolic_opset11 as opset11,
  12. symbolic_opset9 as opset9,
  13. utils,
  14. )
  15. from torch.onnx._internal import _beartype, jit_utils, registration
  16. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
  17. def _apply_params(*args, **kwargs):
  18. """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
  19. def _apply(fn):
  20. return fn(*args, **kwargs)
  21. return _apply
  22. @_onnx_symbolic("aten::softmax")
  23. @symbolic_helper.parse_args("v", "i", "none")
  24. @_beartype.beartype
  25. def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
  26. softmax = g.op("Softmax", input, axis_i=dim)
  27. if dtype and dtype.node().kind() != "prim::Constant":
  28. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  29. softmax = g.op(
  30. "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
  31. )
  32. return softmax
  33. @_onnx_symbolic("aten::log_softmax")
  34. @symbolic_helper.parse_args("v", "i", "none")
  35. @_beartype.beartype
  36. def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
  37. return_op = g.op("LogSoftmax", input, axis_i=dim)
  38. if dtype and dtype.node().kind() != "prim::Constant":
  39. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  40. return_op = g.op(
  41. "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
  42. )
  43. return return_op
  44. @_onnx_symbolic("aten::frobenius_norm")
  45. @symbolic_helper.parse_args("v", "v", "i")
  46. @_beartype.beartype
  47. def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
  48. dim_val = symbolic_helper._maybe_get_const(dim, "is")
  49. if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
  50. return g.op("ReduceL2", self, keepdims_i=0)
  51. sqr = g.op("Mul", self, self)
  52. sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
  53. return g.op("Sqrt", sumsqr)
  54. @_onnx_symbolic("aten::split")
  55. @symbolic_helper.parse_args("v", "v", "i", "i")
  56. @_beartype.beartype
  57. def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
  58. if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
  59. split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
  60. if _outputs is None:
  61. return split_out
  62. # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
  63. if (
  64. symbolic_helper._is_packed_list(split_size_or_sizes)
  65. and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
  66. ):
  67. split_sizes = [
  68. symbolic_helper._unsqueeze_helper(g, v, [0])
  69. for v in symbolic_helper._unpack_list(split_size_or_sizes)
  70. ]
  71. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  72. axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
  73. res = []
  74. for i in range(_outputs):
  75. end = g.op(
  76. "Add", start, split_sizes[i]
  77. ) # split_sizes is a list of same length as _outputs
  78. res.append(g.op("Slice", self, start, end, axis))
  79. start = end
  80. return res
  81. return [
  82. g.op(
  83. "SequenceAt",
  84. split_out,
  85. g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
  86. )
  87. for i in range(_outputs)
  88. ]
  89. split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value")
  90. if split_val.dim() > 0:
  91. return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs)
  92. split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
  93. size = symbolic_helper._get_tensor_dim_size(self, dim)
  94. if size is None:
  95. if _outputs is not None:
  96. size = split_size * _outputs
  97. else:
  98. raise errors.SymbolicValueError(
  99. "Unknown dimension size not supported", self
  100. )
  101. splits = [split_size] * (size // split_size)
  102. leftover = size % split_size
  103. if leftover:
  104. splits.append(leftover)
  105. splits = g.op("Constant", value_t=torch.tensor(splits))
  106. return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  107. @_onnx_symbolic("aten::split_with_sizes")
  108. @_beartype.beartype
  109. def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
  110. return split(g, self, split_sizes, dim, _outputs)
  111. @_onnx_symbolic("aten::unsafe_split")
  112. @_beartype.beartype
  113. def unsafe_split(
  114. g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
  115. ):
  116. return split(g, self, split_size_or_sizes, dim, _outputs)
  117. @_onnx_symbolic("aten::unsafe_split_with_sizes")
  118. @_beartype.beartype
  119. def unsafe_split_with_sizes(
  120. g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
  121. ):
  122. return split_with_sizes(g, self, split_sizes, dim, _outputs)
  123. @_onnx_symbolic("aten::tensor_split")
  124. @symbolic_helper.parse_args("v", "v", "i", "i")
  125. @_beartype.beartype
  126. def tensor_split(
  127. g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None
  128. ):
  129. axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
  130. axis = opset11.unsqueeze(g, axis, 0)
  131. const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))
  132. if symbolic_helper._is_split_static(indices_or_sections, _outputs):
  133. split_val = symbolic_helper._node_get(indices_or_sections.node(), "value")
  134. if split_val.dim() > 0:
  135. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  136. res = []
  137. assert _outputs is not None
  138. for i in range(_outputs - 1):
  139. end = g.op(
  140. "Gather",
  141. indices_or_sections,
  142. g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
  143. axis_i=0,
  144. )
  145. res.append(g.op("Slice", self, start, end, axis))
  146. start = end
  147. end = symbolic_helper._size_helper(g, self, axis)
  148. res.append(g.op("Slice", self, start, end, axis))
  149. return res
  150. split_size = symbolic_helper._get_const(
  151. indices_or_sections, "i", "indices_or_sections"
  152. )
  153. size = symbolic_helper._get_tensor_dim_size(self, dim)
  154. if size is None:
  155. if _outputs is not None:
  156. size = split_size * _outputs
  157. else:
  158. raise errors.SymbolicValueError(
  159. "Unknown dimension size not supported", self
  160. )
  161. min_split_size = size // split_size
  162. num_splits_one_extra = size % split_size
  163. splits = num_splits_one_extra * [min_split_size + 1]
  164. leftover = (split_size - num_splits_one_extra) * [min_split_size]
  165. splits = g.op(
  166. "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long)
  167. )
  168. return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  169. if (
  170. symbolic_helper._is_tensor(indices_or_sections)
  171. and symbolic_helper._get_tensor_rank(indices_or_sections) == 1
  172. ):
  173. loop_len = symbolic_helper._size_helper(
  174. g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0))
  175. )
  176. loop_len = opset11.unsqueeze(g, loop_len, 0)
  177. loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL)
  178. # To make the first slice in the below loop work,
  179. # we pad a zero to the first position so that it will be the initial start of slice.
  180. padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  181. indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0)
  182. final_splits = g.op("SequenceEmpty")
  183. # Loop inputs
  184. loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
  185. g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1
  186. )
  187. loop_block = loop_context.block
  188. block_input_iter = utils._add_input_to_block(loop_block)
  189. cond = utils._add_input_to_block(loop_block)
  190. final_splits = utils._add_input_to_block(loop_block)
  191. start = loop_context.op(
  192. "Gather", indices_or_sections, block_input_iter, axis_i=0
  193. )
  194. end = loop_context.op(
  195. "Gather",
  196. indices_or_sections,
  197. loop_context.op("Add", block_input_iter, const_1),
  198. axis_i=0,
  199. )
  200. slice = loop_context.op("Slice", self, start, end, axis)
  201. final_splits = loop_context.op("SequenceInsert", final_splits, slice)
  202. # Loop outputs
  203. cond_out = loop_context.op("Identity", loop_condition)
  204. utils._add_output_to_block(loop_block, cond_out)
  205. utils._add_output_to_block(loop_block, final_splits)
  206. loop_out = loop.node().output()
  207. start = g.op(
  208. "Gather",
  209. indices_or_sections,
  210. g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)),
  211. axis_i=0,
  212. )
  213. start = opset11.unsqueeze(g, start, 0)
  214. end = symbolic_helper._size_helper(g, self, axis)
  215. last_slice = g.op("Slice", self, start, end, axis)
  216. return g.op("SequenceInsert", loop_out, last_slice)
  217. else: # scalar tensor
  218. dim_size = symbolic_helper._size_helper(g, self, axis)
  219. min_split_size = g.op("Div", dim_size, indices_or_sections)
  220. min_split_size_plus_1 = g.op(
  221. "Add",
  222. min_split_size,
  223. const_1,
  224. )
  225. num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections)
  226. splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra)
  227. leftover = g.op(
  228. "Tile",
  229. min_split_size,
  230. g.op(
  231. "Sub",
  232. opset11.unsqueeze(g, indices_or_sections, 0),
  233. num_splits_one_extra,
  234. ),
  235. )
  236. splits = g.op("Concat", splits, leftover, axis_i=0)
  237. if _outputs is None:
  238. return g.op("SplitToSequence", self, splits, axis_i=dim)
  239. return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  240. @_onnx_symbolic("aten::unbind")
  241. @symbolic_helper.parse_args("v", "i", "i")
  242. @_beartype.beartype
  243. def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
  244. if _outputs is None:
  245. return g.op(
  246. "SplitToSequence",
  247. self,
  248. g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
  249. axis_i=dim,
  250. keepdims_i=0,
  251. )
  252. splits = g.op("Constant", value_t=torch.tensor([1] * _outputs))
  253. outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  254. outputs = [outputs] if _outputs == 1 else outputs
  255. squeezed_outputs = [
  256. g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim])))
  257. for out in outputs
  258. ]
  259. return squeezed_outputs
  260. @_onnx_symbolic("aten::nonzero_numpy")
  261. # Emitted from `torch.nonzero(x, as_tuple=True)`
  262. @_beartype.beartype
  263. def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
  264. return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
  265. @_onnx_symbolic("aten::where")
  266. @symbolic_helper.parse_args("v", "v", "v", "i")
  267. @_beartype.beartype
  268. def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
  269. # Assumes that torch.where's first argument takes only Bool and Byte tensors.
  270. if not symbolic_helper._is_bool(condition):
  271. condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
  272. if self is None:
  273. condition = opset9.nonzero(g, condition)
  274. return symbolic_helper._unbind_helper(
  275. g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
  276. )
  277. return g.op("Where", condition, self, other)
  278. @_onnx_symbolic("aten::fake_quantize_per_channel_affine")
  279. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
  280. @_beartype.beartype
  281. def fake_quantize_per_channel_affine(
  282. g: jit_utils.GraphContext,
  283. inputs,
  284. scale,
  285. zero_point,
  286. axis,
  287. quant_min=-128,
  288. quant_max=127,
  289. ):
  290. # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
  291. # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
  292. if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
  293. raise errors.SymbolicValueError(
  294. "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
  295. f"Got ({quant_min}, {quant_max})",
  296. inputs,
  297. )
  298. # ONNX defines zero_point to be int8 or uint8
  299. if quant_min == 0:
  300. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  301. else:
  302. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
  303. quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
  304. if (quant_min, quant_max) == (0, 127):
  305. quantized = g.op(
  306. "Clip",
  307. quantized,
  308. opset9.unused(g),
  309. g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
  310. )
  311. return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
  312. @_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
  313. @symbolic_helper.parse_args("v", "v", "v", "i", "i")
  314. @_beartype.beartype
  315. def fake_quantize_per_tensor_affine(
  316. g: jit_utils.GraphContext,
  317. inputs,
  318. scale,
  319. zero_point,
  320. quant_min=-128,
  321. quant_max=127,
  322. ):
  323. # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
  324. # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
  325. if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
  326. raise errors.SymbolicValueError(
  327. "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
  328. f"Got ({quant_min}, {quant_max})",
  329. inputs,
  330. )
  331. if quant_min == 0:
  332. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
  333. else:
  334. zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
  335. if (
  336. _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
  337. != _type_utils.JitScalarType.FLOAT
  338. ):
  339. scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  340. quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
  341. if (quant_min, quant_max) == (0, 127):
  342. quantized = g.op(
  343. "Clip",
  344. quantized,
  345. opset9.unused(g),
  346. g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
  347. )
  348. return g.op("DequantizeLinear", quantized, scale, zero_point)
  349. @_beartype.beartype
  350. def _reduce_op_symbolic(onnx_op_name):
  351. @_beartype.beartype
  352. def symbolic(g, self, dim=None, keepdim=None):
  353. self = opset9._maybe_cast_reduce_op_input(g, self)
  354. if dim is None:
  355. # all-reduce path
  356. return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
  357. else:
  358. keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
  359. return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
  360. return symbolic
  361. @_onnx_symbolic(
  362. "aten::sum",
  363. decorate=[_apply_params("ReduceSum", "sum")],
  364. )
  365. @_beartype.beartype
  366. def _reduce_with_dtype(onnx_op, name):
  367. symbolic = _reduce_op_symbolic(onnx_op)
  368. @opset9.overload_by_arg_count
  369. @_beartype.beartype
  370. def reduce(g, *args, **kwargs):
  371. @symbolic_helper.parse_args("v", "none")
  372. @_beartype.beartype
  373. def reduce_nodim(g, self, dtype):
  374. if dtype.node().kind() == "onnx::Constant":
  375. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  376. self = g.op(
  377. "Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()
  378. )
  379. elif dtype.node().kind() != "prim::Constant":
  380. return symbolic_helper._unimplemented(name, "dtype", dtype)
  381. return symbolic(g, self)
  382. @symbolic_helper.parse_args("v", "v", "i", "none")
  383. @_beartype.beartype
  384. def reduce_dim(g, self, dim, keepdim, dtype):
  385. if dtype.node().kind() == "onnx::Constant":
  386. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  387. self = g.op(
  388. "Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()
  389. )
  390. elif dtype.node().kind() != "prim::Constant":
  391. return symbolic_helper._unimplemented(name, "dtype", dtype)
  392. return symbolic(g, self, dim, keepdim)
  393. return reduce_nodim, reduce_dim
  394. return reduce
  395. @_onnx_symbolic("aten::unsafe_chunk")
  396. @symbolic_helper.parse_args("v", "i", "i", "i")
  397. @_beartype.beartype
  398. def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
  399. if _outputs is None:
  400. return g.op(
  401. "SplitToSequence",
  402. self,
  403. g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
  404. axis_i=dim,
  405. keepdims_i=0,
  406. )
  407. size = symbolic_helper._get_tensor_dim_size(self, dim)
  408. if size is None:
  409. return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size")
  410. split_size = (size + chunks - 1) // chunks
  411. splits = [split_size] * (size // split_size)
  412. leftover = size % split_size
  413. if leftover:
  414. splits.append(leftover)
  415. # TODO: So far we don"t have a module using this method. We"ll keep
  416. # this as a constant unless we see a request of dynamics in any
  417. # user's modules.
  418. splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
  419. return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
  420. @_onnx_symbolic("aten::repeat_interleave")
  421. @_beartype.beartype
  422. def repeat_interleave(
  423. g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
  424. ):
  425. input = self
  426. final_dim = dim
  427. # if dim is None flatten
  428. # By default, use the flattened input array, and return a flat output array
  429. if symbolic_helper._is_none(dim):
  430. input = symbolic_helper._reshape_helper(
  431. g, self, g.op("Constant", value_t=torch.tensor([-1]))
  432. )
  433. dim = 0
  434. else:
  435. dim = symbolic_helper._maybe_get_scalar(dim)
  436. repeats_dim = symbolic_helper._get_tensor_rank(repeats)
  437. repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
  438. input_sizes = symbolic_helper._get_tensor_sizes(input)
  439. if repeats_dim is None:
  440. raise errors.SymbolicValueError(
  441. "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.",
  442. self,
  443. )
  444. if repeats_sizes is None:
  445. raise errors.SymbolicValueError(
  446. "Unsupported: ONNX export of repeat_interleave for unknown repeats size.",
  447. self,
  448. )
  449. if input_sizes is None:
  450. raise errors.SymbolicValueError(
  451. "Unsupported: ONNX export of repeat_interleave for unknown input size.",
  452. self,
  453. )
  454. # Handle cases where dim is negative
  455. if dim < 0:
  456. dim += len(input_sizes)
  457. output_sizes = input_sizes.copy()
  458. for idx, input_size in enumerate(input_sizes):
  459. if input_size is None:
  460. output_sizes[idx], input_sizes[idx] = 0, -1
  461. cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None
  462. # If input size is dynamic or repeats vector is dynamic
  463. if output_sizes[dim] == 0 or cond_dynamic_repeats:
  464. reps = symbolic_helper._size_helper(g, input, dim)
  465. reps = opset11.unsqueeze(g, reps, 0)
  466. # Check if repeats vector is a single integer value
  467. # or a single dimension tensor with non-dynamic values
  468. if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
  469. if not symbolic_helper._is_tensor(repeats):
  470. repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
  471. repeats = g.op("Expand", repeats, reps)
  472. # Check if repeats is dynamic
  473. # As repeats is dynamic, we use a where node as a substitute for the if statement
  474. # If repests_dim = 1, expand repeats otherwise use original tensor
  475. elif cond_dynamic_repeats:
  476. repeat_dim = symbolic_helper._size_helper(
  477. g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))
  478. )
  479. repeat_cond = g.op(
  480. "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))
  481. )
  482. repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
  483. # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
  484. # provided along one of the dynamic axes provided. A simple example would be
  485. # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
  486. # Now, repeat interleaving can be performed in pytorch when the value of * matches
  487. # with the number of elements in repeat, for example if * -> 2, number of repeats
  488. # should be 2 as well.
  489. else:
  490. return opset9.repeat_interleave(g, self, repeats, final_dim)
  491. reps_like = g.op(
  492. "ConstantOfShape",
  493. g.op("Shape", repeats),
  494. value_t=torch.tensor([1], dtype=torch.long),
  495. )
  496. r_splits = split(g, repeats, reps_like, 0)
  497. i_splits = split(g, input, reps_like, dim)
  498. output_sizes[dim], input_sizes[dim] = -1, 1
  499. # Create a loop to iterate over each value along the dimension
  500. # and perform individual interleaving using the repeats tensor
  501. # Loop is of the following pattern
  502. # input (trip_count, cond)
  503. # int trip_count = ...;
  504. # bool cond = ...;
  505. # for (int i=0; i < trip_count && cond; ++i) {
  506. # cond = ...;
  507. # }
  508. # Loop conditions
  509. loop_condition = g.op("Constant", value_t=torch.tensor(1))
  510. loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
  511. loop_len = reps
  512. # Create an empty sequence to store final expansions
  513. final_splits = g.op("SequenceEmpty")
  514. # Loop inputs
  515. loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
  516. g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1
  517. )
  518. loop_block = loop_context.block
  519. block_input_iter = utils._add_input_to_block(loop_block)
  520. cond = utils._add_input_to_block(loop_block)
  521. final_splits = utils._add_input_to_block(loop_block)
  522. r_split = loop_context.op("SequenceAt", r_splits, block_input_iter)
  523. i_split = loop_context.op("SequenceAt", i_splits, block_input_iter)
  524. i_split = opset11.unsqueeze(loop_context, i_split, dim + 1)
  525. r_concat = [
  526. loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])),
  527. r_split,
  528. loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])),
  529. ]
  530. r_concat = loop_context.op("Concat", *r_concat, axis_i=0)
  531. i_split = opset9.expand(loop_context, i_split, r_concat, None)
  532. i_split = symbolic_helper._reshape_helper(
  533. loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))
  534. )
  535. final_splits = loop_context.op("SequenceInsert", final_splits, i_split)
  536. # Loop outputs
  537. cond_out = loop_context.op(
  538. "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
  539. )
  540. utils._add_output_to_block(loop_block, cond_out)
  541. utils._add_output_to_block(loop_block, final_splits)
  542. loop_out = loop.node().output()
  543. loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
  544. return loop_out
  545. @_onnx_symbolic("aten::diagonal")
  546. @symbolic_helper.parse_args("v", "i", "i", "i")
  547. @_beartype.beartype
  548. def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
  549. dim1_size = opset9.size(
  550. g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
  551. )
  552. dim2_size = opset9.size(
  553. g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
  554. )
  555. # Create appropriate mask
  556. mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
  557. mask = opset9.zeros(g, mask_shape, None, None, None)
  558. mask = g.op("EyeLike", mask, k_i=offset)
  559. # dim1 and dim2 appended as a dimension at the end of the shape
  560. rank = symbolic_helper._get_tensor_rank(self)
  561. if rank is not None:
  562. axes = list(range(rank))
  563. axes.remove(dim1)
  564. axes.remove(dim2)
  565. self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
  566. else:
  567. return symbolic_helper._unimplemented("diagonal", "unknown input rank")
  568. # Multiply input and mask to calculate values along diagonal
  569. # The mask consists of one values where diagonal values are to be calculated
  570. # For example:
  571. # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0],
  572. # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0],
  573. # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]]
  574. result = g.op("Mul", self, mask)
  575. result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)
  576. # Calculate gather indices based on offset and dims
  577. # If offset is greater than zero, set offset to zero as this aids in
  578. # calculation of selection window
  579. offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
  580. if offset >= 0:
  581. diag_size = g.op(
  582. "Max",
  583. g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
  584. g.op("Constant", value_t=torch.LongTensor([0])),
  585. )
  586. offset = 0
  587. else:
  588. diag_size = g.op(
  589. "Max",
  590. g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
  591. g.op("Constant", value_t=torch.LongTensor([0])),
  592. )
  593. diag_size = g.op("Concat", diag_size, axis_i=0)
  594. # Calculate which diagonal values to select
  595. # For example, in cases with offsets:
  596. # [[0, 1.1, 0]
  597. # [0, 0, 2.2]]
  598. # we need to select the last two columns, so we create a tensor
  599. # with all columns that are to be selected
  600. # So in this example, it is [1, 2]
  601. select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
  602. select_window = g.op(
  603. "CumSum",
  604. select_window_ones_fill,
  605. g.op("Constant", value_t=torch.LongTensor([0])),
  606. )
  607. select_window = g.op(
  608. "Add",
  609. select_window,
  610. g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
  611. )
  612. gather_shape = [
  613. opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
  614. for axis in list(range(rank))[:-2]
  615. ]
  616. gather_shape.append(diag_size)
  617. gather_shape = g.op("Concat", *gather_shape, axis_i=0)
  618. gather_indices = opset9.zeros(g, gather_shape, 4, None, None)
  619. # There might be cases where offset value is greater than number of rows/columns
  620. # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
  621. # For example, if
  622. # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
  623. # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
  624. # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
  625. # In cases without diagonal overrun, we select the appropriate rows/columns along which we
  626. # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
  627. # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
  628. # returning an empty tensor
  629. overrun_cond = g.op(
  630. "Not",
  631. g.op(
  632. "Equal",
  633. diag_size,
  634. g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
  635. ),
  636. )
  637. if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
  638. g, "If", overrun_cond, n_blocks=2
  639. )
  640. gather_indices_if_block = if_context.op("Add", gather_indices, select_window)
  641. gather_indices_if_block = symbolic_helper._unsqueeze_helper(
  642. if_context, gather_indices_if_block, [rank - 1]
  643. )
  644. final_non_overrun = if_context.op(
  645. "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
  646. )
  647. final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None)
  648. utils._add_output_to_block(if_context.block, final_non_overrun)
  649. utils._add_output_to_block(else_context.block, final_overrun)
  650. return if_op
  651. # Quantized ops
  652. @_onnx_symbolic("quantized::linear")
  653. @_beartype.beartype
  654. def quantized_linear(
  655. g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
  656. ):
  657. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  658. weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
  659. q_bias = symbolic_helper.requantize_bias_helper(
  660. g, bias, input_scale, weight_scale, axis
  661. )
  662. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  663. output = opset9.linear(g, input, weight, bias)
  664. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  665. @_onnx_symbolic("quantized::conv2d")
  666. @_beartype.beartype
  667. def quantized_conv2d(
  668. g: jit_utils.GraphContext,
  669. q_input,
  670. q_weight,
  671. bias,
  672. stride,
  673. padding,
  674. dilation,
  675. groups,
  676. op_scale,
  677. op_zero_point,
  678. ):
  679. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  680. weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
  681. q_bias = symbolic_helper.requantize_bias_helper(
  682. g, bias, input_scale, weight_scale, axis
  683. )
  684. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  685. output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
  686. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
  687. @_onnx_symbolic("quantized::conv2d_relu")
  688. @_beartype.beartype
  689. def quantized_conv2d_relu(
  690. g: jit_utils.GraphContext,
  691. q_input,
  692. q_weight,
  693. bias,
  694. stride,
  695. padding,
  696. dilation,
  697. groups,
  698. op_scale,
  699. op_zero_point,
  700. ):
  701. input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
  702. weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
  703. q_bias = symbolic_helper.requantize_bias_helper(
  704. g, bias, input_scale, weight_scale, axis
  705. )
  706. bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
  707. output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
  708. output = opset9.relu(g, output)
  709. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)