split_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from dataclasses import dataclass, field
  2. from typing import List, Optional, Dict
  3. import torch.fx
  4. from torch.fx.graph import map_arg
  5. from .tools_common import NodeList
  6. from torch.fx._compatibility import compatibility
  7. from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule
  8. __all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags']
  9. @compatibility(is_backward_compatible=False)
  10. def getattr_recursive(obj, name):
  11. for layer in name.split("."):
  12. if hasattr(obj, layer):
  13. obj = getattr(obj, layer)
  14. else:
  15. return None
  16. return obj
  17. @compatibility(is_backward_compatible=False)
  18. def setattr_recursive(obj, attr, value):
  19. if "." not in attr:
  20. setattr(obj, attr, value)
  21. else:
  22. layer = attr.split(".")
  23. setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value)
  24. @compatibility(is_backward_compatible=False)
  25. @dataclass
  26. class Component:
  27. """
  28. A component serves as a container for a subgraph we want to create afterwards.
  29. """
  30. graph: torch.fx.Graph
  31. order: int
  32. name: str
  33. # Stores the placeholder nodes in `graph`.
  34. input_placeholders: List = field(default_factory=list)
  35. # Store the nodes in original graph that are placeholder in `graph`.
  36. orig_inputs: List = field(default_factory=list)
  37. # Store the nodes in original graph that are outputs in `graph`.
  38. orig_outputs: List = field(default_factory=list)
  39. # Mapping from get_attr node in original graph to get_attr node in `graph`.
  40. getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict)
  41. constructor_args: List[str] = field(default_factory=list)
  42. gm: Optional[torch.fx.GraphModule] = None
  43. @compatibility(is_backward_compatible=False)
  44. def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule:
  45. """
  46. Splits a GraphModule using tags on its graph nodes. We honor the order of
  47. tags. For example, we have tags = ["a", "b", "c"], the function will create
  48. the initial submodules in the order of "a_0", "b_1", "c_2".
  49. To set a tag:
  50. gm.graph.nodes[idx].tag = "mytag"
  51. This will result in all nodes with the same tag being extracted and placed in their
  52. own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder
  53. and output nodes are created when needed while get_attr nodes get copied to submodules
  54. where they are used.
  55. Given the following module def:
  56. class SimpleModule(torch.nn.Module):
  57. def __init__(self):
  58. super().__init__()
  59. self.linear1 = torch.nn.Linear(...)
  60. self.linear2 = torch.nn.Linear(...)
  61. self.linear3 = torch.nn.Linear(...)
  62. def forward(self, in1, in2):
  63. r1 = self.linear1(in1)
  64. r2 = self.linear2(in2)
  65. r3 = torch.cat([r1, r2])
  66. return self.linear3(r3)
  67. Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
  68. ro_0:
  69. def forward(self, in1):
  70. self = self.root
  71. linear1 = self.linear1(in1)
  72. return linear1
  73. main_1:
  74. def forward(self, in2, linear1):
  75. self = self.root
  76. linear2 = self.linear2(in2)
  77. cat_1 = torch.cat([linear1, linear2])
  78. linear3 = self.linear3(cat_1)
  79. return linear3
  80. main_0:
  81. def forward(self, in1, in2):
  82. self = self.root
  83. ro_0 = self.ro_0(in1)
  84. main_1 = self.main_1(in2, ro_0)
  85. return main_1
  86. """
  87. def flatten(x: torch.fx.node.Argument) -> NodeList:
  88. """
  89. Stores nodes in x to a list and returns the list.
  90. """
  91. r: NodeList = []
  92. map_arg(x, r.append)
  93. return r
  94. # Mapping from node in original module to node in created submodule.
  95. node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
  96. # Mapping from node in original module or created submodules to
  97. # corresponding component.
  98. node_to_component: Dict[torch.fx.Node, Component] = {}
  99. # Mapping from tag to the corresponding component.
  100. tag_to_component: Dict[str, Component] = {}
  101. # Stores all components.
  102. all_components: List[Component] = []
  103. # Stores nodes that will be used in main graph.
  104. used_in_main: Dict[torch.fx.Node, None] = {}
  105. # Main graph after split.
  106. main_g = torch.fx.Graph()
  107. # Mapping from node in original module to node in main graph after split.
  108. main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
  109. # Output node of original module.
  110. output_node: Optional[torch.fx.Node] = None
  111. # Create a component for each tag, we don't expect to create other components afterwards.
  112. for tag in tags:
  113. comp = Component(torch.fx.Graph(), len(all_components), f"{tag}")
  114. all_components.append(comp)
  115. tag_to_component[tag] = comp
  116. # Traverse the nodes in original graph and take care of them.
  117. for node in gm.graph.nodes:
  118. if node.op == "output":
  119. if output_node is not None:
  120. raise RuntimeError("Multiple output nodes in graph!")
  121. output_node = node
  122. continue
  123. # Placeholders in the original graph get copied to main graph.
  124. if node.op == "placeholder":
  125. main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
  126. continue
  127. # Get_attr nodes are ignored because we are not tagging them.
  128. # Instead, we copy them directly to the submodules use them afterwards.
  129. if node.op == "get_attr":
  130. continue
  131. # Now we process callable nodes which are nodes with op of call_module,
  132. # call_function or call_method. Every callable nodes should be tagged.
  133. assert hasattr(node, "tag")
  134. upstream_components = [
  135. node_to_component[x]
  136. for x in flatten(node.args) + flatten(node.kwargs)
  137. if x.op not in {"placeholder", "get_attr"}
  138. ]
  139. comp = tag_to_component[node.tag]
  140. node_to_component[node] = comp
  141. # Max order of upperstream components.
  142. mx = max((c.order for c in upstream_components), default=0)
  143. # Expect the componet for `node` has higher order then its upstream components.
  144. assert comp.order >= mx
  145. # Map a input of `node` to nodes in the component's graph.
  146. def remap_func(x):
  147. # If input is a get_attr node, copy it to current component's graph.
  148. # Returns the get_attr node in current component's graph.
  149. if x.op == "get_attr":
  150. if x not in comp.getattr_maps:
  151. comp.getattr_maps[x] = comp.graph.get_attr(
  152. x.target, type_expr=x.type
  153. )
  154. return comp.getattr_maps[x]
  155. # If input is not a placeholder, it should have been put into a component
  156. # already. If it's the current component then we return the corresponding
  157. # node in the component.
  158. if x.op != "placeholder" and node_to_component[x] == comp:
  159. return node_remapping[x]
  160. # If input is a placeholder or it's in other components, we want to make it
  161. # as a placeholder in current component's graph.
  162. if x not in comp.orig_inputs:
  163. comp.orig_inputs.append(x)
  164. comp.input_placeholders.append(
  165. comp.graph.placeholder(x.name, type_expr=x.type)
  166. )
  167. used_in_main[x] = None
  168. return comp.input_placeholders[
  169. next(i for i, y in enumerate(comp.orig_inputs) if x is y)
  170. ]
  171. n = comp.graph.node_copy(node, remap_func)
  172. n.tag = node.tag # type: ignore[attr-defined]
  173. node_remapping[node] = n
  174. node_to_component[n] = comp
  175. if output_node is None:
  176. raise RuntimeError("Graph had no output node!")
  177. for x in flatten(output_node.args[0]):
  178. if x.op == "get_attr":
  179. # We don't need components mapping for nodes of type "get_attr"
  180. # that are consumed by the output. Only need to make sure we create
  181. # corresponding counterparts in the resulting graph.
  182. main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type)
  183. else:
  184. # All component results consumed by the output node should be
  185. # marked as "used in main".
  186. used_in_main[x] = None
  187. # If a node is used in main graph then we mark it as an output in the component
  188. # it belongs to.
  189. for n in used_in_main:
  190. if n.op != "placeholder":
  191. node_to_component[n].orig_outputs.append(n)
  192. # Now we create a graphmodule for each component.
  193. for comp in all_components:
  194. outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
  195. # Take care of the args of FX output node. If there's a single
  196. # output then the output node args is like (output_single), else
  197. # if there're multiple outputs then the output node args is like
  198. # ((output_0, output_1, ...)).
  199. comp.graph.output(outs[0] if len(outs) == 1 else outs)
  200. comp.gm = lift_subgraph_as_module(gm, comp.graph)
  201. # Create a call_module node in main graph.
  202. main_node = main_g.call_module(
  203. comp.name,
  204. args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)),
  205. kwargs=None,
  206. )
  207. if len(outs) == 1:
  208. main_remapping[comp.orig_outputs[0]] = main_node
  209. else:
  210. for i, o in enumerate(comp.orig_outputs):
  211. # Use Proxy to record getitem access.
  212. main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index]
  213. main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__))
  214. main_root = HolderModule({comp.name: comp.gm for comp in all_components})
  215. # If the output nodes consumes get_attr directly in the original graph,
  216. # then we need to make sure get_attr is copied to the new graph.
  217. for x in flatten(output_node.args[0]):
  218. if x.op == "get_attr":
  219. setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
  220. return torch.fx.GraphModule(main_root, main_g)