123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871 |
- import argparse
- import copy
- from collections import defaultdict
- from dataclasses import dataclass
- from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
- import logging
- import torch
- from torch.fx.passes.graph_manipulation import get_size_of_node
- from torch.fx.node import map_arg
- from torch.fx._compatibility import compatibility
- from .operator_support import (
- get_node_target,
- OperatorSupportBase,
- )
- from .graph_drawer import FxGraphDrawer
- from .shape_prop import ShapeProp
- from .split_utils import split_by_tags
- from .tools_common import (
- FxNetAccFusionsFinder,
- CALLABLE_NODE_OPS,
- Tensors,
- NodeList,
- NodeSet,
- is_node_output_tensor,
- )
- __all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
- _LOGGER = logging.getLogger(__name__)
- DEFAULT_MIN_ACC_MODULE_SIZE = 1
- DEFAULT_SKIP_FUSION = False
- DEFAULT_ALLOW_NON_TENSOR = False
- class _SplitterSettingBase:
- def __init__(
- self,
- min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
- skip_fusion=DEFAULT_SKIP_FUSION,
- allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
- ):
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--min-acc-module-size",
- "--min_acc_module_size",
- required=False,
- type=int,
- help="Minimum size limit of an accelerator subgraph.",
- )
- parser.add_argument(
- "--skip-fusion",
- "--skip_fusion",
- default=False,
- action="store_true",
- help="If true then no fusion groups. Fusion group is used to "
- "enforce no non-tensor data flow between submodules. If we don't "
- "have this constrain, setting this to false is recommended as it "
- "can reduce overhead.",
- )
- parser.add_argument(
- "--allow-non-tensor",
- "--allow_non_tensor",
- default=False,
- action="store_true",
- help="For some backends non-tensor data flow between cpu and them "
- "are not allowed. Therefore, if a node supported by accelerator but "
- "it has non-tensor inputs or outputs to a cpu node we would want to "
- "consider it as a cpu node during splitting. However, for some backends "
- "we might not care about non-tensor data flow and we can set this option "
- "to true to disable the functionality that prevent non-tensor data flow.",
- )
- args, unknown = parser.parse_known_args()
- self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
- self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
- self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
- @compatibility(is_backward_compatible=False)
- class FxNetAccNodesFinder:
- """
- Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
- input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
- I.e. if we have a chain:
- ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
- where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
- This behavior can be turned off by passing allow_non_tensor=True.
- """
- def __init__(
- self,
- module: torch.fx.GraphModule,
- operator_support: OperatorSupportBase,
- allow_non_tensor: bool,
- ):
- self.module = module
- self.operator_support = operator_support
- self.allow_non_tensor = allow_non_tensor
- def reduce_acc_nodes_non_tensor_input_helper(
- self, cpu_worklist: NodeList
- ):
- """
- Transitively excludes nodes from ACC supported set.
- For every node in the worklist:
- - removes its downstream ACC nodes from ACC supported set,
- - if any downstream ACC node produces non-tensor output,
- then it gets added into the worklist.
- """
- while cpu_worklist:
- node = cpu_worklist.pop(0)
- for user in node.users:
- if user in self.acc_nodes:
- self.acc_nodes.remove(user)
- if not is_node_output_tensor(user):
- cpu_worklist.append(user)
- def reduce_acc_nodes_non_tensor_input(self):
- """
- Excludes nodes from ACC supported set that have direct
- upstream CPU nodes that produce non-tensor outputs.
- """
- non_tensor_cpu_nodes: NodeList = []
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- if node in self.acc_nodes:
- continue
- if is_node_output_tensor(node):
- continue
- non_tensor_cpu_nodes.append(node)
- self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
- def reduce_acc_nodes_non_tensor_output(self):
- """
- Excludes nodes from ACC supported set that produce non-tensor
- outputs and have downstream CPU nodes.
- """
- while True:
- new_cpu_nodes: NodeList = []
- for acc_node in self.acc_nodes:
- if is_node_output_tensor(acc_node):
- continue
- for user in acc_node.users:
- if user not in self.acc_nodes:
- new_cpu_nodes.append(acc_node)
- break
- if not new_cpu_nodes:
- break
- for new_cpu_node in new_cpu_nodes:
- self.acc_nodes.remove(new_cpu_node)
- self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
- def __call__(self) -> NodeSet:
- submodules = dict(self.module.named_modules())
- self.acc_nodes = {
- n
- for n in self.module.graph.nodes
- if n.op in CALLABLE_NODE_OPS
- and self.operator_support.is_node_supported(submodules, n)
- }
- if not self.allow_non_tensor:
- self.reduce_acc_nodes_non_tensor_input()
- self.reduce_acc_nodes_non_tensor_output()
- return self.acc_nodes
- @compatibility(is_backward_compatible=False)
- class FxNetSplitterInternalError(Exception):
- pass
- @compatibility(is_backward_compatible=False)
- @dataclass
- class Subgraph:
- is_acc: bool
- nodes: NodeList
- @compatibility(is_backward_compatible=False)
- class SplitResult(NamedTuple):
- """
- Stores the results of the splitter.
- Attributes:
- split_module: root module after splitting.
- submodule_inputs: a dict that maps submodule name to its inputs.
- non_acc_submodule_prefix: the prefix for non acc submodules. For
- acc submodule the prefix is alwasy "_run_on_acc_".
- """
- split_module: torch.fx.GraphModule
- submodule_inputs: Dict[str, Any]
- non_acc_submodule_prefix: str
- @compatibility(is_backward_compatible=False)
- def generate_inputs_for_submodules(
- model: torch.nn.Module,
- inputs: Sequence[Any],
- target_submodules: Iterable[str],
- deepcopy: bool = False,
- ) -> Dict[str, Any]:
- """
- Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
- function doesn't work.
- Args:
- model: root model.
- inputs: inputs to the root model.
- target_submodules: submodules that we want to generate inputs for.
- Returns:
- A dict that maps from submodule name to its inputs.
- """
- handles = []
- results = {}
- submodule_to_names = {mod: name for name, mod in model.named_modules()}
- def pre_forward(module, module_inputs):
- results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
- for name, mod in model.named_modules():
- if name in target_submodules:
- handles.append(mod.register_forward_pre_hook(pre_forward))
- def clean_up_handles():
- for h in handles:
- h.remove()
- try:
- with torch.no_grad():
- model(*inputs)
- except Exception as e:
- clean_up_handles()
- raise e
- clean_up_handles()
- return results
- class _SplitterBase:
- """
- Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
- Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
- Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
- Given the following graph:
- ==> b ==>
- // \\
- a d
- \\ //
- ==> c ==>
- class SimpleModule(torch.nn.Module):
- def forward(self, a):
- b = torch.sin(a)
- c = torch.cos(a)
- d = b + c
- return d
- and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
- we will get the following split result:
- main:
- def forward(self, a):
- run_on_acc_0_0 = self._run_on_acc_0_0(a)
- getitem = run_on_acc_0_0[0]
- getitem_1 = run_on_acc_0_0[1]
- run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
- return run_on_cpu_1_1
- _run_on_acc_0_0:
- def forward(self, a):
- sin_1 = torch.sin(a)
- cos_1 = torch.cos(a)
- return (sin_1, cos_1)
- _run_on_cpu_1_1:
- def forward(self, sin_1, cos_1):
- add_1 = sin_1 + cos_1
- return add_1
- """
- # PCIe bandwidth for the backend, default to 100 GB/s
- PCIe_BW = 100 * 2 ** 30
- def __init__(
- self,
- module: torch.fx.GraphModule,
- sample_input: Sequence[Any],
- operator_support: OperatorSupportBase,
- settings: _SplitterSettingBase,
- non_acc_submodule_name: str = "_run_on_cpu_",
- ):
- """
- Preprocesses graph before splitting:
- - finds nodes supported by ACC,
- - finds fusion groups for ACC nodes having non-tensor IO,
- - builds a graph of direct dependencies,
- - builds a map of fused nodes to their fusions.
- As a result we get self.acc_nodes, self.deps and self.fusions.
- """
- assert isinstance(module, torch.fx.GraphModule)
- self.module = module
- ShapeProp(self.module).propagate(*sample_input)
- self.settings = settings
- self.operator_support = operator_support
- self.sample_input = sample_input
- self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
- if self.settings.skip_fusion:
- self.fusions = {}
- else:
- self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
- # Modify deps to add more deps for fused nodes
- self.deps = self.find_deps()
- self.update_deps_for_fusions()
- self.non_acc_submodule_name = non_acc_submodule_name
- self._node_submodule_map: Dict[str, str] = {}
- # ===============================================================
- # Helpers for ctor and initial state
- # ===============================================================
- def get_node_submodule_map(self) -> Dict[str, str]:
- """ Returns a map from node name to submodule name, e.g.
- node: main_module_impl_impl_over_arch_unary_multiple_embedding
- _pooling_embedding_pooling_sparse_entity_equivalence_key
- _proxy_embedding_bag
- maps to submodule name of: _run_on_acc_1
- """
- return self._node_submodule_map
- def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
- """
- Builds a graph of node dependencies. Leaf nodes don't have any
- dependencies and the "output" node doesn't have nodes depending on it.
- Resulting graph has only direct dependencies, i.e. there are no
- transitive dependencies.
- """
- deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- for user in node.users:
- if user.op != "output":
- deps[user].add(node)
- return deps
- def update_deps_for_fusions(self):
- """
- Updates graph of dependencies so that:
- - nodes from the same fusion depend on the same set of outer nodes,
- - outer nodes depending on a fusion depend on all nodes in that fusion.
- """
- for node in self.fusions:
- fusion = self.fusions[node]
- for fused_neighbor in fusion:
- self.deps[node].update(self.deps[fused_neighbor] - fusion)
- for user in fused_neighbor.users:
- if user not in fusion:
- self.deps[user].add(node)
- # ===============================================================
- # Helpers for preview
- # ===============================================================
- def _lower_model_to_backend(
- self, mod: torch.fx.GraphModule, inputs: Tensors
- ) -> torch.nn.Module:
- """
- Lower the model to a backend.
- """
- return mod
- def _find_culprit(
- self, mod: torch.fx.GraphModule, inputs: Tensors
- ) -> str:
- """
- When an error occurs during lowering or running the lowered mod, we use this
- function to find culprits in the `mod` that causes the error.
- """
- return "Unable to find a culprit because _find_culprit() function is not implemented."
- def _draw_graph_based_on_node_support(
- self, mod: torch.fx.GraphModule, supported_nodes: NodeList
- ):
- color_map = {
- "default": "AliceBlue",
- "supported": "chartreuse1",
- "unsupported": "crimson",
- }
- class CustomDrawer(FxGraphDrawer):
- def _get_node_style(self, node):
- template = super()._get_node_style(node)
- if node in supported_nodes:
- template["fillcolor"] = color_map["supported"]
- elif node.op in CALLABLE_NODE_OPS:
- template["fillcolor"] = color_map["unsupported"]
- else:
- template["fillcolor"] = color_map["default"]
- return template
- drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
- dot_graph = drawer.get_main_dot_graph()
- dot_graph.write_raw("node_support.dot")
- def node_support_preview(self, dump_graph: bool = False):
- submodules = dict(self.module.named_modules())
- supported_nodes: NodeList = []
- supported_node_types = defaultdict(set)
- unsupported_node_types = defaultdict(set)
- def get_dtype(arg):
- tensor_meta = arg.meta.get("tensor_meta")
- return getattr(tensor_meta, "dtype", None)
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- target = get_node_target(submodules, node)
- # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
- arg_dtypes = [
- get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
- for arg in node.args
- ]
- # Find last non-None element. If all elements are None, return max_len.
- last_index = len(arg_dtypes) - next(
- (
- i
- for i, dtype in enumerate(reversed(arg_dtypes))
- if dtype is not None
- ),
- len(arg_dtypes),
- )
- # Strip None elements at the end.
- arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
- kwarg_dtypes_tuple = tuple(
- (k, get_dtype(arg))
- for k, arg in node.kwargs.items()
- if isinstance(arg, torch.fx.Node)
- )
- if self.operator_support.is_node_supported(submodules, node):
- supported_nodes.append(node)
- supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
- else:
- unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
- if dump_graph:
- self._draw_graph_based_on_node_support(self.module, supported_nodes)
- reports = "\nSupported node types in the model:\n"
- for t, dtypes in supported_node_types.items():
- for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
- reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
- reports += "\nUnsupported node types in the model:\n"
- for t, dtypes in unsupported_node_types.items():
- for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
- reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
- print(reports)
- # Return reports for testing purpose
- return reports
- def split_preview(self, dump_graph: bool = False):
- reports = ""
- subgraphs = self.put_nodes_into_subgraphs()
- acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
- cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
- reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
- reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
- subgraphs = self.remove_small_acc_subgraphs(subgraphs)
- acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
- cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
- reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
- reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
- for i, subgraph in enumerate(subgraphs):
- reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
- reports += f"{len(subgraph.nodes)} node(s)\n"
- self.tag(subgraphs)
- split_mod = self.split(remove_tag=True)
- split_mod.eval()
- if dump_graph:
- drawer = FxGraphDrawer(
- split_mod, "preview", ignore_getattr=True
- )
- dot_graphs = drawer.get_all_dot_graphs()
- for name, dot_graph in dot_graphs.items():
- dot_graph.write_raw(f"{name}.dot")
- max_qps: float = self.PCIe_BW
- bottleneck_module = ""
- for node in split_mod.graph.nodes:
- if node.op == "call_module" and "acc" in node.target:
- reports += f"\nProcessing acc submodule {node.target}\n"
- submod = getattr(split_mod, node.target)
- def get_submod_inputs(main_mod, submod, example_inputs):
- sub_inputs = None
- def get_inputs(self, inputs):
- nonlocal sub_inputs
- sub_inputs = inputs
- handle = submod.register_forward_pre_hook(get_inputs)
- main_mod(*example_inputs)
- handle.remove()
- return sub_inputs
- submod_inputs = get_submod_inputs(
- split_mod, submod, self.sample_input
- )
- ShapeProp(submod).propagate(*submod_inputs)
- total_input_bytes = 0
- total_output_bytes = 0
- reports += "Checking inputs...\n"
- for n in submod.graph.nodes:
- if n.op == "placeholder":
- if not is_node_output_tensor(n):
- reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
- else:
- total_input_bytes += get_size_of_node(submod, n)[0]
- if n.op == "output":
- output_node = n
- reports += "Checking outputs...\n"
- def get_bytes(node: torch.fx.Node):
- nonlocal total_output_bytes
- nonlocal reports
- if not is_node_output_tensor(node):
- reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
- else:
- total_output_bytes += get_size_of_node(submod, node)[0]
- map_arg(output_node.args, get_bytes)
- qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
- reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
- reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
- if qps < max_qps:
- max_qps = qps
- bottleneck_module = node.target
- try:
- lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
- except RuntimeError:
- reports += "Run into an error during lowering!\n"
- reports += self._find_culprit(submod, submod_inputs)
- continue
- try:
- lowered_submod(*submod_inputs)
- except RuntimeError:
- reports += "Run into an error during inference!\n"
- reports += self._find_culprit(submod, submod_inputs)
- else:
- reports += "Lowering and running succeed!\n"
- reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
- reports += f" bottleneck is submodule {bottleneck_module}."
- print(reports)
- # return the reports for testing purposes
- return reports
- # ===============================================================
- # Helpers for extend_acc_subgraph() method
- # ===============================================================
- def find_reverse_deps(
- self, tag_id: Optional[int] = None
- ) -> Dict[torch.fx.Node, NodeSet]:
- """
- Builds reversed topological node dependencies, if tag_id is specified,
- we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
- """
- result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
- for node in self.module.graph.nodes:
- if node.op not in CALLABLE_NODE_OPS:
- continue
- for user in node.users:
- if user.op not in CALLABLE_NODE_OPS:
- continue
- if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
- result[node].add(user)
- return result
- def update_reverse_deps_for_fusions(
- self, deps: Dict[torch.fx.Node, NodeSet]
- ):
- processed_node = set()
- for node, fusion in self.fusions.items():
- if node in processed_node:
- continue
- new_dep = set()
- # Create a new dependency set which include all the
- # dependencies of the nodes in the fusion group
- for n in fusion:
- new_dep.update(deps[n])
- # Exclude nodes in the fusion
- new_dep.difference_update(fusion)
- # Update dependency
- for n in fusion:
- deps[n] = new_dep
- for arg in n.all_input_nodes:
- if arg not in fusion:
- deps[arg].update(fusion)
- processed_node.add(n)
- def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
- """
- Finds parent nodes of the `tag` subgraph.
- Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
- and is not a placeholder, we consider it as the parent node of the subgraph.
- """
- parent_nodes = set()
- for node in self.module.graph.nodes:
- if node.op in CALLABLE_NODE_OPS and node.tag == tag:
- for arg in node.all_input_nodes:
- if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
- parent_nodes.add(arg)
- return parent_nodes
- def extend_acc_subgraph(self, tag: str):
- """
- Extend the acc subgraph with `tag` going the reversed topological direction.
- """
- # Dict that maps node to its users and ignore users that
- # are in the subgraph that has greater tag
- deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
- self.update_reverse_deps_for_fusions(deps)
- # Parent nodes of the subgraph
- parent_nodes = self.find_parent_nodes_of_subgraph(tag)
- visited_nodes: NodeSet = set()
- while parent_nodes:
- node = None
- # Find a acc node that depends on visited nodes only
- for n in parent_nodes:
- if deps[n] <= visited_nodes and n in self.acc_nodes:
- node = n
- break
- if node is None:
- break
- # Put the node into `tag` subgraph
- node.tag = tag # type: ignore[attr-defined]
- parent_nodes.remove(node)
- visited_nodes.add(node)
- # If node is in a fusion group, add all fusion buddies to parent nodes
- if node in self.fusions:
- for fusion_node in self.fusions[node]:
- if fusion_node not in visited_nodes:
- parent_nodes.add(fusion_node)
- # Add inputs of the node to parent nodes
- for arg in node.all_input_nodes:
- if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
- parent_nodes.add(arg)
- # ===============================================================
- # Helpers for split() method
- # ===============================================================
- def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
- """
- Finds nodes that consume module inputs or get_attr nodes.
- """
- starter_cpu_nodes: NodeSet = set()
- starter_acc_nodes: NodeSet = set()
- for node in self.module.graph.nodes:
- if node.op not in {"placeholder", "get_attr"}:
- continue
- for user in node.users:
- if user in self.acc_nodes:
- starter_acc_nodes.add(user)
- else:
- starter_cpu_nodes.add(user)
- return starter_cpu_nodes, starter_acc_nodes
- def put_nodes_into_subgraphs(self) -> List[Subgraph]:
- # We start graph traversal from leaf nodes
- current_cpu_nodes, current_acc_nodes = self.starter_nodes()
- visited_nodes: NodeSet = set()
- # Determine which subgraph to start from based on which subgraph has
- # 0-dep node
- acc_subgraph: bool = not any([len(self.deps[n]) == 0 for n in current_cpu_nodes])
- current_subgraph_nodes: NodeList = []
- # Result accumulator
- subgraphs: List[Subgraph] = []
- while current_cpu_nodes or current_acc_nodes:
- # Find the first node that should belong to the current subgraph and has all dependencies resolved
- current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
- node = next(
- (n for n in current_nodes if self.deps[n] <= visited_nodes),
- None,
- )
- # If nothing was found, then it's time to flip the mode and start a new subgraph
- if node is None:
- if not current_subgraph_nodes:
- raise FxNetSplitterInternalError("Subgraph can't be empty")
- subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
- )
- acc_subgraph = not acc_subgraph
- current_subgraph_nodes = []
- continue
- current_nodes.remove(node)
- visited_nodes.add(node)
- current_subgraph_nodes.append(node)
- # Add fusion buddies
- if node in self.fusions:
- if node in self.acc_nodes:
- current_acc_nodes.update(self.fusions[node] - visited_nodes)
- else:
- current_cpu_nodes.update(self.fusions[node] - visited_nodes)
- # Put depending nodes into the queue
- for user in node.users:
- if user.op not in CALLABLE_NODE_OPS:
- continue
- # Add downstream nodes
- if user in self.acc_nodes:
- current_acc_nodes.add(user)
- else:
- current_cpu_nodes.add(user)
- # Check if the last subgraph was not created
- if current_subgraph_nodes:
- subgraphs.append(
- Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
- )
- if not subgraphs:
- raise FxNetSplitterInternalError("Couldn't create subgraphs")
- return subgraphs
- def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
- """
- This pass finds ACC submodules with less than specified size and merges
- them with adjacent CPU submodules.
- """
- result: List[Subgraph] = []
- for subgraph in subgraphs:
- if subgraph.is_acc:
- if len(subgraph.nodes) >= self.settings.min_acc_module_size:
- result.append(subgraph)
- else:
- print(
- "Eliminating acc subgraph because it's smaller than the threshold: "
- f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
- )
- if result:
- result[-1].nodes.extend(subgraph.nodes)
- else:
- subgraph.is_acc = False
- result.append(subgraph)
- else:
- if result and not result[-1].is_acc:
- result[-1].nodes.extend(subgraph.nodes)
- else:
- result.append(subgraph)
- return result
- def tag(self, subgraphs: List[Subgraph]):
- self.tags: List[str] = []
- for subgraph in subgraphs:
- tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
- self.tags.append(tag)
- for node in subgraph.nodes:
- if hasattr(node, "tag"):
- raise FxNetSplitterInternalError(f"Node {node} was already tagged")
- node.tag = tag # type: ignore[attr-defined]
- self._node_submodule_map[node.name] = tag
- def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
- split_module = split_by_tags(self.module, self.tags)
- if remove_tag:
- for node in self.module.graph.nodes:
- if hasattr(node, "tag"):
- del node.tag
- return split_module
- def __call__(self) -> torch.fx.GraphModule:
- subgraphs = self.put_nodes_into_subgraphs()
- subgraphs = self.remove_small_acc_subgraphs(subgraphs)
- acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
- non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
- print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
- self.tag(subgraphs)
- return self.split()
- def generate_split_results(self) -> SplitResult:
- split_module = self()
- submodule_names = []
- for name, mod in split_module.named_children():
- submodule_names.append(name)
- submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
- return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
|