symbolic_opset8.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. """
  2. Note [ONNX operators that are added/updated from opset 8 to opset 9]
  3. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  4. New operators:
  5. Compress
  6. ConstantOfShape
  7. EyeLike
  8. MaxUnpool
  9. OneHot
  10. Sinh
  11. Cosh
  12. Asinh
  13. Acosh
  14. Atanh
  15. Shrink
  16. IsNaN
  17. Sign
  18. Erf
  19. Scatter
  20. Where
  21. NonZero
  22. TfIdfVectorizer
  23. MeanVarianceNormalization
  24. Updated operators:
  25. BatchNormalization: removed spatial attribute.
  26. Greater, Less, Constant, MatMul, PRelu, Gemm, Flatten: more data types{integers} supported.
  27. Cast: more data types{string} supported.
  28. Upsample: moved scales from attribute to input.
  29. Scan
  30. """
  31. import functools
  32. import warnings
  33. import torch
  34. from torch._C import _onnx as _C_onnx
  35. from torch.onnx import _type_utils, errors, symbolic_helper, symbolic_opset9 as opset9
  36. from torch.onnx._internal import jit_utils, registration
  37. _onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=8)
  38. block_listed_operators = (
  39. "nonzero",
  40. "where",
  41. "scatter",
  42. "scatter_add",
  43. "erf",
  44. "sign",
  45. "isnan",
  46. "gather",
  47. "arange",
  48. "masked_fill",
  49. "index_fill",
  50. "index_copy",
  51. "repeat_interleave",
  52. "any",
  53. "all",
  54. )
  55. for block_listed_op in block_listed_operators:
  56. _onnx_symbolic(f"aten::{block_listed_op}")(
  57. symbolic_helper._block_list_in_opset(block_listed_op)
  58. )
  59. def _apply_params(*args, **kwargs):
  60. """Returns a decorator that calls the decorated (higher-order) function with the given parameters."""
  61. def _apply(fn):
  62. return fn(*args, **kwargs)
  63. return _apply
  64. @_onnx_symbolic(
  65. "aten::upsample_nearest1d",
  66. decorate=[_apply_params("upsample_nearest1d", 3, "nearest")],
  67. )
  68. @_onnx_symbolic(
  69. "aten::upsample_nearest2d",
  70. decorate=[_apply_params("upsample_nearest2d", 4, "nearest")],
  71. )
  72. @_onnx_symbolic(
  73. "aten::upsample_nearest3d",
  74. decorate=[_apply_params("upsample_nearest3d", 5, "nearest")],
  75. )
  76. @_onnx_symbolic(
  77. "aten::upsample_linear1d",
  78. decorate=[_apply_params("upsample_linear1d", 3, "linear")],
  79. )
  80. @_onnx_symbolic(
  81. "aten::upsample_bilinear2d",
  82. decorate=[_apply_params("upsample_bilinear2d", 4, "linear")],
  83. )
  84. @_onnx_symbolic(
  85. "aten::upsample_trilinear3d",
  86. decorate=[_apply_params("upsample_trilinear3d", 5, "linear")],
  87. )
  88. def _interpolate(name, dim, interpolate_mode):
  89. def symbolic_fn(g, input, output_size, *args):
  90. scales, align_corners = symbolic_helper._get_interpolate_attributes(
  91. g, interpolate_mode, args
  92. )
  93. symbolic_helper._interpolate_warning(interpolate_mode)
  94. align_corners = symbolic_helper._maybe_get_scalar(align_corners)
  95. if align_corners:
  96. return symbolic_helper._unimplemented(name, "align_corners == True", input)
  97. output_size = symbolic_helper._maybe_get_const(output_size, "is")
  98. if symbolic_helper._is_value(output_size):
  99. return symbolic_helper._unimplemented(
  100. name, "torch._C.Value (output_size) indexing"
  101. )
  102. if scales is None:
  103. scales = [
  104. 1.0
  105. if i < 2
  106. else float(output_size[-(dim - i)])
  107. / float(input.type().sizes()[-(dim - i)])
  108. for i in range(0, dim)
  109. ]
  110. return g.op("Upsample", input, mode_s=interpolate_mode, scales_f=scales)
  111. return symbolic_fn
  112. @_onnx_symbolic("aten::__interpolate")
  113. def __interpolate(
  114. g: jit_utils.GraphContext,
  115. input,
  116. size,
  117. scale_factor,
  118. mode,
  119. align_corners,
  120. recompute_scale_factor,
  121. antialias,
  122. ):
  123. align_corners = symbolic_helper._maybe_get_const(align_corners, "b")
  124. if not symbolic_helper._is_none(align_corners) and align_corners:
  125. return symbolic_helper._unimplemented("interpolate", "align_corners == True")
  126. if not symbolic_helper._is_none(scale_factor) and symbolic_helper._is_value(
  127. scale_factor
  128. ):
  129. return symbolic_helper._unimplemented(
  130. "interpolate", "dynamic scales in opset 8"
  131. )
  132. if not symbolic_helper._is_none(size) and symbolic_helper._is_value(size):
  133. return symbolic_helper._unimplemented("interpolate", "dynamic size in opset 8")
  134. scales, mode = symbolic_helper._interpolate_get_scales_and_mode(
  135. g, input, size, scale_factor, mode, align_corners
  136. )
  137. return g.op("Upsample", input, mode_s=mode, scales_f=scales)
  138. # NOTE: We should create a wrapper for this kind of operation, after resolving the shape/type propagation
  139. # issue for "cast" operators. Some symbolic functions depend on shape information of input tensor, which
  140. # is lost after casting.
  141. def _try_cast_integer_to_float(g: jit_utils.GraphContext, *args):
  142. floating_scalar_types = {
  143. _type_utils.JitScalarType.HALF,
  144. _type_utils.JitScalarType.FLOAT,
  145. _type_utils.JitScalarType.DOUBLE,
  146. }
  147. old_type = None
  148. # Cast the input tensor to Float if its scalarType is known and is not floating number.
  149. # If casting is performed, return the old scalarType, otherwise return None.
  150. arg0_type = _type_utils.JitScalarType.from_value(
  151. args[0], _type_utils.JitScalarType.UNDEFINED
  152. )
  153. if arg0_type != _type_utils.JitScalarType.UNDEFINED:
  154. old_type = arg0_type
  155. if old_type not in floating_scalar_types:
  156. old_type = old_type.scalar_name()
  157. args = tuple(
  158. g.op("Cast", arg, to_i=_C_onnx.TensorProtoDataType.FLOAT)
  159. for arg in args
  160. )
  161. else:
  162. return (None,) + args
  163. else:
  164. warnings.warn(
  165. "Only floating datatype is supported for these operators: "
  166. "{Greater, Less, MatMul, PRelu, Gemm, Flatten}. This might cause "
  167. "the onnx model to be incorrect, if inputs have integer datatypes."
  168. )
  169. return (old_type,) + args
  170. def _cast_to_type(g: jit_utils.GraphContext, input, to_type):
  171. if to_type is None:
  172. return input
  173. return getattr(opset9, f"_cast_{to_type}")(g, input, False)
  174. def _comparison_operator(g: jit_utils.GraphContext, input, other, op_name):
  175. other = symbolic_helper._maybe_get_scalar(other)
  176. other = symbolic_helper._if_scalar_type_as(other, input)
  177. _, input, other = _try_cast_integer_to_float(g, input, other)
  178. return g.op(op_name, input, other)
  179. # NOTE: For symbolics {gt, lt, bmm, matmul, prelu, mm, addmm, view, flatten},
  180. # integer input type not supported in opset8. Cast to float if possible.
  181. @_onnx_symbolic("aten::gt")
  182. def gt(g: jit_utils.GraphContext, input, other):
  183. return _comparison_operator(g, input, other, "Greater")
  184. @_onnx_symbolic("aten::lt")
  185. def lt(g: jit_utils.GraphContext, input, other):
  186. return _comparison_operator(g, input, other, "Less")
  187. @_onnx_symbolic("aten::bmm")
  188. def bmm(g: jit_utils.GraphContext, self, other):
  189. if symbolic_helper._try_get_scalar_type(self):
  190. old_type, self, other = _try_cast_integer_to_float(g, self, other)
  191. return _cast_to_type(g, g.op("MatMul", self, other), old_type)
  192. else:
  193. return g.op("MatMul", self, other)
  194. @_onnx_symbolic("aten::matmul")
  195. def matmul(g: jit_utils.GraphContext, self, other):
  196. return bmm(g, self, other)
  197. @_onnx_symbolic("aten::prelu")
  198. def prelu(g: jit_utils.GraphContext, self, weight):
  199. self_rank = symbolic_helper._get_tensor_rank(self)
  200. weight_sizes = symbolic_helper._get_tensor_sizes(weight)
  201. if self_rank is not None and self_rank > 2:
  202. weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1)))
  203. elif self_rank == 0 and weight_sizes == [1]:
  204. # self and weight are both scalar but weight has rank == 1, squeeze weight.
  205. weight = symbolic_helper._squeeze_helper(g, weight, [0])
  206. if symbolic_helper._try_get_scalar_type(self):
  207. old_type, self, weight = _try_cast_integer_to_float(g, self, weight)
  208. return _cast_to_type(g, g.op("PRelu", self, weight), old_type)
  209. else:
  210. return g.op("PRelu", self, weight)
  211. @_onnx_symbolic("aten::mm")
  212. def mm(g: jit_utils.GraphContext, self, other):
  213. # Create a dummy C tensor. Only needed for API purposes, the value is
  214. # since beta = 0
  215. scalar_type = symbolic_helper._try_get_scalar_type(self, other)
  216. if scalar_type is None:
  217. raise errors.SymbolicValueError(
  218. "mm can only operate on tensors with known types", self
  219. )
  220. zero_constant = g.op(
  221. "Constant",
  222. value_t=torch.tensor([0], dtype=scalar_type.dtype()),
  223. )
  224. if symbolic_helper._try_get_scalar_type(self):
  225. old_type, self, other, zero_constant = _try_cast_integer_to_float(
  226. g, self, other, zero_constant
  227. )
  228. return _cast_to_type(
  229. g,
  230. g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0),
  231. old_type,
  232. )
  233. return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
  234. @_onnx_symbolic("aten::addmm")
  235. @symbolic_helper.parse_args("v", "v", "v", "t", "t")
  236. def addmm(g: jit_utils.GraphContext, self, mat1, mat2, beta, alpha):
  237. if symbolic_helper._try_get_scalar_type(self):
  238. old_type, self, mat1, mat2 = _try_cast_integer_to_float(g, self, mat1, mat2)
  239. return _cast_to_type(
  240. g,
  241. g.op(
  242. "Gemm",
  243. mat1,
  244. mat2,
  245. self,
  246. beta_f=symbolic_helper._scalar(beta),
  247. alpha_f=symbolic_helper._scalar(alpha),
  248. ),
  249. old_type,
  250. )
  251. else:
  252. return g.op(
  253. "Gemm",
  254. mat1,
  255. mat2,
  256. self,
  257. beta_f=symbolic_helper._scalar(beta),
  258. alpha_f=symbolic_helper._scalar(alpha),
  259. )
  260. @_onnx_symbolic("aten::flatten")
  261. def flatten(g: jit_utils.GraphContext, input, start_dim, end_dim):
  262. start_dim_i = symbolic_helper._get_const(start_dim, "i", "start_dim")
  263. end_dim_i = symbolic_helper._get_const(end_dim, "i", "end_dim")
  264. dim = input.type().dim()
  265. if end_dim_i < 0:
  266. end_dim_i = dim + end_dim_i
  267. # use ONNX's Flatten operator for cases where the output shape is 2D
  268. if start_dim_i == 1 and end_dim_i == dim - 1:
  269. if symbolic_helper._try_get_scalar_type(input):
  270. old_type, input = _try_cast_integer_to_float(g, input)
  271. return _cast_to_type(
  272. g, g.op("Flatten", input, axis_i=start_dim_i), old_type
  273. )
  274. else:
  275. return g.op("Flatten", input, axis_i=start_dim_i)
  276. if start_dim_i == 0 and end_dim_i == dim - 2:
  277. if symbolic_helper._try_get_scalar_type(input):
  278. old_type, input = _try_cast_integer_to_float(g, input)
  279. return _cast_to_type(
  280. g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type
  281. )
  282. else:
  283. return g.op("Flatten", input, axis_i=end_dim_i + 1)
  284. return opset9.flatten(g, input, start_dim, end_dim)
  285. def _constant_fill(g: jit_utils.GraphContext, sizes, dtype: int, const_value):
  286. if dtype is None:
  287. scalar_type = _type_utils.JitScalarType.FLOAT
  288. else:
  289. scalar_type = _type_utils.JitScalarType(dtype)
  290. if not scalar_type.dtype().is_floating_point:
  291. result = g.op(
  292. "ConstantFill",
  293. sizes,
  294. dtype_i=_type_utils.JitScalarType.FLOAT.onnx_type(),
  295. input_as_shape_i=1,
  296. value_f=const_value,
  297. )
  298. return g.op("Cast", result, to_i=scalar_type.onnx_type())
  299. else:
  300. return g.op(
  301. "ConstantFill",
  302. sizes,
  303. dtype_i=scalar_type.onnx_type(),
  304. input_as_shape_i=1,
  305. value_f=const_value,
  306. )
  307. @_onnx_symbolic("aten::empty")
  308. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  309. def empty(
  310. g: jit_utils.GraphContext,
  311. sizes,
  312. dtype,
  313. layout,
  314. device,
  315. pin_memory=False,
  316. memory_format=None,
  317. ):
  318. return zeros(g, sizes, dtype, layout, device, pin_memory)
  319. @_onnx_symbolic("aten::empty_like")
  320. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  321. def empty_like(
  322. g: jit_utils.GraphContext,
  323. input,
  324. dtype,
  325. layout,
  326. device,
  327. pin_memory=False,
  328. memory_format=None,
  329. ):
  330. return zeros_like(g, input, dtype, layout, device, pin_memory)
  331. @_onnx_symbolic("aten::zeros")
  332. @symbolic_helper.parse_args("v", "i", "v", "v", "v")
  333. def zeros(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
  334. # NOTE: no way to set device and layout in ONNX, so we ignore it
  335. return _constant_fill(g, sizes, dtype, 0)
  336. @_onnx_symbolic("aten::zeros_like")
  337. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  338. def zeros_like(
  339. g: jit_utils.GraphContext,
  340. input,
  341. dtype,
  342. layout,
  343. device,
  344. pin_memory=False,
  345. memory_format=None,
  346. ):
  347. shape = g.op("Shape", input)
  348. return _constant_fill(g, shape, dtype, 0)
  349. @_onnx_symbolic("aten::ones")
  350. @symbolic_helper.parse_args("v", "i", "v", "v", "v")
  351. def ones(g: jit_utils.GraphContext, sizes, dtype, layout, device, pin_memory=False):
  352. return _constant_fill(g, sizes, dtype, 1)
  353. @_onnx_symbolic("aten::ones_like")
  354. @symbolic_helper.parse_args("v", "i", "v", "v", "v", "v")
  355. def ones_like(
  356. g: jit_utils.GraphContext,
  357. input,
  358. dtype,
  359. layout,
  360. device,
  361. pin_memory=False,
  362. memory_format=None,
  363. ):
  364. shape = g.op("Shape", input)
  365. return _constant_fill(g, shape, dtype, 1)
  366. @_onnx_symbolic("aten::full")
  367. def full(
  368. g: jit_utils.GraphContext, sizes, value, dtype, layout, device, pin_memory=False
  369. ):
  370. const_value = symbolic_helper._maybe_get_const(value, "t")
  371. if symbolic_helper._is_value(const_value):
  372. tmp = zeros(g, sizes, dtype, layout, device)
  373. return opset9.add(g, tmp, value, g.op("Constant", value_t=torch.tensor(1)))
  374. else:
  375. dtype = symbolic_helper._get_const(dtype, "i", "dtype")
  376. return _constant_fill(g, sizes, dtype, const_value)
  377. @_onnx_symbolic("aten::full_like")
  378. @symbolic_helper.parse_args("v", "f", "i", "v", "v", "v", "v")
  379. def full_like(
  380. g: jit_utils.GraphContext,
  381. input,
  382. fill_value,
  383. dtype,
  384. layout,
  385. device,
  386. pin_memory=False,
  387. memory_format=None,
  388. ):
  389. shape = g.op("Shape", input)
  390. return _constant_fill(g, shape, dtype, fill_value)
  391. @_onnx_symbolic("aten::repeat")
  392. def repeat(g: jit_utils.GraphContext, self, repeats):
  393. if not symbolic_helper._is_value(repeats):
  394. repeats = g.op("Constant", value_t=torch.LongTensor(repeats))
  395. if symbolic_helper._is_packed_list(repeats):
  396. repeat_size_len = len(symbolic_helper._unpack_list(repeats))
  397. else:
  398. const_repeats = symbolic_helper._maybe_get_const(repeats, "is")
  399. repeat_size_len = len(const_repeats)
  400. if self.isCompleteTensor():
  401. sizes = self.type().sizes()
  402. diff_dims = repeat_size_len - len(sizes)
  403. if diff_dims > 0:
  404. self = opset9.view(
  405. g, self, g.op("Constant", value_t=torch.tensor([1] * diff_dims + sizes))
  406. )
  407. return g.op("Tile", self, repeats)