123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822 |
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- # This file exports ONNX ops for opset 13
- import functools
- import torch
- import torch._C._onnx as _C_onnx
- from torch.onnx import (
- _type_utils,
- errors,
- symbolic_helper,
- symbolic_opset11 as opset11,
- symbolic_opset9 as opset9,
- utils,
- )
- from torch.onnx._internal import _beartype, jit_utils, registration
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=13)
- def _apply_params(*args, **kwargs):
- """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
- def _apply(fn):
- return fn(*args, **kwargs)
- return _apply
- @_onnx_symbolic("aten::softmax")
- @symbolic_helper.parse_args("v", "i", "none")
- @_beartype.beartype
- def softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
- softmax = g.op("Softmax", input, axis_i=dim)
- if dtype and dtype.node().kind() != "prim::Constant":
- parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
- softmax = g.op(
- "Cast", softmax, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
- )
- return softmax
- @_onnx_symbolic("aten::log_softmax")
- @symbolic_helper.parse_args("v", "i", "none")
- @_beartype.beartype
- def log_softmax(g: jit_utils.GraphContext, input, dim, dtype=None):
- return_op = g.op("LogSoftmax", input, axis_i=dim)
- if dtype and dtype.node().kind() != "prim::Constant":
- parsed_dtype = symbolic_helper._get_const(dtype, "i", "dtype")
- return_op = g.op(
- "Cast", return_op, to_i=_type_utils.JitScalarType(parsed_dtype).onnx_type()
- )
- return return_op
- @_onnx_symbolic("aten::frobenius_norm")
- @symbolic_helper.parse_args("v", "v", "i")
- @_beartype.beartype
- def frobenius_norm(g: jit_utils.GraphContext, self, dim=None, keepdim=False):
- dim_val = symbolic_helper._maybe_get_const(dim, "is")
- if not symbolic_helper._is_value(dim_val) and len(dim_val) == 0:
- return g.op("ReduceL2", self, keepdims_i=0)
- sqr = g.op("Mul", self, self)
- sumsqr = symbolic_helper._reducesum_helper(g, sqr, dim, keepdims_i=keepdim)
- return g.op("Sqrt", sumsqr)
- @_onnx_symbolic("aten::split")
- @symbolic_helper.parse_args("v", "v", "i", "i")
- @_beartype.beartype
- def split(g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None):
- if not symbolic_helper._is_split_static(split_size_or_sizes, _outputs):
- split_out = g.op("SplitToSequence", self, split_size_or_sizes, axis_i=dim)
- if _outputs is None:
- return split_out
- # Convert to multiple slice nodes iff number of splits and number of outputs are statically known.
- if (
- symbolic_helper._is_packed_list(split_size_or_sizes)
- and len(symbolic_helper._unpack_list(split_size_or_sizes)) == _outputs
- ):
- split_sizes = [
- symbolic_helper._unsqueeze_helper(g, v, [0])
- for v in symbolic_helper._unpack_list(split_size_or_sizes)
- ]
- start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
- axis = g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long))
- res = []
- for i in range(_outputs):
- end = g.op(
- "Add", start, split_sizes[i]
- ) # split_sizes is a list of same length as _outputs
- res.append(g.op("Slice", self, start, end, axis))
- start = end
- return res
- return [
- g.op(
- "SequenceAt",
- split_out,
- g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
- )
- for i in range(_outputs)
- ]
- split_val = symbolic_helper._node_get(split_size_or_sizes.node(), "value")
- if split_val.dim() > 0:
- return g.op("Split", self, split_size_or_sizes, axis_i=dim, outputs=_outputs)
- split_size = symbolic_helper._get_const(split_size_or_sizes, "i", "split_size")
- size = symbolic_helper._get_tensor_dim_size(self, dim)
- if size is None:
- if _outputs is not None:
- size = split_size * _outputs
- else:
- raise errors.SymbolicValueError(
- "Unknown dimension size not supported", self
- )
- splits = [split_size] * (size // split_size)
- leftover = size % split_size
- if leftover:
- splits.append(leftover)
- splits = g.op("Constant", value_t=torch.tensor(splits))
- return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
- @_onnx_symbolic("aten::split_with_sizes")
- @_beartype.beartype
- def split_with_sizes(g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None):
- return split(g, self, split_sizes, dim, _outputs)
- @_onnx_symbolic("aten::unsafe_split")
- @_beartype.beartype
- def unsafe_split(
- g: jit_utils.GraphContext, self, split_size_or_sizes, dim, _outputs=None
- ):
- return split(g, self, split_size_or_sizes, dim, _outputs)
- @_onnx_symbolic("aten::unsafe_split_with_sizes")
- @_beartype.beartype
- def unsafe_split_with_sizes(
- g: jit_utils.GraphContext, self, split_sizes, dim, _outputs=None
- ):
- return split_with_sizes(g, self, split_sizes, dim, _outputs)
- @_onnx_symbolic("aten::tensor_split")
- @symbolic_helper.parse_args("v", "v", "i", "i")
- @_beartype.beartype
- def tensor_split(
- g: jit_utils.GraphContext, self, indices_or_sections, dim, _outputs=None
- ):
- axis = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.long))
- axis = opset11.unsqueeze(g, axis, 0)
- const_1 = g.op("Constant", value_t=torch.tensor(1, dtype=torch.long))
- if symbolic_helper._is_split_static(indices_or_sections, _outputs):
- split_val = symbolic_helper._node_get(indices_or_sections.node(), "value")
- if split_val.dim() > 0:
- start = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
- res = []
- assert _outputs is not None
- for i in range(_outputs - 1):
- end = g.op(
- "Gather",
- indices_or_sections,
- g.op("Constant", value_t=torch.tensor([i], dtype=torch.long)),
- axis_i=0,
- )
- res.append(g.op("Slice", self, start, end, axis))
- start = end
- end = symbolic_helper._size_helper(g, self, axis)
- res.append(g.op("Slice", self, start, end, axis))
- return res
- split_size = symbolic_helper._get_const(
- indices_or_sections, "i", "indices_or_sections"
- )
- size = symbolic_helper._get_tensor_dim_size(self, dim)
- if size is None:
- if _outputs is not None:
- size = split_size * _outputs
- else:
- raise errors.SymbolicValueError(
- "Unknown dimension size not supported", self
- )
- min_split_size = size // split_size
- num_splits_one_extra = size % split_size
- splits = num_splits_one_extra * [min_split_size + 1]
- leftover = (split_size - num_splits_one_extra) * [min_split_size]
- splits = g.op(
- "Constant", value_t=torch.tensor(splits + leftover, dtype=torch.long)
- )
- return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
- if (
- symbolic_helper._is_tensor(indices_or_sections)
- and symbolic_helper._get_tensor_rank(indices_or_sections) == 1
- ):
- loop_len = symbolic_helper._size_helper(
- g, indices_or_sections, g.op("Constant", value_t=torch.tensor(0))
- )
- loop_len = opset11.unsqueeze(g, loop_len, 0)
- loop_condition = g.op("Cast", const_1, to_i=_C_onnx.TensorProtoDataType.BOOL)
- # To make the first slice in the below loop work,
- # we pad a zero to the first position so that it will be the initial start of slice.
- padding_0 = g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))
- indices_or_sections = g.op("Concat", padding_0, indices_or_sections, axis_i=0)
- final_splits = g.op("SequenceEmpty")
- # Loop inputs
- loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
- g, "Loop", loop_len, loop_condition, final_splits, outputs=1, n_blocks=1
- )
- loop_block = loop_context.block
- block_input_iter = utils._add_input_to_block(loop_block)
- cond = utils._add_input_to_block(loop_block)
- final_splits = utils._add_input_to_block(loop_block)
- start = loop_context.op(
- "Gather", indices_or_sections, block_input_iter, axis_i=0
- )
- end = loop_context.op(
- "Gather",
- indices_or_sections,
- loop_context.op("Add", block_input_iter, const_1),
- axis_i=0,
- )
- slice = loop_context.op("Slice", self, start, end, axis)
- final_splits = loop_context.op("SequenceInsert", final_splits, slice)
- # Loop outputs
- cond_out = loop_context.op("Identity", loop_condition)
- utils._add_output_to_block(loop_block, cond_out)
- utils._add_output_to_block(loop_block, final_splits)
- loop_out = loop.node().output()
- start = g.op(
- "Gather",
- indices_or_sections,
- g.op("Constant", value_t=torch.tensor(-1, dtype=torch.long)),
- axis_i=0,
- )
- start = opset11.unsqueeze(g, start, 0)
- end = symbolic_helper._size_helper(g, self, axis)
- last_slice = g.op("Slice", self, start, end, axis)
- return g.op("SequenceInsert", loop_out, last_slice)
- else: # scalar tensor
- dim_size = symbolic_helper._size_helper(g, self, axis)
- min_split_size = g.op("Div", dim_size, indices_or_sections)
- min_split_size_plus_1 = g.op(
- "Add",
- min_split_size,
- const_1,
- )
- num_splits_one_extra = g.op("Mod", dim_size, indices_or_sections)
- splits = g.op("Tile", min_split_size_plus_1, num_splits_one_extra)
- leftover = g.op(
- "Tile",
- min_split_size,
- g.op(
- "Sub",
- opset11.unsqueeze(g, indices_or_sections, 0),
- num_splits_one_extra,
- ),
- )
- splits = g.op("Concat", splits, leftover, axis_i=0)
- if _outputs is None:
- return g.op("SplitToSequence", self, splits, axis_i=dim)
- return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
- @_onnx_symbolic("aten::unbind")
- @symbolic_helper.parse_args("v", "i", "i")
- @_beartype.beartype
- def unbind(g: jit_utils.GraphContext, self, dim=0, _outputs=None):
- if _outputs is None:
- return g.op(
- "SplitToSequence",
- self,
- g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
- axis_i=dim,
- keepdims_i=0,
- )
- splits = g.op("Constant", value_t=torch.tensor([1] * _outputs))
- outputs = g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
- outputs = [outputs] if _outputs == 1 else outputs
- squeezed_outputs = [
- g.op("Squeeze", out, g.op("Constant", value_t=torch.tensor([dim])))
- for out in outputs
- ]
- return squeezed_outputs
- @_onnx_symbolic("aten::nonzero_numpy")
- # Emitted from `torch.nonzero(x, as_tuple=True)`
- @_beartype.beartype
- def nonzero_numpy(g: jit_utils.GraphContext, input, _outputs=None):
- return unbind(g, opset9.nonzero(g, input), 1, _outputs=_outputs)
- @_onnx_symbolic("aten::where")
- @symbolic_helper.parse_args("v", "v", "v", "i")
- @_beartype.beartype
- def where(g: jit_utils.GraphContext, condition, self=None, other=None, _outputs=None):
- # Assumes that torch.where's first argument takes only Bool and Byte tensors.
- if not symbolic_helper._is_bool(condition):
- condition = g.op("Cast", condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
- if self is None:
- condition = opset9.nonzero(g, condition)
- return symbolic_helper._unbind_helper(
- g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs
- )
- return g.op("Where", condition, self, other)
- @_onnx_symbolic("aten::fake_quantize_per_channel_affine")
- @symbolic_helper.parse_args("v", "v", "v", "i", "i", "i")
- @_beartype.beartype
- def fake_quantize_per_channel_affine(
- g: jit_utils.GraphContext,
- inputs,
- scale,
- zero_point,
- axis,
- quant_min=-128,
- quant_max=127,
- ):
- # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
- # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
- if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
- raise errors.SymbolicValueError(
- "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
- f"Got ({quant_min}, {quant_max})",
- inputs,
- )
- # ONNX defines zero_point to be int8 or uint8
- if quant_min == 0:
- zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
- else:
- zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
- quantized = g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis)
- if (quant_min, quant_max) == (0, 127):
- quantized = g.op(
- "Clip",
- quantized,
- opset9.unused(g),
- g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
- )
- return g.op("DequantizeLinear", quantized, scale, zero_point, axis_i=axis)
- @_onnx_symbolic("aten::fake_quantize_per_tensor_affine")
- @symbolic_helper.parse_args("v", "v", "v", "i", "i")
- @_beartype.beartype
- def fake_quantize_per_tensor_affine(
- g: jit_utils.GraphContext,
- inputs,
- scale,
- zero_point,
- quant_min=-128,
- quant_max=127,
- ):
- # NOTE: (0, 127) is allowed as special case. PyTorch restricts activations to be in the range (0, 127).
- # https://github.com/pytorch/pytorch/blob/b34b192d6b97325c9f78e5995c48c8498ede34bd/torch/ao/quantization/observer.py#L1422
- if (quant_min, quant_max) not in [(0, 255), (-128, 127), (0, 127)]:
- raise errors.SymbolicValueError(
- "For (quant_min, quant_max), ONNX allows only (0, 127), (0, 255) and (-128, 127). "
- f"Got ({quant_min}, {quant_max})",
- inputs,
- )
- if quant_min == 0:
- zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.UINT8)
- else:
- zero_point = g.op("Cast", zero_point, to_i=_C_onnx.TensorProtoDataType.INT8)
- if (
- _type_utils.JitScalarType.from_value(scale, _type_utils.JitScalarType.UNDEFINED)
- != _type_utils.JitScalarType.FLOAT
- ):
- scale = g.op("Cast", scale, to_i=_C_onnx.TensorProtoDataType.FLOAT)
- quantized = g.op("QuantizeLinear", inputs, scale, zero_point)
- if (quant_min, quant_max) == (0, 127):
- quantized = g.op(
- "Clip",
- quantized,
- opset9.unused(g),
- g.op("Constant", value_t=torch.tensor(127, dtype=torch.uint8)),
- )
- return g.op("DequantizeLinear", quantized, scale, zero_point)
- @_beartype.beartype
- def _reduce_op_symbolic(onnx_op_name):
- @_beartype.beartype
- def symbolic(g, self, dim=None, keepdim=None):
- self = opset9._maybe_cast_reduce_op_input(g, self)
- if dim is None:
- # all-reduce path
- return symbolic_helper._handle_reduce_dim_none(g, self, onnx_op_name)
- else:
- keepdim = symbolic_helper._get_const(keepdim, "i", "keepdim")
- return g.op(onnx_op_name, self, dim, keepdims_i=keepdim)
- return symbolic
- @_onnx_symbolic(
- "aten::sum",
- decorate=[_apply_params("ReduceSum", "sum")],
- )
- @_beartype.beartype
- def _reduce_with_dtype(onnx_op, name):
- symbolic = _reduce_op_symbolic(onnx_op)
- @opset9.overload_by_arg_count
- @_beartype.beartype
- def reduce(g, *args, **kwargs):
- @symbolic_helper.parse_args("v", "none")
- @_beartype.beartype
- def reduce_nodim(g, self, dtype):
- if dtype.node().kind() == "onnx::Constant":
- dtype = symbolic_helper._get_const(dtype, "i", "dtype")
- self = g.op(
- "Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()
- )
- elif dtype.node().kind() != "prim::Constant":
- return symbolic_helper._unimplemented(name, "dtype", dtype)
- return symbolic(g, self)
- @symbolic_helper.parse_args("v", "v", "i", "none")
- @_beartype.beartype
- def reduce_dim(g, self, dim, keepdim, dtype):
- if dtype.node().kind() == "onnx::Constant":
- dtype = symbolic_helper._get_const(dtype, "i", "dtype")
- self = g.op(
- "Cast", self, to_i=_type_utils.JitScalarType(dtype).onnx_type()
- )
- elif dtype.node().kind() != "prim::Constant":
- return symbolic_helper._unimplemented(name, "dtype", dtype)
- return symbolic(g, self, dim, keepdim)
- return reduce_nodim, reduce_dim
- return reduce
- @_onnx_symbolic("aten::unsafe_chunk")
- @symbolic_helper.parse_args("v", "i", "i", "i")
- @_beartype.beartype
- def unsafe_chunk(g: jit_utils.GraphContext, self, chunks, dim, _outputs=None):
- if _outputs is None:
- return g.op(
- "SplitToSequence",
- self,
- g.op("Constant", value_t=torch.tensor(1, dtype=torch.long)),
- axis_i=dim,
- keepdims_i=0,
- )
- size = symbolic_helper._get_tensor_dim_size(self, dim)
- if size is None:
- return symbolic_helper._unimplemented("unsafe_chunk", "unknown dimension size")
- split_size = (size + chunks - 1) // chunks
- splits = [split_size] * (size // split_size)
- leftover = size % split_size
- if leftover:
- splits.append(leftover)
- # TODO: So far we don"t have a module using this method. We"ll keep
- # this as a constant unless we see a request of dynamics in any
- # user's modules.
- splits = g.op("Constant", value_t=torch.tensor(splits, dtype=torch.long))
- return g.op("Split", self, splits, axis_i=dim, outputs=_outputs)
- @_onnx_symbolic("aten::repeat_interleave")
- @_beartype.beartype
- def repeat_interleave(
- g: jit_utils.GraphContext, self, repeats, dim=None, output_size=None
- ):
- input = self
- final_dim = dim
- # if dim is None flatten
- # By default, use the flattened input array, and return a flat output array
- if symbolic_helper._is_none(dim):
- input = symbolic_helper._reshape_helper(
- g, self, g.op("Constant", value_t=torch.tensor([-1]))
- )
- dim = 0
- else:
- dim = symbolic_helper._maybe_get_scalar(dim)
- repeats_dim = symbolic_helper._get_tensor_rank(repeats)
- repeats_sizes = symbolic_helper._get_tensor_sizes(repeats)
- input_sizes = symbolic_helper._get_tensor_sizes(input)
- if repeats_dim is None:
- raise errors.SymbolicValueError(
- "Unsupported: ONNX export of repeat_interleave for unknown repeats rank.",
- self,
- )
- if repeats_sizes is None:
- raise errors.SymbolicValueError(
- "Unsupported: ONNX export of repeat_interleave for unknown repeats size.",
- self,
- )
- if input_sizes is None:
- raise errors.SymbolicValueError(
- "Unsupported: ONNX export of repeat_interleave for unknown input size.",
- self,
- )
- # Handle cases where dim is negative
- if dim < 0:
- dim += len(input_sizes)
- output_sizes = input_sizes.copy()
- for idx, input_size in enumerate(input_sizes):
- if input_size is None:
- output_sizes[idx], input_sizes[idx] = 0, -1
- cond_dynamic_repeats = repeats_dim == 1 and repeats_sizes[0] is None
- # If input size is dynamic or repeats vector is dynamic
- if output_sizes[dim] == 0 or cond_dynamic_repeats:
- reps = symbolic_helper._size_helper(g, input, dim)
- reps = opset11.unsqueeze(g, reps, 0)
- # Check if repeats vector is a single integer value
- # or a single dimension tensor with non-dynamic values
- if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1):
- if not symbolic_helper._is_tensor(repeats):
- repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
- repeats = g.op("Expand", repeats, reps)
- # Check if repeats is dynamic
- # As repeats is dynamic, we use a where node as a substitute for the if statement
- # If repests_dim = 1, expand repeats otherwise use original tensor
- elif cond_dynamic_repeats:
- repeat_dim = symbolic_helper._size_helper(
- g, repeats, g.op("Constant", value_t=torch.LongTensor([0]))
- )
- repeat_cond = g.op(
- "Equal", repeat_dim, g.op("Constant", value_t=torch.LongTensor([1]))
- )
- repeats = where(g, repeat_cond, g.op("Expand", repeats, reps), repeats)
- # There are cases when the repeats are 1-d tensor with multiple repeats, but dim
- # provided along one of the dynamic axes provided. A simple example would be
- # input.shape -> [1, 1, *] where * represents the dynamic axes, and dim = 2
- # Now, repeat interleaving can be performed in pytorch when the value of * matches
- # with the number of elements in repeat, for example if * -> 2, number of repeats
- # should be 2 as well.
- else:
- return opset9.repeat_interleave(g, self, repeats, final_dim)
- reps_like = g.op(
- "ConstantOfShape",
- g.op("Shape", repeats),
- value_t=torch.tensor([1], dtype=torch.long),
- )
- r_splits = split(g, repeats, reps_like, 0)
- i_splits = split(g, input, reps_like, dim)
- output_sizes[dim], input_sizes[dim] = -1, 1
- # Create a loop to iterate over each value along the dimension
- # and perform individual interleaving using the repeats tensor
- # Loop is of the following pattern
- # input (trip_count, cond)
- # int trip_count = ...;
- # bool cond = ...;
- # for (int i=0; i < trip_count && cond; ++i) {
- # cond = ...;
- # }
- # Loop conditions
- loop_condition = g.op("Constant", value_t=torch.tensor(1))
- loop_condition = g.op("Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL)
- loop_len = reps
- # Create an empty sequence to store final expansions
- final_splits = g.op("SequenceEmpty")
- # Loop inputs
- loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
- g, "Loop", loop_len, loop_condition, final_splits, n_blocks=1
- )
- loop_block = loop_context.block
- block_input_iter = utils._add_input_to_block(loop_block)
- cond = utils._add_input_to_block(loop_block)
- final_splits = utils._add_input_to_block(loop_block)
- r_split = loop_context.op("SequenceAt", r_splits, block_input_iter)
- i_split = loop_context.op("SequenceAt", i_splits, block_input_iter)
- i_split = opset11.unsqueeze(loop_context, i_split, dim + 1)
- r_concat = [
- loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[: dim + 1])),
- r_split,
- loop_context.op("Constant", value_t=torch.LongTensor(input_sizes[dim + 1 :])),
- ]
- r_concat = loop_context.op("Concat", *r_concat, axis_i=0)
- i_split = opset9.expand(loop_context, i_split, r_concat, None)
- i_split = symbolic_helper._reshape_helper(
- loop_context, i_split, g.op("Constant", value_t=torch.LongTensor(output_sizes))
- )
- final_splits = loop_context.op("SequenceInsert", final_splits, i_split)
- # Loop outputs
- cond_out = loop_context.op(
- "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
- )
- utils._add_output_to_block(loop_block, cond_out)
- utils._add_output_to_block(loop_block, final_splits)
- loop_out = loop.node().output()
- loop_out = g.op("ConcatFromSequence", loop_out, axis_i=dim)
- return loop_out
- @_onnx_symbolic("aten::diagonal")
- @symbolic_helper.parse_args("v", "i", "i", "i")
- @_beartype.beartype
- def diagonal(g: jit_utils.GraphContext, self, offset, dim1, dim2):
- dim1_size = opset9.size(
- g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim1]))
- )
- dim2_size = opset9.size(
- g, self, dim=g.op("Constant", value_t=torch.LongTensor([dim2]))
- )
- # Create appropriate mask
- mask_shape = g.op("Concat", dim1_size, dim2_size, axis_i=0)
- mask = opset9.zeros(g, mask_shape, None, None, None)
- mask = g.op("EyeLike", mask, k_i=offset)
- # dim1 and dim2 appended as a dimension at the end of the shape
- rank = symbolic_helper._get_tensor_rank(self)
- if rank is not None:
- axes = list(range(rank))
- axes.remove(dim1)
- axes.remove(dim2)
- self = g.op("Transpose", self, perm_i=axes + [dim1, dim2])
- else:
- return symbolic_helper._unimplemented("diagonal", "unknown input rank")
- # Multiply input and mask to calculate values along diagonal
- # The mask consists of one values where diagonal values are to be calculated
- # For example:
- # [[1.1, 1.2, 1.3], * [[1, 0, 0] = [[1.1, 0, 0],
- # [2.1, 2.2, 2.3], [0, 1, 0] [0, 2.2, 0],
- # [3.1, 3.2, 3.3]] [0, 0, 1]] [0, 0, 3.3]]
- result = g.op("Mul", self, mask)
- result = symbolic_helper._reducesum_helper(g, result, axes_i=[-1], keepdims_i=0)
- # Calculate gather indices based on offset and dims
- # If offset is greater than zero, set offset to zero as this aids in
- # calculation of selection window
- offset_op = g.op("Constant", value_t=torch.LongTensor([offset]))
- if offset >= 0:
- diag_size = g.op(
- "Max",
- g.op("Min", dim1_size, g.op("Sub", dim2_size, offset_op)),
- g.op("Constant", value_t=torch.LongTensor([0])),
- )
- offset = 0
- else:
- diag_size = g.op(
- "Max",
- g.op("Min", g.op("Add", dim1_size, offset_op), dim2_size),
- g.op("Constant", value_t=torch.LongTensor([0])),
- )
- diag_size = g.op("Concat", diag_size, axis_i=0)
- # Calculate which diagonal values to select
- # For example, in cases with offsets:
- # [[0, 1.1, 0]
- # [0, 0, 2.2]]
- # we need to select the last two columns, so we create a tensor
- # with all columns that are to be selected
- # So in this example, it is [1, 2]
- select_window_ones_fill = opset9.ones(g, diag_size, 4, None, None)
- select_window = g.op(
- "CumSum",
- select_window_ones_fill,
- g.op("Constant", value_t=torch.LongTensor([0])),
- )
- select_window = g.op(
- "Add",
- select_window,
- g.op("Constant", value_t=torch.LongTensor([abs(offset) - 1])),
- )
- gather_shape = [
- opset9.size(g, result, dim=g.op("Constant", value_t=torch.LongTensor([axis])))
- for axis in list(range(rank))[:-2]
- ]
- gather_shape.append(diag_size)
- gather_shape = g.op("Concat", *gather_shape, axis_i=0)
- gather_indices = opset9.zeros(g, gather_shape, 4, None, None)
- # There might be cases where offset value is greater than number of rows/columns
- # and might cause the diagonal to overrun and as a result of this, diag_size would be zero.
- # For example, if
- # offset = 9, dim1_size = 2 (columns), dim2_size = 4 (rows)
- # diag_size = max(min(2, (4-9)), 0) = 0, based on calculation above
- # Cases with diagonal overrun always result in diag_size = max(0, -ve value) = 0
- # In cases without diagonal overrun, we select the appropriate rows/columns along which we
- # are calculating diagonal values. In cases with diagonal overrun, we return a tensor which has
- # the dimension of the row/column where overrun occurred as 0-dim, as we are essentially
- # returning an empty tensor
- overrun_cond = g.op(
- "Not",
- g.op(
- "Equal",
- diag_size,
- g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)),
- ),
- )
- if_op, (if_context, else_context), _ = jit_utils.add_op_with_blocks(
- g, "If", overrun_cond, n_blocks=2
- )
- gather_indices_if_block = if_context.op("Add", gather_indices, select_window)
- gather_indices_if_block = symbolic_helper._unsqueeze_helper(
- if_context, gather_indices_if_block, [rank - 1]
- )
- final_non_overrun = if_context.op(
- "GatherND", result, gather_indices_if_block, batch_dims_i=rank - 2
- )
- final_overrun = opset9.zeros(else_context, gather_shape, 6, None, None)
- utils._add_output_to_block(if_context.block, final_non_overrun)
- utils._add_output_to_block(else_context.block, final_overrun)
- return if_op
- # Quantized ops
- @_onnx_symbolic("quantized::linear")
- @_beartype.beartype
- def quantized_linear(
- g: jit_utils.GraphContext, q_input, q_weight, bias, op_scale, op_zero_point
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(
- g, bias, input_scale, weight_scale, axis
- )
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.linear(g, input, weight, bias)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::conv2d")
- @_beartype.beartype
- def quantized_conv2d(
- g: jit_utils.GraphContext,
- q_input,
- q_weight,
- bias,
- stride,
- padding,
- dilation,
- groups,
- op_scale,
- op_zero_point,
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(
- g, bias, input_scale, weight_scale, axis
- )
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
- @_onnx_symbolic("quantized::conv2d_relu")
- @_beartype.beartype
- def quantized_conv2d_relu(
- g: jit_utils.GraphContext,
- q_input,
- q_weight,
- bias,
- stride,
- padding,
- dilation,
- groups,
- op_scale,
- op_zero_point,
- ):
- input, input_scale, _, _ = symbolic_helper.dequantize_helper(g, q_input)
- weight, weight_scale, _, axis = symbolic_helper.dequantize_helper(g, q_weight)
- q_bias = symbolic_helper.requantize_bias_helper(
- g, bias, input_scale, weight_scale, axis
- )
- bias, _, _, _ = symbolic_helper.dequantize_helper(g, q_bias)
- output = opset9.conv2d(g, input, weight, bias, stride, padding, dilation, groups)
- output = opset9.relu(g, output)
- return symbolic_helper.quantize_helper(g, output, op_scale, op_zero_point)
|