split_module.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import inspect
  2. from typing import Any, Callable, Dict, List, Optional
  3. import torch
  4. from torch.fx._compatibility import compatibility
  5. from torch.fx.graph_module import GraphModule
  6. __all__ = ["Partition", "split_module"]
  7. @compatibility(is_backward_compatible=True)
  8. class Partition:
  9. def __init__(self, name: str):
  10. self.name: str = name
  11. self.submod_name = f"submod_{name}"
  12. self.node_names: List[str] = []
  13. self.inputs: Dict[str, None] = {}
  14. self.outputs: Dict[str, None] = {}
  15. self.partitions_dependent_on: Dict[str, None] = {}
  16. self.partition_dependents: Dict[str, None] = {}
  17. self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
  18. self.environment: Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
  19. self.targets: Dict[str, Any] = {}
  20. def __repr__(self) -> str:
  21. return (
  22. f"name: {self.name},\n"
  23. f" nodes: {self.node_names},\n"
  24. f" inputs: {self.inputs},\n"
  25. f" outputs: {self.outputs},\n"
  26. f" partitions dependent on: {self.partitions_dependent_on},\n"
  27. f" partition dependents: {self.partition_dependents}"
  28. )
  29. # Creates subgraphs out of main graph
  30. @compatibility(is_backward_compatible=True)
  31. def split_module(
  32. m: GraphModule,
  33. root_m: torch.nn.Module,
  34. split_callback: Callable[[torch.fx.node.Node], int],
  35. qualname_map: Optional[Dict[str, str]] = None,
  36. keep_original_order: Optional[bool] = False,
  37. ):
  38. """
  39. Creates subgraphs out of main graph
  40. Args:
  41. m (GraphModule): Graph module to split
  42. root_m (torch.nn.Module): root nn module. Not currently used. Included
  43. because the root nn module is usually transformed via
  44. torch.fx._symbolic_trace.symbolic_trace (see example below)
  45. split_callback (Callable[[torch.fx.node.Node], int]): Callable function
  46. that maps a given Node instance to a numeric partition identifier.
  47. split_module will use this function as the policy for which operations
  48. appear in which partitions in the output Module.
  49. qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a
  50. mapping from new target names in the module after split to old target
  51. names in the original module.
  52. keep_original_order: Optional[bool]: keep the original order of the GraphModule
  53. or use the Topological order of the new constructed GraphModule
  54. Returns:
  55. GraphModule: the module after split.
  56. Example:
  57. This is a sample setup:
  58. import torch
  59. from torch.fx.symbolic_trace import symbolic_trace
  60. from torch.fx.graph_module import GraphModule
  61. from torch.fx.node import Node
  62. from torch.fx.passes.split_module import split_module
  63. class MyModule(torch.nn.Module):
  64. def __init__(self):
  65. super().__init__()
  66. self.param = torch.nn.Parameter(torch.rand(3, 4))
  67. self.linear = torch.nn.Linear(4, 5)
  68. def forward(self, x, y):
  69. z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
  70. w = self.linear(y).clamp(min=0.0, max=1.0)
  71. return z + w
  72. # symbolically trace model
  73. my_module = MyModule()
  74. my_module_traced = symbolic_trace(my_module)
  75. # random mod partitioning
  76. partition_counter = 0
  77. NPARTITIONS = 3
  78. def mod_partition(node: Node):
  79. global partition_counter
  80. partition = partition_counter % NPARTITIONS
  81. partition_counter = (partition_counter + 1) % NPARTITIONS
  82. return partition
  83. # split module in module with submodules
  84. module_with_submodules = split_module(
  85. my_module_traced, my_module, mod_partition
  86. )
  87. Output looks like this. Original graph is broken into partitions
  88. > print(module_with_submodules)
  89. GraphModule(
  90. (submod_0): GraphModule(
  91. (linear): Linear(in_features=4, out_features=5, bias=True)
  92. )
  93. (submod_1): GraphModule(
  94. (linear): Linear(in_features=4, out_features=5, bias=True)
  95. )
  96. (submod_2): GraphModule()
  97. )
  98. def forward(self, x, y):
  99. param = self.param
  100. submod_0 = self.submod_0(x, param, y); x = param = y = None
  101. getitem = submod_0[0]
  102. getitem_1 = submod_0[1]; submod_0 = None
  103. submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None
  104. getitem_2 = submod_1[0]
  105. getitem_3 = submod_1[1]; submod_1 = None
  106. submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None
  107. return submod_2
  108. Output of split module is the same as output of input traced module.
  109. This is an example within a test setting:
  110. > orig_out = my_module_traced(x, y)
  111. > submodules_out = module_with_submodules(x, y)
  112. > self.assertEqual(orig_out, submodules_out)
  113. True
  114. """
  115. partitions: Dict[str, Partition] = {}
  116. orig_nodes: Dict[str, torch.fx.node.Node] = {}
  117. def record_cross_partition_use(
  118. def_node: torch.fx.node.Node, use_node: Optional[torch.fx.node.Node]
  119. ): # noqa: B950
  120. def_partition_name = getattr(def_node, "_fx_partition", None)
  121. use_partition_name = getattr(use_node, "_fx_partition", None)
  122. if def_partition_name != use_partition_name:
  123. if def_partition_name is not None:
  124. def_partition = partitions[def_partition_name]
  125. def_partition.outputs.setdefault(def_node.name)
  126. if use_partition_name is not None:
  127. def_partition.partition_dependents.setdefault(use_partition_name)
  128. if use_partition_name is not None:
  129. use_partition = partitions[use_partition_name]
  130. use_partition.inputs.setdefault(def_node.name)
  131. if def_partition_name is not None:
  132. use_partition.partitions_dependent_on.setdefault(def_partition_name)
  133. # split nodes into parititons
  134. for node in m.graph.nodes:
  135. orig_nodes[node.name] = node
  136. # TODO currently placeholders/parameters aren't put into random partitions,
  137. # rather they're added to the graphs where they are used down below
  138. if node.op in ["placeholder", "get_attr"]:
  139. continue
  140. if node.op == "output":
  141. torch.fx.graph.map_arg(
  142. node.args[0], lambda n: record_cross_partition_use(n, None)
  143. )
  144. continue
  145. partition_name = str(split_callback(node))
  146. # add node to partitions
  147. partition = partitions.get(partition_name)
  148. if partition is None:
  149. partitions[partition_name] = partition = Partition(partition_name)
  150. partition.node_names.append(node.name)
  151. node._fx_partition = partition_name
  152. torch.fx.graph.map_arg(
  153. node.args, lambda def_node: record_cross_partition_use(def_node, node)
  154. )
  155. torch.fx.graph.map_arg(
  156. node.kwargs, lambda def_node: record_cross_partition_use(def_node, node)
  157. ) # noqa: B950
  158. original_partition_order = list(partitions.keys())
  159. # find partitions with no dependencies
  160. root_partitions: List[str] = []
  161. for partition_name, partition in partitions.items():
  162. if not len(partition.partitions_dependent_on):
  163. root_partitions.append(partition_name)
  164. # check partitions for circular dependencies and create topological partition ordering
  165. sorted_partitions: List[str] = []
  166. while root_partitions:
  167. root_partition = root_partitions.pop()
  168. sorted_partitions.append(root_partition)
  169. for dependent in partitions[root_partition].partition_dependents:
  170. partitions[dependent].partitions_dependent_on.pop(root_partition)
  171. if not partitions[dependent].partitions_dependent_on:
  172. root_partitions.append(dependent)
  173. if len(sorted_partitions) != len(partitions):
  174. raise RuntimeError("cycle exists between partitions!")
  175. # add placeholders to parititons
  176. for partition_name in sorted_partitions:
  177. partition = partitions[partition_name]
  178. for input in partition.inputs:
  179. placeholder = partition.graph.placeholder(
  180. input,
  181. type_expr=orig_nodes[input].type,
  182. )
  183. placeholder.meta = orig_nodes[input].meta.copy()
  184. partition.environment[orig_nodes[input]] = placeholder
  185. # Transform nodes and collect targets for partition's submodule
  186. for node in m.graph.nodes:
  187. if hasattr(node, "_fx_partition"):
  188. partition = partitions[node._fx_partition]
  189. # swap out old graph nodes in kw/args with references to new nodes in this submodule
  190. environment = partition.environment
  191. gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n])
  192. gathered_kwargs = torch.fx.graph.map_arg(
  193. node.kwargs, lambda n: environment[n]
  194. )
  195. if node.op not in ["call_module", "get_attr"]:
  196. target = node.target
  197. else:
  198. target_atoms = node.target.split(".")
  199. target_attr = m
  200. for atom in target_atoms:
  201. if not hasattr(target_attr, atom):
  202. raise RuntimeError(f"Operator target {node.target} not found!")
  203. target_attr = getattr(target_attr, atom)
  204. # target = target_atoms[-1]
  205. target = "_".join(target_atoms)
  206. partition.targets[target] = target_attr
  207. # Fill in the passed-in mapping from new qualname to old qualname
  208. if qualname_map is not None:
  209. # When creating the split module later, the submodules will have
  210. # path prefix matching the corresponding partition's submod_name
  211. qualname = f"{partition.submod_name}.{target}"
  212. qualname_map[qualname] = node.target
  213. assert isinstance(gathered_args, tuple)
  214. assert isinstance(gathered_kwargs, dict)
  215. new_node = partition.graph.create_node(
  216. op=node.op,
  217. target=target,
  218. args=gathered_args,
  219. kwargs=gathered_kwargs,
  220. type_expr=node.type,
  221. )
  222. new_node.meta = node.meta.copy()
  223. partition.environment[node] = new_node
  224. # Set up values to construct base module
  225. base_mod_env: Dict[str, torch.fx.node.Node] = {}
  226. base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph()
  227. base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
  228. for node in m.graph.nodes:
  229. if node.op == "placeholder":
  230. default_value = (
  231. node.args[0] if len(node.args) > 0 else inspect.Signature.empty
  232. )
  233. base_mod_env[node.name] = base_mod_graph.placeholder(
  234. node.target, type_expr=node.type, default_value=default_value
  235. )
  236. base_mod_env[node.name].meta = node.meta.copy()
  237. elif node.op == "get_attr":
  238. base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
  239. base_mod_env[node.name].meta = node.meta.copy()
  240. attr_val = m
  241. for atom in node.target.split("."):
  242. if not hasattr(attr_val, atom):
  243. raise RuntimeError(f"Node target {node.target} not found!")
  244. attr_val = getattr(attr_val, atom)
  245. base_mod_attrs[node.target] = attr_val
  246. # Do some things iterating over the partitions in topological order again:
  247. # 1) Finish off submodule Graphs by setting corresponding outputs
  248. # 2) Construct GraphModules for each submodule
  249. # 3) Construct the base graph by emitting calls to those submodules in
  250. # topological order
  251. construct_order_partitions = (
  252. sorted_partitions if not keep_original_order else original_partition_order
  253. )
  254. for partition_name in construct_order_partitions:
  255. partition = partitions[partition_name]
  256. # Set correct output values
  257. output_vals = tuple(
  258. partition.environment[orig_nodes[name]] for name in partition.outputs
  259. )
  260. output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore[assignment]
  261. partition.graph.output(output_vals)
  262. # Construct GraphModule for this partition
  263. base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule(
  264. partition.targets, partition.graph
  265. ) # noqa: B950
  266. # Emit call in base graph to this submodule
  267. output_val = base_mod_graph.call_module(
  268. partition.submod_name,
  269. tuple(base_mod_env[name] for name in partition.inputs),
  270. )
  271. if len(partition.outputs) > 1:
  272. # Unpack multiple return values from submodule
  273. output_val_proxy = torch.fx.proxy.Proxy(output_val)
  274. for i, output_name in enumerate(partition.outputs):
  275. base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index]
  276. else:
  277. base_mod_env[list(partition.outputs)[0]] = output_val
  278. for node in m.graph.nodes:
  279. if node.op == "output":
  280. base_mod_graph.output(
  281. torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name])
  282. ) # noqa: B950
  283. return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)