matcher_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from dataclasses import dataclass, field
  2. from collections import defaultdict
  3. import copy
  4. import torch
  5. from torch.fx.graph import Graph
  6. from torch.fx.node import Node
  7. from torch.fx._compatibility import compatibility
  8. from typing import Dict, List, Set, Any, Union, Tuple
  9. import logging
  10. import os
  11. __all__ = ['SubgraphMatcher', 'InternalMatch']
  12. # Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
  13. def _init_logger():
  14. logger = logging.getLogger(__name__)
  15. level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
  16. logger.setLevel(level)
  17. console = logging.StreamHandler()
  18. formatter = logging.Formatter("%(filename)s > %(message)s")
  19. console.setFormatter(formatter)
  20. console.setLevel(level)
  21. # add the handlers to the logger
  22. logger.addHandler(console)
  23. logger.propagate = False
  24. return logger
  25. logger = _init_logger()
  26. @compatibility(is_backward_compatible=False)
  27. @dataclass
  28. class InternalMatch():
  29. # Nodes from which the match was found
  30. anchors: List[Node]
  31. # Maps nodes in the pattern subgraph to nodes in the larger graph
  32. nodes_map: Dict[Node, Node] = field(default_factory=dict)
  33. # nodes in target graph that are matched placeholder in pattern
  34. placeholder_nodes: List[Node] = field(default_factory=list)
  35. # nodes in matched subgraph returned by output
  36. returning_nodes: List[Node] = field(default_factory=list)
  37. def __copy__(self):
  38. return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(),
  39. placeholder_nodes=self.placeholder_nodes.copy(),
  40. returning_nodes=self.returning_nodes.copy())
  41. @compatibility(is_backward_compatible=False)
  42. class SubgraphMatcher:
  43. def __init__(self, pattern: Graph,
  44. match_output: bool = False,
  45. match_placeholder: bool = False,
  46. remove_overlapping_matches: bool = True) -> None:
  47. """
  48. Args:
  49. pattern: the targeted matching pattern, represented in fx.Graph.
  50. match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
  51. If False, output node is ignored during match.
  52. match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
  53. the targeted pattern. If False, placeholder nodes will be used a wildcard.
  54. remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
  55. will be returned.
  56. """
  57. self.pattern = pattern
  58. self.match_output = match_output
  59. self.match_placeholder = match_placeholder
  60. self.remove_overlapping_matches = remove_overlapping_matches
  61. if len(pattern.nodes) == 0:
  62. raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern")
  63. for node in pattern.nodes:
  64. if node.op != "output":
  65. assert len(node.users) > 0, \
  66. "SubgraphMatcher cannot be initialized with an pattern with dead code"
  67. # TODO: assert pattern is a connected graph
  68. self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"]
  69. output_node = next(iter(reversed(pattern.nodes)))
  70. # nodes returned by outputs
  71. self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
  72. self.pattern_anchors: List[Node] = []
  73. if match_output:
  74. self.pattern_anchors = [output_node]
  75. else:
  76. # If a node has output_node as the ONLY user, then this node is a graph sink,
  77. # and should be matched against as an anchor
  78. self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1]
  79. def _match_attributes(self, pn: Node, gn: Node) -> bool:
  80. # Attributes matching is compilcated. Right now we only support matching constant tensor
  81. assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
  82. assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
  83. pn_value = getattr(pn.graph.owning_module, pn.target)
  84. gn_value = getattr(gn.graph.owning_module, gn.target)
  85. if type(pn_value) != type(gn_value):
  86. return False
  87. # Don't require exact match on tensor values.
  88. if isinstance(pn_value, torch.Tensor):
  89. return isinstance(gn_value, torch.Tensor)
  90. else:
  91. raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
  92. return False
  93. def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
  94. # if exact match for placeholder is not required, then use placeholder as a wildcard
  95. if not self.match_placeholder and pn.op == "placeholder":
  96. return True
  97. if pn.op == gn.op:
  98. if pn.op == "placeholder" or pn.op == "output":
  99. return True
  100. elif pn.op == "get_attr":
  101. return self._match_attributes(pn, gn)
  102. return pn.target == gn.target
  103. return False
  104. def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
  105. # `lookup` represents all the nodes in `original_graph`
  106. # that are part of `pattern`
  107. # Placeholders can be used by other nodes in the graphs
  108. lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
  109. for gn, pn in lookup.items():
  110. # nodes returned by output are allowed to be used in other areas of the graph
  111. if pn in self.pattern_returning_nodes:
  112. continue
  113. for user in gn.users:
  114. # If this node has users that were not in `lookup`, then it must leak out of the
  115. # pattern subgraph
  116. if user not in lookup:
  117. return False
  118. return True
  119. def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]:
  120. non_overlapping_matches: List[InternalMatch] = list()
  121. nodes_matched: Set[Node] = set()
  122. for match in matches:
  123. found_overlap = False
  124. for pn, gn in match.nodes_map.items():
  125. if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
  126. found_overlap = True
  127. break
  128. if not found_overlap:
  129. non_overlapping_matches.append(match)
  130. for pn, gn in match.nodes_map.items():
  131. if pn.op not in {"placeholder", "output"}:
  132. nodes_matched.add(gn)
  133. return non_overlapping_matches
  134. def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
  135. assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
  136. if isinstance(pn, Node) and not isinstance(gn, Node):
  137. if pn.op == "placeholder":
  138. # Check if we've already matched these nodes in the current
  139. # traversal
  140. if pn in match.nodes_map:
  141. return match.nodes_map[pn] == gn
  142. match.nodes_map[pn] = gn
  143. return True
  144. else:
  145. return False
  146. elif not isinstance(pn, Node) and isinstance(gn, Node):
  147. return False
  148. else:
  149. return type(gn) == type(pn) and gn == pn
  150. def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
  151. logger.info(f" matching {pn} to {gn}")
  152. assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}")
  153. # Check if we've already matched these nodes in the current
  154. # traversal
  155. if pn in match.nodes_map:
  156. return match.nodes_map[pn] == gn
  157. # TODO: use a more efficienty way to check if gn is matched before: two-way dict
  158. if gn in match.nodes_map.values():
  159. return False
  160. if not self._nodes_are_equal(pn, gn):
  161. return False
  162. # Optimistically mark `pn` as a match for `gn`, and save a local copy of match
  163. saved_match = copy.copy(match)
  164. match.nodes_map[pn] = gn
  165. # Placeholder is a wildcard and can be matched with any python object
  166. # (including list/tuple)
  167. if pn.op == "placeholder":
  168. return True
  169. # Recursively traverse upwards to check if `pn` is a true
  170. # match for `gn`
  171. match_found = True
  172. def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
  173. if len(args1) != len(args2):
  174. return False
  175. for a1, a2 in zip(args1, args2):
  176. if isinstance(a1, Node) and isinstance(a2, Node):
  177. matched = self._match_nodes(a1, a2, match)
  178. elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
  179. matched = _match_args(a1, a2)
  180. else:
  181. matched = self._match_literals(a1, a2, match)
  182. if not matched:
  183. return False
  184. return True
  185. match_found = match_found and _match_args(pn.args, gn.args)
  186. pn_kwargs, gn_kwargs = [], []
  187. if pn.kwargs.keys() == gn.kwargs.keys():
  188. for key in pn.kwargs.keys():
  189. pn_kwargs.append(pn.kwargs[key])
  190. gn_kwargs.append(gn.kwargs[key])
  191. else:
  192. match_found = False
  193. match_found = match_found and _match_args(pn_kwargs, gn_kwargs)
  194. if not match_found:
  195. # revert to saved_match before matching with current node
  196. match = copy.copy(saved_match)
  197. return False
  198. return True
  199. def match(self, graph: Graph) -> List[InternalMatch]:
  200. """
  201. Returns:
  202. The matched subgraphs.
  203. Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
  204. and nodes returned by output) can only be consumed by nodes within the matched subgraph.
  205. Subgraph pattern matcher is implemented with the backtracking style in the following steps:
  206. 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
  207. are the "sinks" (nodes with no user other than the output node) of the pattern graph.
  208. One pattern graph could have multiple anchors if it has multiple return values.
  209. 2. In the target graph, we identify the potential candidate nodes that can be matched
  210. with each anchor. These anchor-candidate pairs are the starting points for
  211. pairwise per-node matching.
  212. 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
  213. pattern and target graphs. For every pattern nodes along traversal path, we compare it
  214. against the target nodes. In case any comparison failed, the match for this anchor-candidate
  215. pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
  216. for more details.
  217. 4. In the case of multiple anchors, every anchor will need to find a match using step 3.
  218. In addition, the matches found between anchors need to have a common intersection node
  219. in order for the match to be valid. This is implemented with backtracking. See `backtracking`
  220. for more details.
  221. Notice: graph traversal must be done in the reverser order because a tensor can have multiple
  222. consumers, but can only have a single producer. Only with reverser order, we can we jointly
  223. traverse the pattern and target graph in a deterministic path.
  224. Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
  225. in practice, it's unlikely to blow up.
  226. """
  227. from torch.fx.passes.utils.fuser_utils import validate_partition
  228. # find candidate nodes to match with pattern anchors
  229. match_candidates: Dict[Node, List[Node]] = defaultdict(list)
  230. for pattern_anchor in self.pattern_anchors:
  231. for node in graph.nodes:
  232. if self._nodes_are_equal(pattern_anchor, node):
  233. match_candidates[pattern_anchor].append(node)
  234. match_candidates_list = list(match_candidates.items())
  235. logger.info(f"Initial match_candidates_list: {match_candidates_list}\n")
  236. matches: List[InternalMatch] = []
  237. def backtracking(anchor_index, match):
  238. if anchor_index == len(match_candidates_list):
  239. match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes]
  240. match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes]
  241. matches.append(match)
  242. logger.info(f"Found a match: {match}\n")
  243. return
  244. pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
  245. saved_match = copy.copy(match)
  246. for node in candidate_nodes:
  247. logger.info(f"Trying to match anchor {pattern_anchor} to {node}")
  248. match_found = self._match_nodes(pattern_anchor, node, match)
  249. if match_found:
  250. # match next anchor
  251. backtracking(anchor_index + 1, match)
  252. else:
  253. logger.info(f"Failed to match anchor {pattern_anchor} to {node}\n")
  254. # revert to saved_match before matching with current anchor
  255. match = copy.copy(saved_match)
  256. match = InternalMatch(anchors=self.pattern_anchors)
  257. if match_candidates_list:
  258. backtracking(0, match)
  259. # filter out the matches where the subgraph is not fully_contained
  260. before = len(matches)
  261. matches = [match for match in matches if self._is_contained(match.nodes_map)]
  262. after = len(matches)
  263. if before != after:
  264. logger.info(f"Filtered out {before - after} matches because they are not fully contained")
  265. # filter out the matches that that forms a cycle if the subgraph is fused
  266. valid_matches = []
  267. for match in matches:
  268. matched_compute_nodes = \
  269. [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}]
  270. if validate_partition(matched_compute_nodes):
  271. valid_matches.append(match)
  272. if len(valid_matches) != len(matches):
  273. logger.info(f"Filtered out {len(matches) - len(valid_matches)} matches because \
  274. matched subgraph would form a cycle if fused")
  275. if self.remove_overlapping_matches:
  276. before = len(valid_matches)
  277. matches = self._remove_overlapping_matches(valid_matches)
  278. after = len(matches)
  279. if before != after:
  280. logger.info(f"Filtered out {before - after} matches because matched subgraphs are overlapping")
  281. logger.info(f"Matches returned: {matches}")
  282. return matches