fuser_utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import copy
  2. from queue import SimpleQueue
  3. from typing import List, Dict, Tuple
  4. import torch.fx
  5. from torch.fx.graph_module import GraphModule
  6. from torch.fx.graph import Graph
  7. from torch.fx.node import Node
  8. from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph
  9. from torch.fx.passes.utils import lift_subgraph_as_module
  10. def topo_sort(nodes: NodeList) -> NodeList:
  11. # sort nodes according to the topological order
  12. indegree_map = {node : 0 for node in nodes}
  13. candidates: SimpleQueue = SimpleQueue()
  14. for node in nodes:
  15. for n in node.all_input_nodes:
  16. if n in indegree_map:
  17. indegree_map[node] += 1
  18. if indegree_map[node] == 0:
  19. candidates.put(node)
  20. sorted_nodes: NodeList = list()
  21. while not candidates.empty():
  22. node = candidates.get()
  23. sorted_nodes.append(node)
  24. for n in node.users:
  25. if n in indegree_map:
  26. indegree_map[n] -= 1
  27. if indegree_map[n] == 0:
  28. candidates.put(n)
  29. assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes"
  30. return sorted_nodes
  31. def validate_partition(partition: NodeList) -> bool:
  32. # verify the partition does't form a dependency cycle in the original graph
  33. # returns True for valid partition, False for invalid
  34. partition_set = set(partition)
  35. outputs: NodeList = list()
  36. for node in partition_set:
  37. for user_node in node.users:
  38. if user_node not in partition_set:
  39. # external user node, need to expose as an output
  40. outputs.append(user_node)
  41. # perform DFS on the parition outputs
  42. # if it reaches a node within the partition, then it found a cycle
  43. visited: NodeSet = set()
  44. def dfs_find_cycle(node):
  45. if node in partition_set:
  46. return True # found cycle, return
  47. visited.add(node)
  48. for user_node in node.users:
  49. if user_node not in visited:
  50. if dfs_find_cycle(user_node):
  51. return True
  52. return False
  53. for output_node in outputs:
  54. if dfs_find_cycle(output_node):
  55. return False
  56. return True
  57. def fuse_as_graphmodule(gm: GraphModule,
  58. nodes: NodeList,
  59. module_name: str) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]:
  60. """
  61. Fuse nodes in graph_module into a GraphModule.
  62. Args:
  63. gm (GraphModule): target graph_module
  64. nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted
  65. module_name: class name for the fused GraphModule
  66. Returns:
  67. fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm`
  68. original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm`
  69. original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm`
  70. """
  71. # assumption: nodes are already sorted in topo order
  72. for node in nodes:
  73. assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}"
  74. assert not node._erased, f"{node} has been removed from owning graph"
  75. assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}"
  76. # validates partition doesn't introduce dependency circles in the graph
  77. assert validate_partition(nodes), "Invalid partition, found dependency cycles"
  78. subgraph = Graph()
  79. node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph
  80. node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph
  81. # handles inputs throught graph.node_copy's arg_transform functions
  82. def remap_inputs(x):
  83. if x.op == "get_attr":
  84. # TODO: do we really need copy the get_attr node into the graph?
  85. # do something here
  86. pass
  87. if x in nodes:
  88. # x is inside subgraph, return the copied node
  89. # the node should have been copied aleady, as we are copying graph in the topological order
  90. return node_map[x]
  91. if x not in node_to_placeholder:
  92. # x is not in subgraph, create a new placeholder for subgraph
  93. placeholder_node = subgraph.placeholder(x.name, type_expr=x.type)
  94. # copy all meta fields, even if some fields might be irrelvant for the placeholder node
  95. placeholder_node.meta = copy.copy(x.meta)
  96. node_to_placeholder[x] = placeholder_node
  97. return node_to_placeholder[x]
  98. # copy nodes in topological order
  99. for node in nodes:
  100. new_node = subgraph.node_copy(node, remap_inputs)
  101. node_map[node] = new_node
  102. # handles outputs
  103. output_mapping: Dict[Node, Node] = {} # mapping from old output to new outputs
  104. for node in nodes:
  105. for user_node in node.users:
  106. if user_node not in nodes:
  107. # external user node, need to expose as an output
  108. output_mapping[node] = node_map[node]
  109. # outs contain nodes in the new subgraph
  110. outs = tuple(output_mapping.values())
  111. # Take care of the args of FX output node. If there's a single
  112. # output then the output node args is like (output_single), else
  113. # if there're multiple outputs then the output node args is like
  114. # ((output_0, output_1, ...)).
  115. subgraph.output(outs[0] if len(outs) == 1 else outs)
  116. # lint to ensure correctness
  117. subgraph.lint()
  118. fused_gm: GraphModule = lift_subgraph_as_module(gm, subgraph, class_name=module_name)
  119. # sub_gm's input nodes in the original module
  120. original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys())
  121. # sub_gm's outputs node in the original module
  122. original_outputs: Tuple[Node, ...] = tuple(output_mapping.keys())
  123. return fused_gm, original_inputs, original_outputs
  124. def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]):
  125. # add sub_gm into gm
  126. submodule_name = sub_gm.__class__.__name__
  127. gm.add_submodule(submodule_name, sub_gm)
  128. # Create a call_module node in main graph.
  129. module_node = gm.graph.call_module(
  130. submodule_name,
  131. args=orig_inputs,
  132. kwargs=None)
  133. if len(orig_outputs) == 1:
  134. # main_remapping[comp.orig_outputs[0]] = module_node
  135. orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True)
  136. else:
  137. for i, orig_output in enumerate(orig_outputs):
  138. # Use Proxy to record getitem access.
  139. proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index]
  140. orig_output.replace_all_uses_with(proxy_out, propagate_meta=True)
  141. return gm
  142. def erase_nodes(gm: GraphModule, nodes: NodeList):
  143. # erase original nodes in inversed topological order
  144. for node in reversed(nodes):
  145. gm.graph.erase_node(node)
  146. def fuse_by_partitions(gm: GraphModule, partitions: List[NodeList]) -> GraphModule:
  147. for partition_id, nodes in enumerate(partitions):
  148. sorted_nodes = topo_sort(nodes)
  149. submodule_name = "fused_" + str(partition_id)
  150. sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name)
  151. insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
  152. erase_nodes(gm, sorted_nodes)
  153. # topological sort original gm with newly created sub_gm
  154. legalize_graph(gm)
  155. return gm