graph_matcher.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. import collections
  2. import enum
  3. import torch
  4. toq = torch.ops.quantized
  5. from torch.fx import GraphModule
  6. from torch.fx.graph import Graph, Node
  7. from torch.ao.quantization.utils import getattr_from_fqn
  8. from .ns_types import NSSubgraph, NSNodeTargetType
  9. from .mappings import (
  10. get_base_name_to_sets_of_related_ops,
  11. get_unmatchable_types_map,
  12. )
  13. from .pattern_utils import (
  14. get_type_a_related_to_b,
  15. get_reversed_fusions,
  16. end_node_matches_reversed_fusion,
  17. )
  18. from torch.ao.quantization import (
  19. ObserverBase,
  20. FakeQuantizeBase,
  21. )
  22. from typing import Dict, Tuple, List, Optional, Set, Any
  23. def _get_output_nodes(g: Graph) -> List[Node]:
  24. return [n for n in g.nodes if n.op == 'output']
  25. class _NSGraphMatchableSubgraphsIterator:
  26. """
  27. Iterates through the graph of gm, starting with the output nodes
  28. and continuing backwards.
  29. 1. Returns matchable subgraphs, in order. A subgraph is defined by
  30. (start_node, end_node).
  31. 2. Skips over non-matchable subgraphs
  32. """
  33. def __init__(
  34. self,
  35. gm: GraphModule,
  36. non_matchable_functions: Set[NSNodeTargetType],
  37. non_matchable_modules: Set[NSNodeTargetType],
  38. non_matchable_methods: Set[NSNodeTargetType],
  39. ):
  40. self.gm: GraphModule = gm
  41. self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
  42. self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
  43. self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
  44. self.seen_nodes: Set[Node] = set()
  45. self.stack: List[Node] = []
  46. for start_node in _get_output_nodes(self.gm.graph):
  47. self.stack.append(start_node)
  48. def __iter__(self):
  49. return self
  50. def __next__(self) -> NSSubgraph:
  51. """
  52. Returns the next matchable subgraph.
  53. """
  54. while len(self.stack) > 0:
  55. cur_end_node = self.stack.pop()
  56. if cur_end_node in self.seen_nodes:
  57. continue
  58. # for subgraphs which are single nodes, start_node == end_node
  59. # for subgraphs with more than one node, start node != end_node
  60. cur_start_node = cur_end_node
  61. # Subgraphs like linear-relu have the base node as the start node.
  62. # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
  63. # base node as the second node.
  64. # The cur_base_op_node var will move to the actual node during
  65. # the fusion matching later in this code block.
  66. cur_base_op_node = cur_end_node
  67. # Check for potential fusions. For now, we are greedy
  68. # and always skip all non-base nodes of a fusion. For example,
  69. # if we match linear-relu backwards, we will always skip the
  70. # relu node and attempt to match the linear node. This can
  71. # be made configurable later if needed.
  72. for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
  73. is_match = end_node_matches_reversed_fusion(
  74. cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes)
  75. if is_match:
  76. # navigate to the base node
  77. for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
  78. self.seen_nodes.add(cur_start_node)
  79. # for now, assume that there are no other nodes
  80. # which need to be added to the stack
  81. cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
  82. # if the base op index matches the current node, set it
  83. rev_base_op_idx = \
  84. len(_reverse_fusion_ops) - 2 - base_op_idx
  85. if rev_fusion_idx == rev_base_op_idx:
  86. cur_base_op_node = cur_start_node
  87. break
  88. self.seen_nodes.add(cur_start_node)
  89. # add args of previous nodes to stack
  90. for arg in cur_start_node.all_input_nodes:
  91. self._recursively_add_node_arg_to_stack(arg)
  92. # skip unmatchable nodes
  93. # note: this check is done on the start_node, i.e.
  94. # if we are matching linear-relu in reverse, this would do the matchable
  95. # check on the linear
  96. if not self._is_matchable(cur_base_op_node):
  97. continue
  98. # If an observer or a fake_quant was not matched as a part of
  99. # a pattern of multiple nodes, ignore it. One case where this is
  100. # relevant is an observer on a graph input, which was added because
  101. # it is necessary for the next node.
  102. if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
  103. maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
  104. if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
  105. continue
  106. return NSSubgraph(
  107. start_node=cur_start_node, end_node=cur_end_node,
  108. base_op_node=cur_base_op_node)
  109. raise StopIteration
  110. def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
  111. """
  112. Adds all of the nodes in this arg to the stack, properly navigating
  113. through list, dicts and tuples.
  114. """
  115. if isinstance(arg, Node):
  116. self.stack.append(arg)
  117. elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
  118. for inner_arg in arg:
  119. self._recursively_add_node_arg_to_stack(inner_arg)
  120. elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
  121. for key, value in arg.items():
  122. self._recursively_add_node_arg_to_stack(value)
  123. def _is_matchable(self, node: Node) -> bool:
  124. if node.op == 'call_function':
  125. return not (node.target in self.non_matchable_functions)
  126. elif node.op == 'call_module':
  127. assert isinstance(node.target, str)
  128. target_mod = getattr_from_fqn(self.gm, node.target)
  129. return not \
  130. any(isinstance(target_mod, t) # type: ignore[arg-type]
  131. for t in self.non_matchable_modules)
  132. elif node.op == 'call_method':
  133. return not (node.target in self.non_matchable_methods)
  134. else:
  135. return False
  136. class GraphMatchingException(Exception):
  137. """
  138. Exception raised when two graphs cannot be matched.
  139. """
  140. pass
  141. class SubgraphTypeRelationship(enum.Enum):
  142. # same type, known
  143. # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
  144. EQUAL = enum.auto()
  145. # same type, but the type is not known to Numerical Suite
  146. # (user defined type, etc).
  147. EQUAL_BUT_UKNOWN = enum.auto()
  148. # known, same subgraph_relationship set, but not the same type
  149. # example: F.linear and toq.linear
  150. RELATED_BUT_NOT_EQUAL = enum.auto()
  151. # not related
  152. NOT_RELATED = enum.auto()
  153. def _get_subgraph_relationship_type(
  154. subgraph_a: NSSubgraph,
  155. subgraph_b: NSSubgraph,
  156. gm_a: GraphModule,
  157. gm_b: GraphModule,
  158. type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
  159. ) -> SubgraphTypeRelationship:
  160. node_a = subgraph_a.base_op_node
  161. node_b = subgraph_b.base_op_node
  162. # TODO(next): make this code handle matching by what is before the base op
  163. if node_a.op != node_b.op:
  164. if not (
  165. node_a.op in ('call_function', 'call_method') and
  166. node_b.op in ('call_function', 'call_method')
  167. ):
  168. return SubgraphTypeRelationship.NOT_RELATED
  169. if node_a.op in ('call_function', 'call_method'):
  170. key = (node_a.target, node_b.target)
  171. if key not in type_a_related_to_b:
  172. if node_a.target == node_b.target:
  173. return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
  174. else:
  175. return SubgraphTypeRelationship.NOT_RELATED
  176. # after this point, we are dealing with known types
  177. if node_a.target == node_b.target:
  178. node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
  179. node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
  180. if node_a_has_prev and (not node_b_has_prev):
  181. return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
  182. elif (not node_a_has_prev) and node_b_has_prev:
  183. return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
  184. elif (not node_a_has_prev) and (not node_b_has_prev):
  185. return SubgraphTypeRelationship.EQUAL
  186. else:
  187. # TODO(future PR): check for matches start_op_node and base_op_node
  188. return SubgraphTypeRelationship.EQUAL
  189. if key in type_a_related_to_b:
  190. return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
  191. else:
  192. return SubgraphTypeRelationship.NOT_RELATED
  193. elif node_a.op == 'call_module':
  194. assert (subgraph_a.base_op_node == subgraph_a.start_node and
  195. subgraph_b.base_op_node == subgraph_b.start_node), \
  196. "Matching call_module patterns where base_op_node != start_node is not supported yet"
  197. # for call_module, we need to look up the modules to do the type check
  198. assert isinstance(node_a.target, str)
  199. mod_a = getattr_from_fqn(gm_a, node_a.target)
  200. assert isinstance(node_b.target, str)
  201. mod_b = getattr_from_fqn(gm_b, node_b.target)
  202. key = (type(mod_a), type(mod_b))
  203. if key not in type_a_related_to_b:
  204. if type(mod_a) == type(mod_b):
  205. return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
  206. else:
  207. return SubgraphTypeRelationship.NOT_RELATED
  208. elif type(mod_a) == type(mod_b):
  209. return SubgraphTypeRelationship.EQUAL
  210. else:
  211. return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
  212. return SubgraphTypeRelationship.NOT_RELATED
  213. def _get_name_for_subgraph(
  214. subgraph_a: NSSubgraph,
  215. gm_a: GraphModule,
  216. base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
  217. existing_names: Set[str],
  218. ) -> str:
  219. """
  220. Returns a unique name for a subgraph. This name is based on two things:
  221. 1. the name of the set containing the underlying type of the base op in the
  222. subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
  223. 2. the number of previous subgraphs with related underlying type of the base op
  224. For example, in the graph
  225. linear0 -> relu0 -> linear1 -> relu1
  226. The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
  227. from the output node backwards, the name given to (linear1, relu1) will be
  228. `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
  229. will be `base_op_torch.nn.functional.linear_1`.
  230. Why are we not just using the node name? Answer: because of two requirements:
  231. A. fusions must be supported
  232. B. some Numeric Suite APIs can be called without having all of the models in memory
  233. For example, let's say we need to match nodes of
  234. (1) ... -> linear0 -> relu0 -> ...
  235. And
  236. (2) ... -> linear_relu0 -> ...
  237. Without being able to inspect them together. With the current naming scheme, if
  238. we iterate through both of these graphs in the same order, and assuming the rest
  239. of the graphs match, both of these subgraphs will get the same name without
  240. (1) and (2) knowing anything about each other.
  241. """
  242. target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
  243. target_base_type = None
  244. for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
  245. if target_type in sets_of_related_ops:
  246. target_base_type = base_name
  247. target_base_name = 'base_op_' + str(target_base_type)
  248. counter = 0
  249. proposed_name = target_base_name + '_' + str(counter)
  250. while proposed_name in existing_names:
  251. counter += 1
  252. proposed_name = target_base_name + '_' + str(counter)
  253. existing_names.add(proposed_name)
  254. return proposed_name
  255. def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
  256. if node.op in ('call_function', 'call_method'):
  257. return node.target
  258. elif node.op == 'call_module':
  259. assert isinstance(node.target, str)
  260. mod = getattr_from_fqn(gm, node.target)
  261. return type(mod)
  262. return None
  263. def get_matching_subgraph_pairs(
  264. gm_a: GraphModule,
  265. gm_b: GraphModule,
  266. base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  267. unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
  268. ) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
  269. """
  270. Matches matchable subgraphs of graph_a to graph_b.
  271. For a node, "matchable" is defined as a node which is not an observer,
  272. fake_quants, quant or dequant.
  273. A subgraph can contain one or more nodes. A subgraph is matchable if
  274. at least one node inside of it is matchable. Currently, all nodes in
  275. a subgraph must be matchable (because we assume no observers will be
  276. inserted in the middle of a fusion).
  277. A subgraph is defined by (start_node, end_node). We assume that only
  278. start_node and end_node are linked with the surrounding graph, all other
  279. nodes in a subgraph are self-contained.
  280. A pair of nodes is "related" if both nodes represent the same mathematical
  281. operation across different quantization flavors. For example,
  282. `F.linear` and `torch.ops.quantized.linear` are related, and
  283. `F.linear` and `torch.nn.Conv` are not related.
  284. For each matchable pair of nodes node_a and node_b, they will match
  285. if node_a and node_b are related.
  286. For graphs A and B, they will match iff:
  287. 1. the number of matchable subgraphs in A and B is equivalent
  288. 2. when iterating through the matchable subgraphs of A and B in the same order, each
  289. corresponding pair of base nodes is related.
  290. This enables us to find the corresponding subgraphs between
  291. graphs of related models. For example, if we had two graphs such as:
  292. graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
  293. w -/
  294. b -/
  295. graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
  296. packed_params_0 -/
  297. This function will return the following result:
  298. {
  299. 'conv_0': ( # the name of the node in graph_b
  300. (conv_0, conv_0), # (start_node_a, end_node_a)
  301. (qconv_0, qconv_0), # (start_node_b, end_node_b)
  302. ),
  303. }
  304. Or, if we have a fusion pattern,
  305. graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
  306. w -/
  307. b -/
  308. graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
  309. packed_params_0 -/
  310. This function will return the following result:
  311. {
  312. 'linear_relu_0': ( # the name of the node in graph_b
  313. (linear_0, relu_0), # (start_node_a, end_node_a)
  314. (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
  315. ),
  316. }
  317. """
  318. if unmatchable_types_map is None:
  319. unmatchable_types_map = get_unmatchable_types_map()
  320. non_matchable_functions = unmatchable_types_map['funs_unmatchable']
  321. non_matchable_modules = unmatchable_types_map['mods_unmatchable']
  322. non_matchable_methods = unmatchable_types_map['meths_unmatchable']
  323. graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
  324. gm_a, non_matchable_functions, non_matchable_modules,
  325. non_matchable_methods)
  326. graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
  327. gm_b, non_matchable_functions, non_matchable_modules,
  328. non_matchable_methods)
  329. results = collections.OrderedDict()
  330. if base_name_to_sets_of_related_ops is None:
  331. base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
  332. type_a_related_to_b = \
  333. get_type_a_related_to_b(base_name_to_sets_of_related_ops)
  334. existing_names_a: Set[str] = set()
  335. existing_names_b: Set[str] = set()
  336. while True:
  337. # fetch the next subgraphs from a and b
  338. cur_subgraph_a, cur_subgraph_b = None, None
  339. try:
  340. cur_subgraph_a = next(graph_a_iterator)
  341. except StopIteration:
  342. pass
  343. try:
  344. cur_subgraph_b = next(graph_b_iterator)
  345. except StopIteration:
  346. pass
  347. # look up types of a and b for useful error messages
  348. type_start_a, type_start_b = None, None
  349. if cur_subgraph_a is not None:
  350. type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
  351. if cur_subgraph_b is not None:
  352. type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
  353. # check for results and determine what to do next
  354. if cur_subgraph_a is not None and cur_subgraph_b is not None:
  355. # both nodes were fetched, check for subgraph_relationship
  356. # note: subgraph_relationship is checked on the start node, i.e.
  357. # if a linear-relu pattern is checked, we would check for subgraph_relationship
  358. # of the linear
  359. subgraph_relationship = _get_subgraph_relationship_type(
  360. cur_subgraph_a, cur_subgraph_b,
  361. gm_a, gm_b, type_a_related_to_b)
  362. if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
  363. msg = f"""
  364. The subgraphs
  365. ({cur_subgraph_a}, {type_start_a}) and
  366. ({cur_subgraph_b}, {type_start_b})
  367. are not related. Please ensure that the two models you pass in have the same number
  368. of subgraphs, and each pair of subgraphs is related to each other."""
  369. raise GraphMatchingException(msg)
  370. elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
  371. # skip matching but unknown types
  372. continue
  373. key_name_a = _get_name_for_subgraph(
  374. cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops,
  375. existing_names_a)
  376. key_name_b = _get_name_for_subgraph(
  377. cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops,
  378. existing_names_b)
  379. assert key_name_a == key_name_b, \
  380. f"Subgraph names {key_name_a} and {key_name_b} do not match"
  381. results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
  382. continue
  383. elif cur_subgraph_a is None and cur_subgraph_b is None:
  384. # we reached the end of both graphs
  385. break
  386. else:
  387. # only one node was fetched, no match possible, throw error
  388. msg = f"""
  389. Attempting to match
  390. ({cur_subgraph_a}, {type_start_a}) and
  391. ({cur_subgraph_b}, {type_start_b}),
  392. one of which is empty. Please ensure that the two models you pass in have the same number
  393. of subgraphs."""
  394. raise GraphMatchingException(msg)
  395. # The subgraph pairs are originally created by traversing the two graphs
  396. # from the outputs to the inputs. Reverse the results to return the
  397. # subgraphs in their order of execution.
  398. results = collections.OrderedDict(reversed(list(results.items())))
  399. return results