123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470 |
- """
- Note [ONNX operators that are added/updated from opset 8 to opset 9]
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- New operators:
- Compress
- ConstantOfShape
- EyeLike
- MaxUnpool
- OneHot
- Sinh
- Cosh
- Asinh
- Acosh
- Atanh
- Shrink
- IsNaN
- Sign
- Erf
- Scatter
- Where
- NonZero
- TfIdfVectorizer
- MeanVarianceNormalization
- Updated operators:
- BatchNormalization: removed spatial attribute.
- Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
- Cast: more data types{string} supported.
- Upsample: moved scales from attribute to input.
- Scan
- """
- import functools
- import warnings
- import torch
- from torch._C import _onnx as _C_onnx
- from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
- from torch.onnx._internal import jit_utils, registration
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
- block_listed_operators = (
- "nonzero",
- "where",
- "scatter",
- "scatter_add",
- "erf",
- "sign",
- "isnan",
- "gather",
- "arange",
- "masked_fill",
- "index_fill",
- "index_copy",
- "repeat_interleave",
- "any",
- "all",
- )
- for block_listed_op in block_listed_operators:
- _onnx_symbolic(f"aten::{block_listed_op}")(
- symbolic_helper._block_list_in_opset(block_listed_op)
- )
- def _apply_params(*args, **kwargs):
- """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
- def _apply(fn):
- return fn(*args, **kwargs)
- return _apply
- @_onnx_symbolic(
- "aten::upsample_nearest1d",
- decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
- )
- @_onnx_symbolic(
- "aten::upsample_nearest2d",
- decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
- )
- @_onnx_symbolic(
- "aten::upsample_nearest3d",
- decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
- )
- @_onnx_symbolic(
- "aten::upsample_linear1d",
- decorate=[_apply_params("upsample_linear1d", 3, "linear")],
- )
- @_onnx_symbolic(
- "aten::upsample_bilinear2d",
- decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
- )
- @_onnx_symbolic(
- "aten::upsample_trilinear3d",
- decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
- )
- def _interpolate(name, dim, interpolate_mode):
- def symbolic_fn(g, input, output_size, *args):
- scales, align_corners = symbolic_helper._get_interpolate_attributes(
- g, interpolate_mode, args
- )
- symbolic_helper._interpolate_warning(interpolate_mode)
- align_corners = symbolic_helper._maybe_get_scalar(align_corners)
- if align_corners:
- return symbolic_helper._unimplemented(name, "align_corners == True", input)
- output_size = symbolic_helper._maybe_get_const(output_size, "is")
- if symbolic_helper._is_value(output_size):
- return symbolic_helper._unimplemented(
- name, "torch._C.Value (output_size) indexing"
- )
- if scales is None:
- scales = [
- 1.0
- if i < 2
- else float(output_size[-(dim - i)])
- / float(input.type().sizes()[-(dim - i)])
- for i in range(0, dim)
- ]
- return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
- return symbolic_fn
- @_onnx_symbolic("aten::__interpolate")
- def __interpolate(
- g: jit_utils.GraphContext,
- input,
- size,
- scale_factor,
- mode,
- align_corners,
- recompute_scale_factor,
- antialias,
- ):
- align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
- if not symbolic_helper._is_none(align_corners) and align_corners:
- return symbolic_helper._unimplemented("interpolate", "align_corners == True")
- if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
- scale_factor
- ):
- return symbolic_helper._unimplemented(
- "interpolate", "dynamic scales in opset 8"
- )
- if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
- return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
- scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
- g, input, size, scale_factor, mode, align_corners
- )
- return g.op("Upsample", input, mode_s=mode, scales_f=scales)
- # NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
- # issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
- # is lost after casting.
- def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
- floating_scalar_types = {
- _type_utils.JitScalarType.HALF,
- _type_utils.JitScalarType.FLOAT,
- _type_utils.JitScalarType.DOUBLE,
- }
- old_type = None
- # Cast the input tensor to Float if its scalarType is known and is not floating number.
- # If casting is performed, return the old scalarType, otherwise return None.
- arg0_type = _type_utils.JitScalarType.from_value(
- args[0], _type_utils.JitScalarType.UNDEFINED
- )
- if arg0_type != _type_utils.JitScalarType.UNDEFINED:
- old_type = arg0_type
- if old_type not in floating_scalar_types:
- old_type = old_type.scalar_name()
- args = tuple(
- g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
- for arg in args
- )
- else:
- return (None,) + args
- else:
- warnings.warn(
- "Only floating datatype is supported for these operators: "
- "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
- "the onnx model to be incorrect, if inputs have integer datatypes."
- )
- return (old_type,) + args
- def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
- if to_type is None:
- return input
- return getattr(opset9, f"_cast_{to_type}")(g, input, False)
- def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
- other = symbolic_helper._maybe_get_scalar(other)
- other = symbolic_helper._if_scalar_type_as(other, input)
- _, input, other = _try_cast_integer_to_float(g, input, other)
- return g.op(op_name, input, other)
- # NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
- # integer input type not supported in opset8. Cast to float if possible.
- @_onnx_symbolic("aten::gt")
- def gt(g: jit_utils.GraphContext, input, other):
- return _comparison_operator(g, input, other, "Greater")
- @_onnx_symbolic("aten::lt")
- def lt(g: jit_utils.GraphContext, input, other):
- return _comparison_operator(g, input, other, "Less")
- @_onnx_symbolic("aten::bmm")
- def bmm(g: jit_utils.GraphContext, self, other):
- if symbolic_helper._try_get_scalar_type(self):
- old_type, self, other = _try_cast_integer_to_float(g, self, other)
- return _cast_to_type(g, g.op("MatMul", self, other), old_type)
- else:
- return g.op("MatMul", self, other)
- @_onnx_symbolic("aten::matmul")
- def matmul(g: jit_utils.GraphContext, self, other):
- return bmm(g, self, other)
- @_onnx_symbolic("aten::prelu")
- def prelu(g: jit_utils.GraphContext, self, weight):
- self_rank = symbolic_helper._get_tensor_rank(self)
- weight_sizes = symbolic_helper._get_tensor_sizes(weight)
- if self_rank is not None and self_rank > 2:
- weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
- elif self_rank == 0 and weight_sizes == [1]:
- # self and weight are both scalar but weight has rank == 1, squeeze weight.
- weight = symbolic_helper._squeeze_helper(g, weight, [0])
- if symbolic_helper._try_get_scalar_type(self):
- old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
- return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
- else:
- return g.op("PRelu", self, weight)
- @_onnx_symbolic("aten::mm")
- def mm(g: jit_utils.GraphContext, self, other):
- # Create a dummy C tensor. Only needed for API purposes, the value is
- # since beta = 0
- scalar_type = symbolic_helper._try_get_scalar_type(self, other)
- if scalar_type is None:
- raise errors.SymbolicValueError(
- "mm can only operate on tensors with known types", self
- )
- zero_constant = g.op(
- "Constant",
- value_t=torch.tensor([0], dtype=scalar_type.dtype()),
- )
- if symbolic_helper._try_get_scalar_type(self):
- old_type, self, other, zero_constant = _try_cast_integer_to_float(
- g, self, other, zero_constant
- )
- return _cast_to_type(
- g,
- g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
- old_type,
- )
- return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
- @_onnx_symbolic("aten::addmm")
- @symbolic_helper.parse_args("v", "v", "v", "t", "t")
- def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
- if symbolic_helper._try_get_scalar_type(self):
- old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
- return _cast_to_type(
- g,
- g.op(
- "Gemm",
- mat1,
- mat2,
- self,
- beta_f=symbolic_helper._scalar(beta),
- alpha_f=symbolic_helper._scalar(alpha),
- ),
- old_type,
- )
- else:
- return g.op(
- "Gemm",
- mat1,
- mat2,
- self,
- beta_f=symbolic_helper._scalar(beta),
- alpha_f=symbolic_helper._scalar(alpha),
- )
- @_onnx_symbolic("aten::flatten")
- def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
- start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
- end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
- dim = input.type().dim()
- if end_dim_i < 0:
- end_dim_i = dim + end_dim_i
- # use ONNX's Flatten operator for cases where the output shape is 2D
- if start_dim_i == 1 and end_dim_i == dim - 1:
- if symbolic_helper._try_get_scalar_type(input):
- old_type, input = _try_cast_integer_to_float(g, input)
- return _cast_to_type(
- g, g.op("Flatten", input, axis_i=start_dim_i), old_type
- )
- else:
- return g.op("Flatten", input, axis_i=start_dim_i)
- if start_dim_i == 0 and end_dim_i == dim - 2:
- if symbolic_helper._try_get_scalar_type(input):
- old_type, input = _try_cast_integer_to_float(g, input)
- return _cast_to_type(
- g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
- )
- else:
- return g.op("Flatten", input, axis_i=end_dim_i + 1)
- return opset9.flatten(g, input, start_dim, end_dim)
- def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
- if dtype is None:
- scalar_type = _type_utils.JitScalarType.FLOAT
- else:
- scalar_type = _type_utils.JitScalarType(dtype)
- if not scalar_type.dtype().is_floating_point:
- result = g.op(
- "ConstantFill",
- sizes,
- dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
- input_as_shape_i=1,
- value_f=const_value,
- )
- return g.op("Cast", result, to_i=scalar_type.onnx_type())
- else:
- return g.op(
- "ConstantFill",
- sizes,
- dtype_i=scalar_type.onnx_type(),
- input_as_shape_i=1,
- value_f=const_value,
- )
- @_onnx_symbolic("aten::empty")
- @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
- def empty(
- g: jit_utils.GraphContext,
- sizes,
- dtype,
- layout,
- device,
- pin_memory=False,
- memory_format=None,
- ):
- return zeros(g, sizes, dtype, layout, device, pin_memory)
- @_onnx_symbolic("aten::empty_like")
- @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
- def empty_like(
- g: jit_utils.GraphContext,
- input,
- dtype,
- layout,
- device,
- pin_memory=False,
- memory_format=None,
- ):
- return zeros_like(g, input, dtype, layout, device, pin_memory)
- @_onnx_symbolic("aten::zeros")
- @symbolic_helper.parse_args("v", "i", "v", "v", "v")
- def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
- # NOTE: no way to set device and layout in ONNX, so we ignore it
- return _constant_fill(g, sizes, dtype, 0)
- @_onnx_symbolic("aten::zeros_like")
- @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
- def zeros_like(
- g: jit_utils.GraphContext,
- input,
- dtype,
- layout,
- device,
- pin_memory=False,
- memory_format=None,
- ):
- shape = g.op("Shape", input)
- return _constant_fill(g, shape, dtype, 0)
- @_onnx_symbolic("aten::ones")
- @symbolic_helper.parse_args("v", "i", "v", "v", "v")
- def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
- return _constant_fill(g, sizes, dtype, 1)
- @_onnx_symbolic("aten::ones_like")
- @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
- def ones_like(
- g: jit_utils.GraphContext,
- input,
- dtype,
- layout,
- device,
- pin_memory=False,
- memory_format=None,
- ):
- shape = g.op("Shape", input)
- return _constant_fill(g, shape, dtype, 1)
- @_onnx_symbolic("aten::full")
- def full(
- g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
- ):
- const_value = symbolic_helper._maybe_get_const(value, "t")
- if symbolic_helper._is_value(const_value):
- tmp = zeros(g, sizes, dtype, layout, device)
- return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
- else:
- dtype = symbolic_helper._get_const(dtype, "i", "dtype")
- return _constant_fill(g, sizes, dtype, const_value)
- @_onnx_symbolic("aten::full_like")
- @symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
- def full_like(
- g: jit_utils.GraphContext,
- input,
- fill_value,
- dtype,
- layout,
- device,
- pin_memory=False,
- memory_format=None,
- ):
- shape = g.op("Shape", input)
- return _constant_fill(g, shape, dtype, fill_value)
- @_onnx_symbolic("aten::repeat")
- def repeat(g: jit_utils.GraphContext, self, repeats):
- if not symbolic_helper._is_value(repeats):
- repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
- if symbolic_helper._is_packed_list(repeats):
- repeat_size_len = len(symbolic_helper._unpack_list(repeats))
- else:
- const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
- repeat_size_len = len(const_repeats)
- if self.isCompleteTensor():
- sizes = self.type().sizes()
- diff_dims = repeat_size_len - len(sizes)
- if diff_dims > 0:
- self = opset9.view(
- g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
- )
- return g.op("Tile", self, repeats)
|