pattern_utils.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. toq = torch.ops.quantized
  5. from torch.fx import GraphModule
  6. from torch.fx.graph import Node
  7. from torch.ao.quantization.backend_config import get_native_backend_config
  8. from torch.ao.quantization.fx.quantize_handler import _get_pattern_to_quantize_handlers
  9. from torch.ao.quantization.utils import getattr_from_fqn
  10. from .ns_types import NSNodeTargetType
  11. from torch.ao.quantization import (
  12. ObserverBase,
  13. FakeQuantizeBase,
  14. )
  15. from typing import Dict, Tuple, Set, Callable, Any, Union, List
  16. def get_type_a_related_to_b(
  17. base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
  18. ) -> Set[Tuple[NSNodeTargetType, NSNodeTargetType]]:
  19. # TODO(future PR): allow customizations
  20. # TODO(future PR): reuse existing quantization mappings
  21. # TODO(future PR): add the rest of modules and ops here
  22. type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]] = set()
  23. for base_name, s in base_name_to_sets_of_related_ops.items():
  24. s_list = list(s)
  25. # add every bidirectional pair
  26. for idx_0 in range(0, len(s_list)):
  27. for idx_1 in range(idx_0, len(s_list)):
  28. type_a_related_to_b.add((s_list[idx_0], s_list[idx_1]))
  29. type_a_related_to_b.add((s_list[idx_1], s_list[idx_0]))
  30. return type_a_related_to_b
  31. NSFusionElType = Union[
  32. Callable, # call_function or call_module type, example: F.linear or nn.Conv2d
  33. str, # call_method name, example: "dequantize"
  34. Tuple[str, Any], # call_method name and first argument, example: ("to", torch.float16)
  35. ]
  36. NSFusionType = Union[
  37. Tuple[NSFusionElType, NSFusionElType],
  38. Tuple[NSFusionElType, NSFusionElType, NSFusionElType, NSFusionElType],
  39. ]
  40. def get_reversed_fusions() -> List[Tuple[NSFusionType, int]]:
  41. """
  42. Set of potential fusions, in reverse order. The order is reversed
  43. to match how fusion patterns are defined in quantization code.
  44. Fusion format:
  45. ((fusion_op_0, fusion_op_1), base_op_idx)
  46. Where base_op_idx is the idx of the op we should use to match other related
  47. ops. Note: base_op_idx is specified in non-reverse order, i.e. a base_op_idx
  48. of 0 represents the first op in regular (non-reverse) order, 1 represents the
  49. second op, etc.
  50. """
  51. results: List[Tuple[NSFusionType, int]] = []
  52. # Possible syntaxes:
  53. # * single op: torch.nn.Conv2d
  54. # * multiple ops: (torch.nn.ReLU, torch.nn.Conv2d)
  55. # For fusions, we only care about patterns composed of multiple ops.
  56. # TODO(future PR): allow customizations from default patterns.
  57. all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config())
  58. default_base_op_idx = 0
  59. for quant_pattern, _quant_handler in all_quant_patterns.items():
  60. # TODO: this is a temporary hack to flatten the patterns from quantization so
  61. # that it works with the ns matcher function, maybe we should use `_is_match`
  62. # in torch.ao.quantization.fx.match_utils to match the patterns
  63. if isinstance(quant_pattern, tuple) and len(quant_pattern) == 2 and \
  64. isinstance(quant_pattern[1], tuple) and len(quant_pattern[1]) == 2:
  65. # flatten the pattern with form (nn.ReLU, (nn.BatchNorm2d, nn.Conv2d))
  66. quant_pattern = (quant_pattern[0], quant_pattern[1][0], quant_pattern[1][1])
  67. # Only patterns of multiple ops are fusions, ignore
  68. # patterns which contain a single ops (they get matched
  69. # without caring about fusions).
  70. if isinstance(quant_pattern, tuple):
  71. results.append((quant_pattern, default_base_op_idx)) # type: ignore[arg-type]
  72. # For each pattern, add additional patterns with observers and
  73. # fake quants at the end.
  74. # TODO(future PR): if needed, implement matching for a node
  75. # having multiple output observers.
  76. for cls in (ObserverBase, FakeQuantizeBase):
  77. if isinstance(quant_pattern, tuple):
  78. new_pattern = (cls, *quant_pattern)
  79. else:
  80. new_pattern = (cls, quant_pattern)
  81. results.append((new_pattern, default_base_op_idx)) # type: ignore[arg-type]
  82. # After this point, results countains values such as
  83. # [..., ((torch.nn.Relu, torch.nn.Conv2d), 0), ...]
  84. # Patterns for matching fp16 emulation are not specified in the quantization
  85. # fusion mappings. For now, define them here.
  86. fp16_em_base_op_idx = 1
  87. patterns_to_add = [
  88. # linear-relu fp16 emulation:
  89. # fp16_to_fp32 -> linear -> relu -> fp32_to_fp16
  90. ((("to", torch.float16), F.relu, F.linear, "dequantize"), fp16_em_base_op_idx,),
  91. # Conv-BN fusion (this happens outside of quantization patterns,
  92. # which is why it is defined separately here).
  93. ((nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
  94. ((nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
  95. ((nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
  96. ((nn.ReLU, nn.BatchNorm1d, nn.Conv1d), default_base_op_idx),
  97. ((nn.ReLU, nn.BatchNorm2d, nn.Conv2d), default_base_op_idx),
  98. ((nn.ReLU, nn.BatchNorm3d, nn.Conv3d), default_base_op_idx),
  99. ]
  100. for p in patterns_to_add:
  101. results.append(p) # type: ignore[arg-type]
  102. results.append(((ObserverBase, *p[0]), p[1])) # type: ignore[arg-type]
  103. results.append(((FakeQuantizeBase, *p[0]), p[1])) # type: ignore[arg-type]
  104. return results
  105. def end_node_matches_reversed_fusion(
  106. end_node: Node,
  107. reversed_fusion: NSFusionType,
  108. gm: GraphModule,
  109. seen_nodes: Set[Node],
  110. ) -> bool:
  111. """
  112. Returns true if a pattern ending with `end_node` matches
  113. the fusion pattern.
  114. """
  115. cur_node = end_node
  116. for fusion_idx in range(len(reversed_fusion)):
  117. # each node can only belong to one matched pattern
  118. if cur_node in seen_nodes:
  119. return False
  120. cur_fusion_el = reversed_fusion[fusion_idx]
  121. if cur_node.op == 'call_function':
  122. fusion_el_is_fun = (not isinstance(cur_fusion_el, str)) and \
  123. (not isinstance(cur_fusion_el, type))
  124. if fusion_el_is_fun:
  125. if cur_node.target != cur_fusion_el:
  126. return False
  127. if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
  128. cur_node = cur_node.args[0]
  129. else:
  130. return False
  131. else:
  132. return False
  133. elif cur_node.op == 'call_module':
  134. fusion_el_is_mod = isinstance(cur_fusion_el, type)
  135. if fusion_el_is_mod:
  136. assert isinstance(cur_node.target, str)
  137. target_mod = getattr_from_fqn(gm, cur_node.target)
  138. if not isinstance(cur_fusion_el, type):
  139. return False
  140. if not isinstance(target_mod, cur_fusion_el):
  141. return False
  142. if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
  143. cur_node = cur_node.args[0]
  144. else:
  145. return False
  146. else:
  147. return False
  148. elif cur_node.op == 'call_method':
  149. fusion_el_is_meth_with_second_arg = \
  150. isinstance(cur_fusion_el, tuple) and len(cur_fusion_el) == 2
  151. fusion_el_is_meth_without_args = isinstance(cur_fusion_el, str)
  152. if fusion_el_is_meth_without_args or fusion_el_is_meth_with_second_arg:
  153. if fusion_el_is_meth_without_args:
  154. if cur_node.target != cur_fusion_el:
  155. return False
  156. else:
  157. assert isinstance(cur_fusion_el, tuple)
  158. if cur_node.target != cur_fusion_el[0]:
  159. return False
  160. elif len(cur_node.args) < 2:
  161. return False
  162. elif cur_node.args[1] != cur_fusion_el[1]:
  163. return False
  164. if len(cur_node.args) > 0 and isinstance(cur_node.args[0], Node):
  165. cur_node = cur_node.args[0]
  166. else:
  167. return False
  168. else:
  169. return False
  170. else:
  171. return False
  172. return True