symbolic_opset12.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. import functools
  2. import sys
  3. from typing import Optional, Tuple
  4. import torch
  5. from torch._C import _onnx as _C_onnx
  6. from torch.onnx import (
  7. _type_utils,
  8. errors,
  9. symbolic_helper,
  10. symbolic_opset9 as opset9,
  11. utils,
  12. )
  13. from torch.onnx._internal import _beartype, jit_utils, registration
  14. # EDITING THIS FILE? READ THIS FIRST!
  15. # see Note [Edit Symbolic Files] in README.md
  16. # This file exports ONNX ops for opset 12
  17. __all__ = [
  18. "argmax",
  19. "argmin",
  20. "binary_cross_entropy_with_logits",
  21. "celu",
  22. "cross_entropy_loss",
  23. "dropout",
  24. "einsum",
  25. "ge",
  26. "le",
  27. "native_dropout",
  28. "nll_loss",
  29. "nll_loss2d",
  30. "nll_loss_nd",
  31. "outer",
  32. "pow",
  33. "tensordot",
  34. "unfold",
  35. ]
  36. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
  37. @_beartype.beartype
  38. def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
  39. if not tensors:
  40. raise RuntimeError("Einsum inputs are empty.")
  41. # ONNX does not support bool for Einsum inputs.
  42. if symbolic_helper._is_bool(tensors[0]):
  43. tensors = [
  44. g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
  45. for tensor in tensors
  46. ]
  47. return g.op(
  48. "Cast",
  49. g.op("Einsum", *tensors, equation_s=equation),
  50. to_i=_C_onnx.TensorProtoDataType.BOOL,
  51. )
  52. else:
  53. return g.op("Einsum", *tensors, equation_s=equation)
  54. @_onnx_symbolic("aten::einsum")
  55. @symbolic_helper.parse_args("s", "v", "is")
  56. @_beartype.beartype
  57. def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
  58. tensors = symbolic_helper._unpack_list(tensor_list)
  59. return _einsum_helper(g, equation, tensors)
  60. @_onnx_symbolic("aten::outer")
  61. @symbolic_helper.parse_args("v", "v")
  62. @_beartype.beartype
  63. def outer(g: jit_utils.GraphContext, input, other):
  64. # make sure to cast other to self's type
  65. if _type_utils.JitScalarType.from_value(
  66. other, _type_utils.JitScalarType.UNDEFINED
  67. ) != _type_utils.JitScalarType.from_value(input):
  68. other = g.op(
  69. "Cast",
  70. other,
  71. to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
  72. )
  73. return _einsum_helper(g, "i,j->ij", [input, other])
  74. @_beartype.beartype
  75. def _dropout_returns_masked_input_and_mask(
  76. g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
  77. ) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
  78. symbolic_helper.check_training_mode(train, "dropout")
  79. # In eval mode, dropout is non-op. That is, if the node's
  80. # train param is set to False, dropout just returns its inputs.
  81. if not train:
  82. return input, None
  83. p = g.op("Constant", value_t=torch.tensor(p))
  84. t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
  85. r, mask = g.op("Dropout", input, p, t, outputs=2)
  86. return r, mask
  87. @_onnx_symbolic("aten::dropout")
  88. @symbolic_helper.parse_args("v", "f", "b")
  89. @_beartype.beartype
  90. def dropout(g: jit_utils.GraphContext, input, p, train):
  91. masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
  92. return masked
  93. @_onnx_symbolic("aten::native_dropout")
  94. @symbolic_helper.parse_args("v", "f", "b")
  95. @_beartype.beartype
  96. def native_dropout(g: jit_utils.GraphContext, input, p, train):
  97. return _dropout_returns_masked_input_and_mask(g, input, p, train)
  98. @_onnx_symbolic("aten::nll_loss")
  99. @_beartype.beartype
  100. def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
  101. # none reduction : onnx::Constant[value={0}]
  102. # mean reduction : onnx::Constant[value={1}]
  103. # sum reduction : onnx::Constant[value={2}]
  104. reduction = symbolic_helper._maybe_get_const(reduction, "i")
  105. reduction_vals = ["none", "mean", "sum"]
  106. reduction = reduction_vals[reduction]
  107. # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
  108. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
  109. ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
  110. if weight.node().mustBeNone():
  111. nllloss = g.op(
  112. "NegativeLogLikelihoodLoss",
  113. self,
  114. target,
  115. reduction_s=reduction,
  116. ignore_index_i=ignore_index,
  117. )
  118. else:
  119. nllloss = g.op(
  120. "NegativeLogLikelihoodLoss",
  121. self,
  122. target,
  123. weight,
  124. reduction_s=reduction,
  125. ignore_index_i=ignore_index,
  126. )
  127. return nllloss
  128. @_onnx_symbolic("aten::nll_loss2d")
  129. @_beartype.beartype
  130. def nll_loss2d(
  131. g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
  132. ):
  133. return nll_loss(g, self, target, weight, reduction, ignore_index)
  134. @_onnx_symbolic("aten::nll_loss_nd")
  135. @_beartype.beartype
  136. def nll_loss_nd(
  137. g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
  138. ):
  139. return nll_loss(g, self, target, weight, reduction, ignore_index)
  140. @_onnx_symbolic("aten::cross_entropy_loss")
  141. @_beartype.beartype
  142. def cross_entropy_loss(
  143. g: jit_utils.GraphContext,
  144. self,
  145. target,
  146. weight,
  147. reduction,
  148. ignore_index,
  149. label_smoothing,
  150. ):
  151. # none reduction : onnx::Constant[value={0}]
  152. # mean reduction : onnx::Constant[value={1}]
  153. # sum reduction : onnx::Constant[value={2}]
  154. reduction = symbolic_helper._maybe_get_const(reduction, "i")
  155. reduction_vals = ["none", "mean", "sum"]
  156. reduction = reduction_vals[reduction]
  157. label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
  158. if label_smoothing is not None and label_smoothing > 0.0:
  159. raise errors.SymbolicValueError(
  160. "Unsupported: ONNX does not support label_smoothing", self
  161. )
  162. # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
  163. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
  164. ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
  165. if weight.node().mustBeNone():
  166. celoss = g.op(
  167. "SoftmaxCrossEntropyLoss",
  168. self,
  169. target,
  170. reduction_s=reduction,
  171. ignore_index_i=ignore_index,
  172. )
  173. else:
  174. celoss = g.op(
  175. "SoftmaxCrossEntropyLoss",
  176. self,
  177. target,
  178. weight,
  179. reduction_s=reduction,
  180. ignore_index_i=ignore_index,
  181. )
  182. return celoss
  183. @_onnx_symbolic("aten::binary_cross_entropy_with_logits")
  184. @symbolic_helper.parse_args("v", "v", "v", "v", "i")
  185. @_beartype.beartype
  186. def binary_cross_entropy_with_logits(
  187. g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
  188. ):
  189. p = g.op("Constant", value_t=torch.tensor([1]))
  190. sig_x = opset9.sigmoid(g, input)
  191. log_sig_x = opset9.log(g, sig_x)
  192. sub_1_x = opset9.sub(g, p, sig_x)
  193. sub_1_y = opset9.sub(g, p, target)
  194. log_1_x = opset9.log(g, sub_1_x)
  195. if pos_weight is None or symbolic_helper._is_none(pos_weight):
  196. output = opset9.neg(
  197. g,
  198. opset9.add(
  199. g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
  200. ),
  201. )
  202. else:
  203. output = opset9.neg(
  204. g,
  205. opset9.add(
  206. g,
  207. opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
  208. opset9.mul(g, sub_1_y, log_1_x),
  209. ),
  210. )
  211. if weight is not None and not symbolic_helper._is_none(weight):
  212. output = opset9.mul(g, weight, output)
  213. reduction = symbolic_helper._maybe_get_const(reduction, "i")
  214. if reduction == 0:
  215. return output
  216. elif reduction == 1:
  217. return g.op("ReduceMean", output, keepdims_i=0)
  218. elif reduction == 2:
  219. return g.op("ReduceSum", output, keepdims_i=0)
  220. else:
  221. return symbolic_helper._onnx_unsupported(
  222. "binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
  223. input,
  224. )
  225. @_onnx_symbolic("aten::celu")
  226. @_beartype.beartype
  227. def celu(g: jit_utils.GraphContext, self, alpha):
  228. alpha = symbolic_helper._maybe_get_const(alpha, "f")
  229. # if the input is of type double cast it to float
  230. if (
  231. _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
  232. == _type_utils.JitScalarType.DOUBLE
  233. ):
  234. self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  235. out = g.op("Celu", self, alpha_f=alpha)
  236. return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
  237. return g.op("Celu", self, alpha_f=alpha)
  238. @_onnx_symbolic("aten::argmax")
  239. @symbolic_helper.parse_args("v", "v", "b")
  240. @_beartype.beartype
  241. def argmax(
  242. g: jit_utils.GraphContext,
  243. input: torch._C.Value,
  244. dim: torch._C.Value,
  245. keepdim: bool,
  246. ):
  247. return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
  248. @_onnx_symbolic("aten::argmin")
  249. @symbolic_helper.parse_args("v", "v", "b")
  250. @_beartype.beartype
  251. def argmin(
  252. g: jit_utils.GraphContext,
  253. input: torch._C.Value,
  254. dim: torch._C.Value,
  255. keepdim: bool,
  256. ):
  257. return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
  258. @_onnx_symbolic("aten::pow")
  259. @_beartype.beartype
  260. def pow(g: jit_utils.GraphContext, self, exponent):
  261. return g.op("Pow", self, exponent)
  262. @_onnx_symbolic("aten::ge")
  263. @_beartype.beartype
  264. def ge(g: jit_utils.GraphContext, input, other):
  265. return g.op("GreaterOrEqual", input, other)
  266. @_onnx_symbolic("aten::le")
  267. @_beartype.beartype
  268. def le(g: jit_utils.GraphContext, input, other):
  269. return g.op("LessOrEqual", input, other)
  270. @_onnx_symbolic("aten::unfold")
  271. @symbolic_helper.parse_args("v", "i", "v", "v")
  272. @_beartype.beartype
  273. def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
  274. const_size = symbolic_helper._maybe_get_const(size, "i")
  275. const_step = symbolic_helper._maybe_get_const(step, "i")
  276. if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
  277. const_step
  278. ):
  279. return opset9.unfold(g, input, dimension, const_size, const_step)
  280. if symbolic_helper.is_caffe2_aten_fallback():
  281. return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
  282. sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
  283. if sizedim is not None:
  284. low_start = g.op("Constant", value_t=torch.tensor(0))
  285. low_end = g.op("Constant", value_t=torch.tensor(sizedim))
  286. hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
  287. low_indices = g.op("Range", low_start, low_end, step)
  288. hi_indices = g.op("Range", size, hi_end, step)
  289. low_size = symbolic_helper._size_helper(
  290. g, low_indices, g.op("Constant", value_t=torch.tensor(0))
  291. )
  292. hi_size = symbolic_helper._size_helper(
  293. g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
  294. )
  295. ndim = symbolic_helper._get_tensor_rank(input)
  296. assert ndim is not None
  297. perm = list(range(0, ndim))
  298. perm.append(perm.pop(dimension))
  299. unsqueeze_list = []
  300. loop_condition = g.op("Constant", value_t=torch.tensor(1))
  301. loop_condition = g.op(
  302. "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
  303. )
  304. loop_len = g.op("Min", low_size, hi_size)
  305. loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
  306. g, "Loop", loop_len, loop_condition, n_blocks=1
  307. )
  308. loop_block = loop_context.block
  309. block_input_iter = utils._add_input_to_block(loop_block)
  310. # FIXME(justinchuby): cond is unused?
  311. cond = utils._add_input_to_block(loop_block)
  312. starts = loop_context.op("Gather", low_indices, block_input_iter)
  313. ends = loop_context.op("Gather", hi_indices, block_input_iter)
  314. axes = loop_context.op("Constant", value_t=torch.tensor([2]))
  315. starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
  316. ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
  317. stack = loop_context.op("Slice", input, starts, ends, axes)
  318. unsqueeze = symbolic_helper._unsqueeze_helper(
  319. loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
  320. )
  321. unsqueeze_list.append(unsqueeze)
  322. concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
  323. cond_out = loop_context.op(
  324. "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
  325. )
  326. utils._add_output_to_block(loop_block, cond_out)
  327. utils._add_output_to_block(loop_block, concat)
  328. loop_output = loop.node().output()
  329. perm = [0, 1, 2, 3, 4]
  330. perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
  331. transpose = g.op("Transpose", loop_output, perm_i=perm)
  332. squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
  333. return squeeze
  334. return symbolic_helper._unimplemented("Unfold", "input size not accessible")
  335. @_onnx_symbolic("aten::tensordot")
  336. @symbolic_helper.parse_args("v", "v", "is", "is", "v")
  337. @_beartype.beartype
  338. def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
  339. if out is not None:
  340. symbolic_helper._unimplemented(
  341. "Tensordot", "Out parameter is not supported for tensordot."
  342. )
  343. dim_count_a = symbolic_helper._get_tensor_rank(input_a)
  344. if dim_count_a is None:
  345. raise errors.SymbolicValueError(
  346. "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
  347. input_a,
  348. )
  349. dim_count_b = symbolic_helper._get_tensor_rank(input_b)
  350. if dim_count_b is None:
  351. raise errors.SymbolicValueError(
  352. "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
  353. input_b,
  354. )
  355. dims_a = [
  356. (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
  357. for i in range(len(dims_a))
  358. ]
  359. dims_b = [
  360. (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
  361. for i in range(len(dims_b))
  362. ]
  363. left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
  364. left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
  365. new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
  366. new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
  367. input_shape = g.op("Shape", new_input_a)
  368. left_sizes_a = symbolic_helper._slice_helper(
  369. g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
  370. )
  371. shape_sizes = [
  372. left_sizes_a,
  373. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  374. ]
  375. output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
  376. input_shape = g.op("Shape", output_a)
  377. slices = symbolic_helper._slice_helper(
  378. g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
  379. )
  380. shape_sizes = [
  381. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  382. slices,
  383. ]
  384. output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
  385. input_shape = g.op("Shape", new_input_b)
  386. left_sizes_b = symbolic_helper._slice_helper(
  387. g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
  388. )
  389. slices = symbolic_helper._slice_helper(
  390. g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
  391. )
  392. shape_sizes = [
  393. slices,
  394. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  395. ]
  396. output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
  397. input_shape = g.op("Shape", output_b)
  398. slices = symbolic_helper._slice_helper(
  399. g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
  400. )
  401. shape_sizes = [
  402. g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
  403. slices,
  404. ]
  405. output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
  406. output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
  407. shape_sizes = [left_sizes_a, left_sizes_b]
  408. return opset9._reshape_from_tensor(g, output, shape_sizes)