12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182 |
- """This file exports ONNX ops for opset 15.
- Note [ONNX operators that are added/updated in opset 15]
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
- New operators:
- Bernoulli
- CastLike
- Optional
- OptionalGetElement
- OptionalHasElement
- Updated operators:
- BatchNormalization https://github.com/onnx/onnx/pull/3545
- Backwards compatible
- TODO: test coverage for mixed types inputs.
- Pow https://github.com/onnx/onnx/pull/3412
- Backwards compatible
- TODO: bfloat16 support.
- Shape https://github.com/onnx/onnx/pull/3580
- Backwards compatible
- TODO: optional start/end attribute.
- """
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- import functools
- import torch
- from torch import _C
- from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
- from torch.onnx._internal import _beartype, jit_utils, registration
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
- @_onnx_symbolic("aten::__is_")
- @_beartype.beartype
- def aten__is_(g: jit_utils.GraphContext, self, other):
- if symbolic_helper._is_none(other):
- if isinstance(self.type(), _C.OptionalType):
- none = g.op("OptionalHasElement", self)
- return g.op("Not", none)
- else:
- return g.op("Constant", value_t=torch.BoolTensor([0]))
- return opset9.eq(g, self, other)
- @_onnx_symbolic("aten::__isnot_")
- @opset9.wrap_logical_op_with_negation # type: ignore[has-type]
- @_beartype.beartype
- def aten__isnot_(g: jit_utils.GraphContext, self, other):
- return aten__is_(g, self, other)
- @_onnx_symbolic("aten::bernoulli")
- @_beartype.beartype
- def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
- if out is not None and not symbolic_helper._is_none(out):
- symbolic_helper._unimplemented(
- "Bernoulli", "out parameter is not supported for bernoulli", input
- )
- if generator is not None and not symbolic_helper._is_none(generator):
- symbolic_helper._unimplemented(
- "Bernoulli", "generator is not supported for bernoulli", input
- )
- if p is None or symbolic_helper._is_none(p):
- return g.op("Bernoulli", input)
- return opset9.bernoulli(g, input, p, generator, out)
- @_onnx_symbolic("prim::unchecked_cast")
- @_beartype.beartype
- def prim_unchecked_cast(g: jit_utils.GraphContext, self):
- # exists to refine the type of the Value
- # if x is Optional[Tensor], unchecked_cast will cast
- # x to Tensor, so the rest of the graph knows that x is a Tensor.
- if isinstance(self.type(), _C.OptionalType):
- return g.op("OptionalGetElement", self)
- return self
|