"""This file exports ONNX ops for opset 14. Note [ONNX operators that are added/updated in opset 14] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ New operators: HardSwish, Trilu Updated operators: Reshape Add, Sub, Mul, Div GRU, LSTM, RNN BatchNorm, Cumsum, Relu """ # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md import functools import torch from torch.onnx import symbolic_helper from torch.onnx._globals import GLOBALS from torch.onnx._internal import _beartype, jit_utils, registration _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14) @_onnx_symbolic("aten::hardswish") @symbolic_helper.parse_args("v") @_beartype.beartype def hardswish(g: jit_utils.GraphContext, self): return g.op("HardSwish", self) @_onnx_symbolic("aten::tril") @_beartype.beartype def tril(g: jit_utils.GraphContext, self, diagonal, out=None): return g.op("Trilu", self, diagonal, upper_i=0) @_onnx_symbolic("aten::triu") @_beartype.beartype def triu(g: jit_utils.GraphContext, self, diagonal, out=None): return g.op("Trilu", self, diagonal, upper_i=1) @_onnx_symbolic("aten::reshape") @symbolic_helper.parse_args("v", "v") @_beartype.beartype def reshape(g: jit_utils.GraphContext, self, shape): # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664 # Reshape export cannot utilize the new allowzero attribute introduced in opset 14. return symbolic_helper._reshape_helper(g, self, shape, allowzero=0) @_onnx_symbolic("aten::batch_norm") @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i") @_beartype.beartype def batch_norm( g: jit_utils.GraphContext, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled, ): if ( torch.is_autocast_enabled() and not symbolic_helper.args_have_same_dtype( [input, weight, bias, running_mean, running_var] ) and GLOBALS.export_onnx_opset_version < 15 ): return symbolic_helper._onnx_opset_unsupported_detailed( "BatchNormalization", 14, 15, "All input tensors must have the same `dtype`." " Turn off Autocast or export using opset version 15.", input, ) symbolic_helper.check_training_mode(training, "batch_norm") weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper( g, input, weight, bias, running_mean, running_var ) out = g.op( "BatchNormalization", input, weight, bias, running_mean, running_var, epsilon_f=eps, momentum_f=1 - momentum, training_mode_i=0 if not training else 1, outputs=1 if not training else 3, ) if not training: return out else: res, new_running_mean, new_running_var = out new_running_mean.setType(running_mean.type()) new_running_var.setType(running_var.type()) return res @_onnx_symbolic("quantized::hardswish") @_beartype.beartype def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point): x, _, _, _ = symbolic_helper.dequantize_helper(g, x) output = hardswish(g, x) return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)