qnnpack.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import torch
  2. from ._common_operator_config_utils import (
  3. _get_binary_op_configs,
  4. _get_bn_configs,
  5. _get_cat_config,
  6. _get_conv_configs,
  7. _get_default_op_configs,
  8. _get_embedding_op_configs,
  9. _get_fixed_qparams_op_configs,
  10. _get_linear_configs,
  11. _get_rnn_op_configs,
  12. _get_share_qparams_op_configs,
  13. )
  14. from .backend_config import BackendConfig, DTypeConfig, DTypeWithConstraints
  15. __all__ = [
  16. "get_qnnpack_backend_config",
  17. ]
  18. # ===================
  19. # | DTYPE CONFIGS |
  20. # ===================
  21. qnnpack_weighted_op_quint8_dtype_config = DTypeConfig(
  22. input_dtype=torch.quint8,
  23. output_dtype=torch.quint8,
  24. weight_dtype=torch.qint8,
  25. bias_dtype=torch.float,
  26. )
  27. qnnpack_default_op_quint8_dtype_config = DTypeConfig(
  28. input_dtype=torch.quint8,
  29. output_dtype=torch.quint8,
  30. )
  31. qnnpack_default_op_fp16_dtype_config = DTypeConfig(
  32. input_dtype=torch.float16,
  33. output_dtype=torch.float16,
  34. weight_dtype=torch.float16,
  35. bias_dtype=torch.float16,
  36. )
  37. qnnpack_default_dynamic_int8_dtype_config = DTypeConfig(
  38. input_dtype=torch.quint8,
  39. output_dtype=torch.float,
  40. weight_dtype=torch.qint8,
  41. bias_dtype=torch.float,
  42. is_dynamic=True,
  43. )
  44. qnnpack_default_dynamic_float16_dtype_config = DTypeConfig(
  45. input_dtype=torch.float16,
  46. output_dtype=torch.float,
  47. weight_dtype=torch.float16,
  48. bias_dtype=torch.float,
  49. is_dynamic=True,
  50. )
  51. qnnpack_weight_only_quint8_dtype_config = DTypeConfig(
  52. input_dtype=torch.float,
  53. output_dtype=torch.float,
  54. weight_dtype=torch.quint8,
  55. )
  56. qnnpack_weight_only_quint4x2_dtype_config = DTypeConfig(
  57. input_dtype=torch.float,
  58. output_dtype=torch.float,
  59. weight_dtype=torch.quint4x2,
  60. )
  61. # xnnpack compatible dtype configs
  62. # We restrict scale values to be 2 ** -12 to ensure the
  63. # requantization scale never falls below the xnnpack lower
  64. # threshold. Additionally, for qint8 weight, we restrict
  65. # the quantization values to [-127, +127], excluding -128.
  66. # For more detail, refer to the description of
  67. # `default_symmetric_qnnpack_qconfig`.
  68. # TODO: add additional restriction on qscheme to ensure it
  69. # is either per_tensor_symmetric or per_channel_symmetric
  70. qnnpack_act_qint8_scale_min_2_neg_12 = DTypeWithConstraints(
  71. dtype=torch.qint8,
  72. scale_min_lower_bound=2 ** -12,
  73. )
  74. qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12 = DTypeWithConstraints(
  75. dtype=torch.qint8,
  76. quant_min_lower_bound=-127,
  77. quant_max_upper_bound=127,
  78. scale_min_lower_bound=2 ** -12,
  79. )
  80. qnnpack_weighted_op_qint8_symmetric_dtype_config = DTypeConfig(
  81. input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
  82. output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
  83. weight_dtype=qnnpack_weight_qint8_neg_127_to_127_scale_min_2_neg_12,
  84. bias_dtype=torch.float,
  85. )
  86. qnnpack_default_op_qint8_symmetric_dtype_config = DTypeConfig(
  87. input_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
  88. output_dtype=qnnpack_act_qint8_scale_min_2_neg_12,
  89. )
  90. # =====================
  91. # | BACKEND CONFIGS |
  92. # =====================
  93. def get_qnnpack_backend_config() -> BackendConfig:
  94. """
  95. Return the `BackendConfig` for PyTorch's native QNNPACK backend.
  96. """
  97. conv_dtype_configs = [
  98. qnnpack_weighted_op_qint8_symmetric_dtype_config,
  99. qnnpack_weighted_op_quint8_dtype_config,
  100. ]
  101. linear_dtype_configs = [
  102. qnnpack_weighted_op_qint8_symmetric_dtype_config,
  103. qnnpack_weighted_op_quint8_dtype_config,
  104. qnnpack_default_dynamic_int8_dtype_config,
  105. qnnpack_default_dynamic_float16_dtype_config,
  106. ]
  107. binary_op_dtype_configs = [
  108. qnnpack_default_op_qint8_symmetric_dtype_config,
  109. qnnpack_default_op_quint8_dtype_config,
  110. ]
  111. default_op_dtype_configs = [
  112. qnnpack_default_op_qint8_symmetric_dtype_config,
  113. qnnpack_default_op_quint8_dtype_config,
  114. ]
  115. fixed_qparams_op_dtype_configs = [
  116. qnnpack_default_op_qint8_symmetric_dtype_config,
  117. qnnpack_default_op_quint8_dtype_config,
  118. ]
  119. share_qparams_op_dtype_configs = [
  120. qnnpack_default_op_qint8_symmetric_dtype_config,
  121. qnnpack_default_op_quint8_dtype_config,
  122. ]
  123. rnn_op_dtype_configs = [
  124. qnnpack_default_dynamic_int8_dtype_config,
  125. qnnpack_default_dynamic_float16_dtype_config,
  126. ]
  127. embedding_op_dtype_configs = [
  128. qnnpack_weight_only_quint8_dtype_config,
  129. qnnpack_weight_only_quint4x2_dtype_config,
  130. ]
  131. return BackendConfig("qnnpack") \
  132. .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
  133. .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
  134. .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
  135. .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
  136. .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
  137. .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
  138. .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
  139. .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
  140. .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
  141. .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))