symbolic_opset17.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. """This file exports ONNX ops for opset 17.
  2. Note [ONNX Operators that are added/updated in opset 17]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
  5. New operators:
  6. BlackmanWindow
  7. DFT
  8. HammingWindow
  9. HannWindow
  10. LayerNormalization
  11. MelWeightMatrix
  12. STFT
  13. SequenceMap
  14. """
  15. import functools
  16. from typing import Sequence
  17. from torch import _C
  18. from torch.onnx import symbolic_helper
  19. from torch.onnx._internal import jit_utils, registration
  20. # EDITING THIS FILE? READ THIS FIRST!
  21. # see Note [Edit Symbolic Files] in README.md
  22. __all__ = ["layer_norm"]
  23. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
  24. @_onnx_symbolic("aten::layer_norm")
  25. @symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
  26. def layer_norm(
  27. g: jit_utils.GraphContext,
  28. input: _C.Value,
  29. normalized_shape: Sequence[int],
  30. weight: _C.Value,
  31. bias: _C.Value,
  32. eps: float,
  33. cudnn_enable: bool,
  34. ):
  35. # normalized_shape: input shape from an expected input of size
  36. # axis: The first normalization dimension.
  37. # layer_norm normalizes on the last D dimensions,
  38. # where D is the size of normalized_shape
  39. axis = -len(normalized_shape)
  40. return g.op(
  41. "LayerNormalization",
  42. input,
  43. weight,
  44. bias,
  45. epsilon_f=eps,
  46. axis_i=axis,
  47. )