symbolic_opset18.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. """This file exports ONNX ops for opset 18.
  2. Note [ONNX Operators that are added/updated in opset 18]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-18-of-the-default-onnx-operator-set
  5. New operators:
  6. CenterCropPad
  7. Col2Im
  8. Mish
  9. OptionalGetElement
  10. OptionalHasElement
  11. Pad
  12. Resize
  13. ScatterElements
  14. ScatterND
  15. """
  16. import functools
  17. from typing import Sequence
  18. from torch import _C
  19. from torch.onnx import symbolic_helper
  20. from torch.onnx._internal import _beartype, registration
  21. # EDITING THIS FILE? READ THIS FIRST!
  22. # see Note [Edit Symbolic Files] in symbolic_helper.py
  23. __all__ = ["col2im"]
  24. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18)
  25. @_onnx_symbolic("aten::col2im")
  26. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is")
  27. @_beartype.beartype
  28. def col2im(
  29. g,
  30. input: _C.Value,
  31. output_size: _C.Value,
  32. kernel_size: _C.Value,
  33. dilation: Sequence[int],
  34. padding: Sequence[int],
  35. stride: Sequence[int],
  36. ):
  37. # convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
  38. adjusted_padding = []
  39. for pad in padding:
  40. for _ in range(2):
  41. adjusted_padding.append(pad)
  42. num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
  43. if not adjusted_padding:
  44. adjusted_padding = [0, 0] * num_dimensional_axis
  45. if not dilation:
  46. dilation = [1] * num_dimensional_axis
  47. if not stride:
  48. stride = [1] * num_dimensional_axis
  49. return g.op(
  50. "Col2Im",
  51. input,
  52. output_size,
  53. kernel_size,
  54. dilations_i=dilation,
  55. pads_i=adjusted_padding,
  56. strides_i=stride,
  57. )