match_utils.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import sys
  2. import torch
  3. from torch.fx.graph import (
  4. Graph,
  5. Node,
  6. )
  7. from torch.ao.quantization.utils import Pattern
  8. from .quantize_handler import (
  9. QuantizeHandler,
  10. )
  11. from ..qconfig import (
  12. QConfigAny,
  13. )
  14. from ..utils import (
  15. MatchAllNode
  16. )
  17. from .graph_module import (
  18. _is_observed_standalone_module,
  19. )
  20. from torch.nn.utils.parametrize import type_before_parametrizations
  21. from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
  22. __all__: List[str] = []
  23. # TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
  24. # would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
  25. _MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
  26. _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
  27. QConfigAny]
  28. # Note: The order of patterns is important! match function will take whatever is matched first, so we'll
  29. # need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
  30. # decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
  31. # we'll start from the last node of the graph and traverse back.
  32. def _is_match(modules, node, pattern, max_uses=sys.maxsize):
  33. """ Matches a node in fx against a pattern
  34. """
  35. if isinstance(pattern, tuple):
  36. self_match, *arg_matches = pattern
  37. if self_match is getattr:
  38. assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
  39. arg_matches = []
  40. else:
  41. self_match = pattern
  42. arg_matches = []
  43. if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
  44. return True
  45. if node == pattern:
  46. return True
  47. if not isinstance(node, Node) or len(node.users) > max_uses:
  48. return False
  49. if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
  50. if node.op != 'call_module':
  51. return False
  52. if not type_before_parametrizations(modules[node.target]) == self_match:
  53. return False
  54. elif callable(self_match):
  55. if node.op != 'call_function' or node.target is not self_match:
  56. return False
  57. elif node.target is getattr:
  58. if node.args[1] != pattern[1]:
  59. return False
  60. elif isinstance(self_match, str):
  61. if node.op != 'call_method' or node.target != self_match:
  62. return False
  63. elif node.target != self_match:
  64. return False
  65. if not arg_matches:
  66. return True
  67. if len(arg_matches) != len(node.args):
  68. return False
  69. return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
  70. def _find_matches(
  71. graph: Graph,
  72. modules: Dict[str, torch.nn.Module],
  73. patterns: Dict[Pattern, QuantizeHandler],
  74. root_node_getter_mapping: Dict[Pattern, Callable],
  75. standalone_module_names: List[str] = None,
  76. standalone_module_classes: List[Type] = None,
  77. custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
  78. """
  79. Matches the nodes in the input graph to quantization patterns, and
  80. outputs the information needed to quantize them in future steps.
  81. Inputs:
  82. - graph: an fx.Graph object
  83. - modules: a mapping of fully qualified module name to instance,
  84. for example, {'foo': ModuleFoo, ...}
  85. - patterns: a mapping from a tuple of nodes in reverse order to
  86. uninitialized QuantizeHandler subclass.
  87. Outputs a map of
  88. node_name ->
  89. (node, matched_values, matched_pattern, QuantizeHandler instance,
  90. qconfig)
  91. For example, {
  92. 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
  93. <CopyNodeQuantizeHandler instance>, QConfig(...)),
  94. ...
  95. }
  96. """
  97. if custom_module_classes is None:
  98. custom_module_classes = []
  99. if standalone_module_classes is None:
  100. standalone_module_classes = []
  101. if standalone_module_names is None:
  102. standalone_module_names = []
  103. match_map: Dict[str, _MatchResult] = {}
  104. all_matched : Set[str] = set()
  105. def _recursive_record_node_in_match_map(
  106. last_node,
  107. match_map,
  108. node_pattern,
  109. matched_node_pattern,
  110. pattern,
  111. match_value):
  112. if isinstance(node_pattern, Node):
  113. match_map[node_pattern.name] = (
  114. last_node, matched_node_pattern, pattern, match_value)
  115. elif not isinstance(node_pattern, Iterable):
  116. return
  117. else:
  118. for n in node_pattern:
  119. _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
  120. # TODO: 1. merge with fuse matcher 2. document the code
  121. def record_match(
  122. pattern,
  123. node,
  124. last_node,
  125. matched_node_pattern,
  126. match_map):
  127. if isinstance(pattern, tuple):
  128. s, *args = pattern
  129. is_single_arg = len(args) == 1
  130. current_node_pattern: List[Node] = []
  131. record_match(
  132. s,
  133. node,
  134. last_node,
  135. matched_node_pattern,
  136. match_map)
  137. if pattern[0] is not getattr:
  138. for subpattern, arg in zip(args, node.args):
  139. record_match(
  140. subpattern,
  141. arg,
  142. node,
  143. current_node_pattern,
  144. match_map)
  145. if len(current_node_pattern) > 1:
  146. # current_node_pattern is the node pattern we get from matching
  147. # the subpattern with arguments of the node
  148. # we use is_single_arg to recover the original structure of the pattern
  149. # if the original pattern has a single argument, we will have
  150. # (original_op, (original_arg, ...))
  151. # otherwise, we'll have a list of arguments
  152. # (original_op, arg0, arg1, arg2, ...)
  153. if is_single_arg:
  154. matched_node_pattern.append(tuple(current_node_pattern))
  155. else:
  156. matched_node_pattern.extend(list(current_node_pattern))
  157. else:
  158. matched_node_pattern.append(current_node_pattern[0])
  159. else:
  160. matched_node_pattern.append(node)
  161. for node in reversed(graph.nodes):
  162. if node.name not in match_map and node.name not in all_matched:
  163. for pattern, quantize_handler_cls in patterns.items():
  164. root_node_getter = root_node_getter_mapping.get(pattern, None)
  165. if _is_match(modules, node, pattern) and node.name not in match_map:
  166. matched_node_pattern: List[Node] = []
  167. record_match(
  168. pattern,
  169. node,
  170. node,
  171. matched_node_pattern,
  172. match_map)
  173. quantize_handler = quantize_handler_cls( # type: ignore[operator]
  174. matched_node_pattern,
  175. modules,
  176. root_node_getter)
  177. last_node = node
  178. # record the match for all nodes in the pattern
  179. _recursive_record_node_in_match_map(
  180. last_node,
  181. match_map,
  182. # we need to record all nodes in the matched pattern in the match_map
  183. matched_node_pattern,
  184. # this is a part of the value corresponding to the node
  185. matched_node_pattern,
  186. pattern,
  187. quantize_handler)
  188. break
  189. # add custom module instances to the match result
  190. assert modules is not None
  191. for node in graph.nodes:
  192. if node.op == 'call_module' and \
  193. type(modules[node.target]) in custom_module_classes:
  194. match_map[node.name] = (
  195. node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
  196. def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
  197. assert modules is not None
  198. return (
  199. node_target in standalone_module_names or # type: ignore[operator]
  200. type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
  201. )
  202. # add standalone modules to the match
  203. for node in graph.nodes:
  204. if node.op == 'call_module' and \
  205. (is_standalone_module(node.target, modules) or
  206. _is_observed_standalone_module(modules[node.target])):
  207. # add node to matched nodes
  208. match_map[node.name] = (
  209. node, node, None,
  210. QuantizeHandler(node, modules, is_standalone_module=True))
  211. return match_map