1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- """This file exports ONNX ops for opset 17.
- Note [ONNX Operators that are added/updated in opset 17]
- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
- New operators:
- BlackmanWindow
- DFT
- HammingWindow
- HannWindow
- LayerNormalization
- MelWeightMatrix
- STFT
- SequenceMap
- """
- import functools
- from typing import Sequence
- from torch import _C
- from torch.onnx import symbolic_helper
- from torch.onnx._internal import jit_utils, registration
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- __all__ = ["layer_norm"]
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
- @_onnx_symbolic("aten::layer_norm")
- @symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
- def layer_norm(
- g: jit_utils.GraphContext,
- input: _C.Value,
- normalized_shape: Sequence[int],
- weight: _C.Value,
- bias: _C.Value,
- eps: float,
- cudnn_enable: bool,
- ):
- # normalized_shape: input shape from an expected input of size
- # axis: The first normalization dimension.
- # layer_norm normalizes on the last D dimensions,
- # where D is the size of normalized_shape
- axis = -len(normalized_shape)
- return g.op(
- "LayerNormalization",
- input,
- weight,
- bias,
- epsilon_f=eps,
- axis_i=axis,
- )
|