splitter_base.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. import argparse
  2. import copy
  3. from collections import defaultdict
  4. from dataclasses import dataclass
  5. from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
  6. import logging
  7. import torch
  8. from torch.fx.passes.graph_manipulation import get_size_of_node
  9. from torch.fx.node import map_arg
  10. from torch.fx._compatibility import compatibility
  11. from .operator_support import (
  12. get_node_target,
  13. OperatorSupportBase,
  14. )
  15. from .graph_drawer import FxGraphDrawer
  16. from .shape_prop import ShapeProp
  17. from .split_utils import split_by_tags
  18. from .tools_common import (
  19. FxNetAccFusionsFinder,
  20. CALLABLE_NODE_OPS,
  21. Tensors,
  22. NodeList,
  23. NodeSet,
  24. is_node_output_tensor,
  25. )
  26. __all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
  27. _LOGGER = logging.getLogger(__name__)
  28. DEFAULT_MIN_ACC_MODULE_SIZE = 1
  29. DEFAULT_SKIP_FUSION = False
  30. DEFAULT_ALLOW_NON_TENSOR = False
  31. class _SplitterSettingBase:
  32. def __init__(
  33. self,
  34. min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
  35. skip_fusion=DEFAULT_SKIP_FUSION,
  36. allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
  37. ):
  38. parser = argparse.ArgumentParser()
  39. parser.add_argument(
  40. "--min-acc-module-size",
  41. "--min_acc_module_size",
  42. required=False,
  43. type=int,
  44. help="Minimum size limit of an accelerator subgraph.",
  45. )
  46. parser.add_argument(
  47. "--skip-fusion",
  48. "--skip_fusion",
  49. default=False,
  50. action="store_true",
  51. help="If true then no fusion groups. Fusion group is used to "
  52. "enforce no non-tensor data flow between submodules. If we don't "
  53. "have this constrain, setting this to false is recommended as it "
  54. "can reduce overhead.",
  55. )
  56. parser.add_argument(
  57. "--allow-non-tensor",
  58. "--allow_non_tensor",
  59. default=False,
  60. action="store_true",
  61. help="For some backends non-tensor data flow between cpu and them "
  62. "are not allowed. Therefore, if a node supported by accelerator but "
  63. "it has non-tensor inputs or outputs to a cpu node we would want to "
  64. "consider it as a cpu node during splitting. However, for some backends "
  65. "we might not care about non-tensor data flow and we can set this option "
  66. "to true to disable the functionality that prevent non-tensor data flow.",
  67. )
  68. args, unknown = parser.parse_known_args()
  69. self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
  70. self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
  71. self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
  72. @compatibility(is_backward_compatible=False)
  73. class FxNetAccNodesFinder:
  74. """
  75. Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
  76. input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
  77. I.e. if we have a chain:
  78. ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
  79. where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
  80. This behavior can be turned off by passing allow_non_tensor=True.
  81. """
  82. def __init__(
  83. self,
  84. module: torch.fx.GraphModule,
  85. operator_support: OperatorSupportBase,
  86. allow_non_tensor: bool,
  87. ):
  88. self.module = module
  89. self.operator_support = operator_support
  90. self.allow_non_tensor = allow_non_tensor
  91. def reduce_acc_nodes_non_tensor_input_helper(
  92. self, cpu_worklist: NodeList
  93. ):
  94. """
  95. Transitively excludes nodes from ACC supported set.
  96. For every node in the worklist:
  97. - removes its downstream ACC nodes from ACC supported set,
  98. - if any downstream ACC node produces non-tensor output,
  99. then it gets added into the worklist.
  100. """
  101. while cpu_worklist:
  102. node = cpu_worklist.pop(0)
  103. for user in node.users:
  104. if user in self.acc_nodes:
  105. self.acc_nodes.remove(user)
  106. if not is_node_output_tensor(user):
  107. cpu_worklist.append(user)
  108. def reduce_acc_nodes_non_tensor_input(self):
  109. """
  110. Excludes nodes from ACC supported set that have direct
  111. upstream CPU nodes that produce non-tensor outputs.
  112. """
  113. non_tensor_cpu_nodes: NodeList = []
  114. for node in self.module.graph.nodes:
  115. if node.op not in CALLABLE_NODE_OPS:
  116. continue
  117. if node in self.acc_nodes:
  118. continue
  119. if is_node_output_tensor(node):
  120. continue
  121. non_tensor_cpu_nodes.append(node)
  122. self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
  123. def reduce_acc_nodes_non_tensor_output(self):
  124. """
  125. Excludes nodes from ACC supported set that produce non-tensor
  126. outputs and have downstream CPU nodes.
  127. """
  128. while True:
  129. new_cpu_nodes: NodeList = []
  130. for acc_node in self.acc_nodes:
  131. if is_node_output_tensor(acc_node):
  132. continue
  133. for user in acc_node.users:
  134. if user not in self.acc_nodes:
  135. new_cpu_nodes.append(acc_node)
  136. break
  137. if not new_cpu_nodes:
  138. break
  139. for new_cpu_node in new_cpu_nodes:
  140. self.acc_nodes.remove(new_cpu_node)
  141. self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
  142. def __call__(self) -> NodeSet:
  143. submodules = dict(self.module.named_modules())
  144. self.acc_nodes = {
  145. n
  146. for n in self.module.graph.nodes
  147. if n.op in CALLABLE_NODE_OPS
  148. and self.operator_support.is_node_supported(submodules, n)
  149. }
  150. if not self.allow_non_tensor:
  151. self.reduce_acc_nodes_non_tensor_input()
  152. self.reduce_acc_nodes_non_tensor_output()
  153. return self.acc_nodes
  154. @compatibility(is_backward_compatible=False)
  155. class FxNetSplitterInternalError(Exception):
  156. pass
  157. @compatibility(is_backward_compatible=False)
  158. @dataclass
  159. class Subgraph:
  160. is_acc: bool
  161. nodes: NodeList
  162. @compatibility(is_backward_compatible=False)
  163. class SplitResult(NamedTuple):
  164. """
  165. Stores the results of the splitter.
  166. Attributes:
  167. split_module: root module after splitting.
  168. submodule_inputs: a dict that maps submodule name to its inputs.
  169. non_acc_submodule_prefix: the prefix for non acc submodules. For
  170. acc submodule the prefix is alwasy "_run_on_acc_".
  171. """
  172. split_module: torch.fx.GraphModule
  173. submodule_inputs: Dict[str, Any]
  174. non_acc_submodule_prefix: str
  175. @compatibility(is_backward_compatible=False)
  176. def generate_inputs_for_submodules(
  177. model: torch.nn.Module,
  178. inputs: Sequence[Any],
  179. target_submodules: Iterable[str],
  180. deepcopy: bool = False,
  181. ) -> Dict[str, Any]:
  182. """
  183. Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
  184. function doesn't work.
  185. Args:
  186. model: root model.
  187. inputs: inputs to the root model.
  188. target_submodules: submodules that we want to generate inputs for.
  189. Returns:
  190. A dict that maps from submodule name to its inputs.
  191. """
  192. handles = []
  193. results = {}
  194. submodule_to_names = {mod: name for name, mod in model.named_modules()}
  195. def pre_forward(module, module_inputs):
  196. results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
  197. for name, mod in model.named_modules():
  198. if name in target_submodules:
  199. handles.append(mod.register_forward_pre_hook(pre_forward))
  200. def clean_up_handles():
  201. for h in handles:
  202. h.remove()
  203. try:
  204. with torch.no_grad():
  205. model(*inputs)
  206. except Exception as e:
  207. clean_up_handles()
  208. raise e
  209. clean_up_handles()
  210. return results
  211. class _SplitterBase:
  212. """
  213. Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
  214. Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
  215. Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
  216. Given the following graph:
  217. ==> b ==>
  218. // \\
  219. a d
  220. \\ //
  221. ==> c ==>
  222. class SimpleModule(torch.nn.Module):
  223. def forward(self, a):
  224. b = torch.sin(a)
  225. c = torch.cos(a)
  226. d = b + c
  227. return d
  228. and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
  229. we will get the following split result:
  230. main:
  231. def forward(self, a):
  232. run_on_acc_0_0 = self._run_on_acc_0_0(a)
  233. getitem = run_on_acc_0_0[0]
  234. getitem_1 = run_on_acc_0_0[1]
  235. run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
  236. return run_on_cpu_1_1
  237. _run_on_acc_0_0:
  238. def forward(self, a):
  239. sin_1 = torch.sin(a)
  240. cos_1 = torch.cos(a)
  241. return (sin_1, cos_1)
  242. _run_on_cpu_1_1:
  243. def forward(self, sin_1, cos_1):
  244. add_1 = sin_1 + cos_1
  245. return add_1
  246. """
  247. # PCIe bandwidth for the backend, default to 100 GB/s
  248. PCIe_BW = 100 * 2 ** 30
  249. def __init__(
  250. self,
  251. module: torch.fx.GraphModule,
  252. sample_input: Sequence[Any],
  253. operator_support: OperatorSupportBase,
  254. settings: _SplitterSettingBase,
  255. non_acc_submodule_name: str = "_run_on_cpu_",
  256. ):
  257. """
  258. Preprocesses graph before splitting:
  259. - finds nodes supported by ACC,
  260. - finds fusion groups for ACC nodes having non-tensor IO,
  261. - builds a graph of direct dependencies,
  262. - builds a map of fused nodes to their fusions.
  263. As a result we get self.acc_nodes, self.deps and self.fusions.
  264. """
  265. assert isinstance(module, torch.fx.GraphModule)
  266. self.module = module
  267. ShapeProp(self.module).propagate(*sample_input)
  268. self.settings = settings
  269. self.operator_support = operator_support
  270. self.sample_input = sample_input
  271. self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
  272. if self.settings.skip_fusion:
  273. self.fusions = {}
  274. else:
  275. self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
  276. # Modify deps to add more deps for fused nodes
  277. self.deps = self.find_deps()
  278. self.update_deps_for_fusions()
  279. self.non_acc_submodule_name = non_acc_submodule_name
  280. self._node_submodule_map: Dict[str, str] = {}
  281. # ===============================================================
  282. # Helpers for ctor and initial state
  283. # ===============================================================
  284. def get_node_submodule_map(self) -> Dict[str, str]:
  285. """ Returns a map from node name to submodule name, e.g.
  286. node: main_module_impl_impl_over_arch_unary_multiple_embedding
  287. _pooling_embedding_pooling_sparse_entity_equivalence_key
  288. _proxy_embedding_bag
  289. maps to submodule name of: _run_on_acc_1
  290. """
  291. return self._node_submodule_map
  292. def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
  293. """
  294. Builds a graph of node dependencies. Leaf nodes don't have any
  295. dependencies and the "output" node doesn't have nodes depending on it.
  296. Resulting graph has only direct dependencies, i.e. there are no
  297. transitive dependencies.
  298. """
  299. deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
  300. for node in self.module.graph.nodes:
  301. if node.op not in CALLABLE_NODE_OPS:
  302. continue
  303. for user in node.users:
  304. if user.op != "output":
  305. deps[user].add(node)
  306. return deps
  307. def update_deps_for_fusions(self):
  308. """
  309. Updates graph of dependencies so that:
  310. - nodes from the same fusion depend on the same set of outer nodes,
  311. - outer nodes depending on a fusion depend on all nodes in that fusion.
  312. """
  313. for node in self.fusions:
  314. fusion = self.fusions[node]
  315. for fused_neighbor in fusion:
  316. self.deps[node].update(self.deps[fused_neighbor] - fusion)
  317. for user in fused_neighbor.users:
  318. if user not in fusion:
  319. self.deps[user].add(node)
  320. # ===============================================================
  321. # Helpers for preview
  322. # ===============================================================
  323. def _lower_model_to_backend(
  324. self, mod: torch.fx.GraphModule, inputs: Tensors
  325. ) -> torch.nn.Module:
  326. """
  327. Lower the model to a backend.
  328. """
  329. return mod
  330. def _find_culprit(
  331. self, mod: torch.fx.GraphModule, inputs: Tensors
  332. ) -> str:
  333. """
  334. When an error occurs during lowering or running the lowered mod, we use this
  335. function to find culprits in the `mod` that causes the error.
  336. """
  337. return "Unable to find a culprit because _find_culprit() function is not implemented."
  338. def _draw_graph_based_on_node_support(
  339. self, mod: torch.fx.GraphModule, supported_nodes: NodeList
  340. ):
  341. color_map = {
  342. "default": "AliceBlue",
  343. "supported": "chartreuse1",
  344. "unsupported": "crimson",
  345. }
  346. class CustomDrawer(FxGraphDrawer):
  347. def _get_node_style(self, node):
  348. template = super()._get_node_style(node)
  349. if node in supported_nodes:
  350. template["fillcolor"] = color_map["supported"]
  351. elif node.op in CALLABLE_NODE_OPS:
  352. template["fillcolor"] = color_map["unsupported"]
  353. else:
  354. template["fillcolor"] = color_map["default"]
  355. return template
  356. drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
  357. dot_graph = drawer.get_main_dot_graph()
  358. dot_graph.write_raw("node_support.dot")
  359. def node_support_preview(self, dump_graph: bool = False):
  360. submodules = dict(self.module.named_modules())
  361. supported_nodes: NodeList = []
  362. supported_node_types = defaultdict(set)
  363. unsupported_node_types = defaultdict(set)
  364. def get_dtype(arg):
  365. tensor_meta = arg.meta.get("tensor_meta")
  366. return getattr(tensor_meta, "dtype", None)
  367. for node in self.module.graph.nodes:
  368. if node.op not in CALLABLE_NODE_OPS:
  369. continue
  370. target = get_node_target(submodules, node)
  371. # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
  372. arg_dtypes = [
  373. get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
  374. for arg in node.args
  375. ]
  376. # Find last non-None element. If all elements are None, return max_len.
  377. last_index = len(arg_dtypes) - next(
  378. (
  379. i
  380. for i, dtype in enumerate(reversed(arg_dtypes))
  381. if dtype is not None
  382. ),
  383. len(arg_dtypes),
  384. )
  385. # Strip None elements at the end.
  386. arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
  387. kwarg_dtypes_tuple = tuple(
  388. (k, get_dtype(arg))
  389. for k, arg in node.kwargs.items()
  390. if isinstance(arg, torch.fx.Node)
  391. )
  392. if self.operator_support.is_node_supported(submodules, node):
  393. supported_nodes.append(node)
  394. supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
  395. else:
  396. unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
  397. if dump_graph:
  398. self._draw_graph_based_on_node_support(self.module, supported_nodes)
  399. reports = "\nSupported node types in the model:\n"
  400. for t, dtypes in supported_node_types.items():
  401. for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
  402. reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
  403. reports += "\nUnsupported node types in the model:\n"
  404. for t, dtypes in unsupported_node_types.items():
  405. for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
  406. reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
  407. print(reports)
  408. # Return reports for testing purpose
  409. return reports
  410. def split_preview(self, dump_graph: bool = False):
  411. reports = ""
  412. subgraphs = self.put_nodes_into_subgraphs()
  413. acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
  414. cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
  415. reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
  416. reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
  417. subgraphs = self.remove_small_acc_subgraphs(subgraphs)
  418. acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
  419. cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
  420. reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
  421. reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
  422. for i, subgraph in enumerate(subgraphs):
  423. reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
  424. reports += f"{len(subgraph.nodes)} node(s)\n"
  425. self.tag(subgraphs)
  426. split_mod = self.split(remove_tag=True)
  427. split_mod.eval()
  428. if dump_graph:
  429. drawer = FxGraphDrawer(
  430. split_mod, "preview", ignore_getattr=True
  431. )
  432. dot_graphs = drawer.get_all_dot_graphs()
  433. for name, dot_graph in dot_graphs.items():
  434. dot_graph.write_raw(f"{name}.dot")
  435. max_qps: float = self.PCIe_BW
  436. bottleneck_module = ""
  437. for node in split_mod.graph.nodes:
  438. if node.op == "call_module" and "acc" in node.target:
  439. reports += f"\nProcessing acc submodule {node.target}\n"
  440. submod = getattr(split_mod, node.target)
  441. def get_submod_inputs(main_mod, submod, example_inputs):
  442. sub_inputs = None
  443. def get_inputs(self, inputs):
  444. nonlocal sub_inputs
  445. sub_inputs = inputs
  446. handle = submod.register_forward_pre_hook(get_inputs)
  447. main_mod(*example_inputs)
  448. handle.remove()
  449. return sub_inputs
  450. submod_inputs = get_submod_inputs(
  451. split_mod, submod, self.sample_input
  452. )
  453. ShapeProp(submod).propagate(*submod_inputs)
  454. total_input_bytes = 0
  455. total_output_bytes = 0
  456. reports += "Checking inputs...\n"
  457. for n in submod.graph.nodes:
  458. if n.op == "placeholder":
  459. if not is_node_output_tensor(n):
  460. reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
  461. else:
  462. total_input_bytes += get_size_of_node(submod, n)[0]
  463. if n.op == "output":
  464. output_node = n
  465. reports += "Checking outputs...\n"
  466. def get_bytes(node: torch.fx.Node):
  467. nonlocal total_output_bytes
  468. nonlocal reports
  469. if not is_node_output_tensor(node):
  470. reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
  471. else:
  472. total_output_bytes += get_size_of_node(submod, node)[0]
  473. map_arg(output_node.args, get_bytes)
  474. qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
  475. reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
  476. reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
  477. if qps < max_qps:
  478. max_qps = qps
  479. bottleneck_module = node.target
  480. try:
  481. lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
  482. except RuntimeError:
  483. reports += "Run into an error during lowering!\n"
  484. reports += self._find_culprit(submod, submod_inputs)
  485. continue
  486. try:
  487. lowered_submod(*submod_inputs)
  488. except RuntimeError:
  489. reports += "Run into an error during inference!\n"
  490. reports += self._find_culprit(submod, submod_inputs)
  491. else:
  492. reports += "Lowering and running succeed!\n"
  493. reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
  494. reports += f" bottleneck is submodule {bottleneck_module}."
  495. print(reports)
  496. # return the reports for testing purposes
  497. return reports
  498. # ===============================================================
  499. # Helpers for extend_acc_subgraph() method
  500. # ===============================================================
  501. def find_reverse_deps(
  502. self, tag_id: Optional[int] = None
  503. ) -> Dict[torch.fx.Node, NodeSet]:
  504. """
  505. Builds reversed topological node dependencies, if tag_id is specified,
  506. we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
  507. """
  508. result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
  509. for node in self.module.graph.nodes:
  510. if node.op not in CALLABLE_NODE_OPS:
  511. continue
  512. for user in node.users:
  513. if user.op not in CALLABLE_NODE_OPS:
  514. continue
  515. if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
  516. result[node].add(user)
  517. return result
  518. def update_reverse_deps_for_fusions(
  519. self, deps: Dict[torch.fx.Node, NodeSet]
  520. ):
  521. processed_node = set()
  522. for node, fusion in self.fusions.items():
  523. if node in processed_node:
  524. continue
  525. new_dep = set()
  526. # Create a new dependency set which include all the
  527. # dependencies of the nodes in the fusion group
  528. for n in fusion:
  529. new_dep.update(deps[n])
  530. # Exclude nodes in the fusion
  531. new_dep.difference_update(fusion)
  532. # Update dependency
  533. for n in fusion:
  534. deps[n] = new_dep
  535. for arg in n.all_input_nodes:
  536. if arg not in fusion:
  537. deps[arg].update(fusion)
  538. processed_node.add(n)
  539. def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
  540. """
  541. Finds parent nodes of the `tag` subgraph.
  542. Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
  543. and is not a placeholder, we consider it as the parent node of the subgraph.
  544. """
  545. parent_nodes = set()
  546. for node in self.module.graph.nodes:
  547. if node.op in CALLABLE_NODE_OPS and node.tag == tag:
  548. for arg in node.all_input_nodes:
  549. if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
  550. parent_nodes.add(arg)
  551. return parent_nodes
  552. def extend_acc_subgraph(self, tag: str):
  553. """
  554. Extend the acc subgraph with `tag` going the reversed topological direction.
  555. """
  556. # Dict that maps node to its users and ignore users that
  557. # are in the subgraph that has greater tag
  558. deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
  559. self.update_reverse_deps_for_fusions(deps)
  560. # Parent nodes of the subgraph
  561. parent_nodes = self.find_parent_nodes_of_subgraph(tag)
  562. visited_nodes: NodeSet = set()
  563. while parent_nodes:
  564. node = None
  565. # Find a acc node that depends on visited nodes only
  566. for n in parent_nodes:
  567. if deps[n] <= visited_nodes and n in self.acc_nodes:
  568. node = n
  569. break
  570. if node is None:
  571. break
  572. # Put the node into `tag` subgraph
  573. node.tag = tag # type: ignore[attr-defined]
  574. parent_nodes.remove(node)
  575. visited_nodes.add(node)
  576. # If node is in a fusion group, add all fusion buddies to parent nodes
  577. if node in self.fusions:
  578. for fusion_node in self.fusions[node]:
  579. if fusion_node not in visited_nodes:
  580. parent_nodes.add(fusion_node)
  581. # Add inputs of the node to parent nodes
  582. for arg in node.all_input_nodes:
  583. if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
  584. parent_nodes.add(arg)
  585. # ===============================================================
  586. # Helpers for split() method
  587. # ===============================================================
  588. def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
  589. """
  590. Finds nodes that consume module inputs or get_attr nodes.
  591. """
  592. starter_cpu_nodes: NodeSet = set()
  593. starter_acc_nodes: NodeSet = set()
  594. for node in self.module.graph.nodes:
  595. if node.op not in {"placeholder", "get_attr"}:
  596. continue
  597. for user in node.users:
  598. if user in self.acc_nodes:
  599. starter_acc_nodes.add(user)
  600. else:
  601. starter_cpu_nodes.add(user)
  602. return starter_cpu_nodes, starter_acc_nodes
  603. def put_nodes_into_subgraphs(self) -> List[Subgraph]:
  604. # We start graph traversal from leaf nodes
  605. current_cpu_nodes, current_acc_nodes = self.starter_nodes()
  606. visited_nodes: NodeSet = set()
  607. # Determine which subgraph to start from based on which subgraph has
  608. # 0-dep node
  609. acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes])
  610. current_subgraph_nodes: NodeList = []
  611. # Result accumulator
  612. subgraphs: List[Subgraph] = []
  613. while current_cpu_nodes or current_acc_nodes:
  614. # Find the first node that should belong to the current subgraph and has all dependencies resolved
  615. current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
  616. node = next(
  617. (n for n in current_nodes if self.deps[n] <= visited_nodes),
  618. None,
  619. )
  620. # If nothing was found, then it's time to flip the mode and start a new subgraph
  621. if node is None:
  622. if not current_subgraph_nodes:
  623. raise FxNetSplitterInternalError("Subgraph can't be empty")
  624. subgraphs.append(
  625. Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
  626. )
  627. acc_subgraph = not acc_subgraph
  628. current_subgraph_nodes = []
  629. continue
  630. current_nodes.remove(node)
  631. visited_nodes.add(node)
  632. current_subgraph_nodes.append(node)
  633. # Add fusion buddies
  634. if node in self.fusions:
  635. if node in self.acc_nodes:
  636. current_acc_nodes.update(self.fusions[node] - visited_nodes)
  637. else:
  638. current_cpu_nodes.update(self.fusions[node] - visited_nodes)
  639. # Put depending nodes into the queue
  640. for user in node.users:
  641. if user.op not in CALLABLE_NODE_OPS:
  642. continue
  643. # Add downstream nodes
  644. if user in self.acc_nodes:
  645. current_acc_nodes.add(user)
  646. else:
  647. current_cpu_nodes.add(user)
  648. # Check if the last subgraph was not created
  649. if current_subgraph_nodes:
  650. subgraphs.append(
  651. Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
  652. )
  653. if not subgraphs:
  654. raise FxNetSplitterInternalError("Couldn't create subgraphs")
  655. return subgraphs
  656. def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
  657. """
  658. This pass finds ACC submodules with less than specified size and merges
  659. them with adjacent CPU submodules.
  660. """
  661. result: List[Subgraph] = []
  662. for subgraph in subgraphs:
  663. if subgraph.is_acc:
  664. if len(subgraph.nodes) >= self.settings.min_acc_module_size:
  665. result.append(subgraph)
  666. else:
  667. print(
  668. "Eliminating acc subgraph because it's smaller than the threshold: "
  669. f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
  670. )
  671. if result:
  672. result[-1].nodes.extend(subgraph.nodes)
  673. else:
  674. subgraph.is_acc = False
  675. result.append(subgraph)
  676. else:
  677. if result and not result[-1].is_acc:
  678. result[-1].nodes.extend(subgraph.nodes)
  679. else:
  680. result.append(subgraph)
  681. return result
  682. def tag(self, subgraphs: List[Subgraph]):
  683. self.tags: List[str] = []
  684. for subgraph in subgraphs:
  685. tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
  686. self.tags.append(tag)
  687. for node in subgraph.nodes:
  688. if hasattr(node, "tag"):
  689. raise FxNetSplitterInternalError(f"Node {node} was already tagged")
  690. node.tag = tag # type: ignore[attr-defined]
  691. self._node_submodule_map[node.name] = tag
  692. def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
  693. split_module = split_by_tags(self.module, self.tags)
  694. if remove_tag:
  695. for node in self.module.graph.nodes:
  696. if hasattr(node, "tag"):
  697. del node.tag
  698. return split_module
  699. def __call__(self) -> torch.fx.GraphModule:
  700. subgraphs = self.put_nodes_into_subgraphs()
  701. subgraphs = self.remove_small_acc_subgraphs(subgraphs)
  702. acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
  703. non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
  704. print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
  705. self.tag(subgraphs)
  706. return self.split()
  707. def generate_split_results(self) -> SplitResult:
  708. split_module = self()
  709. submodule_names = []
  710. for name, mod in split_module.named_children():
  711. submodule_names.append(name)
  712. submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
  713. return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)