symbolic_caffe2.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. import importlib
  2. import inspect
  3. from torch.onnx import symbolic_helper, symbolic_opset9 as opset9
  4. from torch.onnx._internal import jit_utils, registration
  5. def register_quantized_ops(domain: str, version: int):
  6. # Register all quantized ops
  7. module = importlib.import_module("torch.onnx.symbolic_caffe2")
  8. quant_version_ops = inspect.getmembers(module)
  9. aten_q_ops = {
  10. "relu",
  11. "_empty_affine_quantized",
  12. "dequantize",
  13. "quantize_per_tensor",
  14. "upsample_nearest2d",
  15. "avg_pool2d",
  16. "reshape",
  17. "slice",
  18. "cat",
  19. "max_pool2d",
  20. "sigmoid",
  21. }
  22. for op, func in quant_version_ops:
  23. name = f"{domain}::{op}"
  24. if inspect.isfunction(func) and not registration.registry.is_registered_op(
  25. name, version
  26. ):
  27. if op in aten_q_ops:
  28. # Override the builtin aten ops
  29. registration.registry.register(
  30. f"aten::{op}", version, func, custom=True
  31. )
  32. registration.registry.register(name, version, func)
  33. def _permute_helper(g: jit_utils.GraphContext, input, axes):
  34. quant_args = {
  35. "axes_i": axes,
  36. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  37. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  38. }
  39. output = g.op("_caffe2::Int8Transpose", input, **quant_args)
  40. symbolic_helper._quantized_ops.add(output)
  41. return output
  42. def nchw2nhwc(g: jit_utils.GraphContext, input):
  43. axes = [0, 2, 3, 1]
  44. return _permute_helper(g, input, axes)
  45. def nhwc2nchw(g: jit_utils.GraphContext, input):
  46. axes = [0, 3, 1, 2]
  47. return _permute_helper(g, input, axes)
  48. def linear_prepack(g: jit_utils.GraphContext, weight, bias):
  49. # Mapping to a dummy caffe2 prepack node.
  50. # During the onnx -> c2 conversion we can look up original weight and bias
  51. # from this node
  52. output = g.op("_caffe2::WeightPrepack", weight, bias)
  53. symbolic_helper._quantized_ops.add(output)
  54. return output
  55. @symbolic_helper.parse_args("v", "v", "v", "f", "i")
  56. def linear(g: jit_utils.GraphContext, input, weight, bias, scale, zero_point):
  57. kwargs = {
  58. "Y_scale_f": scale,
  59. "Y_zero_point_i": zero_point,
  60. }
  61. output = g.op("_caffe2::Int8FC", input, weight, bias, **kwargs)
  62. symbolic_helper._quantized_ops.add(output)
  63. return output
  64. def conv_prepack(
  65. g: jit_utils.GraphContext, input, weight, bias, stride, padding, dilation, groups
  66. ):
  67. # Mapping to a dummy caffe2 prepack node.
  68. # During the onnx -> c2 conversion we can look up original weight and bias
  69. # from this node
  70. output = g.op("_caffe2::WeightPrepack", input, weight, bias)
  71. symbolic_helper._quantized_ops.add(output)
  72. return output
  73. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
  74. def conv2d(
  75. g: jit_utils.GraphContext,
  76. input,
  77. weight,
  78. bias,
  79. stride,
  80. padding,
  81. dilation,
  82. groups,
  83. scale,
  84. zero_point,
  85. ):
  86. kernel_size = weight.node()["shape"][1:3]
  87. kwargs = {
  88. "strides_i": stride,
  89. "pads_i": padding + padding,
  90. "dilations_i": dilation,
  91. "group_i": groups,
  92. "kernels_i": kernel_size,
  93. "order_s": "NHWC",
  94. "Y_scale_f": scale,
  95. "Y_zero_point_i": zero_point,
  96. }
  97. output = g.op("_caffe2::Int8Conv", input, weight, bias, **kwargs)
  98. symbolic_helper._quantized_ops.add(output)
  99. return output
  100. @symbolic_helper.parse_args("v", "v", "v", "is", "is", "is", "i", "f", "i")
  101. def conv2d_relu(
  102. g: jit_utils.GraphContext,
  103. input,
  104. weight,
  105. bias,
  106. stride,
  107. padding,
  108. dilation,
  109. groups,
  110. scale,
  111. zero_point,
  112. ):
  113. kernel_size = weight.node()["shape"][1:3]
  114. kwargs = {
  115. "strides_i": stride,
  116. "pads_i": padding + padding,
  117. "dilations_i": dilation,
  118. "group_i": groups,
  119. "kernels_i": kernel_size,
  120. "order_s": "NHWC",
  121. "Y_scale_f": scale,
  122. "Y_zero_point_i": zero_point,
  123. }
  124. output = g.op("_caffe2::Int8ConvRelu", input, weight, bias, **kwargs)
  125. symbolic_helper._quantized_ops.add(output)
  126. return output
  127. @symbolic_helper.parse_args("v", "v", "f", "i")
  128. def add(g: jit_utils.GraphContext, input_a, input_b, scale, zero_point):
  129. kwargs = {
  130. "Y_scale_f": scale,
  131. "Y_zero_point_i": zero_point,
  132. }
  133. output = g.op("_caffe2::Int8Add", input_a, input_b, **kwargs)
  134. symbolic_helper._quantized_ops.add(output)
  135. return output
  136. @symbolic_helper.parse_args("v")
  137. def relu(g: jit_utils.GraphContext, input):
  138. if input not in symbolic_helper._quantized_ops:
  139. return opset9.relu(g, input)
  140. kwargs = {
  141. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  142. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  143. }
  144. output = g.op("_caffe2::Int8Relu", input, **kwargs)
  145. symbolic_helper._quantized_ops.add(output)
  146. return output
  147. @symbolic_helper.parse_args("v", "f", "i", "t")
  148. def quantize_per_tensor(g: jit_utils.GraphContext, input, scale, zero_point, dtype):
  149. kwargs = {
  150. "Y_scale_f": scale,
  151. "Y_zero_point_i": zero_point,
  152. }
  153. output = g.op("_caffe2::Int8Quantize", input, **kwargs)
  154. symbolic_helper._quantized_ops.add(output)
  155. return output
  156. @symbolic_helper.parse_args("v")
  157. def dequantize(g: jit_utils.GraphContext, input):
  158. return g.op("_caffe2::Int8Dequantize", input)
  159. @symbolic_helper.parse_args("v", "t", "t", "t", "t", "t", "t", "t")
  160. def _empty_affine_quantized(
  161. g: jit_utils.GraphContext,
  162. input,
  163. shape,
  164. scale,
  165. zero_point,
  166. dtype,
  167. pin_memory,
  168. memory_format,
  169. layout,
  170. ):
  171. return input
  172. def upsample_nearest2d(
  173. g: jit_utils.GraphContext,
  174. input,
  175. output_size,
  176. align_corners=None,
  177. scales_h=None,
  178. scales_w=None,
  179. ):
  180. if input not in symbolic_helper._quantized_ops:
  181. return opset9.upsample_nearest2d(g, input, output_size, align_corners) # type: ignore[attr-defined]
  182. output_size = symbolic_helper._parse_arg(output_size, "is")
  183. kwargs = {
  184. "output_size_i": output_size,
  185. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  186. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  187. }
  188. input = nchw2nhwc(g, input)
  189. output = g.op("_caffe2::Int8ResizeNearest", input, **kwargs)
  190. output = nhwc2nchw(g, output)
  191. symbolic_helper._quantized_ops.add(output)
  192. return output
  193. @symbolic_helper.parse_args("v", "is", "is", "is", "is", "i")
  194. def max_pool2d(
  195. g: jit_utils.GraphContext,
  196. input,
  197. kernel_size,
  198. stride,
  199. padding,
  200. dilation,
  201. ceil_mode,
  202. ):
  203. if input not in symbolic_helper._quantized_ops:
  204. return opset9.max_pool2d( # type: ignore[attr-defined]
  205. g, input, kernel_size, stride, padding, dilation, ceil_mode
  206. )
  207. kwargs = {
  208. "strides_i": stride,
  209. "pads_i": padding + padding,
  210. "kernel_i": kernel_size[0],
  211. "order_s": "NHWC",
  212. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  213. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  214. }
  215. input = nchw2nhwc(g, input)
  216. output = g.op("_caffe2::Int8MaxPool", input, **kwargs)
  217. output = nhwc2nchw(g, output)
  218. symbolic_helper._quantized_ops.add(output)
  219. return output
  220. @symbolic_helper.parse_args("v", "is", "is", "is", "i", "i", "none")
  221. def avg_pool2d(
  222. g: jit_utils.GraphContext,
  223. input,
  224. kernel_size,
  225. stride,
  226. padding,
  227. ceil_mode,
  228. count_include_pad,
  229. divisor_override=None,
  230. ):
  231. if input not in symbolic_helper._quantized_ops:
  232. return opset9.avg_pool2d( # type: ignore[attr-defined]
  233. g,
  234. input,
  235. kernel_size,
  236. stride,
  237. padding,
  238. ceil_mode,
  239. count_include_pad,
  240. divisor_override,
  241. )
  242. kwargs = {
  243. "strides_i": stride,
  244. "pads_i": padding + padding,
  245. "kernel_i": kernel_size[0],
  246. "order_s": "NHWC",
  247. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  248. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  249. }
  250. input = nchw2nhwc(g, input)
  251. output = g.op("_caffe2::Int8AveragePool", input, **kwargs)
  252. output = nhwc2nchw(g, output)
  253. symbolic_helper._quantized_ops.add(output)
  254. return output
  255. def reshape(g: jit_utils.GraphContext, input, shape):
  256. if input not in symbolic_helper._quantized_ops:
  257. return opset9.reshape(g, input, shape)
  258. kwargs = {
  259. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  260. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  261. }
  262. output = g.op("_caffe2::Int8Reshape", input, shape, **kwargs)
  263. symbolic_helper._quantized_ops.add(output)
  264. return output
  265. @symbolic_helper.parse_args("v", "v", "v", "v", "i")
  266. def slice(g: jit_utils.GraphContext, input, dim, start, end, step):
  267. if input not in symbolic_helper._quantized_ops:
  268. return opset9.slice(g, input, dim, start, end, step)
  269. if step != 1:
  270. raise RuntimeError("ONNX quantized slice export only works for step 1.")
  271. start = symbolic_helper._parse_arg(start, "i")
  272. end = symbolic_helper._parse_arg(end, "i")
  273. dim = symbolic_helper._parse_arg(dim, "i")
  274. kwargs = {
  275. "start_idx_i": start,
  276. "end_idx_i": end,
  277. "dim_i": dim,
  278. "Y_scale_f": symbolic_helper._node_get(input.node(), "Y_scale"),
  279. "Y_zero_point_i": symbolic_helper._node_get(input.node(), "Y_zero_point"),
  280. }
  281. output = g.op("_caffe2::Int8Slice", input, **kwargs)
  282. symbolic_helper._quantized_ops.add(output)
  283. return output
  284. def cat(g: jit_utils.GraphContext, tensor_list, dim, scale=None, zero_point=None):
  285. tensors = symbolic_helper._unpack_list(tensor_list)
  286. input = tensors[0]
  287. if input not in symbolic_helper._quantized_ops:
  288. return opset9.cat(g, tensor_list, dim)
  289. dim = symbolic_helper._parse_arg(dim, "i")
  290. kwargs = {
  291. "Y_scale_f": tensors[0].node()["Y_scale"],
  292. "Y_zero_point_i": tensors[0].node()["Y_zero_point"],
  293. }
  294. output = g.op("_caffe2::Int8Concat", *tensors, axis_i=dim, **kwargs)
  295. symbolic_helper._quantized_ops.add(output)
  296. return output
  297. @symbolic_helper.parse_args("v")
  298. def sigmoid(g: jit_utils.GraphContext, input):
  299. if input not in symbolic_helper._quantized_ops:
  300. return opset9.sigmoid(g, input)
  301. # Caffe2 expects the output scale to be 1/2^8
  302. # and output zero_point to be 0 (quint8 type)
  303. out_scale = 1.0 / 256
  304. zero_point = 0
  305. kwargs = {
  306. "Y_scale_f": out_scale,
  307. "Y_zero_point_i": zero_point,
  308. }
  309. output = g.op("_caffe2::Int8Sigmoid", input, **kwargs)
  310. symbolic_helper._quantized_ops.add(output)
  311. return output