symbolic_opset14.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. """This file exports ONNX ops for opset 14.
  2. Note [ONNX operators that are added/updated in opset 14]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. New operators:
  5. HardSwish, Trilu
  6. Updated operators:
  7. Reshape
  8. Add, Sub, Mul, Div
  9. GRU, LSTM, RNN
  10. BatchNorm, Cumsum, Relu
  11. """
  12. # EDITING THIS FILE? READ THIS FIRST!
  13. # see Note [Edit Symbolic Files] in README.md
  14. import functools
  15. import torch
  16. from torch.onnx import symbolic_helper
  17. from torch.onnx._globals import GLOBALS
  18. from torch.onnx._internal import _beartype, jit_utils, registration
  19. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=14)
  20. @_onnx_symbolic("aten::hardswish")
  21. @symbolic_helper.parse_args("v")
  22. @_beartype.beartype
  23. def hardswish(g: jit_utils.GraphContext, self):
  24. return g.op("HardSwish", self)
  25. @_onnx_symbolic("aten::tril")
  26. @_beartype.beartype
  27. def tril(g: jit_utils.GraphContext, self, diagonal, out=None):
  28. return g.op("Trilu", self, diagonal, upper_i=0)
  29. @_onnx_symbolic("aten::triu")
  30. @_beartype.beartype
  31. def triu(g: jit_utils.GraphContext, self, diagonal, out=None):
  32. return g.op("Trilu", self, diagonal, upper_i=1)
  33. @_onnx_symbolic("aten::reshape")
  34. @symbolic_helper.parse_args("v", "v")
  35. @_beartype.beartype
  36. def reshape(g: jit_utils.GraphContext, self, shape):
  37. # NOTE: Due to bug in ORT https://github.com/microsoft/onnxruntime/issues/10664
  38. # Reshape export cannot utilize the new allowzero attribute introduced in opset 14.
  39. return symbolic_helper._reshape_helper(g, self, shape, allowzero=0)
  40. @_onnx_symbolic("aten::batch_norm")
  41. @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "f", "f", "i")
  42. @_beartype.beartype
  43. def batch_norm(
  44. g: jit_utils.GraphContext,
  45. input,
  46. weight,
  47. bias,
  48. running_mean,
  49. running_var,
  50. training,
  51. momentum,
  52. eps,
  53. cudnn_enabled,
  54. ):
  55. if (
  56. torch.is_autocast_enabled()
  57. and not symbolic_helper.args_have_same_dtype(
  58. [input, weight, bias, running_mean, running_var]
  59. )
  60. and GLOBALS.export_onnx_opset_version < 15
  61. ):
  62. return symbolic_helper._onnx_opset_unsupported_detailed(
  63. "BatchNormalization",
  64. 14,
  65. 15,
  66. "All input tensors must have the same `dtype`."
  67. " Turn off Autocast or export using opset version 15.",
  68. input,
  69. )
  70. symbolic_helper.check_training_mode(training, "batch_norm")
  71. weight, bias, running_mean, running_var = symbolic_helper._batchnorm_helper(
  72. g, input, weight, bias, running_mean, running_var
  73. )
  74. out = g.op(
  75. "BatchNormalization",
  76. input,
  77. weight,
  78. bias,
  79. running_mean,
  80. running_var,
  81. epsilon_f=eps,
  82. momentum_f=1 - momentum,
  83. training_mode_i=0 if not training else 1,
  84. outputs=1 if not training else 3,
  85. )
  86. if not training:
  87. return out
  88. else:
  89. res, new_running_mean, new_running_var = out
  90. new_running_mean.setType(running_mean.type())
  91. new_running_var.setType(running_var.type())
  92. return res
  93. @_onnx_symbolic("quantized::hardswish")
  94. @_beartype.beartype
  95. def quantized_hardswish(g: jit_utils.GraphContext, x, op_scale, op_zero_point):
  96. x, _, _, _ = symbolic_helper.dequantize_helper(g, x)
  97. output = hardswish(g, x)
  98. return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)