subgraph_rewriter.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. from .graph_module import GraphModule
  2. from .graph import Graph
  3. from .node import Node
  4. from ._symbolic_trace import symbolic_trace
  5. from ._compatibility import compatibility
  6. import copy
  7. from dataclasses import dataclass
  8. from typing import Callable, Dict, List, NamedTuple, Optional, Set, Union
  9. import torch
  10. __all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
  11. @compatibility(is_backward_compatible=True)
  12. class Match(NamedTuple):
  13. # Node from which the match was found
  14. anchor: Node
  15. # Maps nodes in the pattern subgraph to nodes in the larger graph
  16. nodes_map: Dict[Node, Node]
  17. @compatibility(is_backward_compatible=False)
  18. @dataclass
  19. class ReplacedPatterns:
  20. # Node from which the match was found
  21. anchor: Node
  22. # Maps nodes in the pattern subgraph to nodes in the larger graph
  23. nodes_map: Dict[Node, Node]
  24. # List of nodes that were added into the graph
  25. replacements: List[Node]
  26. def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None:
  27. gm.delete_all_unused_submodules()
  28. if isinstance(replacement, GraphModule):
  29. replacement.graph.lint()
  30. def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]:
  31. try:
  32. mod_match = mod.get_submodule(target)
  33. return mod_match
  34. except AttributeError:
  35. return None
  36. for node in gm.graph.nodes:
  37. if node.op == "call_module" or node.op == "get_attr":
  38. gm_submod = try_get_submodule(gm, node.target)
  39. replacement_submod = try_get_submodule(replacement, node.target)
  40. # CASE 1: This target already exists as a submodule in our
  41. # result GraphModule. Whether or not it exists in
  42. # `replacement`, the existing submodule takes precedence.
  43. if gm_submod is not None:
  44. continue
  45. # CASE 2: The target exists as a submodule in `replacement`
  46. # only, so we need to copy it over.
  47. elif replacement_submod is not None:
  48. new_submod = copy.deepcopy(getattr(replacement, node.target))
  49. gm.add_submodule(node.target, new_submod)
  50. # CASE 3: The target doesn't exist as a submodule in `gm`
  51. # or `replacement`
  52. else:
  53. raise RuntimeError("Attempted to create a \"", node.op,
  54. "\" node during subgraph rewriting "
  55. f"with target {node.target}, but "
  56. "the referenced submodule does not "
  57. "exist in either the original "
  58. "GraphModule `gm` or the replacement"
  59. " GraphModule `replacement`")
  60. gm.graph.lint()
  61. @compatibility(is_backward_compatible=True)
  62. def replace_pattern(
  63. gm: GraphModule,
  64. pattern: Union[Callable, GraphModule],
  65. replacement: Union[Callable, GraphModule]
  66. ) -> List[Match]:
  67. """
  68. Matches all possible non-overlapping sets of operators and their
  69. data dependencies (``pattern``) in the Graph of a GraphModule
  70. (``gm``), then replaces each of these matched subgraphs with another
  71. subgraph (``replacement``).
  72. Args:
  73. ``gm``: The GraphModule that wraps the Graph to operate on
  74. ``pattern``: The subgraph to match in ``gm`` for replacement
  75. ``replacement``: The subgraph to replace ``pattern`` with
  76. Returns:
  77. List[Match]: A list of ``Match`` objects representing the places
  78. in the original graph that ``pattern`` was matched to. The list
  79. is empty if there are no matches. ``Match`` is defined as:
  80. .. code-block:: python
  81. class Match(NamedTuple):
  82. # Node from which the match was found
  83. anchor: Node
  84. # Maps nodes in the pattern subgraph to nodes in the larger graph
  85. nodes_map: Dict[Node, Node]
  86. Examples:
  87. .. code-block:: python
  88. import torch
  89. from torch.fx import symbolic_trace, subgraph_rewriter
  90. class M(torch.nn.Module):
  91. def __init__(self):
  92. super().__init__()
  93. def forward(self, x, w1, w2):
  94. m1 = torch.cat([w1, w2]).sum()
  95. m2 = torch.cat([w1, w2]).sum()
  96. return x + torch.max(m1) + torch.max(m2)
  97. def pattern(w1, w2):
  98. return torch.cat([w1, w2]).sum()
  99. def replacement(w1, w2):
  100. return torch.stack([w1, w2])
  101. traced_module = symbolic_trace(M())
  102. subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
  103. The above code will first match ``pattern`` in the ``forward``
  104. method of ``traced_module``. Pattern-matching is done based on
  105. use-def relationships, not node names. For example, if you had
  106. ``p = torch.cat([a, b])`` in ``pattern``, you could match
  107. ``m = torch.cat([a, b])`` in the original ``forward`` function,
  108. despite the variable names being different (``p`` vs ``m``).
  109. The ``return`` statement in ``pattern`` is matched based on its
  110. value only; it may or may not match to the ``return`` statement in
  111. the larger graph. In other words, the pattern doesn't have to extend
  112. to the end of the larger graph.
  113. When the pattern is matched, it will be removed from the larger
  114. function and replaced by ``replacement``. If there are multiple
  115. matches for ``pattern`` in the larger function, each non-overlapping
  116. match will be replaced. In the case of a match overlap, the first
  117. found match in the set of overlapping matches will be replaced.
  118. ("First" here being defined as the first in a topological ordering
  119. of the Nodes' use-def relationships. In most cases, the first Node
  120. is the parameter that appears directly after ``self``, while the
  121. last Node is whatever the function returns.)
  122. One important thing to note is that the parameters of the
  123. ``pattern`` Callable must be used in the Callable itself,
  124. and the parameters of the ``replacement`` Callable must match
  125. the pattern. The first rule is why, in the above code block, the
  126. ``forward`` function has parameters ``x, w1, w2``, but the
  127. ``pattern`` function only has parameters ``w1, w2``. ``pattern``
  128. doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
  129. As an example of the second rule, consider replacing
  130. .. code-block:: python
  131. def pattern(x, y):
  132. return torch.neg(x) + torch.relu(y)
  133. with
  134. .. code-block:: python
  135. def replacement(x, y):
  136. return torch.relu(x)
  137. In this case, ``replacement`` needs the same number of parameters
  138. as ``pattern`` (both ``x`` and ``y``), even though the parameter
  139. ``y`` isn't used in ``replacement``.
  140. After calling ``subgraph_rewriter.replace_pattern``, the generated
  141. Python code looks like this:
  142. .. code-block:: python
  143. def forward(self, x, w1, w2):
  144. stack_1 = torch.stack([w1, w2])
  145. sum_1 = stack_1.sum()
  146. stack_2 = torch.stack([w1, w2])
  147. sum_2 = stack_2.sum()
  148. max_1 = torch.max(sum_1)
  149. add_1 = x + max_1
  150. max_2 = torch.max(sum_2)
  151. add_2 = add_1 + max_2
  152. return add_2
  153. """
  154. match_and_replacements = _replace_pattern(gm, pattern, replacement)
  155. return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
  156. # Experimental API, not backward compatible
  157. @compatibility(is_backward_compatible=False)
  158. def replace_pattern_with_filters(
  159. gm: GraphModule,
  160. pattern: Union[Callable, GraphModule],
  161. replacement: Union[Callable, GraphModule],
  162. match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]], # type: ignore[name-defined]
  163. ) -> List[ReplacedPatterns]:
  164. """
  165. See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
  166. Args:
  167. ``match_filters``: A list of functions that take in
  168. (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
  169. whether the match satisfies the condition.
  170. See matcher_utils.py for definition of InternalMatch.
  171. """
  172. return _replace_pattern(gm, pattern, replacement, match_filters)
  173. def _replace_pattern(
  174. gm: GraphModule,
  175. pattern: Union[Callable, GraphModule],
  176. replacement: Union[Callable, GraphModule],
  177. match_filters: List[Callable[["InternalMatch", Graph, Graph], bool]] = None, # type: ignore[name-defined]
  178. ) -> List[ReplacedPatterns]:
  179. from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
  180. if match_filters is None:
  181. match_filters = []
  182. # Get the graphs for `gm`, `pattern`, `replacement`
  183. original_graph: Graph = gm.graph
  184. if isinstance(pattern, GraphModule):
  185. pattern_graph = pattern.graph
  186. else:
  187. pattern_graph = symbolic_trace(pattern).graph
  188. if isinstance(replacement, GraphModule):
  189. replacement_graph = replacement.graph
  190. else:
  191. replacement_graph = symbolic_trace(replacement).graph
  192. matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
  193. remove_overlapping_matches=True)
  194. _matches: List[InternalMatch] = matcher.match(original_graph)
  195. # Filter out matches that don't match the filter
  196. _matches = [
  197. m for m in _matches
  198. if all(match_filter(m, original_graph, pattern_graph)
  199. for match_filter in match_filters)
  200. ]
  201. replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
  202. # As we progressively replace nodes, we'll need to keep track of how the match results should change
  203. match_changed_node: Dict[Node, Node] = {}
  204. match_and_replacements = []
  205. for match in _matches:
  206. # Build connecting between replacement graph's input and original graph input producer node
  207. # Initialize `val_map` with mappings from placeholder nodes in
  208. # `replacement` to their corresponding node in `original_graph`
  209. assert len(match.placeholder_nodes) == len(replacement_placeholders)
  210. val_map: Dict[Node, Node] = {}
  211. for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
  212. if isinstance(gn, Node):
  213. val_map[rn] = match_changed_node.get(gn, gn)
  214. else:
  215. val_map[rn] = gn
  216. # Copy the replacement graph over
  217. user_nodes: Set[Node] = set()
  218. for n in match.returning_nodes:
  219. for user in n.users:
  220. user_nodes.add(user)
  221. assert user_nodes, "The returning_nodes should have at least one user node"
  222. if len(user_nodes) == 1:
  223. first_user_node = list(user_nodes)[0]
  224. else:
  225. # If there are multiple user nodes, we need to find the first user node
  226. # in the current execution order of the `original_graph`
  227. for n in original_graph.nodes:
  228. if n in user_nodes:
  229. first_user_node = n
  230. break
  231. with original_graph.inserting_before(first_user_node):
  232. copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
  233. if isinstance(copied_returning_nodes, Node):
  234. copied_returning_nodes = (copied_returning_nodes, )
  235. # Get a list of nodes that have been replaced into the graph
  236. replacement_nodes = []
  237. def get_replacement_nodes(curr_node: Node):
  238. nonlocal replacement_nodes
  239. for arg in curr_node.args:
  240. if isinstance(arg, Node):
  241. if arg not in val_map.values():
  242. get_replacement_nodes(arg)
  243. replacement_nodes.append(curr_node)
  244. for ret_node in copied_returning_nodes:
  245. get_replacement_nodes(ret_node)
  246. # Hook the output Node of the replacement subgraph in to the
  247. # original Graph at the correct location
  248. assert len(match.returning_nodes) == len(copied_returning_nodes)
  249. for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):
  250. gn.replace_all_uses_with(copied_node)
  251. match_changed_node[gn] = copied_node
  252. # Remove the original nodes
  253. for node in reversed(pattern_graph.nodes):
  254. if node.op != "placeholder" and node.op != "output":
  255. gn = match.nodes_map[node]
  256. gm.graph.erase_node(gn)
  257. match_and_replacements.append(
  258. ReplacedPatterns(
  259. anchor=match.anchors[0],
  260. nodes_map=match.nodes_map,
  261. replacements=replacement_nodes
  262. )
  263. )
  264. # Update the passed-in GraphModule to reflect the new state of
  265. # `original_graph`
  266. gm.recompile()
  267. # If `replacement` was an nn.Module, we'll need to make sure that
  268. # all the submodules have been copied over correctly
  269. if isinstance(replacement, torch.nn.Module):
  270. _replace_submodules(gm, replacement)
  271. return match_and_replacements