123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- """This file exports ONNX ops for opset 14.
- Note [ONNX operators that are added/updated in opset 14]
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- New operators:
- HardSwish, Trilu
- Updated operators:
- Reshape
- Add, Sub, Mul, Div
- GRU, LSTM, RNN
- BatchNorm, Cumsum, Relu
- """
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- import functools
- import torch
- from torch.onnx import symbolic_helper
- from torch.onnx._globals import GLOBALS
- from torch.onnx._internal import _beartype, jit_utils, registration
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
- @_onnx_symbolic("aten::hardswish")
- @symbolic_helper.parse_args("v")
- @_beartype.beartype
- def hardswish(g: jit_utils.GraphContext, self):
- return g.op("HardSwish", self)
- @_onnx_symbolic("aten::tril")
- @_beartype.beartype
- def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
- return g.op("Trilu", self, diagonal, upper_i=0)
- @_onnx_symbolic("aten::triu")
- @_beartype.beartype
- def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
- return g.op("Trilu", self, diagonal, upper_i=1)
- @_onnx_symbolic("aten::reshape")
- @symbolic_helper.parse_args("v", "v")
- @_beartype.beartype
- def reshape(g: jit_utils.GraphContext, self, shape):
- # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
- # Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
- return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
- @_onnx_symbolic("aten::batch_norm")
- @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
- @_beartype.beartype
- def batch_norm(
- g: jit_utils.GraphContext,
- input,
- weight,
- bias,
- running_mean,
- running_var,
- training,
- momentum,
- eps,
- cudnn_enabled,
- ):
- if (
- torch.is_autocast_enabled()
- and not symbolic_helper.args_have_same_dtype(
- [input, weight, bias, running_mean, running_var]
- )
- and GLOBALS.export_onnx_opset_version < 15
- ):
- return symbolic_helper._onnx_opset_unsupported_detailed(
- "BatchNormalization",
- 14,
- 15,
- "All input tensors must have the same `dtype`."
- " Turn off Autocast or export using opset version 15.",
- input,
- )
- symbolic_helper.check_training_mode(training, "batch_norm")
- weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
- g, input, weight, bias, running_mean, running_var
- )
- out = g.op(
- "BatchNormalization",
- input,
- weight,
- bias,
- running_mean,
- running_var,
- epsilon_f=eps,
- momentum_f=1 - momentum,
- training_mode_i=0 if not training else 1,
- outputs=1 if not training else 3,
- )
- if not training:
- return out
- else:
- res, new_running_mean, new_running_var = out
- new_running_mean.setType(running_mean.type())
- new_running_var.setType(running_var.type())
- return res
- @_onnx_symbolic("quantized::hardswish")
- @_beartype.beartype
- def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- output = hardswish(g, x)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|