123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483 |
- import functools
- import sys
- from typing import Optional, Tuple
- import torch
- from torch._C import _onnx as _C_onnx
- from torch.onnx import (
- _type_utils,
- errors,
- symbolic_helper,
- symbolic_opset9 as opset9,
- utils,
- )
- from torch.onnx._internal import _beartype, jit_utils, registration
- # EDITING THIS FILE? READ THIS FIRST!
- # see Note [Edit Symbolic Files] in README.md
- # This file exports ONNX ops for opset 12
- __all__ = [
- "argmax",
- "argmin",
- "binary_cross_entropy_with_logits",
- "celu",
- "cross_entropy_loss",
- "dropout",
- "einsum",
- "ge",
- "le",
- "native_dropout",
- "nll_loss",
- "nll_loss2d",
- "nll_loss_nd",
- "outer",
- "pow",
- "tensordot",
- "unfold",
- ]
- _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=12)
- @_beartype.beartype
- def _einsum_helper(g: jit_utils.GraphContext, equation, tensors):
- if not tensors:
- raise RuntimeError("Einsum inputs are empty.")
- # ONNX does not support bool for Einsum inputs.
- if symbolic_helper._is_bool(tensors[0]):
- tensors = [
- g.op("Cast", tensor, to_i=_C_onnx.TensorProtoDataType.INT64)
- for tensor in tensors
- ]
- return g.op(
- "Cast",
- g.op("Einsum", *tensors, equation_s=equation),
- to_i=_C_onnx.TensorProtoDataType.BOOL,
- )
- else:
- return g.op("Einsum", *tensors, equation_s=equation)
- @_onnx_symbolic("aten::einsum")
- @symbolic_helper.parse_args("s", "v", "is")
- @_beartype.beartype
- def einsum(g: jit_utils.GraphContext, equation, tensor_list, path=None):
- tensors = symbolic_helper._unpack_list(tensor_list)
- return _einsum_helper(g, equation, tensors)
- @_onnx_symbolic("aten::outer")
- @symbolic_helper.parse_args("v", "v")
- @_beartype.beartype
- def outer(g: jit_utils.GraphContext, input, other):
- # make sure to cast other to self's type
- if _type_utils.JitScalarType.from_value(
- other, _type_utils.JitScalarType.UNDEFINED
- ) != _type_utils.JitScalarType.from_value(input):
- other = g.op(
- "Cast",
- other,
- to_i=_type_utils.JitScalarType.from_value(input).onnx_type(),
- )
- return _einsum_helper(g, "i,j->ij", [input, other])
- @_beartype.beartype
- def _dropout_returns_masked_input_and_mask(
- g: jit_utils.GraphContext, input: torch._C.Value, p: float, train: bool
- ) -> Tuple[torch._C.Value, Optional[torch._C.Value]]:
- symbolic_helper.check_training_mode(train, "dropout")
- # In eval mode, dropout is non-op. That is, if the node's
- # train param is set to False, dropout just returns its inputs.
- if not train:
- return input, None
- p = g.op("Constant", value_t=torch.tensor(p))
- t = g.op("Constant", value_t=torch.tensor(train, dtype=torch.bool))
- r, mask = g.op("Dropout", input, p, t, outputs=2)
- return r, mask
- @_onnx_symbolic("aten::dropout")
- @symbolic_helper.parse_args("v", "f", "b")
- @_beartype.beartype
- def dropout(g: jit_utils.GraphContext, input, p, train):
- masked, _ = _dropout_returns_masked_input_and_mask(g, input, p, train)
- return masked
- @_onnx_symbolic("aten::native_dropout")
- @symbolic_helper.parse_args("v", "f", "b")
- @_beartype.beartype
- def native_dropout(g: jit_utils.GraphContext, input, p, train):
- return _dropout_returns_masked_input_and_mask(g, input, p, train)
- @_onnx_symbolic("aten::nll_loss")
- @_beartype.beartype
- def nll_loss(g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index):
- # none reduction : onnx::Constant[value={0}]
- # mean reduction : onnx::Constant[value={1}]
- # sum reduction : onnx::Constant[value={2}]
- reduction = symbolic_helper._maybe_get_const(reduction, "i")
- reduction_vals = ["none", "mean", "sum"]
- reduction = reduction_vals[reduction]
- # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
- # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
- ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
- if weight.node().mustBeNone():
- nllloss = g.op(
- "NegativeLogLikelihoodLoss",
- self,
- target,
- reduction_s=reduction,
- ignore_index_i=ignore_index,
- )
- else:
- nllloss = g.op(
- "NegativeLogLikelihoodLoss",
- self,
- target,
- weight,
- reduction_s=reduction,
- ignore_index_i=ignore_index,
- )
- return nllloss
- @_onnx_symbolic("aten::nll_loss2d")
- @_beartype.beartype
- def nll_loss2d(
- g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
- ):
- return nll_loss(g, self, target, weight, reduction, ignore_index)
- @_onnx_symbolic("aten::nll_loss_nd")
- @_beartype.beartype
- def nll_loss_nd(
- g: jit_utils.GraphContext, self, target, weight, reduction, ignore_index
- ):
- return nll_loss(g, self, target, weight, reduction, ignore_index)
- @_onnx_symbolic("aten::cross_entropy_loss")
- @_beartype.beartype
- def cross_entropy_loss(
- g: jit_utils.GraphContext,
- self,
- target,
- weight,
- reduction,
- ignore_index,
- label_smoothing,
- ):
- # none reduction : onnx::Constant[value={0}]
- # mean reduction : onnx::Constant[value={1}]
- # sum reduction : onnx::Constant[value={2}]
- reduction = symbolic_helper._maybe_get_const(reduction, "i")
- reduction_vals = ["none", "mean", "sum"]
- reduction = reduction_vals[reduction]
- label_smoothing = symbolic_helper._maybe_get_const(label_smoothing, "f")
- if label_smoothing is not None and label_smoothing > 0.0:
- raise errors.SymbolicValueError(
- "Unsupported: ONNX does not support label_smoothing", self
- )
- # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value.
- # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
- ignore_index = symbolic_helper._maybe_get_const(ignore_index, "i")
- if weight.node().mustBeNone():
- celoss = g.op(
- "SoftmaxCrossEntropyLoss",
- self,
- target,
- reduction_s=reduction,
- ignore_index_i=ignore_index,
- )
- else:
- celoss = g.op(
- "SoftmaxCrossEntropyLoss",
- self,
- target,
- weight,
- reduction_s=reduction,
- ignore_index_i=ignore_index,
- )
- return celoss
- @_onnx_symbolic("aten::binary_cross_entropy_with_logits")
- @symbolic_helper.parse_args("v", "v", "v", "v", "i")
- @_beartype.beartype
- def binary_cross_entropy_with_logits(
- g: jit_utils.GraphContext, input, target, weight, pos_weight, reduction
- ):
- p = g.op("Constant", value_t=torch.tensor([1]))
- sig_x = opset9.sigmoid(g, input)
- log_sig_x = opset9.log(g, sig_x)
- sub_1_x = opset9.sub(g, p, sig_x)
- sub_1_y = opset9.sub(g, p, target)
- log_1_x = opset9.log(g, sub_1_x)
- if pos_weight is None or symbolic_helper._is_none(pos_weight):
- output = opset9.neg(
- g,
- opset9.add(
- g, opset9.mul(g, target, log_sig_x), opset9.mul(g, sub_1_y, log_1_x)
- ),
- )
- else:
- output = opset9.neg(
- g,
- opset9.add(
- g,
- opset9.mul(g, opset9.mul(g, target, log_sig_x), pos_weight),
- opset9.mul(g, sub_1_y, log_1_x),
- ),
- )
- if weight is not None and not symbolic_helper._is_none(weight):
- output = opset9.mul(g, weight, output)
- reduction = symbolic_helper._maybe_get_const(reduction, "i")
- if reduction == 0:
- return output
- elif reduction == 1:
- return g.op("ReduceMean", output, keepdims_i=0)
- elif reduction == 2:
- return g.op("ReduceSum", output, keepdims_i=0)
- else:
- return symbolic_helper._onnx_unsupported(
- "binary_cross_entropy_with_logits with reduction other than none, mean, or sum",
- input,
- )
- @_onnx_symbolic("aten::celu")
- @_beartype.beartype
- def celu(g: jit_utils.GraphContext, self, alpha):
- alpha = symbolic_helper._maybe_get_const(alpha, "f")
- # if the input is of type double cast it to float
- if (
- _type_utils.JitScalarType.from_value(self, _type_utils.JitScalarType.UNDEFINED)
- == _type_utils.JitScalarType.DOUBLE
- ):
- self = g.op("Cast", self, to_i=_C_onnx.TensorProtoDataType.FLOAT)
- out = g.op("Celu", self, alpha_f=alpha)
- return g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.DOUBLE)
- return g.op("Celu", self, alpha_f=alpha)
- @_onnx_symbolic("aten::argmax")
- @symbolic_helper.parse_args("v", "v", "b")
- @_beartype.beartype
- def argmax(
- g: jit_utils.GraphContext,
- input: torch._C.Value,
- dim: torch._C.Value,
- keepdim: bool,
- ):
- return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMax")
- @_onnx_symbolic("aten::argmin")
- @symbolic_helper.parse_args("v", "v", "b")
- @_beartype.beartype
- def argmin(
- g: jit_utils.GraphContext,
- input: torch._C.Value,
- dim: torch._C.Value,
- keepdim: bool,
- ):
- return symbolic_helper._argmin_argmax_helper(g, input, dim, keepdim, "ArgMin")
- @_onnx_symbolic("aten::pow")
- @_beartype.beartype
- def pow(g: jit_utils.GraphContext, self, exponent):
- return g.op("Pow", self, exponent)
- @_onnx_symbolic("aten::ge")
- @_beartype.beartype
- def ge(g: jit_utils.GraphContext, input, other):
- return g.op("GreaterOrEqual", input, other)
- @_onnx_symbolic("aten::le")
- @_beartype.beartype
- def le(g: jit_utils.GraphContext, input, other):
- return g.op("LessOrEqual", input, other)
- @_onnx_symbolic("aten::unfold")
- @symbolic_helper.parse_args("v", "i", "v", "v")
- @_beartype.beartype
- def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
- const_size = symbolic_helper._maybe_get_const(size, "i")
- const_step = symbolic_helper._maybe_get_const(step, "i")
- if not symbolic_helper._is_value(const_size) and not symbolic_helper._is_value(
- const_step
- ):
- return opset9.unfold(g, input, dimension, const_size, const_step)
- if symbolic_helper.is_caffe2_aten_fallback():
- return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
- sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
- if sizedim is not None:
- low_start = g.op("Constant", value_t=torch.tensor(0))
- low_end = g.op("Constant", value_t=torch.tensor(sizedim))
- hi_end = g.op("Constant", value_t=torch.tensor(sizedim + 1))
- low_indices = g.op("Range", low_start, low_end, step)
- hi_indices = g.op("Range", size, hi_end, step)
- low_size = symbolic_helper._size_helper(
- g, low_indices, g.op("Constant", value_t=torch.tensor(0))
- )
- hi_size = symbolic_helper._size_helper(
- g, hi_indices, g.op("Constant", value_t=torch.tensor(0))
- )
- ndim = symbolic_helper._get_tensor_rank(input)
- assert ndim is not None
- perm = list(range(0, ndim))
- perm.append(perm.pop(dimension))
- unsqueeze_list = []
- loop_condition = g.op("Constant", value_t=torch.tensor(1))
- loop_condition = g.op(
- "Cast", loop_condition, to_i=_C_onnx.TensorProtoDataType.BOOL
- )
- loop_len = g.op("Min", low_size, hi_size)
- loop, (loop_context,), _ = jit_utils.add_op_with_blocks(
- g, "Loop", loop_len, loop_condition, n_blocks=1
- )
- loop_block = loop_context.block
- block_input_iter = utils._add_input_to_block(loop_block)
- # FIXME(justinchuby): cond is unused?
- cond = utils._add_input_to_block(loop_block)
- starts = loop_context.op("Gather", low_indices, block_input_iter)
- ends = loop_context.op("Gather", hi_indices, block_input_iter)
- axes = loop_context.op("Constant", value_t=torch.tensor([2]))
- starts = symbolic_helper._unsqueeze_helper(loop_context, starts, [0])
- ends = symbolic_helper._unsqueeze_helper(loop_context, ends, [0])
- stack = loop_context.op("Slice", input, starts, ends, axes)
- unsqueeze = symbolic_helper._unsqueeze_helper(
- loop_context, loop_context.op("Transpose", stack, perm_i=perm), [dimension]
- )
- unsqueeze_list.append(unsqueeze)
- concat = loop_context.op("Concat", *unsqueeze_list, axis_i=0)
- cond_out = loop_context.op(
- "Cast", loop_condition, _C_onnx.TensorProtoDataType.BOOL
- )
- utils._add_output_to_block(loop_block, cond_out)
- utils._add_output_to_block(loop_block, concat)
- loop_output = loop.node().output()
- perm = [0, 1, 2, 3, 4]
- perm[0], perm[dimension + 1] = perm[dimension + 1], perm[0]
- transpose = g.op("Transpose", loop_output, perm_i=perm)
- squeeze = symbolic_helper._squeeze_helper(g, transpose, [0])
- return squeeze
- return symbolic_helper._unimplemented("Unfold", "input size not accessible")
- @_onnx_symbolic("aten::tensordot")
- @symbolic_helper.parse_args("v", "v", "is", "is", "v")
- @_beartype.beartype
- def tensordot(g: jit_utils.GraphContext, input_a, input_b, dims_a, dims_b, out=None):
- if out is not None:
- symbolic_helper._unimplemented(
- "Tensordot", "Out parameter is not supported for tensordot."
- )
- dim_count_a = symbolic_helper._get_tensor_rank(input_a)
- if dim_count_a is None:
- raise errors.SymbolicValueError(
- "Unsupported: ONNX export of tensordot for tensor(input_a) of unknown rank.",
- input_a,
- )
- dim_count_b = symbolic_helper._get_tensor_rank(input_b)
- if dim_count_b is None:
- raise errors.SymbolicValueError(
- "Unsupported: ONNX export of tensordot for tensor(input_b) of unknown rank.",
- input_b,
- )
- dims_a = [
- (dims_a[i] + dim_count_a) if (dims_a[i] < 0) else dims_a[i]
- for i in range(len(dims_a))
- ]
- dims_b = [
- (dims_b[i] + dim_count_b) if (dims_b[i] < 0) else dims_b[i]
- for i in range(len(dims_b))
- ]
- left_dims_a = [i for i in range(dim_count_a) if (i not in dims_a)]
- left_dims_b = [i for i in range(dim_count_b) if (i not in dims_b)]
- new_input_a = opset9.permute(g, input_a, left_dims_a + dims_a)
- new_input_b = opset9.permute(g, input_b, dims_b + left_dims_b)
- input_shape = g.op("Shape", new_input_a)
- left_sizes_a = symbolic_helper._slice_helper(
- g, input_shape, axes=[0], starts=[0], ends=[len(left_dims_a)]
- )
- shape_sizes = [
- left_sizes_a,
- g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
- ]
- output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
- input_shape = g.op("Shape", output_a)
- slices = symbolic_helper._slice_helper(
- g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
- )
- shape_sizes = [
- g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
- slices,
- ]
- output_a = opset9._reshape_from_tensor(g, new_input_a, shape_sizes)
- input_shape = g.op("Shape", new_input_b)
- left_sizes_b = symbolic_helper._slice_helper(
- g, input_shape, axes=[0], starts=[len(dims_b)], ends=[sys.maxsize]
- )
- slices = symbolic_helper._slice_helper(
- g, input_shape, axes=[0], starts=[0], ends=[len(dims_b)]
- )
- shape_sizes = [
- slices,
- g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
- ]
- output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
- input_shape = g.op("Shape", output_b)
- slices = symbolic_helper._slice_helper(
- g, input_shape, axes=[0], starts=[-1], ends=[sys.maxsize]
- )
- shape_sizes = [
- g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)),
- slices,
- ]
- output_b = opset9._reshape_from_tensor(g, new_input_b, shape_sizes)
- output = einsum(g, "ij,jk->ik", g.op("prim::ListConstruct", *[output_a, output_b]))
- shape_sizes = [left_sizes_a, left_sizes_b]
- return opset9._reshape_from_tensor(g, output, shape_sizes)
|