mappings.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728
  1. import operator
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. toq = torch.ops.quantized
  6. import torch.ao.nn.quantized as nnq
  7. import torch.ao.nn.quantized.dynamic as nnqd
  8. import torch.ao.nn.intrinsic.quantized as nniq
  9. import torch.ao.nn.intrinsic.quantized.dynamic as nniqd
  10. import torch.ao.nn.intrinsic.qat as nniqat
  11. import torch.ao.nn.intrinsic as nni
  12. import torch.ao.nn.qat as nnqat
  13. import torch.ao.nn.qat.dynamic as nnqatd
  14. from torch.ao.quantization.backend_config import get_native_backend_config
  15. import torch.ao.quantization.fx._lower_to_native_backend as \
  16. _lower_to_native_backend
  17. import torch.ao.quantization.quantization_mappings as quantization_mappings
  18. from .ns_types import NSNodeTargetType
  19. from typing import Callable, Dict, List, Optional, Set, Tuple
  20. def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
  21. # note: this set is modified below by items from backend_config
  22. sets_of_related_ops: List[Set[NSNodeTargetType]] = [
  23. # conv modules
  24. {
  25. nn.Conv1d,
  26. },
  27. {
  28. nn.Conv2d,
  29. },
  30. {
  31. nn.Conv3d,
  32. },
  33. # conv functionals
  34. {
  35. F.conv1d,
  36. },
  37. {
  38. F.conv2d,
  39. },
  40. {
  41. F.conv3d,
  42. },
  43. # linear modules
  44. {
  45. nn.Linear,
  46. },
  47. # linear functionals
  48. {
  49. F.linear,
  50. },
  51. # average pool
  52. {
  53. nn.AvgPool1d,
  54. torch.avg_pool1d,
  55. },
  56. {
  57. nn.AvgPool2d,
  58. torch._C._nn.avg_pool2d,
  59. },
  60. {
  61. nn.AvgPool3d,
  62. torch._C._nn.avg_pool3d,
  63. },
  64. # adaptive average pool
  65. {
  66. nn.AdaptiveAvgPool1d,
  67. F.adaptive_avg_pool1d,
  68. },
  69. {
  70. nn.AdaptiveAvgPool2d,
  71. F.adaptive_avg_pool2d,
  72. },
  73. {
  74. nn.AdaptiveAvgPool3d,
  75. F.adaptive_avg_pool3d,
  76. },
  77. # LSTM
  78. {
  79. nn.LSTM,
  80. },
  81. # add
  82. {
  83. torch.add,
  84. operator.add, # x + y
  85. },
  86. # cat
  87. {
  88. torch.cat,
  89. },
  90. # mul
  91. {
  92. torch.mul,
  93. operator.mul,
  94. },
  95. # relu
  96. {
  97. F.relu,
  98. nn.ReLU,
  99. 'relu',
  100. 'relu_',
  101. torch.relu,
  102. },
  103. # maxpool
  104. {
  105. nn.MaxPool1d,
  106. F.max_pool1d,
  107. },
  108. {
  109. nn.MaxPool2d,
  110. F.max_pool2d,
  111. },
  112. {
  113. nn.MaxPool3d,
  114. F.max_pool3d,
  115. },
  116. # sigmoid
  117. {
  118. torch.sigmoid,
  119. 'sigmoid',
  120. 'sigmoid_',
  121. nn.Sigmoid,
  122. F.sigmoid,
  123. },
  124. # BatchNorm
  125. {
  126. nn.BatchNorm2d,
  127. },
  128. {
  129. nn.BatchNorm3d,
  130. },
  131. # ConvTranspose
  132. {
  133. nn.ConvTranspose1d,
  134. },
  135. {
  136. nn.ConvTranspose2d,
  137. },
  138. {
  139. nn.ConvTranspose3d,
  140. },
  141. # ELU
  142. {
  143. nn.ELU,
  144. },
  145. # Embedding
  146. {
  147. nn.Embedding,
  148. },
  149. # EmbeddingBag
  150. {
  151. nn.EmbeddingBag,
  152. },
  153. # GroupNorm
  154. {
  155. nn.GroupNorm,
  156. },
  157. # Hardswish
  158. {
  159. nn.Hardswish,
  160. },
  161. # InstanceNorm
  162. {
  163. nn.InstanceNorm1d,
  164. },
  165. {
  166. nn.InstanceNorm2d,
  167. },
  168. {
  169. nn.InstanceNorm3d,
  170. },
  171. # LayerNorm
  172. {
  173. nn.LayerNorm,
  174. },
  175. # LeakyReLU
  176. {
  177. nn.LeakyReLU,
  178. },
  179. # ReLU6
  180. {
  181. nn.ReLU6,
  182. F.relu6,
  183. },
  184. # F.elu
  185. {
  186. F.elu,
  187. },
  188. # F.hardswish
  189. {
  190. F.hardswish,
  191. },
  192. # F.group_norm
  193. {
  194. F.group_norm,
  195. },
  196. # F.instance_norm
  197. {
  198. F.instance_norm,
  199. },
  200. # F.layer_norm
  201. {
  202. F.layer_norm,
  203. },
  204. # F.leaky_relu
  205. {
  206. F.leaky_relu,
  207. },
  208. # F.silu
  209. {
  210. nn.SiLU,
  211. F.silu,
  212. },
  213. # F.mish
  214. {
  215. nn.Mish,
  216. F.mish,
  217. },
  218. # F.tanh
  219. {
  220. nn.Tanh,
  221. F.tanh,
  222. torch.tanh,
  223. 'tanh_',
  224. 'tanh',
  225. },
  226. # F.hardsigmoid
  227. {
  228. 'hardsigmoid_',
  229. 'hardsigmoid',
  230. F.hardsigmoid,
  231. nn.Hardsigmoid,
  232. },
  233. # F.hardtanh
  234. {
  235. nn.Hardtanh,
  236. F.hardtanh,
  237. F.hardtanh_,
  238. },
  239. # floordiv
  240. {
  241. operator.floordiv,
  242. },
  243. # unsqueeze
  244. {
  245. torch.unsqueeze,
  246. },
  247. # stack
  248. {
  249. torch.stack,
  250. },
  251. # squeeze
  252. {
  253. torch.squeeze,
  254. },
  255. # sort
  256. {
  257. torch.sort,
  258. },
  259. # repeat_interleave
  260. {
  261. torch.repeat_interleave,
  262. },
  263. # min
  264. {
  265. torch.min,
  266. },
  267. # mean
  268. {
  269. torch.mean,
  270. },
  271. # max
  272. {
  273. torch.max,
  274. },
  275. # transpose
  276. {
  277. torch.transpose,
  278. },
  279. # flatten
  280. {
  281. torch.flatten,
  282. },
  283. # clamp
  284. {
  285. torch.clamp,
  286. },
  287. # chunk
  288. {
  289. torch.chunk,
  290. },
  291. # interpolate
  292. {
  293. torch.nn.functional.interpolate,
  294. },
  295. # dropout
  296. {
  297. nn.Dropout,
  298. },
  299. # F.dropout
  300. {
  301. F.dropout,
  302. },
  303. # matmul
  304. {
  305. torch.matmul,
  306. },
  307. # Softmax
  308. {
  309. nn.Softmax,
  310. },
  311. # PReLU
  312. {
  313. nn.PReLU,
  314. nnq.PReLU,
  315. },
  316. # F.prelu
  317. {
  318. F.prelu,
  319. toq.prelu,
  320. },
  321. ]
  322. # for each floating point op, add versions of the op added by
  323. # backend_config
  324. backend_config = get_native_backend_config()
  325. new_connections: List[Tuple[Callable, Callable]] = [
  326. # technical debt edge case
  327. (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear),
  328. ]
  329. for pattern, config in backend_config._pattern_complex_format_to_config.items():
  330. # pattern format: (c, (b, a))
  331. first_element = pattern
  332. # look from the end, because pattern is in reverse order
  333. while isinstance(first_element, (list, tuple)):
  334. first_element = first_element[-1]
  335. if config.fused_module is not None:
  336. # case 1: pattern fuses a pattern of ops into an op
  337. # example: nn.Conv1d, nn.ReLU fused into nni.ConvReLU1d
  338. new_connections.append((first_element, config.fused_module))
  339. if config.qat_module is not None:
  340. # case 2: pattern swaps a module into a QAT module
  341. # example: nni.ConvReLU1d swapped into nniqat.ConvReLU1d
  342. new_connections.append((first_element, config.qat_module))
  343. if config.reference_quantized_module is not None:
  344. # case 3: reference version of floating point module, such as
  345. # nn.Conv2d and nnqr.Conv2d
  346. new_connections.append((first_element, config.reference_quantized_module))
  347. #
  348. # Add reference module swaps from default lowering path
  349. #
  350. for source_to_target in (
  351. _lower_to_native_backend.STATIC_LOWER_MODULE_MAP,
  352. _lower_to_native_backend.DYNAMIC_LOWER_MODULE_MAP,
  353. _lower_to_native_backend.WEIGHT_ONLY_LOWER_MODULE_MAP,
  354. _lower_to_native_backend.SPECIAL_PATTERN_LOWER_MODULE_MAP,
  355. ):
  356. for source, target in source_to_target.items(): # type: ignore[attr-defined]
  357. new_connections.append((source, target))
  358. for source_to_double_target in (
  359. _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_MAP,
  360. _lower_to_native_backend.STATIC_LOWER_FUSED_MODULE_TWO_INPUTS_MAP,
  361. _lower_to_native_backend.DYNAMIC_LOWER_FUSED_MODULE_MAP,
  362. ):
  363. for source, (target1, target2) in source_to_double_target.items(): # type: ignore[attr-defined]
  364. new_connections.append((source, target1))
  365. new_connections.append((source, target2))
  366. #
  367. # Add function swaps from default lowering path
  368. #
  369. for source, (target1, target2) in \
  370. _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
  371. new_connections.append((source, target1))
  372. new_connections.append((source, target2))
  373. for source_to_target in (
  374. _lower_to_native_backend.QBIN_OP_MAPPING,
  375. _lower_to_native_backend.QBIN_RELU_OP_MAPPING,
  376. quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
  377. ):
  378. for source, target in source_to_target.items():
  379. new_connections.append((source, target))
  380. #
  381. # Add other swaps, ideally in the future this could be removed
  382. # after the lowering code stops using these.
  383. #
  384. for source_to_target in (
  385. quantization_mappings.DEFAULT_DYNAMIC_QUANT_MODULE_MAPPINGS,
  386. ):
  387. for source, target in source_to_target.items():
  388. new_connections.append((source, target))
  389. # add the new connections from backend_config
  390. for item1, item2 in new_connections:
  391. for set_of_related_ops in sets_of_related_ops:
  392. if item1 in set_of_related_ops or item2 in set_of_related_ops:
  393. set_of_related_ops.add(item1)
  394. set_of_related_ops.add(item2)
  395. break
  396. base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {}
  397. counter = 0
  398. for set_of_related_ops in sets_of_related_ops:
  399. base_name = str(counter)
  400. counter += 1
  401. base_name_to_sets_of_related_ops[base_name] = set_of_related_ops
  402. return base_name_to_sets_of_related_ops
  403. def get_base_name_for_op(
  404. base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
  405. op: NSNodeTargetType,
  406. ) -> Optional[str]:
  407. for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
  408. if op in set_of_related_ops:
  409. return base_name
  410. return None
  411. def add_op_to_sets_of_related_ops(
  412. base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
  413. op: NSNodeTargetType,
  414. related_op: Optional[NSNodeTargetType],
  415. ) -> None:
  416. if related_op is not None:
  417. for base_name, set_of_related_ops in base_name_to_sets_of_related_ops.items():
  418. if related_op in set_of_related_ops:
  419. set_of_related_ops.add(op)
  420. return
  421. # if we got here, related_op was not found
  422. raise AssertionError(f"{related_op} was not found")
  423. else:
  424. counter = 0
  425. while str(counter) in base_name_to_sets_of_related_ops:
  426. counter += 1
  427. base_name_to_sets_of_related_ops[str(counter)] = {op}
  428. # TODO(future PR): clean this up
  429. def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
  430. FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
  431. F.linear,
  432. F.conv1d,
  433. F.conv2d,
  434. F.conv3d,
  435. torch.cat,
  436. F.elu,
  437. F.hardswish,
  438. F.instance_norm,
  439. F.layer_norm,
  440. F.leaky_relu,
  441. F.dropout,
  442. F.silu,
  443. F.mish,
  444. operator.add,
  445. torch.add,
  446. operator.mul,
  447. torch.mul,
  448. torch.sum,
  449. F.prelu,
  450. }
  451. FUNS_IO_TYPE_FP16: Set[NSNodeTargetType] = set()
  452. FUNS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
  453. toq.linear,
  454. toq.linear_relu,
  455. toq.conv1d,
  456. toq.conv1d_relu,
  457. toq.conv2d,
  458. toq.conv2d_relu,
  459. toq.conv3d,
  460. toq.conv3d_relu,
  461. toq.cat,
  462. toq.elu,
  463. toq.hardswish,
  464. toq.instance_norm,
  465. toq.layer_norm,
  466. toq.leaky_relu,
  467. toq.dropout,
  468. toq.prelu,
  469. # TODO(future PR): implement shadowing for binary ops and
  470. # uncomment below
  471. # toq.add,
  472. # toq.mul,
  473. }
  474. FUNS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
  475. F.relu,
  476. F.tanh,
  477. torch.tanh,
  478. F.sigmoid,
  479. torch.sigmoid,
  480. F.hardsigmoid,
  481. operator.floordiv,
  482. torch.adaptive_avg_pool1d,
  483. F.adaptive_avg_pool2d,
  484. F.adaptive_avg_pool3d,
  485. F.dropout,
  486. F.hardtanh,
  487. F.hardtanh_,
  488. F.interpolate,
  489. F.max_pool1d,
  490. F.max_pool2d,
  491. F.max_pool3d,
  492. F.relu6,
  493. torch.avg_pool1d,
  494. torch._C._nn.avg_pool2d,
  495. torch._C._nn.avg_pool3d,
  496. torch.cat,
  497. torch.chunk,
  498. torch.clamp,
  499. torch.flatten,
  500. torch.transpose,
  501. torch.max,
  502. torch.mean,
  503. torch.min,
  504. torch.repeat_interleave,
  505. torch.sort,
  506. torch.squeeze,
  507. torch.stack,
  508. torch.unsqueeze,
  509. operator.add,
  510. }
  511. MODS_IO_TYPE_FP32: Set[NSNodeTargetType] = {
  512. nn.Linear,
  513. nnqat.Linear,
  514. nnqatd.Linear,
  515. nnqd.Linear,
  516. torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
  517. nn.Conv1d,
  518. nn.Conv2d,
  519. nn.Conv3d,
  520. nnqat.Conv1d,
  521. nnqat.Conv2d,
  522. nnqat.Conv3d,
  523. nnqat.Embedding,
  524. nnqat.EmbeddingBag,
  525. nn.LSTM,
  526. # note: nnqd.Linear is an instance of nnq.Linear, so this
  527. # check has to happen before the int8 module check
  528. nnqd.LSTM,
  529. nn.BatchNorm2d,
  530. nn.BatchNorm3d,
  531. nn.Dropout,
  532. nn.ConvTranspose1d,
  533. nn.ConvTranspose2d,
  534. nn.ConvTranspose3d,
  535. nn.ELU,
  536. nn.GroupNorm,
  537. nn.InstanceNorm1d,
  538. nn.InstanceNorm2d,
  539. nn.InstanceNorm3d,
  540. nn.LayerNorm,
  541. nn.Hardswish,
  542. nn.LeakyReLU,
  543. nn.ReLU6,
  544. nn.SiLU,
  545. nn.Mish,
  546. nn.Softmax,
  547. nn.PReLU,
  548. nni.BNReLU2d,
  549. nni.BNReLU3d,
  550. nni.ConvReLU1d,
  551. nni.ConvReLU2d,
  552. nni.ConvReLU3d,
  553. nni.LinearReLU,
  554. nni.LinearBn1d,
  555. nni.ConvBn1d,
  556. nni.ConvBn2d,
  557. nni.ConvBn3d,
  558. nniqat.ConvBn1d,
  559. nniqat.ConvBn2d,
  560. nniqat.ConvBn3d,
  561. nniqat.ConvBnReLU1d,
  562. nniqat.ConvBnReLU2d,
  563. nniqat.ConvBnReLU3d,
  564. nniqat.ConvReLU1d,
  565. nniqat.ConvReLU2d,
  566. nniqat.ConvReLU3d,
  567. nniqat.LinearReLU,
  568. nniqat.LinearBn1d,
  569. nniqd.LinearReLU,
  570. nni.LinearLeakyReLU,
  571. nni.LinearTanh,
  572. nni.ConvAdd2d,
  573. nni.ConvAddReLU2d,
  574. }
  575. MODS_IO_TYPE_INT8: Set[NSNodeTargetType] = {
  576. nnq.Linear,
  577. nnq.Conv1d,
  578. nnq.Conv2d,
  579. nnq.Conv3d,
  580. nnq.BatchNorm2d,
  581. nnq.BatchNorm3d,
  582. nnq.Dropout,
  583. nnq.ConvTranspose1d,
  584. nnq.ConvTranspose2d,
  585. nnq.ELU,
  586. nnq.InstanceNorm1d,
  587. nnq.InstanceNorm2d,
  588. nnq.InstanceNorm3d,
  589. nnq.LayerNorm,
  590. nnq.Hardswish,
  591. nnq.LeakyReLU,
  592. nnq.Embedding,
  593. nnq.EmbeddingBag,
  594. nnq.Dropout,
  595. nnq.Softmax,
  596. nnq.PReLU,
  597. nniq.BNReLU2d,
  598. nniq.BNReLU3d,
  599. nniq.ConvReLU1d,
  600. nniq.ConvReLU2d,
  601. nniq.ConvReLU3d,
  602. nniq.LinearReLU,
  603. nniq.LinearLeakyReLU,
  604. nniq.LinearTanh,
  605. nniq.ConvAdd2d,
  606. nniq.ConvAddReLU2d,
  607. }
  608. MODS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
  609. nn.ReLU,
  610. nn.Tanh,
  611. nn.Sigmoid,
  612. nn.Hardsigmoid,
  613. nn.AdaptiveAvgPool1d,
  614. nn.AdaptiveAvgPool2d,
  615. nn.AdaptiveAvgPool3d,
  616. nn.AvgPool1d,
  617. nn.AvgPool2d,
  618. nn.AvgPool3d,
  619. nn.Dropout,
  620. nn.Hardtanh,
  621. nn.Identity,
  622. nn.MaxPool1d,
  623. nn.MaxPool2d,
  624. nn.MaxPool3d,
  625. nn.ReLU6,
  626. }
  627. METHS_IO_TYPE_FP32_OR_INT8: Set[NSNodeTargetType] = {
  628. 'sigmoid_',
  629. 'sigmoid',
  630. 'tanh_',
  631. 'tanh',
  632. 'hardsigmoid_',
  633. 'hardsigmoid',
  634. 'relu_',
  635. 'relu',
  636. }
  637. return {
  638. 'funs_io_type_fp32': FUNS_IO_TYPE_FP32,
  639. 'funs_io_type_fp16': FUNS_IO_TYPE_FP16,
  640. 'funs_io_type_int8': FUNS_IO_TYPE_INT8,
  641. 'funs_io_type_fp32_or_int8': FUNS_IO_TYPE_FP32_OR_INT8,
  642. 'mods_io_type_fp32': MODS_IO_TYPE_FP32,
  643. 'mods_io_type_int8': MODS_IO_TYPE_INT8,
  644. 'mods_io_type_fp32_or_int8': MODS_IO_TYPE_FP32_OR_INT8,
  645. 'meths_io_type_fp32_or_int8': METHS_IO_TYPE_FP32_OR_INT8,
  646. }
  647. def get_unmatchable_types_map() -> Dict[str, Set[NSNodeTargetType]]:
  648. FUNS_UNMATCHABLE: Set[NSNodeTargetType] = {
  649. torch.quantize_per_tensor,
  650. operator.getitem,
  651. }
  652. MODS_UNMATCHABLE: Set[NSNodeTargetType] = {
  653. nn.Identity,
  654. }
  655. METHS_UNMATCHABLE: Set[NSNodeTargetType] = {
  656. 'to',
  657. 'dequantize',
  658. 'reshape',
  659. 'view',
  660. 'unsqueeze_',
  661. 'unsqueeze',
  662. 'transpose',
  663. 'squeeze_',
  664. 'squeeze',
  665. 'size',
  666. 'shape',
  667. 'resize_',
  668. 'repeat_interleave',
  669. 'repeat',
  670. 'permute',
  671. 'numel',
  672. 'mean',
  673. 'detach_',
  674. 'detach',
  675. 'contiguous',
  676. 'clamp',
  677. 'chunk',
  678. }
  679. return {
  680. 'funs_unmatchable': FUNS_UNMATCHABLE,
  681. 'mods_unmatchable': MODS_UNMATCHABLE,
  682. 'meths_unmatchable': METHS_UNMATCHABLE,
  683. }