onednn.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542
  1. import torch
  2. import torch.nn as nn
  3. import torch.ao.nn.intrinsic as nni
  4. import torch.nn.functional as F
  5. import torch.ao.nn.quantized.reference as nnqr
  6. from ._common_operator_config_utils import (
  7. _get_conv_configs,
  8. _get_linear_configs,
  9. _get_binary_op_configs,
  10. _get_bn_configs,
  11. _get_cat_config,
  12. _get_default_op_configs,
  13. _get_embedding_op_configs,
  14. _get_fixed_qparams_op_configs,
  15. _get_ln_configs,
  16. _get_rnn_op_configs,
  17. _get_share_qparams_op_configs,
  18. )
  19. from .backend_config import (
  20. BackendPatternConfig,
  21. BackendConfig,
  22. DTypeConfig,
  23. ObservationType,
  24. )
  25. from ..fuser_method_mappings import (
  26. _sequential_wrapper2,
  27. )
  28. import operator
  29. from torch.ao.quantization.utils import MatchAllNode
  30. import itertools
  31. # ===================
  32. # | DTYPE CONFIGS |
  33. # ===================
  34. onednn_weighted_op_int8_dtype_config = DTypeConfig(
  35. input_dtype=torch.quint8,
  36. output_dtype=torch.quint8,
  37. weight_dtype=torch.qint8,
  38. bias_dtype=torch.float,
  39. )
  40. onednn_op_quint8_dtype_config = DTypeConfig(
  41. input_dtype=torch.quint8,
  42. output_dtype=torch.quint8,
  43. )
  44. onednn_dynamic_int8_dtype_config = DTypeConfig(
  45. input_dtype=torch.quint8,
  46. output_dtype=torch.float,
  47. weight_dtype=torch.qint8,
  48. bias_dtype=torch.float,
  49. is_dynamic=True,
  50. )
  51. onednn_weight_only_qint8_dtype_config = DTypeConfig(
  52. input_dtype=torch.float,
  53. output_dtype=torch.float,
  54. weight_dtype=torch.qint8,
  55. )
  56. onednn_input_output_only_quint8_dtype_config = DTypeConfig(
  57. input_dtype=torch.quint8,
  58. output_dtype=torch.quint8,
  59. weight_dtype=torch.float,
  60. bias_dtype=torch.float,
  61. )
  62. # ===================
  63. # | FUSER METHODS |
  64. # ===================
  65. def _fuse_linear_bn_leaky_relu(is_qat, linear, bn, leaky_relu):
  66. r"""Given the linear, bn and leaky_relu modules, fuses them and returns the fused module
  67. Args:
  68. is_qat: a flag for whether we are using quantization aware training fusion
  69. or post training quantization fusion
  70. linear: Module instance of type Linear
  71. bn: BatchNorm1d instance that needs to be fused with the linear layer
  72. leaky_relu: LeakyReLU instance that needs to be fused with the linear layer
  73. Examples::
  74. >>> # xdoctest: +SKIP(failing)
  75. >>> m1 = nn.Linear(20, 10)
  76. >>> b1 = nn.BatchNorm1d(10)
  77. >>> lr = nn.LeakyReLU(0.01)
  78. >>> m2 = _fuse_linear_bn_leaky_relu(m1, b1, lr)
  79. """
  80. assert(linear.training == bn.training and bn.training == leaky_relu.training),\
  81. "Linear, BN and LeakyReLU all must be in the same mode (train or eval)."
  82. if is_qat:
  83. raise NotImplementedError("Cannot fuse train modules: {}".format((linear, bn, leaky_relu)))
  84. else:
  85. map_to_fused_module_eval = {
  86. nn.Linear: nni.LinearLeakyReLU,
  87. }
  88. fused_module = map_to_fused_module_eval.get(type(linear), None)
  89. if fused_module is not None:
  90. fused_linear = nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
  91. fm = fused_module(fused_linear, leaky_relu)
  92. return fm
  93. else:
  94. raise NotImplementedError("Cannot fuse eval modules: {}".format((linear, bn, leaky_relu)))
  95. # ======================
  96. # | CONFIGS FOR CONV |
  97. # ======================
  98. observation_type = ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
  99. conv_dtype_configs = [onednn_weighted_op_int8_dtype_config]
  100. conv_configs = _get_conv_configs(conv_dtype_configs)
  101. # (1) Conv2d + Add
  102. # conv2d Y
  103. # \ /
  104. # add
  105. # include:
  106. # conv2d conv2d
  107. # \ /
  108. # add
  109. def _fuse_conv_add_left(is_qat, add, conv, _):
  110. return nni.ConvAdd2d(conv, add)
  111. def _conv_add_root_node_getter_left(pattern):
  112. _, conv, _ = pattern
  113. return conv
  114. def _conv_add_extra_inputs_getter_left(pattern):
  115. """ get inputs pattern for extra inputs, inputs for root node
  116. are assumed to be copied over from root node to the fused node
  117. """
  118. _, conv, extra_input = pattern
  119. return [extra_input]
  120. # conv2d
  121. # \
  122. # bn Y
  123. # \ /
  124. # add
  125. def _fuse_conv_bn_add_left(is_qat, add, bn_conv, _):
  126. bn, conv = bn_conv
  127. if is_qat:
  128. raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add)))
  129. else:
  130. fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
  131. return nni.ConvAdd2d(fused_conv, add)
  132. def _conv_bn_add_root_node_getter_left(add_pattern):
  133. _, bn_conv, _ = add_pattern
  134. bn, conv = bn_conv
  135. return conv
  136. def _conv_bn_add_extra_inputs_getter_left(add_pattern):
  137. """ get inputs pattern for extra inputs, inputs for root node
  138. are assumed to be copied over from root node to the fused node
  139. """
  140. _, bn_conv, extra_input = add_pattern
  141. bn, conv = bn_conv
  142. return [extra_input]
  143. conv_add_left_optioins = itertools.product(
  144. [True, False], # with_bn
  145. [torch.add, operator.add], # add_op
  146. )
  147. for with_bn, add_op in conv_add_left_optioins:
  148. if with_bn:
  149. conv_configs.append(
  150. BackendPatternConfig()
  151. ._set_pattern_complex_format((add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) # noqa: E131
  152. .set_observation_type(observation_type)
  153. .set_dtype_configs(conv_dtype_configs)
  154. .set_fuser_method(_fuse_conv_bn_add_left)
  155. ._set_root_node_getter(_conv_bn_add_root_node_getter_left)
  156. ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_left)
  157. .set_fused_module(nni.ConvAdd2d))
  158. else:
  159. conv_configs.append(
  160. BackendPatternConfig()
  161. ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131
  162. .set_observation_type(observation_type)
  163. .set_dtype_configs(conv_dtype_configs)
  164. .set_fuser_method(_fuse_conv_add_left)
  165. ._set_root_node_getter(_conv_add_root_node_getter_left)
  166. ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_left)
  167. .set_fused_module(nni.ConvAdd2d))
  168. # Y conv2d
  169. # \ /
  170. # add
  171. def _fuse_conv_add_right(is_qat, add, _, conv):
  172. return nni.ConvAdd2d(conv, add)
  173. def _conv_add_root_node_getter_right(pattern):
  174. add, _, conv = pattern
  175. return conv
  176. def _conv_add_extra_inputs_getter_right(pattern):
  177. """ get inputs pattern for extra inputs, inputs for root node
  178. are assumed to be copied over from root node to the fused node
  179. """
  180. _, extra_input, conv = pattern
  181. return [extra_input]
  182. # conv2d
  183. # /
  184. # Y bn
  185. # \ /
  186. # add
  187. def _fuse_conv_bn_add_right(is_qat, add, _, bn_conv):
  188. bn, conv = bn_conv
  189. if is_qat:
  190. raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add)))
  191. else:
  192. fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
  193. return nni.ConvAdd2d(fused_conv, add)
  194. def _conv_bn_add_root_node_getter_right(pattern):
  195. add, _, bn_conv = pattern
  196. bn, conv = bn_conv
  197. return conv
  198. def _conv_bn_add_extra_inputs_getter_right(pattern):
  199. """ get inputs pattern for extra inputs, inputs for root node
  200. are assumed to be copied over from root node to the fused node
  201. """
  202. _, extra_input, bn_conv = pattern
  203. bn, conv = bn_conv
  204. return [extra_input]
  205. conv_add_optioins = itertools.product(
  206. [True, False], # with_bn
  207. [torch.add, operator.add], # add_op
  208. )
  209. for with_bn, add_op in conv_add_optioins:
  210. if with_bn:
  211. conv_configs.append(
  212. BackendPatternConfig()
  213. ._set_pattern_complex_format((add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) # noqa: E131
  214. .set_observation_type(observation_type)
  215. .set_dtype_configs(conv_dtype_configs)
  216. .set_fuser_method(_fuse_conv_bn_add_right)
  217. ._set_root_node_getter(_conv_bn_add_root_node_getter_right)
  218. ._set_extra_inputs_getter(_conv_bn_add_extra_inputs_getter_right)
  219. .set_fused_module(nni.ConvAdd2d))
  220. else:
  221. conv_configs.append(
  222. BackendPatternConfig()
  223. ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131
  224. .set_observation_type(observation_type)
  225. .set_dtype_configs(conv_dtype_configs)
  226. .set_fuser_method(_fuse_conv_add_right)
  227. ._set_root_node_getter(_conv_add_root_node_getter_right)
  228. ._set_extra_inputs_getter(_conv_add_extra_inputs_getter_right)
  229. .set_fused_module(nni.ConvAdd2d))
  230. conv_configs.append(
  231. BackendPatternConfig(nni.ConvAdd2d)
  232. .set_observation_type(observation_type) # noqa: E131
  233. .set_dtype_configs(conv_dtype_configs)
  234. .set_root_module(nn.Conv2d)
  235. .set_reference_quantized_module(nnqr.Conv2d))
  236. # (2) Conv2d + Add + Relu
  237. # conv2d Y
  238. # \ /
  239. # add
  240. # \
  241. # relu
  242. def _fuse_conv_add_relu_left(is_qat, relu, add_pattern):
  243. add, conv, _ = add_pattern
  244. return nni.ConvAddReLU2d(conv, add, relu)
  245. def _conv_add_relu_root_node_getter_left(pattern):
  246. relu, add_pattern = pattern
  247. _, conv, _ = add_pattern
  248. return conv
  249. def _conv_add_relu_extra_inputs_getter_left(pattern):
  250. """ get inputs pattern for extra inputs, inputs for root node
  251. are assumed to be copied over from root node to the fused node
  252. """
  253. relu, add_pattern = pattern
  254. _, conv, extra_input = add_pattern
  255. return [extra_input]
  256. # conv2d
  257. # \
  258. # bn Y
  259. # \ /
  260. # add
  261. # \
  262. # relu
  263. def _fuse_conv_bn_add_relu_left(is_qat, relu, add_pattern):
  264. add, bn_conv, _ = add_pattern
  265. bn, conv = bn_conv
  266. if is_qat:
  267. raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add, relu)))
  268. else:
  269. fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
  270. return nni.ConvAddReLU2d(fused_conv, add, relu)
  271. def _conv_bn_add_relu_root_node_getter_left(pattern):
  272. relu, add_pattern = pattern
  273. _, bn_conv, _ = add_pattern
  274. bn, conv = bn_conv
  275. return conv
  276. def _conv_bn_add_relu_extra_inputs_getter_left(pattern):
  277. """ get inputs pattern for extra inputs, inputs for root node
  278. are assumed to be copied over from root node to the fused node
  279. """
  280. relu, add_pattern = pattern
  281. _, bn_conv, extra_input = add_pattern
  282. bn, conv = bn_conv
  283. return [extra_input]
  284. conv_add_relu_left_optioins = itertools.product(
  285. [True, False], # with_bn
  286. [torch.add, operator.add], # add_op
  287. )
  288. for with_bn, add_op in conv_add_relu_left_optioins:
  289. if with_bn:
  290. conv_configs.append(
  291. BackendPatternConfig()
  292. ._set_pattern_complex_format((nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))) # noqa: E131
  293. .set_observation_type(observation_type)
  294. .set_dtype_configs(conv_dtype_configs)
  295. .set_fuser_method(_fuse_conv_bn_add_relu_left)
  296. ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_left)
  297. ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_left)
  298. .set_fused_module(nni.ConvAddReLU2d))
  299. else:
  300. conv_configs.append(
  301. BackendPatternConfig()
  302. ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131
  303. .set_observation_type(observation_type)
  304. .set_dtype_configs(conv_dtype_configs)
  305. .set_fuser_method(_fuse_conv_add_relu_left)
  306. ._set_root_node_getter(_conv_add_relu_root_node_getter_left)
  307. ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_left)
  308. .set_fused_module(nni.ConvAddReLU2d))
  309. # Y conv2d
  310. # \ /
  311. # add
  312. # \
  313. # relu
  314. def _fuse_conv_add_relu_right(is_qat, relu, add_pattern):
  315. add, _, conv = add_pattern
  316. return nni.ConvAddReLU2d(conv, add, relu)
  317. def _conv_add_relu_root_node_getter_right(pattern):
  318. relu, add_pattern = pattern
  319. _, _, conv = add_pattern
  320. return conv
  321. def _conv_add_relu_extra_inputs_getter_right(pattern):
  322. """ get inputs pattern for extra inputs, inputs for root node
  323. are assumed to be copied over from root node to the fused node
  324. """
  325. relu, add_pattern = pattern
  326. _, extra_input, conv = add_pattern
  327. return [extra_input]
  328. # conv2d
  329. # /
  330. # Y bn
  331. # \ /
  332. # add
  333. # \
  334. # relu
  335. def _fuse_conv_bn_add_relu_right(is_qat, relu, add_pattern):
  336. add, _, bn_conv = add_pattern
  337. bn, conv = bn_conv
  338. if is_qat:
  339. raise NotImplementedError("Cannot fuse train modules: {}".format((conv, bn, add, relu)))
  340. else:
  341. fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn)
  342. return nni.ConvAddReLU2d(fused_conv, add, relu)
  343. def _conv_bn_add_relu_root_node_getter_right(pattern):
  344. relu, add_pattern = pattern
  345. _, _, bn_conv = add_pattern
  346. bn, conv = bn_conv
  347. return conv
  348. def _conv_bn_add_relu_extra_inputs_getter_right(pattern):
  349. """ get inputs pattern for extra inputs, inputs for root node
  350. are assumed to be copied over from root node to the fused node
  351. """
  352. relu, add_pattern = pattern
  353. _, extra_input, bn_conv = add_pattern
  354. bn, conv = bn_conv
  355. return [extra_input]
  356. conv_add_relu_optioins = itertools.product(
  357. [True, False], # with_bn
  358. [torch.add, operator.add], # add_op
  359. )
  360. for with_bn, add_op in conv_add_relu_optioins:
  361. if with_bn:
  362. conv_configs.append(
  363. BackendPatternConfig()
  364. ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)))) # noqa: E131
  365. .set_observation_type(observation_type)
  366. .set_dtype_configs(conv_dtype_configs)
  367. .set_fuser_method(_fuse_conv_bn_add_relu_right)
  368. ._set_root_node_getter(_conv_bn_add_relu_root_node_getter_right)
  369. ._set_extra_inputs_getter(_conv_bn_add_relu_extra_inputs_getter_right)
  370. .set_fused_module(nni.ConvAddReLU2d))
  371. else:
  372. conv_configs.append(
  373. BackendPatternConfig()
  374. ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131
  375. .set_observation_type(observation_type)
  376. .set_dtype_configs(conv_dtype_configs)
  377. .set_fuser_method(_fuse_conv_add_relu_right)
  378. ._set_root_node_getter(_conv_add_relu_root_node_getter_right)
  379. ._set_extra_inputs_getter(_conv_add_relu_extra_inputs_getter_right)
  380. .set_fused_module(nni.ConvAddReLU2d))
  381. conv_configs.append(
  382. BackendPatternConfig(nni.ConvAddReLU2d)
  383. .set_observation_type(observation_type) # noqa: E131
  384. .set_dtype_configs(conv_dtype_configs)
  385. .set_root_module(nn.Conv2d)
  386. .set_reference_quantized_module(nnqr.Conv2d))
  387. # ========================
  388. # | CONFIGS FOR LINEAR |
  389. # ========================
  390. linear_dtype_configs = [
  391. onednn_weighted_op_int8_dtype_config,
  392. onednn_dynamic_int8_dtype_config,
  393. ]
  394. linear_configs = _get_linear_configs(linear_dtype_configs)
  395. def _add_eltwise_fusion_configs(configs, root_module, root_op, post_module, post_op,
  396. dtype_configs, fuser_method, fused_module, observation_type,
  397. ref_quant_module):
  398. # 1 base module + op module fusion config
  399. configs.append(
  400. BackendPatternConfig((root_module, post_module))
  401. .set_dtype_configs(dtype_configs) # noqa: E131
  402. .set_fuser_method(fuser_method)
  403. .set_fused_module(fused_module))
  404. # base module + functional post op
  405. configs.append(
  406. BackendPatternConfig((root_module, post_op))
  407. .set_dtype_configs(dtype_configs) # noqa: E131
  408. .set_fuser_method(fuser_method)
  409. .set_fused_module(fused_module))
  410. # 2 fused module configs
  411. configs.append(
  412. BackendPatternConfig(fused_module)
  413. .set_observation_type(observation_type) # noqa: E131
  414. .set_dtype_configs(dtype_configs)
  415. .set_root_module(root_module)
  416. .set_reference_quantized_module(ref_quant_module))
  417. # 3 functional base op + post op configs
  418. configs.append(
  419. BackendPatternConfig((root_op, post_module))
  420. .set_observation_type(observation_type) # noqa: E131
  421. .set_dtype_configs(dtype_configs))
  422. configs.append(
  423. BackendPatternConfig((root_op, post_op))
  424. .set_observation_type(observation_type) # noqa: E131
  425. .set_dtype_configs(dtype_configs))
  426. # Configs for linear + leaky_relu fusion
  427. _add_eltwise_fusion_configs(linear_configs, nn.Linear, F.linear,
  428. nn.LeakyReLU, F.leaky_relu, linear_dtype_configs,
  429. _sequential_wrapper2(nni.LinearLeakyReLU),
  430. nni.LinearLeakyReLU, observation_type, nnqr.Linear)
  431. # Configs for linear module + batchnorm + leaky_relu
  432. linear_configs.append(
  433. BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU))
  434. .set_dtype_configs(linear_dtype_configs) # noqa: E131
  435. .set_fuser_method(_fuse_linear_bn_leaky_relu)
  436. .set_fused_module(nni.LinearLeakyReLU))
  437. # Configs for linear + tanh fusion
  438. _add_eltwise_fusion_configs(linear_configs, nn.Linear, F.linear,
  439. nn.Tanh, torch.tanh, linear_dtype_configs,
  440. _sequential_wrapper2(nni.LinearTanh),
  441. nni.LinearTanh, observation_type, nnqr.Linear)
  442. # ===========================
  443. # | CONFIGS FOR OTHER OPS |
  444. # ===========================
  445. binary_op_dtype_configs = [onednn_op_quint8_dtype_config]
  446. default_op_dtype_configs = [onednn_op_quint8_dtype_config]
  447. fixed_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
  448. share_qparams_op_dtype_configs = [onednn_op_quint8_dtype_config]
  449. rnn_op_dtype_configs = [onednn_dynamic_int8_dtype_config]
  450. embedding_op_dtype_configs = [onednn_weight_only_qint8_dtype_config]
  451. layer_norm_op_dtype_configs = [onednn_input_output_only_quint8_dtype_config]
  452. # =====================
  453. # | BACKEND CONFIGS |
  454. # =====================
  455. def get_onednn_backend_config() -> BackendConfig:
  456. """
  457. Return the `BackendConfig` for PyTorch's native ONEDNN backend.
  458. """
  459. return BackendConfig("onednn") \
  460. .set_backend_pattern_configs(conv_configs) \
  461. .set_backend_pattern_configs(linear_configs) \
  462. .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
  463. .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
  464. .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
  465. .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
  466. .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
  467. .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
  468. .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
  469. .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
  470. .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
  471. __all__ = [
  472. "get_onednn_backend_config",
  473. ]