"""This file exports ONNX ops for opset 17. Note [ONNX Operators that are added/updated in opset 17] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set New operators: BlackmanWindow DFT HammingWindow HannWindow LayerNormalization MelWeightMatrix STFT SequenceMap """ import functools from typing import Sequence from torch import _C from torch.onnx import symbolic_helper from torch.onnx._internal import jit_utils, registration # EDITING THIS FILE? READ THIS FIRST! # see Note [Edit Symbolic Files] in README.md __all__ = ["layer_norm"] _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17) @_onnx_symbolic("aten::layer_norm") @symbolic_helper.parse_args("v", "is", "v", "v", "f", "none") def layer_norm( g: jit_utils.GraphContext, input: _C.Value, normalized_shape: Sequence[int], weight: _C.Value, bias: _C.Value, eps: float, cudnn_enable: bool, ): # normalized_shape: input shape from an expected input of size # axis: The first normalization dimension. # layer_norm normalizes on the last D dimensions, # where D is the size of normalized_shape axis = -len(normalized_shape) return g.op( "LayerNormalization", input, weight, bias, epsilon_f=eps, axis_i=axis, )