symbolic_opset16.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """This file exports ONNX ops for opset 16.
  2. Note [ONNX Operators that are added/updated in opset 16]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-16-of-the-default-onnx-operator-set
  5. New operators:
  6. GridSample https://github.com/onnx/onnx/pull/3557
  7. Updated operators:
  8. Identity
  9. If
  10. LeakyRelu
  11. Loop
  12. PRelu
  13. RoiAlign
  14. Scan
  15. ScatterElements
  16. ScatterND
  17. Where
  18. GreaterOrEqual
  19. LessOrEqual
  20. """
  21. # EDITING THIS FILE? READ THIS FIRST!
  22. # see Note [Edit Symbolic Files] in README.md
  23. import functools
  24. import torch
  25. from torch.nn.functional import (
  26. GRID_SAMPLE_INTERPOLATION_MODES,
  27. GRID_SAMPLE_PADDING_MODES,
  28. )
  29. from torch.onnx import _type_utils, symbolic_helper
  30. from torch.onnx._internal import _beartype, jit_utils, registration
  31. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=16)
  32. # note (mkozuki): Why `grid_sampler` instead of `grid_sample`?
  33. # Because `torch.nn.functional.grid_sample` calls `torch.grid_sampler`.
  34. @_onnx_symbolic("aten::grid_sampler")
  35. @symbolic_helper.parse_args("v", "v", "i", "i", "b")
  36. @_beartype.beartype
  37. def grid_sampler(
  38. g: jit_utils.GraphContext,
  39. input,
  40. grid,
  41. mode_enum,
  42. padding_mode_enum,
  43. align_corners,
  44. ):
  45. # Check the input and grid tensor rank beforehand.
  46. if symbolic_helper._get_tensor_rank(input) == 5:
  47. return symbolic_helper._onnx_unsupported("GridSample with 5D volumetric input")
  48. mode_s = {v: k for k, v in GRID_SAMPLE_INTERPOLATION_MODES.items()}[mode_enum] # type: ignore[call-arg]
  49. padding_mode_s = {v: k for k, v in GRID_SAMPLE_PADDING_MODES.items()}[padding_mode_enum] # type: ignore[call-arg]
  50. return g.op(
  51. "GridSample",
  52. input,
  53. grid,
  54. align_corners_i=int(align_corners),
  55. mode_s=mode_s,
  56. padding_mode_s=padding_mode_s,
  57. )
  58. @_onnx_symbolic("aten::scatter_add")
  59. @symbolic_helper.parse_args("v", "i", "v", "v")
  60. @_beartype.beartype
  61. def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
  62. if symbolic_helper.is_caffe2_aten_fallback():
  63. return g.at("scatter", self, dim, index, src, overload_name="src")
  64. src_type = _type_utils.JitScalarType.from_value(
  65. src, _type_utils.JitScalarType.UNDEFINED
  66. )
  67. src_sizes = symbolic_helper._get_tensor_sizes(src)
  68. index_sizes = symbolic_helper._get_tensor_sizes(index)
  69. if len(src_sizes) != len(index_sizes):
  70. return symbolic_helper._unimplemented(
  71. "scatter_add",
  72. f"`index` ({index_sizes}) should have the same dimensionality as `src` ({src_sizes})",
  73. )
  74. # PyTorch only allows index shape <= src shape, so we can only consider
  75. # taking index as subset size to src, like PyTorch does. When sizes for src
  76. # and index are not matched or there are dynamic axes, we take index shape to
  77. # slice src to accommodate.
  78. if src_sizes != index_sizes or None in index_sizes:
  79. adjusted_shape = g.op("Shape", index)
  80. starts = g.op("Constant", value_t=torch.tensor([0] * len(index_sizes)))
  81. src = g.op("Slice", src, starts, adjusted_shape)
  82. src = symbolic_helper._maybe_get_scalar(src)
  83. if symbolic_helper._is_value(src):
  84. return g.op("ScatterElements", self, index, src, axis_i=dim, reduction_s="add")
  85. else:
  86. # Check if scalar "src" has same type as self (PyTorch allows different
  87. # type for scalar src (but not when src is tensor)). If not, insert Cast node.
  88. if _type_utils.JitScalarType.from_value(self) != src_type:
  89. src = g.op(
  90. "Cast",
  91. src,
  92. to_i=_type_utils.JitScalarType.from_value(self).onnx_type(),
  93. )
  94. return g.op(
  95. "ScatterElements",
  96. self,
  97. index,
  98. src,
  99. axis_i=dim,
  100. reduction_s="add",
  101. )