native.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. import torch
  2. from ._common_operator_config_utils import (
  3. _get_binary_op_configs,
  4. _get_bn_configs,
  5. _get_cat_config,
  6. _get_conv_configs,
  7. _get_default_op_configs,
  8. _get_embedding_op_configs,
  9. _get_fixed_qparams_op_configs,
  10. _get_linear_configs,
  11. _get_ln_configs,
  12. _get_rnn_op_configs,
  13. _get_share_qparams_op_configs,
  14. _get_tensor_info_op_configs,
  15. )
  16. from .backend_config import BackendConfig, DTypeConfig
  17. __all__ = [
  18. "get_test_only_legacy_native_backend_config",
  19. "default_op_quint8_dtype_config",
  20. "default_op_fp16_dtype_config",
  21. "default_dynamic_int8_dtype_config",
  22. "default_dynamic_float16_dtype_config",
  23. "input_output_only_quint8_dtype_config",
  24. "weight_only_quint8_dtype_config",
  25. "weight_only_quint4x2_dtype_config",
  26. "get_native_backend_config",
  27. "get_native_backend_config_dict",
  28. "get_test_only_legacy_native_backend_config_dict",
  29. ]
  30. # ===================
  31. # | DTYPE CONFIGS |
  32. # ===================
  33. # weighted op int8 dtype config
  34. # this is config for ops that has quantized weights, like linear, conv
  35. weighted_op_quint8_dtype_config = DTypeConfig(
  36. input_dtype=torch.quint8,
  37. output_dtype=torch.quint8,
  38. weight_dtype=torch.qint8,
  39. bias_dtype=torch.float,
  40. )
  41. default_op_quint8_dtype_config = DTypeConfig(
  42. input_dtype=torch.quint8,
  43. output_dtype=torch.quint8,
  44. )
  45. default_op_fp16_dtype_config = DTypeConfig(
  46. input_dtype=torch.float16,
  47. output_dtype=torch.float16,
  48. weight_dtype=torch.float16,
  49. bias_dtype=torch.float16,
  50. )
  51. default_dynamic_int8_dtype_config = DTypeConfig(
  52. input_dtype=torch.quint8,
  53. output_dtype=torch.float,
  54. weight_dtype=torch.qint8,
  55. bias_dtype=torch.float,
  56. # currently the dtype check is not yet enabled, so we provided the dtype_configs but
  57. # it is not really used yet,
  58. # we will enable it a bit later after we moved everything to backend_config_dict
  59. is_dynamic=True,
  60. )
  61. default_dynamic_float16_dtype_config = DTypeConfig(
  62. input_dtype=torch.float16,
  63. output_dtype=torch.float,
  64. weight_dtype=torch.float16,
  65. bias_dtype=torch.float,
  66. # currently the dtype check is not yet enabled, so we provided the dtype_configs but
  67. # it is not really used yet,
  68. # we will enable it a bit later after we moved everything to backend_config_dict
  69. is_dynamic=True,
  70. )
  71. # Needed for LayerNorm and f.layer_norm, since currently the kernel only supports float weights
  72. input_output_only_quint8_dtype_config = DTypeConfig(
  73. input_dtype=torch.quint8,
  74. output_dtype=torch.quint8,
  75. weight_dtype=torch.float,
  76. bias_dtype=torch.float,
  77. )
  78. weight_only_quint8_dtype_config = DTypeConfig(
  79. input_dtype=torch.float,
  80. output_dtype=torch.float,
  81. weight_dtype=torch.quint8,
  82. )
  83. weight_only_quint4x2_dtype_config = DTypeConfig(
  84. input_dtype=torch.float,
  85. output_dtype=torch.float,
  86. weight_dtype=torch.quint4x2,
  87. )
  88. # =====================
  89. # | BACKEND CONFIGS |
  90. # =====================
  91. def get_test_only_legacy_native_backend_config() -> BackendConfig:
  92. """
  93. Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional fp16 ops.
  94. """
  95. conv_dtype_configs = [weighted_op_quint8_dtype_config]
  96. linear_dtype_configs = [
  97. weighted_op_quint8_dtype_config,
  98. default_dynamic_int8_dtype_config,
  99. default_dynamic_float16_dtype_config,
  100. default_op_fp16_dtype_config,
  101. ]
  102. binary_op_dtype_configs = [
  103. default_op_quint8_dtype_config,
  104. default_op_fp16_dtype_config,
  105. ]
  106. default_op_dtype_configs = [default_op_quint8_dtype_config]
  107. fixed_qparams_op_dtype_configs = [
  108. default_op_quint8_dtype_config,
  109. default_op_fp16_dtype_config,
  110. ]
  111. share_qparams_op_dtype_configs = [
  112. default_op_quint8_dtype_config,
  113. default_op_fp16_dtype_config
  114. ]
  115. tensor_info_op_dtype_configs = [
  116. default_op_quint8_dtype_config,
  117. ]
  118. rnn_op_dtype_configs = [
  119. default_dynamic_int8_dtype_config,
  120. default_dynamic_float16_dtype_config,
  121. ]
  122. embedding_op_dtype_configs = [
  123. weight_only_quint8_dtype_config,
  124. weight_only_quint4x2_dtype_config,
  125. ]
  126. layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
  127. return BackendConfig("_native_and_fp16") \
  128. .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
  129. .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
  130. .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
  131. .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
  132. .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
  133. .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
  134. .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
  135. .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
  136. .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
  137. .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
  138. .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
  139. .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
  140. def get_native_backend_config() -> BackendConfig:
  141. """
  142. Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack).
  143. """
  144. # TODO: express this BackendConfig as a union of the FBGEMM and QNNPACK BackendConfigs
  145. conv_dtype_configs = [weighted_op_quint8_dtype_config]
  146. linear_dtype_configs = [
  147. weighted_op_quint8_dtype_config,
  148. default_dynamic_int8_dtype_config,
  149. default_dynamic_float16_dtype_config,
  150. ]
  151. binary_op_dtype_configs = [default_op_quint8_dtype_config]
  152. default_op_dtype_configs = [default_op_quint8_dtype_config]
  153. fixed_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
  154. share_qparams_op_dtype_configs = [default_op_quint8_dtype_config]
  155. tensor_info_op_dtype_configs = [default_op_quint8_dtype_config]
  156. rnn_op_dtype_configs = [
  157. default_dynamic_int8_dtype_config,
  158. default_dynamic_float16_dtype_config,
  159. ]
  160. embedding_op_dtype_configs = [
  161. weight_only_quint8_dtype_config,
  162. weight_only_quint4x2_dtype_config,
  163. ]
  164. layer_norm_op_dtype_configs = [input_output_only_quint8_dtype_config]
  165. return BackendConfig("native") \
  166. .set_backend_pattern_configs(_get_conv_configs(conv_dtype_configs)) \
  167. .set_backend_pattern_configs(_get_linear_configs(linear_dtype_configs)) \
  168. .set_backend_pattern_configs(_get_binary_op_configs(binary_op_dtype_configs)) \
  169. .set_backend_pattern_config(_get_cat_config(default_op_dtype_configs)) \
  170. .set_backend_pattern_configs(_get_default_op_configs(default_op_dtype_configs)) \
  171. .set_backend_pattern_configs(_get_fixed_qparams_op_configs(fixed_qparams_op_dtype_configs)) \
  172. .set_backend_pattern_configs(_get_share_qparams_op_configs(share_qparams_op_dtype_configs)) \
  173. .set_backend_pattern_configs(_get_tensor_info_op_configs(tensor_info_op_dtype_configs)) \
  174. .set_backend_pattern_configs(_get_bn_configs(default_op_dtype_configs)) \
  175. .set_backend_pattern_configs(_get_ln_configs(layer_norm_op_dtype_configs)) \
  176. .set_backend_pattern_configs(_get_rnn_op_configs(rnn_op_dtype_configs)) \
  177. .set_backend_pattern_configs(_get_embedding_op_configs(embedding_op_dtype_configs))
  178. def get_native_backend_config_dict():
  179. """
  180. Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) in dictionary form.
  181. """
  182. return get_native_backend_config().to_dict()
  183. def get_test_only_legacy_native_backend_config_dict():
  184. """
  185. Return the `BackendConfig` for PyTorch Native backend (fbgemm/qnnpack) with various additional
  186. fp16 ops in dictionary form.
  187. """
  188. return get_test_only_legacy_native_backend_config().to_dict()