quantize_handler.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import torch
  2. from torch.fx.graph import (
  3. Node,
  4. )
  5. from .utils import (
  6. all_node_args_have_no_tensors,
  7. )
  8. from torch.ao.quantization.backend_config import (
  9. BackendConfig,
  10. DTypeConfig,
  11. ObservationType,
  12. )
  13. from torch.ao.quantization.utils import (
  14. NodePattern,
  15. Pattern,
  16. QuantizerCls,
  17. )
  18. from abc import ABC
  19. from typing import Callable, Dict, List, Type
  20. __all__ = [
  21. "QuantizeHandler",
  22. "BinaryOpQuantizeHandler",
  23. "CatQuantizeHandler",
  24. "ConvReluQuantizeHandler",
  25. "LinearReLUQuantizeHandler",
  26. "BatchNormQuantizeHandler",
  27. "EmbeddingQuantizeHandler",
  28. "RNNDynamicQuantizeHandler",
  29. "DefaultNodeQuantizeHandler",
  30. "FixedQParamsOpQuantizeHandler",
  31. "CopyNodeQuantizeHandler",
  32. "GeneralTensorShapeOpQuantizeHandler",
  33. "CustomModuleQuantizeHandler",
  34. "StandaloneModuleQuantizeHandler",
  35. ]
  36. def _default_root_node_getter(node_pattern):
  37. if node_pattern is None:
  38. return node_pattern
  39. while not isinstance(node_pattern, Node):
  40. node_pattern = node_pattern[-1]
  41. return node_pattern
  42. # Base Pattern Handler
  43. class QuantizeHandler(ABC):
  44. """ Base handler class for the quantizer patterns
  45. """
  46. def __init__(
  47. self,
  48. node_pattern: NodePattern,
  49. modules: Dict[str, torch.nn.Module],
  50. root_node_getter: Callable = None,
  51. is_custom_module=False,
  52. is_standalone_module=False):
  53. """ Records pattern information in __init__, which will be used
  54. in convert
  55. """
  56. self.node_pattern = node_pattern
  57. self.modules = modules
  58. if root_node_getter is None:
  59. root_node_getter = _default_root_node_getter
  60. self.root_node = root_node_getter(node_pattern)
  61. self.is_custom_module_ = is_custom_module
  62. self.is_standalone_module_ = is_standalone_module
  63. self.num_tensor_args = 0
  64. # determine how many of the first two args are Tensors (versus scalars)
  65. # this distinguishes things like "x + y" from "x + 2" or "2 + x"
  66. if isinstance(self.root_node, Node):
  67. cache_for_no_tensor_check: Dict[Node, bool] = {}
  68. for arg_idx in range(len(self.root_node.args)):
  69. arg = self.root_node.args[arg_idx]
  70. if isinstance(arg, Node) and (
  71. not all_node_args_have_no_tensors(
  72. arg, self.modules, cache_for_no_tensor_check)):
  73. self.num_tensor_args += 1
  74. def is_general_tensor_value_op(self) -> bool:
  75. """
  76. Returns True if the operator works for both floating point and
  77. quantized input, and does some computation based on the input Tensor,
  78. or the ops that only re-arranges the Tensor values or query some metadata
  79. about the Tensor
  80. so we need to insert observer/fake_quant for the output of the
  81. operator (same observer instance as input)
  82. since the distribution of values is different for input and output
  83. Tensors (for HistogramObserver) while they share the same quantization
  84. parameters
  85. Example operator: avgpool2d, reshape, transpose, maxpool2d
  86. Example observed operator:
  87. observer_0 - avgpool2d - observer_0 (same observer instance as input)
  88. """
  89. return False
  90. def is_custom_module(self):
  91. return self.is_custom_module_
  92. def is_standalone_module(self):
  93. return self.is_standalone_module_
  94. def _get_quantize_handler_cls(
  95. observation_type: ObservationType,
  96. dtype_configs: List[DTypeConfig],
  97. num_tensor_args_to_observation_type: Dict[int, ObservationType]) -> Type[QuantizeHandler]:
  98. """
  99. Return a configurable QuantizeHandler that matches the given specifications from the backend.
  100. """
  101. class ConfigurableQuantizeHandler(QuantizeHandler):
  102. def __init__(
  103. self,
  104. node_pattern: NodePattern,
  105. modules: Dict[str, torch.nn.Module],
  106. root_node_getter: Callable = None):
  107. super().__init__(node_pattern, modules, root_node_getter)
  108. if num_tensor_args_to_observation_type:
  109. assert self.num_tensor_args in num_tensor_args_to_observation_type, \
  110. f"Must provide observation_type config for tensor number {self.num_tensor_args}" \
  111. f" in num_tensor_args_to_observation_type for {node_pattern}"
  112. self.observation_type = num_tensor_args_to_observation_type[self.num_tensor_args]
  113. else:
  114. self.observation_type = observation_type
  115. self.dtype_configs = dtype_configs
  116. def is_general_tensor_value_op(self) -> bool:
  117. return self.observation_type == ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT
  118. return ConfigurableQuantizeHandler
  119. def _get_pattern_to_quantize_handlers(backend_config: BackendConfig) -> Dict[Pattern, QuantizerCls]:
  120. """
  121. Note: Quantize handler is just a holder for some check methods like
  122. (should_insert_observer_for_output), maybe this can be a enum as well,
  123. we can refactor this after we convert the path for fbgemm/qnnpack fully to the
  124. new path, this is not exposed to backend developers
  125. """
  126. pattern_to_quantize_handlers = {}
  127. for pattern, config in backend_config._pattern_complex_format_to_config.items():
  128. observation_type = config.observation_type
  129. dtype_configs = config.dtype_configs
  130. num_tensor_args_to_observation_type = config._num_tensor_args_to_observation_type
  131. pattern_to_quantize_handlers[pattern] = \
  132. _get_quantize_handler_cls(
  133. observation_type,
  134. dtype_configs,
  135. num_tensor_args_to_observation_type)
  136. return pattern_to_quantize_handlers
  137. # TODO: remove this class, this is still exposed in torch.ao.quantization
  138. # but we should be able to break bc
  139. class BinaryOpQuantizeHandler(QuantizeHandler):
  140. pass
  141. class CatQuantizeHandler(QuantizeHandler):
  142. pass
  143. # TODO: remove this class
  144. class ConvReluQuantizeHandler(QuantizeHandler):
  145. pass
  146. # TODO: remove this class
  147. class LinearReLUQuantizeHandler(QuantizeHandler):
  148. pass
  149. # TODO: remove this class
  150. class BatchNormQuantizeHandler(QuantizeHandler):
  151. pass
  152. # TODO: remove this class
  153. class EmbeddingQuantizeHandler(QuantizeHandler):
  154. pass
  155. # TODO: remove this class
  156. class RNNDynamicQuantizeHandler(QuantizeHandler):
  157. pass
  158. # TODO: remove this class
  159. class DefaultNodeQuantizeHandler(QuantizeHandler):
  160. """ Common quantized op, first input and first output will be quantized
  161. """
  162. pass
  163. # TODO: remove this class
  164. class FixedQParamsOpQuantizeHandler(QuantizeHandler):
  165. pass
  166. # TODO: remove
  167. class CopyNodeQuantizeHandler(QuantizeHandler):
  168. pass
  169. # TODO: remove
  170. class GeneralTensorShapeOpQuantizeHandler(QuantizeHandler):
  171. pass
  172. # TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
  173. class CustomModuleQuantizeHandler(QuantizeHandler):
  174. pass
  175. # TODO: not used, can be removed after torch.ao.quantization namespace is deprecated
  176. class StandaloneModuleQuantizeHandler(QuantizeHandler):
  177. pass