123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- """This file exports ONNX ops for opset 16.
- Note [ONNX Operators that are added/updated in opset 16]
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
- New operators:
- GridSample https://github.com/onnx/onnx/pull/3557
- Updated operators:
- Identity
- If
- LeakyRelu
- Loop
- PRelu
- RoiAlign
- Scan
- ScatterElements
- ScatterND
- Where
- GreaterOrEqual
- LessOrEqual
- """
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- import functools
- import torch
- from torch.nn.functional import (
- GRID_SAMPLE_INTERPOLATION_MODES,
- GRID_SAMPLE_PADDING_MODES,
- )
- from torch.onnx import _type_utils, symbolic_helper
- from torch.onnx._internal import _beartype, jit_utils, registration
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
- # note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
- # Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
- @_onnx_symbolic("aten::grid_sampler")
- @symbolic_helper.parse_args("v", "v", "i", "i", "b")
- @_beartype.beartype
- def grid_sampler(
- g: jit_utils.GraphContext,
- input,
- grid,
- mode_enum,
- padding_mode_enum,
- align_corners,
- ):
- # Check the input and grid tensor rank beforehand.
- if symbolic_helper._get_tensor_rank(input) == 5:
- return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
- mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
- padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
- return g.op(
- "GridSample",
- input,
- grid,
- align_corners_i=int(align_corners),
- mode_s=mode_s,
- padding_mode_s=padding_mode_s,
- )
- @_onnx_symbolic("aten::scatter_add")
- @symbolic_helper.parse_args("v", "i", "v", "v")
- @_beartype.beartype
- def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
- if symbolic_helper.is_caffe2_aten_fallback():
- return g.at("scatter", self, dim, index, src, overload_name="src")
- src_type = _type_utils.JitScalarType.from_value(
- src, _type_utils.JitScalarType.UNDEFINED
- )
- src_sizes = symbolic_helper._get_tensor_sizes(src)
- index_sizes = symbolic_helper._get_tensor_sizes(index)
- if len(src_sizes) != len(index_sizes):
- return symbolic_helper._unimplemented(
- "scatter_add",
- f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
- )
- # PyTorch only allows index shape <= src shape, so we can only consider
- # taking index as subset size to src, like PyTorch does. When sizes for src
- # and index are not matched or there are dynamic axes, we take index shape to
- # slice src to accommodate.
- if src_sizes != index_sizes or None in index_sizes:
- adjusted_shape = g.op("Shape", index)
- starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
- src = g.op("Slice", src, starts, adjusted_shape)
- src = symbolic_helper._maybe_get_scalar(src)
- if symbolic_helper._is_value(src):
- return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
- else:
- # Check if scalar "src" has same type as self (PyTorch allows different
- # type for scalar src (but not when src is tensor)). If not, insert Cast node.
- if _type_utils.JitScalarType.from_value(self) != src_type:
- src = g.op(
- "Cast",
- src,
- to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
- )
- return g.op(
- "ScatterElements",
- self,
- index,
- src,
- axis_i=dim,
- reduction_s="add",
- )
|