symbolic_opset7.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. """
  2. Note [ONNX operators that are added/updated from opset 7 to opset 8]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. New operators:
  5. Expand
  6. Updated operators:
  7. Min, Max, Sum, Mean: supports multidirectional broadcasting.
  8. MaxPool: added optional indices output.
  9. Scan
  10. """
  11. import functools
  12. import warnings
  13. from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
  14. from torch.onnx._internal import jit_utils, registration
  15. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=7)
  16. block_listed_operators = (
  17. "scan",
  18. "expand",
  19. "expand_as",
  20. "meshgrid",
  21. "adaptive_max_pool1d",
  22. "adaptive_max_pool2d",
  23. "adaptive_max_pool3d",
  24. "max_pool1d_with_indices",
  25. "max_pool2d_with_indices",
  26. "max_pool3d_with_indices",
  27. )
  28. # NOTE: max, min, sum, mean: broadcasting is not supported in opset 7.
  29. # torch.max (same for torch.min) actually has two interfaces smashed together:
  30. # torch.max(x, dim, keepdim) and torch.max(x, y)
  31. @_onnx_symbolic("aten::max")
  32. def max(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
  33. # torch.max(input, other)
  34. if keepdim is None and dim_or_y is not None:
  35. warnings.warn(
  36. "Multidirectional broadcasting is not supported in opset 7. "
  37. "This might cause the onnx model to be incorrect, if inputs to max operators "
  38. "have different shapes"
  39. )
  40. return opset9.max(g, self, dim_or_y, keepdim)
  41. @_onnx_symbolic("aten::min")
  42. def min(g: jit_utils.GraphContext, self, dim_or_y=None, keepdim=None):
  43. # torch.min(input, other)
  44. if keepdim is None and dim_or_y is not None:
  45. warnings.warn(
  46. "Multidirectional broadcasting is not supported in opset 7. "
  47. "This might cause the onnx model to be incorrect, if inputs to min operators "
  48. "have different shapes"
  49. )
  50. return opset9.min(g, self, dim_or_y, keepdim)
  51. for block_listed_op in block_listed_operators:
  52. _onnx_symbolic(f"aten::{block_listed_op}")(
  53. symbolic_helper._block_list_in_opset(block_listed_op)
  54. )