symbolic_opset15.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. """This file exports ONNX ops for opset 15.
  2. Note [ONNX operators that are added/updated in opset 15]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. https://github.com/onnx/onnx/blob/master/docs/Changelog.md#version-15-of-the-default-onnx-operator-set
  5. New operators:
  6. Bernoulli
  7. CastLike
  8. Optional
  9. OptionalGetElement
  10. OptionalHasElement
  11. Updated operators:
  12. BatchNormalization https://github.com/onnx/onnx/pull/3545
  13. Backwards compatible
  14. TODO: test coverage for mixed types inputs.
  15. Pow https://github.com/onnx/onnx/pull/3412
  16. Backwards compatible
  17. TODO: bfloat16 support.
  18. Shape https://github.com/onnx/onnx/pull/3580
  19. Backwards compatible
  20. TODO: optional start/end attribute.
  21. """
  22. # EDITING THIS FILE? READ THIS FIRST!
  23. # see Note [Edit Symbolic Files] in README.md
  24. import functools
  25. import torch
  26. from torch import _C
  27. from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
  28. from torch.onnx._internal import _beartype, jit_utils, registration
  29. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=15)
  30. @_onnx_symbolic("aten::__is_")
  31. @_beartype.beartype
  32. def aten__is_(g: jit_utils.GraphContext, self, other):
  33. if symbolic_helper._is_none(other):
  34. if isinstance(self.type(), _C.OptionalType):
  35. none = g.op("OptionalHasElement", self)
  36. return g.op("Not", none)
  37. else:
  38. return g.op("Constant", value_t=torch.BoolTensor([0]))
  39. return opset9.eq(g, self, other)
  40. @_onnx_symbolic("aten::__isnot_")
  41. @opset9.wrap_logical_op_with_negation # type: ignore[has-type]
  42. @_beartype.beartype
  43. def aten__isnot_(g: jit_utils.GraphContext, self, other):
  44. return aten__is_(g, self, other)
  45. @_onnx_symbolic("aten::bernoulli")
  46. @_beartype.beartype
  47. def bernoulli(g: jit_utils.GraphContext, input, p=None, generator=None, out=None):
  48. if out is not None and not symbolic_helper._is_none(out):
  49. symbolic_helper._unimplemented(
  50. "Bernoulli", "out parameter is not supported for bernoulli", input
  51. )
  52. if generator is not None and not symbolic_helper._is_none(generator):
  53. symbolic_helper._unimplemented(
  54. "Bernoulli", "generator is not supported for bernoulli", input
  55. )
  56. if p is None or symbolic_helper._is_none(p):
  57. return g.op("Bernoulli", input)
  58. return opset9.bernoulli(g, input, p, generator, out)
  59. @_onnx_symbolic("prim::unchecked_cast")
  60. @_beartype.beartype
  61. def prim_unchecked_cast(g: jit_utils.GraphContext, self):
  62. # exists to refine the type of the Value
  63. # if x is Optional[Tensor], unchecked_cast will cast
  64. # x to Tensor, so the rest of the graph knows that x is a Tensor.
  65. if isinstance(self.type(), _C.OptionalType):
  66. return g.op("OptionalGetElement", self)
  67. return self