symbolic_opset11.py 55 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561
  1. """This file exports ONNX ops for opset 11."""
  2. from __future__ import annotations
  3. import functools
  4. import sys
  5. import warnings
  6. from typing import Optional, Sequence, Union
  7. import torch
  8. from torch import _C
  9. from torch._C import _onnx as _C_onnx
  10. from torch.onnx import (
  11. _type_utils,
  12. errors,
  13. symbolic_helper,
  14. symbolic_opset10 as opset10,
  15. symbolic_opset9 as opset9,
  16. utils,
  17. )
  18. from torch.onnx._globals import GLOBALS
  19. from torch.onnx._internal import _beartype, jit_utils, registration
  20. # EDITING THIS FILE? READ THIS FIRST!
  21. # see Note [Edit Symbolic Files] in README.md
  22. __all__ = [
  23. "add",
  24. "append",
  25. "arange",
  26. "argsort",
  27. "cat",
  28. "chunk",
  29. "clamp_max",
  30. "clamp_min",
  31. "clamp",
  32. "constant_pad_nd",
  33. "cumsum",
  34. "Delete",
  35. "embedding_bag",
  36. "embedding_renorm",
  37. "flatten",
  38. "gather",
  39. "hardtanh",
  40. "im2col",
  41. "index_fill",
  42. "index",
  43. "index_copy",
  44. "index_put",
  45. "insert",
  46. "linalg_det",
  47. "linalg_vector_norm",
  48. "logdet",
  49. "masked_scatter",
  50. "masked_select",
  51. "mm",
  52. "narrow",
  53. "normal",
  54. "pad",
  55. "pixel_shuffle",
  56. "pop",
  57. "prim_constant_chunk",
  58. "reflection_pad",
  59. "relu6",
  60. "remainder",
  61. "replication_pad",
  62. "round",
  63. "scatter",
  64. "select",
  65. "size",
  66. "sort",
  67. "split_with_sizes",
  68. "split",
  69. "squeeze",
  70. "stack",
  71. "topk",
  72. "unbind",
  73. "unique_dim",
  74. "unsqueeze",
  75. ]
  76. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=11)
  77. def _apply_params(*args, **kwargs):
  78. """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
  79. def _apply(fn):
  80. return fn(*args, **kwargs)
  81. return _apply
  82. @_onnx_symbolic("aten::hardtanh")
  83. @symbolic_helper.quantized_args(True)
  84. @symbolic_helper.parse_args("v", "f", "f")
  85. @_beartype.beartype
  86. def hardtanh(g: jit_utils.GraphContext, self: _C.Value, min_val: float, max_val: float):
  87. scalar_type = _type_utils.JitScalarType.from_value(
  88. self, _type_utils.JitScalarType.FLOAT
  89. )
  90. min_val = g.op(
  91. "Constant",
  92. value_t=torch.tensor(min_val, dtype=scalar_type.dtype()),
  93. )
  94. max_val = g.op(
  95. "Constant",
  96. value_t=torch.tensor(max_val, dtype=scalar_type.dtype()),
  97. )
  98. return opset9._op_with_optional_float_cast(
  99. g, "Clip", self, min_val, max_val, opset_before=12
  100. )
  101. @_onnx_symbolic("aten::clamp")
  102. @_beartype.beartype
  103. def clamp(g: jit_utils.GraphContext, self, min, max):
  104. @_beartype.beartype
  105. def _cast_if_not_none(tensor, dtype):
  106. if tensor is not None and not symbolic_helper._is_none(tensor):
  107. return g.op(
  108. "Cast",
  109. tensor,
  110. to_i=dtype.onnx_type(),
  111. )
  112. else:
  113. return tensor
  114. scalar_type = _type_utils.JitScalarType.from_value(
  115. self, _type_utils.JitScalarType.UNDEFINED
  116. )
  117. if scalar_type != _type_utils.JitScalarType.UNDEFINED:
  118. min = _cast_if_not_none(min, scalar_type)
  119. max = _cast_if_not_none(max, scalar_type)
  120. if symbolic_helper._is_none(min):
  121. return clamp_max(g, self, max)
  122. elif symbolic_helper._is_none(max):
  123. return clamp_min(g, self, min)
  124. else:
  125. if (
  126. symbolic_helper._get_tensor_rank(min) == 0
  127. and symbolic_helper._get_tensor_rank(max) == 0
  128. ):
  129. return opset9._op_with_optional_float_cast(
  130. g, "Clip", self, min, max, opset_before=12
  131. )
  132. else:
  133. return clamp_max(g, clamp_min(g, self, min), max)
  134. @_onnx_symbolic("aten::clamp_min")
  135. @symbolic_helper.parse_args("v", "v")
  136. @_beartype.beartype
  137. def clamp_min(g: jit_utils.GraphContext, self, min):
  138. min = g.op("Cast", min, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
  139. if symbolic_helper._get_tensor_rank(min) == 0:
  140. max = opset9.unused(g)
  141. return opset9._op_with_optional_float_cast(
  142. g, "Clip", self, min, max, opset_before=12
  143. )
  144. else:
  145. return opset9._op_with_optional_float_cast(g, "Max", self, min, opset_before=12)
  146. @_onnx_symbolic("aten::clamp_max")
  147. @symbolic_helper.parse_args("v", "v")
  148. @_beartype.beartype
  149. def clamp_max(g: jit_utils.GraphContext, self, max):
  150. max = g.op("Cast", max, to_i=_type_utils.JitScalarType.from_value(self).onnx_type())
  151. if symbolic_helper._get_tensor_rank(max) == 0:
  152. min = opset9.unused(g)
  153. return opset9._op_with_optional_float_cast(
  154. g, "Clip", self, min, max, opset_before=12
  155. )
  156. else:
  157. return opset9._op_with_optional_float_cast(g, "Min", self, max, opset_before=12)
  158. @_onnx_symbolic("aten::relu6")
  159. @_beartype.beartype
  160. def relu6(g: jit_utils.GraphContext, input):
  161. relu_ = opset9._op_with_optional_float_cast(g, "Relu", input, opset_before=14)
  162. scalar_type = _type_utils.JitScalarType.from_value(
  163. input, _type_utils.JitScalarType.FLOAT
  164. )
  165. min_val = g.op(
  166. "Constant",
  167. value_t=torch.tensor(0, dtype=scalar_type.dtype()),
  168. )
  169. max_val = g.op(
  170. "Constant",
  171. value_t=torch.tensor(6, dtype=scalar_type.dtype()),
  172. )
  173. return clamp(g, relu_, min_val, max_val)
  174. @_onnx_symbolic("aten::select")
  175. # Opset 11 gather accepts negative indices
  176. @symbolic_helper.quantized_args(True)
  177. @symbolic_helper.parse_args("v", "i", "v")
  178. @_beartype.beartype
  179. def select(g: jit_utils.GraphContext, self, dim, index):
  180. return g.op("Gather", self, index, axis_i=dim)
  181. @_onnx_symbolic("aten::index_put")
  182. @_beartype.beartype
  183. def index_put(
  184. g: jit_utils.GraphContext, self, indices_list_value, values, accumulate=False
  185. ):
  186. if symbolic_helper._is_packed_list(indices_list_value):
  187. indices_list = symbolic_helper._unpack_list(indices_list_value)
  188. else:
  189. indices_list = [indices_list_value]
  190. if symbolic_helper.is_caffe2_aten_fallback():
  191. args = [self] + indices_list + [values, accumulate]
  192. return g.at("index_put", *args)
  193. accumulate = symbolic_helper._parse_arg(accumulate, "b")
  194. if len(indices_list) == 0:
  195. return values
  196. if len(indices_list) > 1:
  197. for idx_ in range(len(indices_list)):
  198. if symbolic_helper._is_bool(indices_list[idx_]):
  199. indices_list[idx_] = g.op("NonZero", indices_list[idx_])
  200. index = indices_list[0]
  201. for ind in indices_list[1:]:
  202. index = opset9.add(g, index, ind)
  203. broadcast_index_shape = g.op("Shape", index)
  204. indices_list = [
  205. symbolic_helper._unsqueeze_helper(
  206. g, opset9.expand(g, ind, broadcast_index_shape, None), [-1]
  207. )
  208. for ind in indices_list
  209. ]
  210. index = g.op("Concat", *indices_list, axis_i=-1)
  211. else:
  212. # Replace index_put node with masked_scatter or masked_fill
  213. # when inputs to the index_put node contains a single boolean input.
  214. #
  215. # index_put -> masked_fill
  216. # * input index contains single tensor of Bool type (e.g.: %24 <- %23).
  217. # * input value contains single element (e.g.: %18).
  218. #
  219. # Torch IR
  220. # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
  221. # %16 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
  222. # aten::to(%8, %26, %27, %11, %12, %28, %29, %15)
  223. # %18 : Float(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
  224. # %23 : Bool(8, strides=[1], device=cpu) = aten::view(%16, %22)
  225. # %24 : Tensor?[] = prim::ListConstruct(%23)
  226. # %25 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) =
  227. # aten::index_put(%mask, %24, %18, %30)
  228. # return (%25)
  229. #
  230. #
  231. # index_put -> masked_scatter
  232. # * input index contains single tensor of Bool type (e.g.: %32 <- %31).
  233. # * input value contains multiple elements (e.g.: %28).
  234. #
  235. # Torch IR
  236. # %mask : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu) = aten::clone(%0, %6)
  237. # %28 : Float(8, strides=[1], requires_grad=0, device=cpu)
  238. # = prim::Constant[value= 1 1 1 1 1 1 1 1 [ CPUFloatType{8} ]]()
  239. # %15 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
  240. # = aten::ne(%mask, %some_const)
  241. # %23 : Bool(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
  242. # = aten::to(%15, %34, %35, %18, %19, %36, %37, %22)
  243. # %38 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]()
  244. # %30 : int[] = prim::Constant[value=[-1]]()
  245. # %31 : Bool(8, strides=[1], device=cpu) = aten::view(%23, %30)
  246. # %32 : Tensor?[] = prim::ListConstruct(%31)
  247. # %33 : Float(2, 2, 2, strides=[4, 2, 1], requires_grad=0, device=cpu)
  248. # = aten::index_put(%mask, %32, %28, %38)
  249. # return (%33)
  250. index = indices_list[0]
  251. bool_inp = index
  252. if symbolic_helper._is_bool(bool_inp):
  253. rank = symbolic_helper._get_tensor_rank(values)
  254. if rank is not None and rank == 0:
  255. return opset9.masked_fill(g, self, bool_inp, values)
  256. mask_rank = symbolic_helper._get_tensor_rank(bool_inp)
  257. self_rank = symbolic_helper._get_tensor_rank(self)
  258. if (
  259. mask_rank is not None
  260. and self_rank is not None
  261. and self_rank > mask_rank
  262. ):
  263. # Unsqueeze 'bool_inp' to be broadcastable to shape of 'self'.
  264. bool_inp = symbolic_helper._unsqueeze_helper(
  265. g, bool_inp, list(range(mask_rank, self_rank))
  266. )
  267. return masked_scatter(g, self, bool_inp, values)
  268. broadcast_index_shape = g.op("Shape", index)
  269. index = symbolic_helper._unsqueeze_helper(g, index, [-1])
  270. sub_data_shape = symbolic_helper._slice_helper(
  271. g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[sys.maxsize]
  272. )
  273. values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
  274. # Check if values is a singular value and expand accordingly
  275. rank = symbolic_helper._get_tensor_rank(values)
  276. if rank is not None and rank == 0:
  277. values = opset9.expand(g, values, values_shape, None)
  278. values = symbolic_helper._reshape_helper(g, values, values_shape)
  279. self_scalar_type = _type_utils.JitScalarType.from_value(
  280. self, _type_utils.JitScalarType.UNDEFINED
  281. )
  282. if self_scalar_type != _type_utils.JitScalarType.UNDEFINED:
  283. values_scalar_type = _type_utils.JitScalarType.from_value(
  284. values, _type_utils.JitScalarType.UNDEFINED
  285. )
  286. if self_scalar_type != values_scalar_type:
  287. values = g.op("Cast", values, to_i=self_scalar_type.onnx_type())
  288. elif accumulate:
  289. raise errors.SymbolicValueError("self does not have a valid scalar type.", self)
  290. if accumulate:
  291. zeros = g.op(
  292. "ConstantOfShape",
  293. g.op("Shape", self),
  294. value_t=torch.tensor([0], dtype=self_scalar_type.dtype()),
  295. )
  296. result = g.op("ScatterND", zeros, index, values)
  297. result = add(g, self, result)
  298. else:
  299. result = g.op("ScatterND", self, index, values)
  300. return result
  301. @_onnx_symbolic("aten::pixel_shuffle")
  302. @symbolic_helper.parse_args("v", "i")
  303. @_beartype.beartype
  304. def pixel_shuffle(g: jit_utils.GraphContext, self, upscale_factor):
  305. rank = symbolic_helper._get_tensor_rank(self)
  306. if rank is not None and rank != 4:
  307. return symbolic_helper._unimplemented("pixel_shuffle", "only support 4d input")
  308. return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
  309. @_onnx_symbolic(
  310. "aten::upsample_nearest1d",
  311. decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
  312. )
  313. @_onnx_symbolic(
  314. "aten::upsample_nearest2d",
  315. decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
  316. )
  317. @_onnx_symbolic(
  318. "aten::upsample_nearest3d",
  319. decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
  320. )
  321. @_onnx_symbolic(
  322. "aten::upsample_linear1d",
  323. decorate=[_apply_params("upsample_linear1d", 3, "linear")],
  324. )
  325. @_onnx_symbolic(
  326. "aten::upsample_bilinear2d",
  327. decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
  328. )
  329. @_onnx_symbolic(
  330. "aten::upsample_trilinear3d",
  331. decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
  332. )
  333. @_onnx_symbolic(
  334. "aten::upsample_bicubic2d",
  335. decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
  336. )
  337. @_beartype.beartype
  338. def _interpolate(name: str, dim: int, interpolate_mode: str):
  339. return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
  340. @_onnx_symbolic("aten::__interpolate")
  341. @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
  342. @_beartype.beartype
  343. def __interpolate(
  344. g: jit_utils.GraphContext,
  345. input,
  346. size,
  347. scale_factor,
  348. mode,
  349. align_corners,
  350. recompute_scale_factor,
  351. antialias,
  352. ):
  353. return symbolic_helper.__interpolate_helper(
  354. g, input, size, scale_factor, mode, align_corners, recompute_scale_factor
  355. )
  356. @_onnx_symbolic("aten::gather")
  357. @symbolic_helper.parse_args("v", "i", "v", "v")
  358. @_beartype.beartype
  359. def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
  360. if symbolic_helper._maybe_get_const(sparse_grad, "i"):
  361. return symbolic_helper._unimplemented("gather", "sparse_grad == True")
  362. if symbolic_helper.is_caffe2_aten_fallback():
  363. return g.at("gather", self, dim, index, sparse_grad)
  364. return g.op("GatherElements", self, index, axis_i=dim)
  365. @_onnx_symbolic("aten::scatter")
  366. @symbolic_helper.parse_args("v", "i", "v", "v")
  367. @_beartype.beartype
  368. def scatter(g: jit_utils.GraphContext, self, dim, index, src):
  369. if symbolic_helper.is_caffe2_aten_fallback():
  370. return g.at("scatter", self, dim, index, src, overload_name="src")
  371. src_type = _type_utils.JitScalarType.from_value(src)
  372. src = symbolic_helper._maybe_get_scalar(src)
  373. if symbolic_helper._is_value(src):
  374. return g.op("ScatterElements", self, index, src, axis_i=dim)
  375. else:
  376. # Check if scalar "src" has same type as self (PyTorch allows different
  377. # type for scalar src (but not when src is tensor)). If not, insert Cast node.
  378. if _type_utils.JitScalarType.from_value(self) != src_type:
  379. src = g.op(
  380. "Cast",
  381. src,
  382. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  383. )
  384. return g.op(
  385. "ScatterElements", self, index, opset9.expand_as(g, src, index), axis_i=dim
  386. )
  387. @_onnx_symbolic("aten::cumsum")
  388. @symbolic_helper.parse_args("v", "i", "none")
  389. @_beartype.beartype
  390. def cumsum(g: jit_utils.GraphContext, self, dim, dtype=None):
  391. dim_tensor = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int))
  392. if dtype and dtype.node().kind() != "prim::Constant":
  393. parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  394. cast = g.op(
  395. "Cast", self, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
  396. )
  397. else:
  398. cast = self
  399. csum = g.op("CumSum", cast, dim_tensor)
  400. return csum
  401. @_onnx_symbolic("aten::masked_select")
  402. @_beartype.beartype
  403. def masked_select(g: jit_utils.GraphContext, self, mask):
  404. index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
  405. return g.op("GatherND", self, index)
  406. @_onnx_symbolic("aten::masked_scatter")
  407. @_beartype.beartype
  408. def masked_scatter(g: jit_utils.GraphContext, self, mask, source):
  409. index = opset9.nonzero(g, opset9.expand_as(g, mask, self))
  410. # NOTE: source can have more elements than needed.
  411. # It could also have arbitrary shape.
  412. # This is not supported by ONNX::ScatterND, so we need to flatten and slice source tensor.
  413. source = symbolic_helper._reshape_helper(g, source, torch.LongTensor([-1]))
  414. source = symbolic_helper._slice_helper(
  415. g,
  416. source,
  417. axes=torch.LongTensor([0]),
  418. starts=torch.LongTensor([0]),
  419. ends=opset9.size(g, index, torch.LongTensor([0])),
  420. dynamic_slice=True,
  421. )
  422. return g.op("ScatterND", self, index, source)
  423. @_onnx_symbolic("aten::len")
  424. @_beartype.beartype
  425. def _len(g: jit_utils.GraphContext, self):
  426. if (
  427. symbolic_helper._is_tensor_list(self)
  428. or self.node().kind() == "onnx::SplitToSequence"
  429. ):
  430. return g.op("SequenceLength", self)
  431. sz_0 = size(g, self, g.op("Constant", value_t=torch.LongTensor([0])))
  432. return symbolic_helper._squeeze_helper(g, sz_0, [0])
  433. @_onnx_symbolic("aten::__getitem_")
  434. @_beartype.beartype
  435. def __getitem_(g: jit_utils.GraphContext, self, i):
  436. if symbolic_helper._is_tensor_list(self):
  437. # SequenceAt requires that the input be a List of Tensors
  438. return g.op("SequenceAt", self, i)
  439. else:
  440. from torch.onnx.symbolic_opset9 import __getitem_ as getitem
  441. return getitem(g, self, i)
  442. @_onnx_symbolic("aten::_set_item")
  443. @_beartype.beartype
  444. def _set_item(g: jit_utils.GraphContext, tensor_list, i, v):
  445. tensor_list = g.op("SequenceErase", tensor_list, i)
  446. return g.op("SequenceInsert", tensor_list, v, i)
  447. @_onnx_symbolic("aten::append")
  448. @_beartype.beartype
  449. def append(g: jit_utils.GraphContext, self, tensor):
  450. return g.op("SequenceInsert", self, tensor)
  451. @_onnx_symbolic("aten::add")
  452. @_beartype.beartype
  453. def add(g: jit_utils.GraphContext, self, other, alpha=None):
  454. if symbolic_helper._is_value(self) and symbolic_helper._is_tensor_list(self):
  455. tensor_list_node = other.node()
  456. if tensor_list_node.kind() != "prim::ListConstruct":
  457. return symbolic_helper._unimplemented(
  458. "add", "does not support adding dynamic tensor list to another"
  459. )
  460. tensors = symbolic_helper._unpack_list(other)
  461. l = self
  462. for t in tensors:
  463. l = g.op("SequenceInsert", l, t)
  464. return l
  465. return opset9.add(g, self, other, alpha)
  466. @_onnx_symbolic("aten::insert")
  467. @_beartype.beartype
  468. def insert(g: jit_utils.GraphContext, self, pos, tensor):
  469. return g.op("SequenceInsert", self, tensor, pos)
  470. @_onnx_symbolic("aten::pop")
  471. @_beartype.beartype
  472. def pop(g: jit_utils.GraphContext, tensor_list, dim):
  473. return g.op("SequenceErase", tensor_list, dim)
  474. @_onnx_symbolic("aten::Delete")
  475. @_beartype.beartype
  476. def Delete(g: jit_utils.GraphContext, tensor_list, dim):
  477. return g.op("SequenceErase", tensor_list, dim)
  478. @_onnx_symbolic("aten::cat")
  479. @symbolic_helper.quantized_args(True)
  480. @_beartype.beartype
  481. def cat(g: jit_utils.GraphContext, tensor_list, dim):
  482. if symbolic_helper._is_packed_list(tensor_list):
  483. return opset9.cat(g, tensor_list, dim)
  484. else:
  485. dim = symbolic_helper._get_const(dim, "i", "dim")
  486. return g.op("ConcatFromSequence", tensor_list, axis_i=dim)
  487. @_onnx_symbolic("aten::stack")
  488. @_beartype.beartype
  489. def stack(g: jit_utils.GraphContext, tensor_list, dim):
  490. if symbolic_helper._is_packed_list(tensor_list):
  491. return opset9.stack(g, tensor_list, dim)
  492. else:
  493. dim = symbolic_helper._get_const(dim, "i", "dim")
  494. return g.op("ConcatFromSequence", tensor_list, axis_i=dim, new_axis_i=1)
  495. @_onnx_symbolic("aten::_unique2")
  496. @symbolic_helper.parse_args("v", "i", "i", "i")
  497. @_beartype.beartype
  498. def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts):
  499. u, indices, inverse_indices, counts = g.op(
  500. "Unique", self, sorted_i=sorted, outputs=4
  501. )
  502. return u, inverse_indices, counts
  503. @_onnx_symbolic(
  504. "aten::avg_pool1d",
  505. decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)],
  506. )
  507. @_onnx_symbolic(
  508. "aten::avg_pool2d",
  509. decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)],
  510. )
  511. @_onnx_symbolic(
  512. "aten::avg_pool3d",
  513. decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)],
  514. )
  515. @_beartype.beartype
  516. def _avg_pool(name, tuple_fn):
  517. @symbolic_helper.quantized_args(True, False, False, False, False, False, False)
  518. @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
  519. @_beartype.beartype
  520. def symbolic_fn(
  521. g,
  522. input: _C.Value,
  523. kernel_size: Sequence[int],
  524. stride: Sequence[int],
  525. padding: Union[int, Sequence[int]],
  526. ceil_mode: int,
  527. count_include_pad: int,
  528. divisor_override=None,
  529. ):
  530. # Although onnx::AvgPool provides count_include_pad and ceil_mode,
  531. # The corner case of Average Pooling with ceil_mode on
  532. # PyTorch allows sliding window go off bound, which leads to
  533. # this accommodation.
  534. # More detail on https://github.com/pytorch/pytorch/issues/57178
  535. if not stride:
  536. stride = kernel_size
  537. padding = symbolic_helper._avgpool_helper(
  538. tuple_fn, padding, kernel_size, stride, divisor_override, name
  539. )
  540. assert isinstance(padding, tuple)
  541. adjusted_padding = padding
  542. if count_include_pad:
  543. input = g.op(
  544. "Pad",
  545. input,
  546. g.op("Constant", value_t=torch.tensor(((0,) * 2 + padding) * 2)),
  547. mode_s="constant",
  548. )
  549. adjusted_padding = (0,) * len(padding)
  550. if ceil_mode:
  551. padding_ceil = opset9.get_pool_ceil_padding(
  552. input, kernel_size, stride, padding
  553. )
  554. adjusted_padding = adjusted_padding + tuple(
  555. a + b for (a, b) in zip(padding_ceil, adjusted_padding)
  556. )
  557. else:
  558. adjusted_padding = adjusted_padding * 2
  559. output = g.op(
  560. "AveragePool",
  561. input,
  562. kernel_shape_i=tuple_fn(kernel_size),
  563. strides_i=tuple_fn(stride),
  564. pads_i=adjusted_padding,
  565. )
  566. return output
  567. return symbolic_fn
  568. @_onnx_symbolic("aten::unique_dim")
  569. @symbolic_helper.parse_args("v", "i", "i", "i", "i")
  570. @_beartype.beartype
  571. def unique_dim(
  572. g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts
  573. ):
  574. u, indices, inverse_indices, counts = g.op(
  575. "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4
  576. )
  577. return u, inverse_indices, counts
  578. @_onnx_symbolic("aten::topk")
  579. @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
  580. @_beartype.beartype
  581. def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
  582. return symbolic_helper._topk_helper(
  583. g, self, k, dim, largest=largest, sorted=sorted, out=out
  584. )
  585. @_onnx_symbolic("aten::sort")
  586. @symbolic_helper.parse_args("v", "i", "i", "none")
  587. @_beartype.beartype
  588. def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
  589. return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
  590. @_onnx_symbolic("aten::argsort")
  591. @symbolic_helper.parse_args("v", "i", "i", "none")
  592. @_beartype.beartype
  593. def argsort(g: jit_utils.GraphContext, self, dim, decending, out=None):
  594. _, indices = symbolic_helper._sort_helper(
  595. g, self, dim, decending=decending, out=out
  596. )
  597. return indices
  598. @_onnx_symbolic("aten::round")
  599. @_beartype.beartype
  600. def round(g: jit_utils.GraphContext, self):
  601. return g.op("Round", self)
  602. @_onnx_symbolic("aten::remainder")
  603. @_beartype.beartype
  604. def remainder(g: jit_utils.GraphContext, input, other):
  605. if symbolic_helper._is_fp(input) or symbolic_helper._is_fp(other):
  606. return opset9.remainder(g, input, other)
  607. return g.op("Mod", input, other, fmod_i=0)
  608. @_onnx_symbolic("aten::split")
  609. @symbolic_helper.parse_args("v", "v", "i", "i")
  610. @_beartype.beartype
  611. def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
  612. if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
  613. split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
  614. if _outputs is None:
  615. return split_out
  616. # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
  617. if (
  618. symbolic_helper._is_packed_list(split_size_or_sizes)
  619. and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
  620. ):
  621. split_sizes = [
  622. symbolic_helper._unsqueeze_helper(g, v, [0])
  623. for v in symbolic_helper._unpack_list(split_size_or_sizes)
  624. ]
  625. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  626. axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
  627. res = []
  628. for i in range(_outputs):
  629. end = g.op(
  630. "Add", start, split_sizes[i]
  631. ) # split_sizes is a list of same length as _outputs
  632. res.append(g.op("Slice", self, start, end, axis))
  633. start = end
  634. return res
  635. return [
  636. g.op(
  637. "SequenceAt",
  638. split_out,
  639. g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
  640. )
  641. for i in range(_outputs)
  642. ]
  643. else:
  644. return opset9.split(g, self, split_size_or_sizes, dim, _outputs)
  645. @_onnx_symbolic("aten::split_with_sizes")
  646. @symbolic_helper.parse_args("v", "v", "i", "i")
  647. @_beartype.beartype
  648. def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
  649. return split(g, self, split_sizes, dim, _outputs)
  650. @_onnx_symbolic("aten::unbind")
  651. @symbolic_helper.parse_args("v", "i", "i")
  652. @_beartype.beartype
  653. def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
  654. if _outputs is None:
  655. return g.op(
  656. "SplitToSequence",
  657. self,
  658. g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
  659. axis_i=dim,
  660. keepdims_i=0,
  661. )
  662. else:
  663. return opset9.unbind(g, self, dim, _outputs)
  664. @_beartype.beartype
  665. def _prepare_onnx_paddings(g: jit_utils.GraphContext, input, pad):
  666. """Generate paddings in ONNX order based on pad in pytorch.
  667. Args:
  668. input: the input tensor.
  669. pad: the paddings in pytorch.
  670. The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
  671. where m is in range [0, n].
  672. """
  673. if (
  674. not symbolic_helper._is_packed_list(pad)
  675. and symbolic_helper._is_list(pad)
  676. and symbolic_helper._is_scalar_list(pad)
  677. ):
  678. pad = g.op("ConcatFromSequence", pad, axis_i=0, new_axis_i=1)
  679. # The desired order of paddings is
  680. # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
  681. # n is the dimension of input.
  682. # Assume zero-dimensions in the beginning, pad the "pad" sequence with zeros in the beginning
  683. pad_len = opset9.size(g, pad, g.op("Constant", value_t=torch.tensor([0])))
  684. # Set extension = [0] * (dim * 2 - len(pad))
  685. rank = symbolic_helper._get_tensor_rank(input)
  686. if rank is None:
  687. rank = g.op("Size", g.op("Shape", input))
  688. else:
  689. rank = g.op("Constant", value_t=torch.tensor(rank, dtype=torch.int64))
  690. extension = g.op(
  691. "Sub",
  692. g.op("Mul", rank, g.op("Constant", value_t=torch.tensor(2, dtype=torch.int64))),
  693. pad_len,
  694. )
  695. # Concat pad with extension: paddings = [dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, 0, 0, ... ]
  696. # Currently ONNX only supports int64 type for Pad
  697. pad = g.op("Cast", pad, to_i=_C_onnx.TensorProtoDataType.INT64)
  698. paddings = g.op(
  699. "Concat",
  700. pad,
  701. g.op(
  702. "ConstantOfShape", extension, value_t=torch.tensor([0], dtype=torch.int64)
  703. ),
  704. axis_i=0,
  705. )
  706. # Reshape and reverse order and collate first beginnings and then ends
  707. # paddings = [[..., 0, dim_n-1_begin, dim_n_begin],
  708. # [..., 0, dim_n-1_end, dim_n_end]]
  709. # Reshape back to 1-D paddings = [..., 0, dim_n - 1_begin, dim_n_begin, ..., 0, dim_n - 1_end, dim_n_end]
  710. paddings = symbolic_helper._reshape_helper(
  711. g, paddings, g.op("Constant", value_t=torch.tensor([-1, 2]))
  712. )
  713. paddings = g.op("Transpose", opset10.flip(g, paddings, [0]), perm_i=[1, 0])
  714. paddings = symbolic_helper._reshape_helper(
  715. g, paddings, g.op("Constant", value_t=torch.tensor([-1]))
  716. )
  717. padding_c = g.op("Cast", paddings, to_i=_C_onnx.TensorProtoDataType.INT64)
  718. return padding_c
  719. @_onnx_symbolic("aten::constant_pad_nd")
  720. @_beartype.beartype
  721. def constant_pad_nd(g: jit_utils.GraphContext, input, padding, value=None):
  722. mode = "constant"
  723. value = symbolic_helper._maybe_get_scalar(value)
  724. value = symbolic_helper._if_scalar_type_as(value, input)
  725. pad = _prepare_onnx_paddings(g, input, padding)
  726. return g.op("Pad", input, pad, value, mode_s=mode)
  727. @_onnx_symbolic("aten::reflection_pad1d")
  728. @_onnx_symbolic("aten::reflection_pad2d")
  729. @_onnx_symbolic("aten::reflection_pad3d")
  730. @_beartype.beartype
  731. def reflection_pad(g: jit_utils.GraphContext, input, padding):
  732. mode = "reflect"
  733. paddings = _prepare_onnx_paddings(g, input, padding)
  734. return g.op("Pad", input, paddings, mode_s=mode)
  735. @_onnx_symbolic("aten::replication_pad1d")
  736. @_onnx_symbolic("aten::replication_pad2d")
  737. @_onnx_symbolic("aten::replication_pad3d")
  738. @_beartype.beartype
  739. def replication_pad(g: jit_utils.GraphContext, input, padding):
  740. mode = "edge"
  741. paddings = _prepare_onnx_paddings(g, input, padding)
  742. return g.op("Pad", input, paddings, mode_s=mode)
  743. @_onnx_symbolic("aten::pad")
  744. @_beartype.beartype
  745. def pad(
  746. g: jit_utils.GraphContext,
  747. input: _C.Value,
  748. pad: _C.Value,
  749. mode: _C.Value,
  750. value: _C.Value,
  751. ):
  752. mode = symbolic_helper._parse_arg(mode, "s")
  753. if mode == "replicate":
  754. return replication_pad(g, input, pad)
  755. elif mode == "reflect":
  756. return reflection_pad(g, input, pad)
  757. elif mode == "constant":
  758. return constant_pad_nd(g, input, pad, value)
  759. elif mode == "circular":
  760. return opset9._pad_circular(g, input, pad)
  761. else:
  762. raise errors.SymbolicValueError(f"Unrecognized padding mode {mode}", input)
  763. @_onnx_symbolic("aten::linalg_det")
  764. @_beartype.beartype
  765. def linalg_det(g: jit_utils.GraphContext, self):
  766. return g.op("Det", self)
  767. @_onnx_symbolic("aten::logdet")
  768. @_beartype.beartype
  769. def logdet(g: jit_utils.GraphContext, input):
  770. return opset9.log(g, linalg_det(g, input))
  771. @_onnx_symbolic("aten::arange")
  772. @_beartype.beartype
  773. def arange(g: jit_utils.GraphContext, *args):
  774. def _get_arange_dtype(dtype):
  775. dtype = symbolic_helper._maybe_get_const(dtype, "i")
  776. return dtype
  777. if len(args) == 2 and all(map(lambda val: isinstance(val, int), args)):
  778. # aten::arange(Scalar start, Scalar end)
  779. dtype = torch.int64
  780. # Start index.
  781. start = g.op(
  782. "Constant",
  783. value_t=torch.tensor(args[0], dtype=dtype),
  784. )
  785. # End (exclusive) index.
  786. end = g.op(
  787. "Constant",
  788. value_t=torch.tensor(args[1], dtype=dtype),
  789. )
  790. # Step size from start to end indexes.
  791. delta_default = g.op(
  792. "Constant",
  793. value_t=torch.tensor(1, dtype=dtype),
  794. )
  795. return g.op("Range", start, end, delta_default)
  796. elif len(args) == 2 or len(args) == 5:
  797. if len(args) == 2:
  798. # aten::arange(Scalar end, Tensor out)
  799. dtype = None
  800. else:
  801. # aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  802. dtype = _get_arange_dtype(args[1])
  803. type_, end, start, step = symbolic_helper._arange_cast_helper(
  804. g, end=args[0], dtype=dtype
  805. )
  806. start_default = g.op(
  807. "Constant",
  808. value_t=torch.tensor(0, dtype=type_.dtype()),
  809. )
  810. delta_default = g.op(
  811. "Constant",
  812. value_t=torch.tensor(1, dtype=type_.dtype()),
  813. )
  814. return g.op("Range", start_default, end, delta_default)
  815. elif len(args) == 4 or len(args) == 7:
  816. if len(args) == 4:
  817. # aten::arange(Scalar start, Scalar end, Scalar step, Tensor out)
  818. dtype = None
  819. else:
  820. # aten::arange(Scalar start, Scalar end, Scalar step, ScalarType dtype, Layout, Device, bool pin_memory)
  821. dtype = _get_arange_dtype(args[3])
  822. _, end, start, step = symbolic_helper._arange_cast_helper(
  823. g, start=args[0], end=args[1], step=args[2], dtype=dtype
  824. )
  825. return g.op("Range", start, end, step)
  826. elif len(args) == 6:
  827. # aten::arange(Scalar start, Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
  828. dtype = _get_arange_dtype(args[2])
  829. type_, end, start, step = symbolic_helper._arange_cast_helper(
  830. g, start=args[0], end=args[1], dtype=dtype
  831. )
  832. delta_default = g.op(
  833. "Constant",
  834. value_t=torch.tensor(1, dtype=type_.dtype()),
  835. )
  836. return g.op("Range", start, end, delta_default)
  837. else:
  838. return symbolic_helper._unimplemented(
  839. "aten::arange", f"with {len(args)} arguments"
  840. )
  841. @_onnx_symbolic("aten::_dim_arange")
  842. @symbolic_helper.parse_args("v", "i")
  843. @_beartype.beartype
  844. def _dim_arange(g: jit_utils.GraphContext, like, dim):
  845. like_shape = g.op("Shape", like)
  846. stop = g.op(
  847. "Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
  848. )
  849. if symbolic_helper.is_caffe2_aten_fallback():
  850. return g.op("_caffe2::Range", stop)
  851. return arange(g, stop, 4, None, None, None)
  852. @_onnx_symbolic("aten::size")
  853. @_beartype.beartype
  854. def size(g: jit_utils.GraphContext, self, dim=None):
  855. if dim is None:
  856. return g.op("Shape", self)
  857. return symbolic_helper._size_helper(g, self, dim)
  858. @_onnx_symbolic("aten::squeeze")
  859. @_beartype.beartype
  860. def squeeze(g: jit_utils.GraphContext, self, dim=None):
  861. if dim is None:
  862. return g.op("Squeeze", self)
  863. # dim as a tensor
  864. if not symbolic_helper._is_constant(dim):
  865. return symbolic_helper._squeeze_helper(g, self, [dim])
  866. dim = symbolic_helper._get_const(dim, "i", "dim")
  867. input_rank = symbolic_helper._get_tensor_rank(self)
  868. adjusted_dim = dim
  869. if input_rank is not None and dim < 0:
  870. adjusted_dim += input_rank
  871. dim_size = symbolic_helper._get_tensor_dim_size(self, adjusted_dim)
  872. if (dim < 0 and input_rank is None) or dim_size is None:
  873. # If onnx shape inference is not on, export always as dynamic.
  874. # Because we cannot tell if observed static shape is also static at runtime.
  875. # create "cond" node (condition is shape[i]==1)
  876. dim_constant = g.op("Constant", value_t=torch.tensor([dim]))
  877. size = symbolic_helper._size_helper(g, self, dim_constant)
  878. const_one = g.op("Constant", value_t=torch.ones(1, dtype=torch.int64))
  879. cond = g.op("Equal", size, const_one)
  880. # create the "If" node and add the "then" and "else" blocks to it.
  881. if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
  882. g, "If", cond, n_blocks=2
  883. )
  884. squeeze_ = symbolic_helper._squeeze_helper(if_context, self, [dim])
  885. utils._add_output_to_block(if_context.block, squeeze_)
  886. identity_ = else_context.op("Identity", self)
  887. utils._add_output_to_block(else_context.block, identity_)
  888. return if_op
  889. # For static input shape
  890. dim = adjusted_dim
  891. if dim_size > 1:
  892. warnings.warn(
  893. "This model contains a squeeze operation on dimension "
  894. + str(dim)
  895. + ". The size of "
  896. + "this dimension in the given input is "
  897. + str(dim_size)
  898. + ". The model will "
  899. + "be exported without the squeeze node. If the model is intended to be used with dynamic "
  900. + "input shapes, please export with dynamic_axes argument."
  901. )
  902. return self
  903. return symbolic_helper._squeeze_helper(g, self, [dim])
  904. @_onnx_symbolic("aten::unsqueeze")
  905. @_beartype.beartype
  906. def unsqueeze(g: jit_utils.GraphContext, self, dim):
  907. if symbolic_helper._is_constant(dim):
  908. dim = symbolic_helper._get_const(dim, "i", "dim")
  909. return symbolic_helper._unsqueeze_helper(g, self, [dim])
  910. @_onnx_symbolic("aten::mm")
  911. @_beartype.beartype
  912. def mm(g: jit_utils.GraphContext, self, other):
  913. return g.op("Gemm", self, other, beta_f=0.0, alpha_f=1.0)
  914. @_onnx_symbolic("aten::index")
  915. @_beartype.beartype
  916. def index(g: jit_utils.GraphContext, self, index):
  917. if symbolic_helper.is_caffe2_aten_fallback():
  918. return g.at("index", self, index, overload_name="Tensor")
  919. if symbolic_helper._is_packed_list(index):
  920. indices = symbolic_helper._unpack_list(index)
  921. else:
  922. indices = [index]
  923. # Handle single mask index.
  924. if len(indices) == 1:
  925. index = indices[0]
  926. if not symbolic_helper._is_none(index) and (
  927. symbolic_helper._is_bool(index)
  928. or _type_utils.JitScalarType.from_value(index)
  929. == _type_utils.JitScalarType.UINT8
  930. ):
  931. index = opset9.nonzero(g, index)
  932. return g.op("GatherND", self, index)
  933. return opset9.index(g, self, index)
  934. @_onnx_symbolic("aten::index_fill")
  935. @_beartype.beartype
  936. def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
  937. dim_value = symbolic_helper._parse_arg(dim, "i")
  938. if symbolic_helper.is_caffe2_aten_fallback():
  939. return g.at(
  940. "index_fill",
  941. self,
  942. index,
  943. value,
  944. overload_name="int_Scalar",
  945. dim_i=dim_value,
  946. )
  947. expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  948. g, self, dim, index
  949. )
  950. value = symbolic_helper._maybe_get_scalar(value)
  951. value = symbolic_helper._if_scalar_type_as(value, self)
  952. expanded_value = opset9.expand(g, value, expanded_index_shape, None)
  953. return scatter(g, self, dim, expanded_index, expanded_value)
  954. @_onnx_symbolic("aten::index_copy")
  955. @_beartype.beartype
  956. def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
  957. dim_value = symbolic_helper._parse_arg(dim, "i")
  958. if symbolic_helper.is_caffe2_aten_fallback():
  959. return g.at("index_copy", self, index, source, dim_i=dim_value)
  960. expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
  961. g, self, dim, index
  962. )
  963. return scatter(g, self, dim, expanded_index, source)
  964. @_onnx_symbolic("aten::__rshift_")
  965. @_beartype.beartype
  966. def __rshift_(g: jit_utils.GraphContext, self, other):
  967. # make sure to cast other to self's type
  968. # (when self is long, make sure that other is not float)
  969. if _type_utils.JitScalarType.from_value(
  970. other, _type_utils.JitScalarType.UNDEFINED
  971. ) != _type_utils.JitScalarType.from_value(self):
  972. other = g.op(
  973. "Cast",
  974. other,
  975. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  976. )
  977. if (
  978. _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
  979. == _type_utils.JitScalarType.UINT8
  980. ):
  981. return g.op("BitShift", self, other, direction_s="RIGHT")
  982. two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
  983. # exponent (same type as self) has to be float or double in onnx::Pow
  984. if not symbolic_helper._is_fp(self):
  985. other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  986. two_pow = g.op("Pow", two, other)
  987. two_pow = g.op(
  988. "Cast",
  989. two_pow,
  990. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  991. )
  992. rshift = g.op("Div", self, two_pow)
  993. return rshift
  994. @_onnx_symbolic("aten::__lshift_")
  995. @_beartype.beartype
  996. def __lshift_(g: jit_utils.GraphContext, self, other):
  997. # make sure to cast other to self's type
  998. # (when self is long, make sure that other is not float)
  999. if _type_utils.JitScalarType.from_value(
  1000. other, _type_utils.JitScalarType.UNDEFINED
  1001. ) != _type_utils.JitScalarType.from_value(self):
  1002. other = g.op(
  1003. "Cast",
  1004. other,
  1005. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  1006. )
  1007. if (
  1008. _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
  1009. == _type_utils.JitScalarType.UINT8
  1010. ):
  1011. return g.op("BitShift", self, other, direction_s="LEFT")
  1012. two = g.op("Constant", value_t=torch.tensor(2, dtype=torch.float32))
  1013. # exponent (same type as self) has to be float or double in onnx::Pow
  1014. if not symbolic_helper._is_fp(self):
  1015. other = g.op("Cast", other, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  1016. two_pow = g.op("Pow", two, other)
  1017. two_pow = g.op(
  1018. "Cast",
  1019. two_pow,
  1020. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  1021. )
  1022. lshift = g.op("Mul", self, two_pow)
  1023. return lshift
  1024. @_beartype.beartype
  1025. def _get_im2col_indices_along_dim(
  1026. g: jit_utils.GraphContext, input_d, kernel_size_d, dilation_d, padding_d, stride_d
  1027. ):
  1028. # Input is always 4-D (N, C, H, W)
  1029. # Calculate indices of sliding blocks along spatial dimension
  1030. # Slide kernel over input each dim d:
  1031. # each dimension d ranges from 0 to input[d]+2xpadding[d]-dilation[d]x(kernel_size[d]-1)
  1032. # with steps = stride
  1033. blocks_d = g.op(
  1034. "Add", input_d, g.op("Constant", value_t=torch.tensor(padding_d * 2))
  1035. )
  1036. blocks_d = g.op(
  1037. "Sub",
  1038. blocks_d,
  1039. g.op("Constant", value_t=torch.tensor(dilation_d * (kernel_size_d - 1))),
  1040. )
  1041. # Stride kernel over input and find starting indices along dim d
  1042. blocks_d_indices = g.op(
  1043. "Range",
  1044. g.op("Constant", value_t=torch.tensor(0)),
  1045. blocks_d,
  1046. g.op("Constant", value_t=torch.tensor(stride_d)),
  1047. )
  1048. # Apply dilation on kernel and find its indices along dim d
  1049. kernel_grid = torch.arange(0, kernel_size_d * dilation_d, dilation_d)
  1050. kernel_grid = g.op("Constant", value_t=kernel_grid.unsqueeze(0))
  1051. # Broadcast and add kernel staring positions (indices) with
  1052. # kernel_grid along dim d, to get block indices along dim d
  1053. blocks_d_indices = symbolic_helper._unsqueeze_helper(
  1054. g, blocks_d_indices, [0]
  1055. ) # Reshape to [1, -1]
  1056. kernel_mask = symbolic_helper._reshape_helper(
  1057. g, kernel_grid, g.op("Constant", value_t=torch.tensor([-1, 1]))
  1058. )
  1059. block_mask = g.op("Add", blocks_d_indices, kernel_mask)
  1060. return block_mask
  1061. @_beartype.beartype
  1062. def _get_im2col_padded_input(g: jit_utils.GraphContext, input, padding_h, padding_w):
  1063. # Input is always 4-D tensor (N, C, H, W)
  1064. # Padding tensor has the following format: (padding_h, padding_w)
  1065. # Reshape the padding to follow ONNX format: (dim1_begin, dim2_begin,...,dim1_end, dim2_end,...)
  1066. pad = g.op("Constant", value_t=torch.LongTensor([0, 0, padding_h, padding_w] * 2))
  1067. return g.op("Pad", input, pad)
  1068. @_beartype.beartype
  1069. def _get_im2col_output_shape(g: jit_utils.GraphContext, input, kernel_h, kernel_w):
  1070. batch_dim = size(g, input, g.op("Constant", value_t=torch.tensor(0)))
  1071. channel_dim = size(g, input, g.op("Constant", value_t=torch.tensor(1)))
  1072. channel_unfolded = g.op(
  1073. "Mul", channel_dim, g.op("Constant", value_t=torch.tensor(kernel_h * kernel_w))
  1074. )
  1075. return g.op(
  1076. "Concat",
  1077. symbolic_helper._unsqueeze_helper(g, batch_dim, [0]),
  1078. symbolic_helper._unsqueeze_helper(g, channel_unfolded, [0]),
  1079. g.op("Constant", value_t=torch.tensor([-1])),
  1080. axis_i=0,
  1081. )
  1082. @_onnx_symbolic("aten::im2col")
  1083. @symbolic_helper.parse_args("v", "is", "is", "is", "is")
  1084. @_beartype.beartype
  1085. def im2col(g: jit_utils.GraphContext, input, kernel_size, dilation, padding, stride):
  1086. # Input is always 4-D tensor (N, C, H, W)
  1087. # All other args are int[2]
  1088. input_h = size(g, input, g.op("Constant", value_t=torch.tensor(2)))
  1089. input_w = size(g, input, g.op("Constant", value_t=torch.tensor(3)))
  1090. stride_h, stride_w = stride[0], stride[1]
  1091. padding_h, padding_w = padding[0], padding[1]
  1092. dilation_h, dilation_w = dilation[0], dilation[1]
  1093. kernel_h, kernel_w = kernel_size[0], kernel_size[1]
  1094. blocks_row_indices = _get_im2col_indices_along_dim(
  1095. g, input_h, kernel_h, dilation_h, padding_h, stride_h
  1096. )
  1097. blocks_col_indices = _get_im2col_indices_along_dim(
  1098. g, input_w, kernel_w, dilation_w, padding_w, stride_w
  1099. )
  1100. output_shape = _get_im2col_output_shape(g, input, kernel_h, kernel_w)
  1101. padded_input = _get_im2col_padded_input(g, input, padding_h, padding_w)
  1102. # For a 4D matrix of size (1, 1, 3, 3) as below with kernel_size=2, stride=1, and dilation=1
  1103. # [[[[1., 2., 3.,],
  1104. # [4., 5., 6.,],
  1105. # [7., 8., 9.,]]]]
  1106. # First gather indices along rows (dim=2) with blocks_row_indices = [[0,1], [1,2]] to get:
  1107. # [[[[[1., 2., 3.],
  1108. # [4., 5., 6.]],
  1109. # [[4., 5., 6.],
  1110. # [7., 8., 9.]]]]]
  1111. # And then gather along cols (dim=4) with blocks_row_indices = [[0,1], [1,2]] to get:
  1112. # [[[[[[1., 2.],
  1113. # [4., 5.]],
  1114. # [[2., 3.],
  1115. # [5., 6]]],
  1116. # [[[4., 5.],
  1117. # [7., 8.]],
  1118. # [[5., 6.],
  1119. # [8., 9.]]]]]]
  1120. # Transpose dims 3 (depth) and 4 (rows), and then reshape to output shape (1, 1, 4, 4) to get:
  1121. # [[[1., 2., 4., 5.],
  1122. # [2., 3., 5., 6.],
  1123. # [4., 5., 7., 8.],
  1124. # [5., 6., 8., 9.]]]
  1125. output = g.op("Gather", padded_input, blocks_row_indices, axis_i=2)
  1126. output = g.op("Gather", output, blocks_col_indices, axis_i=4)
  1127. output = g.op("Transpose", output, perm_i=[0, 1, 2, 4, 3, 5])
  1128. return symbolic_helper._reshape_helper(g, output, output_shape)
  1129. @_onnx_symbolic("aten::narrow")
  1130. @_beartype.beartype
  1131. def narrow(g: jit_utils.GraphContext, input, dim, start, length):
  1132. end = g.op("Add", start, length)
  1133. return symbolic_helper._slice_helper(
  1134. g, input, axes=dim, starts=start, ends=end, dynamic_slice=True
  1135. )
  1136. @_onnx_symbolic("aten::flatten")
  1137. @symbolic_helper.quantized_args(True, False, False)
  1138. @symbolic_helper.parse_args("v", "i", "i")
  1139. @_beartype.beartype
  1140. def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
  1141. dim = symbolic_helper._get_tensor_rank(input)
  1142. if dim == 1:
  1143. return input
  1144. # use ONNX's Flatten operator for cases where the output shape is 2D
  1145. if start_dim == 1:
  1146. if end_dim == -1 or (dim is not None and end_dim == dim - 1):
  1147. return g.op("Flatten", input, axis_i=start_dim)
  1148. elif start_dim == 0:
  1149. if end_dim == -2 or (dim is not None and end_dim == dim - 2):
  1150. return g.op("Flatten", input, axis_i=end_dim + 1)
  1151. if dim is None:
  1152. return symbolic_helper._unimplemented(
  1153. "dim",
  1154. "ONNX and PyTorch use different strategies to split the input. "
  1155. "Input rank must be known at export time.",
  1156. )
  1157. # if end_dim is negative add dim
  1158. if end_dim < 0:
  1159. end_dim = dim + end_dim
  1160. return symbolic_helper._flatten_helper(g, input, start_dim, end_dim, dim)
  1161. @_onnx_symbolic("aten::linalg_vector_norm")
  1162. @symbolic_helper.parse_args("v", "f", "is", "b", "v")
  1163. @_beartype.beartype
  1164. def linalg_vector_norm(
  1165. g: jit_utils.GraphContext,
  1166. self,
  1167. ord,
  1168. dim: Optional[Sequence[int]],
  1169. keepdim: bool,
  1170. dtype,
  1171. ):
  1172. if ord == 0:
  1173. if dim is None:
  1174. self = symbolic_helper._reshape_helper(
  1175. g, self, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
  1176. )
  1177. keepdim = False
  1178. cond_op = g.op(
  1179. "Not", g.op("Equal", self, g.op("Constant", value_t=torch.LongTensor([0])))
  1180. )
  1181. cond_op = g.op(
  1182. "Cast",
  1183. cond_op,
  1184. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  1185. )
  1186. return symbolic_helper._reducesum_helper(
  1187. g, cond_op, axes_i=dim, keepdims_i=keepdim
  1188. )
  1189. else:
  1190. return opset9.linalg_vector_norm(g, self, ord, dim, keepdim, dtype)
  1191. @_onnx_symbolic("aten::embedding_bag")
  1192. @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
  1193. @_beartype.beartype
  1194. def embedding_bag(
  1195. g: jit_utils.GraphContext,
  1196. embedding_matrix,
  1197. indices,
  1198. offsets,
  1199. scale_grad_by_freq,
  1200. mode,
  1201. sparse,
  1202. per_sample_weights,
  1203. include_last_offset,
  1204. padding_idx,
  1205. ):
  1206. if scale_grad_by_freq and GLOBALS.export_training:
  1207. return symbolic_helper._onnx_unsupported(
  1208. "embedding_bag with scale_grad_by_freq for training mode"
  1209. )
  1210. if padding_idx is not None and padding_idx >= 0:
  1211. raise RuntimeError("embedding_bag with padding_idx")
  1212. loop_condition = g.op("Constant", value_t=torch.tensor(1))
  1213. loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
  1214. zero = g.op("Constant", value_t=torch.tensor([0]))
  1215. indices_len = symbolic_helper._unsqueeze_helper(
  1216. g,
  1217. symbolic_helper._size_helper(
  1218. g, indices, g.op("Constant", value_t=torch.tensor(0))
  1219. ),
  1220. [0],
  1221. )
  1222. if not include_last_offset:
  1223. offsets = [offsets, indices_len]
  1224. offsets = g.op("Concat", *offsets, axis_i=0)
  1225. # Offsets holds the starting index position of each bag. So we create a list of the indices slices (determined by
  1226. # offsets) and gather those indices in indices_row. Then we use this subset of indices to gather from embeddings.
  1227. # The embeddings output is a loop scan output, so we can avoid creating a sequence and inserting elements in.
  1228. offsets_starts = symbolic_helper._slice_helper(
  1229. g, offsets, axes=[0], starts=[0], ends=[sys.maxsize], steps=[1]
  1230. )
  1231. offsets_ends = symbolic_helper._slice_helper(
  1232. g, offsets, axes=[0], starts=[1], ends=[sys.maxsize], steps=[1]
  1233. )
  1234. loop_len = symbolic_helper._size_helper(
  1235. g, offsets_ends, g.op("Constant", value_t=torch.tensor(0))
  1236. )
  1237. loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
  1238. g, "Loop", loop_len, loop_condition, n_blocks=1
  1239. )
  1240. loop_block = loop_context.block
  1241. # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return
  1242. block_input_iter = utils._add_input_to_block(loop_block)
  1243. cond = utils._add_input_to_block(loop_block)
  1244. indices_start = loop_context.op(
  1245. "Gather", offsets_starts, block_input_iter, axis_i=0
  1246. )
  1247. indices_end = loop_context.op("Gather", offsets_ends, block_input_iter, axis_i=0)
  1248. indices_start = symbolic_helper._unsqueeze_helper(loop_context, indices_start, [0])
  1249. indices_end = symbolic_helper._unsqueeze_helper(loop_context, indices_end, [0])
  1250. indices_row = loop_context.op("Slice", indices, indices_start, indices_end, zero)
  1251. embeddings = loop_context.op("Gather", embedding_matrix, indices_row, axis_i=0)
  1252. if not symbolic_helper._is_none(per_sample_weights):
  1253. per_sample_weights_row = loop_context.op(
  1254. "Slice", per_sample_weights, indices_start, indices_end, zero
  1255. )
  1256. per_sample_weights_row = symbolic_helper._unsqueeze_helper(
  1257. loop_context, per_sample_weights_row, [1]
  1258. )
  1259. embeddings = loop_context.op("Mul", embeddings, per_sample_weights_row)
  1260. if mode == 0:
  1261. embeddings = symbolic_helper._reducesum_helper(
  1262. loop_context, embeddings, axes_i=[0], keepdims_i=0
  1263. )
  1264. elif mode == 1:
  1265. embeddings = loop_context.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
  1266. else:
  1267. embeddings = loop_context.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
  1268. cond_out = loop_context.op(
  1269. "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
  1270. )
  1271. utils._add_output_to_block(loop_block, cond_out)
  1272. utils._add_output_to_block(loop_block, embeddings)
  1273. # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
  1274. # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
  1275. return loop.node().output(), None, None, None
  1276. @_onnx_symbolic("aten::embedding_renorm")
  1277. @symbolic_helper.parse_args("v", "v", "f", "f")
  1278. @_beartype.beartype
  1279. def embedding_renorm(g: jit_utils.GraphContext, weight, indices, max_norm, norm_type):
  1280. unique_indices = g.op("Unique", indices)
  1281. partial_weight = g.op("Gather", weight, unique_indices)
  1282. norm_type = int(norm_type)
  1283. if norm_type == 1:
  1284. norm_type = "ReduceL1"
  1285. elif norm_type == 2:
  1286. norm_type = "ReduceL2"
  1287. else:
  1288. raise errors.SymbolicValueError(
  1289. f"Unsupported: ONNX export of embedding_renorm with norm: {norm_type}. "
  1290. "Only 1. and 2. are supported.",
  1291. weight,
  1292. )
  1293. partial_weight_norm = g.op(norm_type, partial_weight, axes_i=[1], keepdims_i=1)
  1294. # https://github.com/pytorch/pytorch/blob/0a07488ed2c47765e337e290bd138c0e6e459cbd/aten/src/ATen/native/Embedding.cpp#L177
  1295. # Add 1e-7 to prevent division by zero.
  1296. partial_weight_norm_ = g.op(
  1297. "Add", partial_weight_norm, g.op("Constant", value_t=torch.tensor(1e-7))
  1298. )
  1299. max_norm = torch.tensor(max_norm)
  1300. scales = g.op("Div", max_norm, partial_weight_norm_)
  1301. partial_weight_renorm = g.op("Mul", partial_weight, scales)
  1302. partial_weight_renorm = g.op(
  1303. "Where",
  1304. g.op("Greater", partial_weight_norm, max_norm),
  1305. partial_weight_renorm,
  1306. partial_weight,
  1307. )
  1308. return g.op(
  1309. "ScatterND",
  1310. weight,
  1311. symbolic_helper._unsqueeze_helper(g, unique_indices, [1]),
  1312. partial_weight_renorm,
  1313. )
  1314. @_onnx_symbolic("aten::chunk")
  1315. @_beartype.beartype
  1316. def chunk(g: jit_utils.GraphContext, self, chunks, dim):
  1317. # Calculate chunk size for dynamic chunk
  1318. dim_size = g.op("Gather", g.op("Shape", self), dim, axis_i=0)
  1319. chunk_size_s = g.op(
  1320. "Sub", chunks, g.op("Constant", value_t=torch.tensor([1], dtype=torch.long))
  1321. )
  1322. chunk_size = g.op("Div", g.op("Add", dim_size, chunk_size_s), chunks)
  1323. # Create splits vector
  1324. chunk_vec = [
  1325. opset9.expand(g, chunk_size, chunk_size_s, None),
  1326. g.op("Sub", dim_size, g.op("Mul", chunk_size, chunk_size_s)),
  1327. ]
  1328. chunk_vec = g.op("Concat", *chunk_vec, axis_i=0)
  1329. return split(g, self, chunk_vec, dim)
  1330. @_onnx_symbolic("aten::normal")
  1331. @_beartype.beartype
  1332. def normal(
  1333. g: jit_utils.GraphContext,
  1334. mean,
  1335. std,
  1336. sizes=None,
  1337. generator=None,
  1338. dtype=None,
  1339. layout=None,
  1340. device=None,
  1341. pin_memory=None,
  1342. ):
  1343. # If you can sample from a given distribution with mean 0 and variance 1, then you can easily sample from a
  1344. # scale-location transformation of that distribution, which has mean μ and variance σ's square. If x is a sample
  1345. # from a mean 0 and variance 1 distribution then
  1346. # σx+μ
  1347. # is a sample with mean μ and variance σ's square.
  1348. if sizes is not None and not symbolic_helper._is_none(sizes):
  1349. mean = opset9.expand(g, mean, sizes, None)
  1350. result = opset9.mul(g, std, g.op("RandomNormalLike", mean))
  1351. return add(g, result, mean)
  1352. @_onnx_symbolic("prim::ConstantChunk")
  1353. @_beartype.beartype
  1354. def prim_constant_chunk(g: jit_utils.GraphContext, self, chunks, dim):
  1355. input_shape = g.op("Shape", self)
  1356. axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
  1357. input_shape_dim = g.op("Gather", input_shape, axis, axis_i=0)
  1358. start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
  1359. chunk_size = g.op("Constant", value_t=torch.tensor([chunks], dtype=torch.long))
  1360. chunk_size_minus_1 = g.op(
  1361. "Constant", value_t=torch.tensor([chunks - 1], dtype=torch.long)
  1362. )
  1363. input_shape_dim_shift = g.op("Add", input_shape_dim, chunk_size_minus_1)
  1364. chunk_dim = g.op("Div", input_shape_dim_shift, chunk_size)
  1365. res = []
  1366. for i in range(chunks):
  1367. index = g.op("Constant", value_t=torch.tensor([i + 1], dtype=torch.long))
  1368. end = g.op("Mul", chunk_dim, index)
  1369. res.append(g.op("Slice", self, start, end, axis))
  1370. start = end
  1371. return res