123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888 |
- import functools
- import sys
- import warnings
- from typing import Callable
- import torch
- import torch._C._onnx as _C_onnx
- import torch.onnx
- from torch import _C
- # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
- from torch.onnx import (
- _constants,
- _type_utils,
- errors,
- symbolic_helper,
- symbolic_opset9 as opset9,
- )
- from torch.onnx._globals import GLOBALS
- from torch.onnx._internal import _beartype, jit_utils, registration
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- # This file exports ONNX ops for opset 10
- # Opset 10 is supported by ONNX release 1.5.0
- # release on 04/24/19
- __all__ = [
- "dequantize",
- "div",
- "embedding_bag",
- "fake_quantize_per_tensor_affine",
- "flip",
- "fmod",
- "isfinite",
- "isinf",
- "nan_to_num",
- "quantize_per_tensor",
- "quantized_add_relu",
- "quantized_add",
- "quantized_cat",
- "quantized_conv1d_relu",
- "quantized_conv2d_relu",
- "quantized_conv2d",
- "quantized_group_norm",
- "quantized_hardswish",
- "quantized_instance_norm",
- "quantized_layer_norm",
- "quantized_leaky_relu",
- "quantized_linear",
- "quantized_mul",
- "quantized_sigmoid",
- "slice",
- "sort",
- "topk",
- ]
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=10)
- def _apply_params(*args, **kwargs):
- """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
- def _apply(fn):
- return fn(*args, **kwargs)
- return _apply
- @_onnx_symbolic("aten::div")
- @_beartype.beartype
- def div(g: jit_utils.GraphContext, self, other, *args):
- if len(args) == 0:
- return opset9.true_divide(g, self, other)
- else:
- return _div_rounding_mode(g, self, other, *args)
- @symbolic_helper.parse_args("v", "v", "s")
- @_beartype.beartype
- def _div_rounding_mode(g: jit_utils.GraphContext, self, other, rounding_mode):
- if rounding_mode == "floor":
- return _floor_divide(g, self, other)
- else:
- return opset9._div_rounding_mode(g, self, other, rounding_mode)
- @_onnx_symbolic("aten::_floor_divide")
- @_beartype.beartype
- def _floor_divide(g: jit_utils.GraphContext, self, other):
- if symbolic_helper._is_fp(self) or symbolic_helper._is_fp(other):
- out = opset9.true_divide(g, self, other)
- return g.op("Floor", out)
- else:
- # Integer division does trunction rounding
- div = g.op("Div", self, other)
- # Division is negative if: self < 0 != other < 0
- zero = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
- negative = g.op("Xor", g.op("Less", self, zero), g.op("Less", other, zero))
- # For negative numbers with self % other != 0, subtract 1 to round down instead of up
- mod = g.op("Mod", self, other, fmod_i=0)
- fixup_mask = g.op("And", negative, g.op("Not", g.op("Equal", mod, zero)))
- one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
- fixup = g.op("Sub", div, one)
- return g.op("Where", fixup_mask, fixup, div)
- @_onnx_symbolic("aten::sort")
- @symbolic_helper.parse_args("v", "i", "i", "none")
- @_beartype.beartype
- def sort(g: jit_utils.GraphContext, self, dim, decending, out=None):
- return symbolic_helper._sort_helper(g, self, dim, decending=decending, out=out)
- @_onnx_symbolic("aten::topk")
- @symbolic_helper.parse_args("v", "v", "i", "i", "i", "none")
- @_beartype.beartype
- def topk(g: jit_utils.GraphContext, self, k, dim, largest, sorted, out=None):
- return symbolic_helper._topk_helper(
- g, self, k, dim, largest=largest, sorted=sorted, out=out
- )
- @_onnx_symbolic(
- "aten::max_pool1d",
- decorate=[
- _apply_params(
- "max_pool1d", torch.nn.modules.utils._single, 1, return_indices=False
- )
- ],
- )
- @_onnx_symbolic(
- "aten::max_pool2d",
- decorate=[
- _apply_params(
- "max_pool2d", torch.nn.modules.utils._pair, 2, return_indices=False
- )
- ],
- )
- @_onnx_symbolic(
- "aten::max_pool3d",
- decorate=[
- _apply_params(
- "max_pool3d", torch.nn.modules.utils._triple, 3, return_indices=False
- )
- ],
- )
- @_onnx_symbolic(
- "aten::max_pool1d_with_indices",
- decorate=[
- _apply_params(
- "max_pool1d_with_indices",
- torch.nn.modules.utils._single,
- 1,
- return_indices=True,
- )
- ],
- )
- @_onnx_symbolic(
- "aten::max_pool2d_with_indices",
- decorate=[
- _apply_params(
- "max_pool2d_with_indices",
- torch.nn.modules.utils._pair,
- 2,
- return_indices=True,
- )
- ],
- )
- @_onnx_symbolic(
- "aten::max_pool3d_with_indices",
- decorate=[
- _apply_params(
- "max_pool3d_with_indices",
- torch.nn.modules.utils._triple,
- 3,
- return_indices=True,
- )
- ],
- )
- @_beartype.beartype
- def _max_pool(name: str, tuple_fn: Callable, ndims: int, return_indices: bool):
- @symbolic_helper.quantized_args(True, False, False, False, False, False)
- @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
- def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
- if not stride:
- stride = kernel_size
- kwargs = {
- "kernel_shape_i": tuple_fn(kernel_size),
- "pads_i": tuple_fn(padding) * 2,
- "strides_i": tuple_fn(stride),
- "ceil_mode_i": ceil_mode,
- }
- if set(tuple_fn(dilation)) != {1}:
- kwargs["dilations_i"] = tuple_fn(dilation)
- # easy but hacky way to get flattened indices values
- # to be used to convert the indices values to non-flattened.
- # In ONNX the indices are computed as a flatten 1-D tensor,
- # so the values in indices are in [0, N x C x D1 x ... x Dn).
- # To convert the indices to the same format used by Pytorch,
- # we first execute a maxpool with a kernel and stride of 1 on the same input.
- # This will result in a tensor of indices in which each index will have it's own value.
- # Using this tensor as a reference, we extract the first index of each axis and subtract
- # it from each index of this axis in the indices to convert.
- # This step will result in a tensor were each dimension has values of indices within
- # the dimension it is in.
- # For more information :
- # https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
- if return_indices:
- r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
- _, flattened_indices = g.op(
- "MaxPool",
- input,
- outputs=2,
- kernel_shape_i=[1 for _ in range(ndims)],
- strides_i=[1 for _ in range(ndims)],
- )
- # convert indices to have non-flattened indices values
- s = symbolic_helper._slice_helper(
- g,
- flattened_indices,
- axes=[2 + i for i in range(ndims)],
- starts=tuple_fn(0),
- ends=tuple_fn(1),
- )
- indices = opset9.sub(g, indices, s)
- return r, indices
- else:
- r = g.op("MaxPool", input, outputs=1, **kwargs)
- return r
- return symbolic_fn
- @_onnx_symbolic(
- "aten::avg_pool1d",
- decorate=[_apply_params("avg_pool1d", torch.nn.modules.utils._single)],
- )
- @_onnx_symbolic(
- "aten::avg_pool2d",
- decorate=[_apply_params("avg_pool2d", torch.nn.modules.utils._pair)],
- )
- @_onnx_symbolic(
- "aten::avg_pool3d",
- decorate=[_apply_params("avg_pool3d", torch.nn.modules.utils._triple)],
- )
- @_beartype.beartype
- def _avg_pool(name, tuple_fn):
- # Although onnx::AvgPool provides count_include_pad and ceil_mode,
- # The corner case of Average Pooling with ceil_mode on
- # PyTorch allows sliding window go off bound, which leads to
- # this accommodation.
- # More detail on https://github.com/pytorch/pytorch/issues/57178
- return opset9._avg_pool(name, tuple_fn)
- @_onnx_symbolic(
- "aten::upsample_nearest1d",
- decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
- )
- @_onnx_symbolic(
- "aten::upsample_nearest2d",
- decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
- )
- @_onnx_symbolic(
- "aten::upsample_nearest3d",
- decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
- )
- @_onnx_symbolic(
- "aten::upsample_linear1d",
- decorate=[_apply_params("upsample_linear1d", 3, "linear")],
- )
- @_onnx_symbolic(
- "aten::upsample_bilinear2d",
- decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
- )
- @_onnx_symbolic(
- "aten::upsample_trilinear3d",
- decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
- )
- @_beartype.beartype
- def _interpolate(name, dim, interpolate_mode):
- @symbolic_helper.quantized_args(True, False, False)
- @_beartype.beartype
- def symbolic_fn(g, input, output_size, *args):
- scales, align_corners = symbolic_helper._get_interpolate_attributes(
- g, interpolate_mode, args
- )
- symbolic_helper._interpolate_warning(interpolate_mode)
- align_corners = symbolic_helper._maybe_get_scalar(align_corners)
- if align_corners:
- return symbolic_helper._unimplemented(name, "align_corners == True", input)
- if scales is None:
- scales = symbolic_helper._interpolate_size_to_scales(
- g, input, output_size, dim
- )
- return g.op("Resize", input, scales, mode_s=interpolate_mode)
- return symbolic_fn
- @_onnx_symbolic("aten::__interpolate")
- @_beartype.beartype
- def __interpolate(
- g: jit_utils.GraphContext,
- input,
- size,
- scale_factor,
- mode,
- align_corners,
- recompute_scale_factor,
- antialias,
- ):
- scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
- g, input, size, scale_factor, mode, align_corners
- )
- return g.op("Resize", input, scales, mode_s=mode)
- @_beartype.beartype
- def _slice(
- g: jit_utils.GraphContext,
- input,
- axes,
- starts,
- ends,
- steps=None,
- dynamic_slice=False,
- ):
- if dynamic_slice:
- starts = symbolic_helper._unsqueeze_helper(g, starts, [0])
- ends = symbolic_helper._unsqueeze_helper(g, ends, [0])
- if isinstance(axes, int):
- axes = g.op("Constant", value_t=torch.tensor(axes))
- axes = symbolic_helper._unsqueeze_helper(g, axes, [0])
- else:
- assert len(starts) == len(ends)
- assert len(starts) == len(axes)
- assert steps is None or len(starts) == len(steps)
- if (
- len(starts) == 1
- and starts[0] == 0
- and ends[0] == _constants.INT64_MAX
- and (steps is None or (len(steps) == 1 and steps[0] == 1))
- ):
- return input
- if ends[0] > _constants.INT64_MAX:
- ends[0] = _constants.INT64_MAX
- axes = g.op("Constant", value_t=torch.tensor(axes))
- starts = g.op("Constant", value_t=torch.tensor(starts))
- ends = g.op("Constant", value_t=torch.tensor(ends))
- if steps is None:
- return g.op("Slice", input, starts, ends, axes)
- steps = g.op("Constant", value_t=torch.tensor(steps))
- return g.op("Slice", input, starts, ends, axes, steps)
- @_onnx_symbolic("aten::slice")
- @_beartype.beartype
- def slice(g: jit_utils.GraphContext, self, *args):
- if len(args) == 4:
- # aten::slice(Tensor self, int dim, int? start=None, int? end=None, int step=1) -> Tensor
- dim, start, end, step = args
- elif len(args) == 3:
- # aten::slice(t[] l, int? start=None, int? end=None, int step=1) -> t[]
- start, end, step = args
- dim = 0
- else:
- raise errors.SymbolicValueError("Unknown aten::slice signature", self)
- is_start_none = start.node().kind() == "prim::Constant" and isinstance(
- start.type(), _C.NoneType
- )
- is_end_none = end.node().kind() == "prim::Constant" and isinstance(
- end.type(), _C.NoneType
- )
- is_start_onnx_const = start.node().kind() == "onnx::Constant"
- is_end_onnx_const = end.node().kind() == "onnx::Constant"
- step = symbolic_helper._parse_arg(step, "i")
- if (
- (not is_start_none and not is_start_onnx_const)
- or (not isinstance(end, int) and not is_end_none and not is_end_onnx_const)
- or (not isinstance(dim, int) and dim.node().kind() != "onnx::Constant")
- ):
- dynamic_slice = True
- if is_start_none:
- start = g.op("Constant", value_t=torch.tensor(0))
- if is_end_none:
- end = g.op("Constant", value_t=torch.tensor(_constants.INT64_MAX))
- else:
- start = [0 if is_start_none else symbolic_helper._parse_arg(start, "i")]
- end = [
- _constants.INT64_MAX
- if is_end_none
- else symbolic_helper._parse_arg(end, "i")
- ]
- dim = [symbolic_helper._parse_arg(dim, "i")]
- dynamic_slice = False
- return symbolic_helper._slice_helper(
- g,
- self,
- axes=dim,
- starts=start,
- ends=end,
- steps=[step],
- dynamic_slice=dynamic_slice,
- )
- @_onnx_symbolic("aten::flip")
- @symbolic_helper.parse_args("v", "is")
- @_beartype.beartype
- def flip(g: jit_utils.GraphContext, input, dims):
- return symbolic_helper._slice_helper(
- g,
- input,
- axes=dims,
- starts=[-1] * len(dims),
- ends=[-_constants.INT64_MAX] * len(dims),
- steps=[-1] * len(dims),
- )
- @_onnx_symbolic("aten::fmod")
- @_beartype.beartype
- def fmod(g: jit_utils.GraphContext, input, other):
- return g.op("Mod", input, other, fmod_i=1)
- @_onnx_symbolic("aten::embedding_bag")
- @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i", "v", "i", "i")
- @_beartype.beartype
- def embedding_bag(
- g: jit_utils.GraphContext,
- embedding_matrix,
- indices,
- offsets,
- scale_grad_by_freq,
- mode,
- sparse,
- per_sample_weights,
- include_last_offset,
- padding_idx,
- ):
- if scale_grad_by_freq and GLOBALS.export_training:
- return symbolic_helper._onnx_unsupported(
- "embedding_bag with scale_grad_by_freq for training mode"
- )
- if padding_idx is not None and padding_idx >= 0:
- raise RuntimeError("embedding_bag with padding_idx")
- warnings.warn(
- "Export of embedding_bag with dynamic input/offsets shape is not supported in opset 10. "
- "Please use opset 11 or higher to export model for dynamic input shape.'"
- )
- offsets_dim_0 = symbolic_helper._get_tensor_dim_size(offsets, 0)
- if offsets_dim_0 is not None:
- if include_last_offset:
- offset_len = offsets_dim_0 - 1
- offsets_extended = offsets
- else:
- offset_len = offsets_dim_0
- offsets_extended = [
- offsets,
- g.op("Constant", value_t=torch.tensor([sys.maxsize])),
- ]
- offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
- list_ = []
- for i in range(offset_len):
- start_ = symbolic_helper._unsqueeze_helper(
- g,
- opset9.select(g, offsets_extended, torch.tensor(0), torch.tensor(i)),
- [0],
- )
- end_ = symbolic_helper._unsqueeze_helper(
- g,
- opset9.select(
- g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)
- ),
- [0],
- )
- axes_ = g.op("Constant", value_t=torch.tensor([0]))
- indices_row = g.op("Slice", indices, start_, end_, axes_)
- embeddings = g.op("Gather", embedding_matrix, indices_row)
- if not symbolic_helper._is_none(per_sample_weights):
- per_sample_weights_row = g.op(
- "Slice", per_sample_weights, start_, end_, axes_
- )
- per_sample_weights_row = symbolic_helper._unsqueeze_helper(
- g, per_sample_weights_row, [1]
- )
- embeddings = g.op("Mul", embeddings, per_sample_weights_row)
- if mode == 0:
- embeddings = symbolic_helper._reducesum_helper(
- g, embeddings, axes_i=[0], keepdims_i=0
- )
- elif mode == 1:
- embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
- else:
- embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
- embeddings = symbolic_helper._unsqueeze_helper(g, embeddings, [0])
- list_.append(embeddings)
- output = g.op("Concat", *list_, axis_i=0)
- # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
- # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
- return output, None, None, None
- else:
- return symbolic_helper._onnx_unsupported(
- "embedding_bag with unknown shape of offsets for opset 10 is not supported. "
- "please use opset 11 or higher."
- )
- @_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
- @symbolic_helper.parse_args("v", "v", "v", "i", "i")
- @_beartype.beartype
- def fake_quantize_per_tensor_affine(
- g: jit_utils.GraphContext,
- inputs,
- scale,
- zero_point,
- quant_min=-128,
- quant_max=127,
- ):
- # NOTE: (0, 127) is a special case. PyTorch restricts activations to be in the range (0, 127).
- # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
- if (quant_min, quant_max) == (0, 127):
- symbolic_helper._onnx_opset_unsupported_detailed(
- "fake_quantize_per_tensor_affine",
- 10,
- 13,
- "Quantize range (0, 127) not supported, requires opset 13 Clip",
- inputs,
- )
- if (quant_min, quant_max) not in [(0, 255), (-128, 127)]:
- raise errors.SymbolicValueError(
- f"For (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). "
- f"Got ({quant_min}, {quant_max})",
- inputs,
- )
- scale = symbolic_helper._maybe_get_scalar(scale)
- if scale is None:
- symbolic_helper._onnx_opset_unsupported_detailed(
- "fake_quantize_per_tensor_affine",
- 10,
- 13,
- "Non-constant scale not supported",
- inputs,
- )
- scale = scale.float().data # Avoid exporter generating double type
- if quant_min == 0:
- zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
- else:
- zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
- return g.op(
- "DequantizeLinear",
- g.op("QuantizeLinear", inputs, scale, zero_point),
- scale,
- zero_point,
- )
- @_onnx_symbolic("aten::isinf")
- @_beartype.beartype
- def isinf(g: jit_utils.GraphContext, input):
- return g.op("IsInf", g.op("Cast", input, to_i=_C_onnx.TensorProtoDataType.DOUBLE))
- @_onnx_symbolic("aten::isfinite")
- @_beartype.beartype
- def isfinite(g: jit_utils.GraphContext, input):
- inf_node = isinf(g, input)
- nan_node = opset9.isnan(g, input)
- return opset9.__not_(g, opset9.__or_(g, inf_node, nan_node))
- @_onnx_symbolic("aten::quantize_per_tensor")
- @_beartype.beartype
- def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
- dtype = symbolic_helper._get_const(dtype, "i", "dtype")
- # TODO(justinchuby): Extract all the cast ops into a helper function.
- zero_point = g.op(
- "Cast", zero_point, to_i=_type_utils.JitScalarType(dtype).onnx_type()
- )
- scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
- return symbolic_helper.quantize_helper(g, input, scale, zero_point)
- @_onnx_symbolic("aten::dequantize")
- @_beartype.beartype
- def dequantize(g: jit_utils.GraphContext, input):
- return symbolic_helper.dequantize_helper(g, input)[0]
- @_onnx_symbolic("aten::nan_to_num")
- @symbolic_helper.parse_args("v", "f", "f", "f")
- @_beartype.beartype
- def nan_to_num(g: jit_utils.GraphContext, input, nan, posinf, neginf):
- # Cannot create a int type tensor with inf/nan values, so we simply
- # return the original tensor
- if not symbolic_helper._is_fp(input):
- return input
- input_dtype = _type_utils.JitScalarType.from_value(input).dtype()
- if nan is None:
- nan = 0.0
- nan_cond = opset9.isnan(g, input)
- nan_result = g.op(
- "Where",
- nan_cond,
- g.op("Constant", value_t=torch.tensor([nan], dtype=input_dtype)),
- input,
- )
- # For None values of posinf, neginf we use the greatest/lowest finite
- # value representable by input’s dtype.
- finfo = torch.finfo(input_dtype)
- if posinf is None:
- posinf = finfo.max
- posinf_cond = opset9.logical_and(
- g,
- isinf(g, nan_result),
- opset9.gt(g, nan_result, g.op("Constant", value_t=torch.LongTensor([0]))),
- )
- nan_posinf_result = g.op(
- "Where",
- posinf_cond,
- g.op("Constant", value_t=torch.tensor([posinf], dtype=input_dtype)),
- nan_result,
- )
- if neginf is None:
- neginf = finfo.min
- neginf_cond = opset9.logical_and(
- g,
- isinf(g, nan_posinf_result),
- opset9.lt(
- g, nan_posinf_result, g.op("Constant", value_t=torch.LongTensor([0]))
- ),
- )
- return g.op(
- "Where",
- neginf_cond,
- g.op("Constant", value_t=torch.tensor([neginf], dtype=input_dtype)),
- nan_posinf_result,
- )
- # Quantized symbolics ---------------------------------------------------------
- # https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
- # Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were
- # introduced in opset version 10.
- @_onnx_symbolic("quantized::linear")
- @_beartype.beartype
- def quantized_linear(
- g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.linear(g, input, weight, bias)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::add")
- @_beartype.beartype
- def quantized_add(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
- output = opset9.add(g, x, y)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::add_relu")
- @_beartype.beartype
- def quantized_add_relu(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
- output = opset9.add(g, x, y)
- output = opset9.relu(g, output)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::mul")
- @_beartype.beartype
- def quantized_mul(g: jit_utils.GraphContext, x, y, op_scale, op_zero_point):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- y, _, _, _ = symbolic_helper.dequantize_helper(g, y)
- output = opset9.mul(g, x, y)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::hardswish")
- @_beartype.beartype
- def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- output = opset9.hardswish(g, x)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::sigmoid")
- @_beartype.beartype
- def quantized_sigmoid(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- output = opset9.sigmoid(g, x)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::leaky_relu")
- @_beartype.beartype
- def quantized_leaky_relu(
- g: jit_utils.GraphContext, x, negative_slope, inplace, op_scale, op_zero_point
- ):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- output = opset9.leaky_relu(g, x, negative_slope, inplace)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::layer_norm")
- @_beartype.beartype
- def quantized_layer_norm(
- g: jit_utils.GraphContext,
- x,
- normalized_shape,
- weight,
- bias,
- eps,
- op_scale,
- op_zero_point,
- ):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- output = opset9.layer_norm(g, x, normalized_shape, weight, bias, eps, False)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::group_norm")
- @_beartype.beartype
- def quantized_group_norm(
- g: jit_utils.GraphContext,
- x,
- num_groups,
- weight,
- bias,
- eps,
- op_scale,
- op_zero_point,
- ):
- x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
- output = opset9.group_norm(g, x, num_groups, weight, bias, eps, False)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::instance_norm")
- @symbolic_helper.parse_args("v", "v", "v", "f", "v", "v")
- @_beartype.beartype
- def quantized_instance_norm(
- g: jit_utils.GraphContext,
- q_input,
- weight,
- bias,
- eps,
- op_scale,
- op_zero_point,
- ):
- input, _, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- output = opset9.instance_norm(
- g, input, weight, bias, None, None, False, 0.0, eps, False
- )
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::conv1d_relu")
- @_beartype.beartype
- def quantized_conv1d_relu(
- g: jit_utils.GraphContext,
- q_input,
- q_weight,
- bias,
- stride,
- padding,
- dilation,
- groups,
- op_scale,
- op_zero_point,
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.conv1d(g, input, weight, bias, stride, padding, dilation, groups)
- output = opset9.relu(g, output)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::conv2d_relu")
- @_beartype.beartype
- def quantized_conv2d_relu(
- g: jit_utils.GraphContext,
- q_input,
- q_weight,
- bias,
- stride,
- padding,
- dilation,
- groups,
- op_scale,
- op_zero_point,
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
- output = opset9.relu(g, output)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::conv2d")
- @_beartype.beartype
- def quantized_conv2d(
- g: jit_utils.GraphContext,
- q_input,
- q_weight,
- bias,
- stride,
- padding,
- dilation,
- groups,
- op_scale,
- op_zero_point,
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, _ = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(g, bias, input_scale, weight_scale)
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::cat")
- @symbolic_helper.parse_args("v", "i", "v", "v")
- @_beartype.beartype
- def quantized_cat(
- g: jit_utils.GraphContext,
- q_inputs: _C.Value,
- dim: int,
- op_scale: _C.Value,
- op_zero_point: _C.Value,
- ) -> _C.Value:
- unpacked_inputs = symbolic_helper._unpack_list(q_inputs)
- dequantized = [
- symbolic_helper.dequantize_helper(g, input)[0] for input in unpacked_inputs
- ]
- concatenated = g.op("Concat", *dequantized, axis_i=dim)
- return symbolic_helper.quantize_helper(g, concatenated, op_scale, op_zero_point)
|