123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950 |
- import torch
- from torch.fx import GraphModule, map_arg
- from torch.fx.graph import Graph, Node
- from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
- from .utils import (
- get_node_first_input_and_output_type,
- getattr_from_fqn,
- NodeInputOrOutputType,
- return_first_non_observer_node,
- get_number_of_non_param_args,
- get_target_type_str,
- get_arg_indices_of_inputs_to_log,
- get_node_input_qparams,
- op_type_supports_shadowing,
- get_normalized_nth_input,
- )
- from .ns_types import (
- NSSingleResultValuesType,
- NSSubgraph,
- NSNodeTargetType,
- )
- from torch.ao.ns.fx.mappings import (
- get_node_type_to_io_type_map,
- )
- from torch.ao.quantization.observer import _is_activation_post_process
- from typing import Dict, Tuple, Callable, List, Any, Union, Optional, Set
- def _maybe_get_fqn(node: Node, gm: GraphModule) -> Optional[str]:
- fqn = None
- if hasattr(gm, '_node_name_to_scope'):
- # fqn on observers is not present, because they do not
- # exist when the fqns are created during tracing. If this is
- # an observer, get the fqn of the node being observed.
- node_to_use_for_fqn = node
- if node.op == 'call_module':
- assert isinstance(node.target, str)
- module = getattr_from_fqn(gm, node.target)
- if _is_activation_post_process(module):
- node_to_use_for_fqn = get_normalized_nth_input(node, gm, 0)
- fqn = gm._node_name_to_scope[node_to_use_for_fqn.name][0] # type: ignore[index]
- return fqn # type: ignore[return-value]
- def _insert_logger_after_node(
- node: Node,
- gm: GraphModule,
- logger_cls: Callable,
- logger_node_name_suffix: str,
- ref_node_name: str,
- model_name: str,
- ref_name: str,
- ref_node_target_type: str,
- results_type: str,
- index_within_arg: int,
- index_of_arg: int,
- fqn: Optional[str],
- ) -> Node:
- """
- Given a starting graph of
- prev_node -> node -> next_node
- This function creates a new logger_cls obj and adds it
- after node, resulting in
- prev_node -> node -> logger_obj -> next_node
- """
- # create new name
- logger_node_name = \
- get_new_attr_name_with_prefix(node.name + logger_node_name_suffix)(gm)
- target_type = get_target_type_str(node, gm)
- # create the logger object
- logger_obj = logger_cls(
- ref_node_name, node.name, model_name, ref_name, target_type,
- ref_node_target_type,
- results_type, index_within_arg, index_of_arg, fqn)
- # attach the logger object to the parent module
- setattr(gm, logger_node_name, logger_obj)
- logger_node = node.graph.create_node(
- 'call_module', logger_node_name, (node,), {})
- return logger_node
- def add_loggers_to_model(
- gm: GraphModule,
- node_to_instrument_inputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
- node_to_instrument_outputs_to_ref_node_name: Dict[Node, Tuple[str, str]],
- logger_cls: Callable,
- model_name: str,
- ) -> GraphModule:
- """
- Takes the graph of gm, adds loggers to the output
- of each node in nodes_to_instrument. Returns a GraphModule with the new
- graph.
- """
- new_graph = Graph()
- env: Dict[str, Any] = {}
- modules = dict(gm.named_modules())
- def load_arg(a):
- return map_arg(a, lambda node: env[node.name])
- for node in gm.graph.nodes:
- if node.op == 'output':
- new_graph.output(map_arg(get_normalized_nth_input(node, gm, 0), load_arg))
- continue
- if (
- (node in node_to_instrument_inputs_to_ref_node_name) or
- (node in node_to_instrument_outputs_to_ref_node_name)
- ):
- fqn = _maybe_get_fqn(node, gm)
- if node in node_to_instrument_inputs_to_ref_node_name:
- ref_name, ref_node_type = node_to_instrument_inputs_to_ref_node_name[node]
- # Ops such add and mul are special because either
- # one or two of the first two arguments can be tensors,
- # and if one argument is a tensor it can be first or
- # second (x + 1 versus 1 + x).
- arg_indices_to_log = get_arg_indices_of_inputs_to_log(node)
- for node_arg_idx in arg_indices_to_log:
- node_arg = get_normalized_nth_input(node, gm, node_arg_idx)
- if type(node_arg) == Node:
- # create a single input logger
- prev_node = env[node_arg.name]
- env[node_arg.name] = _insert_logger_after_node(
- prev_node, gm, logger_cls, '_ns_logger_', node.name,
- model_name, ref_name, ref_node_type,
- NSSingleResultValuesType.NODE_INPUT.value,
- index_within_arg=0, index_of_arg=node_arg_idx,
- fqn=fqn)
- elif type(node_arg) == torch.fx.immutable_collections.immutable_list:
- # create N input loggers, one for each node
- for arg_idx, arg in enumerate(node_arg): # type: ignore[var-annotated, arg-type]
- prev_node = env[arg.name]
- env[prev_node.name] = _insert_logger_after_node(
- prev_node, gm, logger_cls, '_ns_logger_', node.name,
- model_name, ref_name, ref_node_type,
- NSSingleResultValuesType.NODE_INPUT.value,
- index_within_arg=arg_idx, index_of_arg=node_arg_idx,
- fqn=fqn)
- else:
- pass
- # ensure env is populated with base node
- # Note: runs for both inputs and outputs
- env[node.name] = new_graph.node_copy(node, load_arg)
- if node in node_to_instrument_outputs_to_ref_node_name:
- ref_name, ref_node_type = node_to_instrument_outputs_to_ref_node_name[node]
- # add the logger after the base node
- env[node.name] = _insert_logger_after_node(
- env[node.name], gm, logger_cls, '_ns_logger_', node.name,
- model_name, ref_name, ref_node_type,
- NSSingleResultValuesType.NODE_OUTPUT.value,
- index_within_arg=0, index_of_arg=0, fqn=fqn)
- else:
- env[node.name] = new_graph.node_copy(node, load_arg)
- new_gm = GraphModule(gm, new_graph)
- return new_gm
- def _insert_quantize_per_tensor_node(
- prev_node_c: Node,
- node_a: Node,
- gm_b: GraphModule,
- graph_c: Graph,
- scale: Union[torch.Tensor, float],
- zero_point: Union[torch.Tensor, int],
- dtype_cast_name: str,
- ) -> Node:
- # copy scale
- scale_node_name = \
- get_new_attr_name_with_prefix(
- node_a.name + '_input_scale_')(gm_b)
- setattr(gm_b, scale_node_name, scale)
- scale_node = graph_c.create_node(
- 'get_attr', scale_node_name, (), {}, scale_node_name)
- # copy zero_point
- zero_point_node_name = \
- get_new_attr_name_with_prefix(
- node_a.name + '_input_zero_point_')(gm_b)
- setattr(gm_b, zero_point_node_name, zero_point)
- zero_point_node = graph_c.create_node(
- 'get_attr', zero_point_node_name, (), {}, zero_point_node_name)
- # create the quantize_per_tensor call
- return graph_c.create_node(
- 'call_function', torch.quantize_per_tensor,
- (prev_node_c, scale_node, zero_point_node, torch.quint8), {},
- dtype_cast_name)
- def _insert_dtype_cast_after_node(
- node_a: Node,
- node_c: Node,
- prev_node_c: Union[Node, List[Node]],
- gm_a: GraphModule,
- gm_b: GraphModule,
- graph_c: Graph,
- node_name_prefix: str,
- logger_cls: Callable,
- node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
- ) -> Union[Node, List[Node]]:
- """
- Given a starting graph C (derived from graph B) of
- ... -> prev_node_c -> node_c -> ...
- And a corresponding related node_a, inserts the correct dtype
- cast node after prev_node_c to cast into the dtype expected
- by node_a, resulting in:
- dtype_cast
- /
- ... -> prev_node_c -> node_c -> ...
- For example, if node_c is an int8 op and node_a is an fp32 op, this function
- will insert a dequant.
- """
- dtype_cast_op = None
- dtype_cast_mod_cls = None
- dtype_cast_method = None
- dtype_cast_method_dtype = None
- dtype_cast_scale = None
- dtype_cast_zero_point = None
- node_input_type_a, _node_output_type_a = \
- get_node_first_input_and_output_type(
- node_a, gm_a, logger_cls, node_type_to_io_type_map)
- node_input_type_c, _node_output_type_c = \
- get_node_first_input_and_output_type(
- node_c, gm_b, logger_cls, node_type_to_io_type_map)
- if (
- (node_input_type_a == NodeInputOrOutputType.FP32 and
- node_input_type_c == NodeInputOrOutputType.INT8) or
- (node_input_type_a == NodeInputOrOutputType.FP32 and
- node_input_type_c == NodeInputOrOutputType.FP16) or
- # TODO(future PR): determine the actual dtype of node_c,
- # the current code only works because dequantize works with
- # multiple input dtypes.
- (node_input_type_a == NodeInputOrOutputType.FP32 and
- node_input_type_c == NodeInputOrOutputType.FP32_OR_INT8)
- ):
- dtype_cast_op = torch.dequantize
- elif (
- node_input_type_a == node_input_type_c and
- node_input_type_a != NodeInputOrOutputType.UNKNOWN
- ):
- dtype_cast_mod_cls = torch.nn.Identity
- elif (
- node_input_type_a == NodeInputOrOutputType.INT8 and
- node_input_type_c == NodeInputOrOutputType.FP32
- ):
- # int8 shadows fp32, the dtype cast needs to quantize to int8
- # with the right qparams.
- node_a_input_qparams = get_node_input_qparams(
- node_a, gm_a, node_type_to_io_type_map)
- if node_a_input_qparams is not None:
- dtype_cast_op = torch.quantize_per_tensor # type: ignore[assignment]
- dtype_cast_scale, dtype_cast_zero_point = node_a_input_qparams
- elif (
- node_input_type_a == NodeInputOrOutputType.FP16 and
- node_input_type_c == NodeInputOrOutputType.FP32
- ):
- dtype_cast_method = 'to'
- dtype_cast_method_dtype = torch.float16
- else:
- raise AssertionError(
- f"dtype cast from {node_input_type_c} {node_c.format_node()} to " +
- f"{node_input_type_a} {node_a.format_node()} needs to be implemented")
- if isinstance(prev_node_c, Node):
- new_dtype_cast_name = \
- get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
- if dtype_cast_op:
- if dtype_cast_scale is not None and dtype_cast_zero_point is not None:
- return _insert_quantize_per_tensor_node(
- prev_node_c, node_a, gm_b, graph_c, dtype_cast_scale,
- dtype_cast_zero_point, new_dtype_cast_name)
- else:
- return graph_c.create_node(
- 'call_function', dtype_cast_op, (prev_node_c,), {},
- new_dtype_cast_name)
- elif dtype_cast_method:
- return graph_c.create_node(
- 'call_method', dtype_cast_method,
- (prev_node_c, dtype_cast_method_dtype), {}, new_dtype_cast_name)
- else:
- assert dtype_cast_mod_cls
- dtype_cast_mod = dtype_cast_mod_cls()
- setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
- return graph_c.create_node(
- 'call_module', new_dtype_cast_name, (prev_node_c,), {},
- new_dtype_cast_name)
- elif isinstance(prev_node_c, list):
- results = []
- for prev_node_c_inner in prev_node_c:
- new_dtype_cast_name = \
- get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
- if dtype_cast_op:
- # TODO(future PR): add handling for quantize_per_tensor
- new_dtype_cast_node = graph_c.create_node(
- 'call_function', dtype_cast_op, (prev_node_c_inner,), {},
- new_dtype_cast_name)
- results.append(new_dtype_cast_node)
- else:
- assert dtype_cast_mod_cls
- dtype_cast_mod = dtype_cast_mod_cls()
- setattr(gm_b, new_dtype_cast_name, dtype_cast_mod)
- new_dtype_cast_node = graph_c.create_node(
- 'call_module', new_dtype_cast_name, (prev_node_c_inner,), {},
- new_dtype_cast_name)
- results.append(new_dtype_cast_node)
- return results
- else:
- raise AssertionError(f"type f{type(prev_node_c)} is not handled")
- # TODO(future PR): look into using copy_node API instead
- def _copy_node_from_a_to_c(
- node_a: Node,
- gm_a: GraphModule,
- gm_b: GraphModule,
- graph_c: Graph,
- ) -> Node:
- """
- Simple copy of node_a to graph_c.
- """
- if node_a.op == 'get_attr':
- node_a_copy_name = \
- get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
- node_a_obj = getattr_from_fqn(gm_a, node_a.target) # type: ignore[arg-type]
- if torch.is_tensor(node_a_obj):
- node_a_obj = node_a_obj.detach()
- setattr(gm_b, node_a_copy_name, node_a_obj)
- node_a_copy = graph_c.create_node(
- node_a.op, node_a_copy_name, (), {}, node_a_copy_name)
- return node_a_copy
- elif node_a.op == 'call_method':
- assert node_a.target in ('dequantize', 'to'), \
- f"target {node_a.target} is not implemented"
- if node_a.target == 'dequantize':
- arg_copy = _copy_node_from_a_to_c(
- get_normalized_nth_input(node_a, gm_a, 0),
- gm_a, gm_b, graph_c) # type: ignore[arg-type]
- node_a_copy_name = \
- get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
- node_a_copy = graph_c.create_node(
- node_a.op, node_a.target, (arg_copy,), {}, node_a_copy_name)
- return node_a_copy
- else: # to
- arg_copy = _copy_node_from_a_to_c(
- get_normalized_nth_input(node_a, gm_a, 0), gm_a, gm_b, graph_c) # type: ignore[arg-type]
- node_a_copy_name = \
- get_new_attr_name_with_prefix(node_a.name + '_shadow_copy_')(gm_b)
- node_a_copy = graph_c.create_node(
- node_a.op, node_a.target,
- (arg_copy, get_normalized_nth_input(node_a, gm_a, 1)),
- {}, node_a_copy_name)
- return node_a_copy
- else:
- raise AssertionError(
- f"handling of node {node_a.format_node()} with op {node_a.op} is not implemented")
- def _can_insert_copy_of_subgraph_a(
- subgraph_a: NSSubgraph,
- gm_a: GraphModule,
- num_non_param_args_node_a: int,
- ) -> bool:
- """
- This function returns `False` if the input subgraph cannot be copied by
- `_insert_copy_of_subgraph_a_after_input_node_c`. This usually means
- that there is a corner case logic for which copy is not yet implemented.
- """
- # populate the list of nodes we need to check
- nodes = []
- cur_node = subgraph_a.end_node
- while cur_node != subgraph_a.start_node:
- nodes.append(cur_node)
- cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
- nodes.append(cur_node)
- nodes.reverse()
- def _can_insert(node_a_arg, gm_a):
- if isinstance(node_a_arg, Node):
- arg_a = return_first_non_observer_node(node_a_arg, gm_a)
- if arg_a.op == 'call_method':
- return arg_a.target in ('dequantize', 'to')
- elif arg_a.op == 'get_attr':
- return True
- else:
- return False
- elif isinstance(node_a_arg, (list, tuple)):
- for el in node_a_arg:
- if not isinstance(el, Node):
- return False
- return True
- # For each node, check if we handle the copy behavior. This follows the
- # logic in `_insert_copy_of_subgraph_a_after_input_node_c`.
- for node_a in nodes:
- local_num_non_param_args_node_a = num_non_param_args_node_a \
- if node_a is nodes[0] else 1
- norm_args_kwargs = node_a.normalized_arguments(
- gm_a, normalize_to_only_use_kwargs=True)
- if norm_args_kwargs is not None:
- norm_args, norm_kwargs = norm_args_kwargs
- else:
- norm_args, norm_kwargs = node_a.args, node_a.kwargs
- cur_idx = 0
- while cur_idx < len(norm_args):
- if cur_idx == 0:
- pass
- elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
- pass
- else:
- if not _can_insert(norm_args[cur_idx], gm_a):
- return False
- cur_idx += 1
- for kwarg_name, kwarg_val in norm_kwargs.items():
- # stitch the inputs from base graph
- if cur_idx == 0:
- pass
- elif cur_idx == 1 and local_num_non_param_args_node_a == 2:
- pass
- else:
- if not _can_insert(kwarg_val, gm_a):
- return False
- cur_idx += 1
- return True
- def _insert_copy_of_subgraph_a_after_input_node_c(
- input_node_c: Union[Node, List[Node]],
- input_node_c_2: Optional[Union[Node, List[Node]]],
- subgraph_a: NSSubgraph,
- gm_a: GraphModule,
- gm_b: GraphModule,
- node_name_prefix: str,
- ) -> Node:
- """
- TODO(before land): real docblock
- """
- if isinstance(input_node_c, Node):
- graph_c = input_node_c.graph
- else:
- assert isinstance(input_node_c, list)
- graph_c = input_node_c[0].graph
- # create a sequential list of the subgraphs' nodes from start to end,
- # because we need to add the nodes to graph C in non-reverse order
- nodes_of_a = [subgraph_a.end_node]
- cur_node = subgraph_a.end_node
- while cur_node != subgraph_a.start_node:
- cur_node = get_normalized_nth_input(cur_node, gm_a, 0) # type: ignore[assignment]
- nodes_of_a.insert(0, cur_node)
- # go through nodes of a in order, and insert them into the graph of c
- # sequentially
- cur_node_a = nodes_of_a[0]
- cur_node_c = _insert_copy_of_node_a_after_input_node_c(
- input_node_c,
- input_node_c_2,
- cur_node_a,
- gm_a,
- gm_b,
- node_name_prefix)
- for cur_idx_a in range(1, len(nodes_of_a)):
- cur_node_a = nodes_of_a[cur_idx_a]
- prev_node_c = cur_node_c # previous added node is the input to next node
- cur_node_c = _insert_copy_of_node_a_after_input_node_c(
- prev_node_c,
- # TODO(future PR): enable multiple inputs for nodes which are not at start of subgraph
- None,
- cur_node_a,
- gm_a,
- gm_b,
- node_name_prefix)
- # return the last inserted node
- return cur_node_c
- def _insert_copy_of_node_a_after_input_node_c(
- input_node_c: Union[Node, List[Node]],
- input_node_c_2: Optional[Union[Node, List[Node]]],
- node_a: Node,
- gm_a: GraphModule,
- gm_b: GraphModule,
- node_name_prefix: str,
- ) -> Node:
- """
- Assume that node_a from graph_a has
- args (input, (input2)?, arg1, ...), and
- kwargs {kw0: kwarg0, ...}
- Note: input2 is optional. If it equals to None, we assume that the op
- has a single non-param input. If it is specified, we assume that the op
- has two non-param inputs.
- Copies the underlying values of arg1..argn and kwarg0..kwargn into gm_b,
- and creates the corresponding nodes in graph_c. Note: observers are ignored,
- so if an arg is an observer we navigate up until we find a non-observer parent.
- If node_a is a call_module, points the module pointed to by node_a to gm_b.
- Creates the copy of node_a in graph_c, with input as the first arg,
- and all other args and kwargs pointing to the copies of the objects
- in gm_b created above.
- An example in pictures:
- graph A:
- ========
- input -------------> node_a
- / / /
- (input_2)?----------/ / /
- / /
- weight -> weight_obs /
- /
- bias ----------------
- graph C (derived from B):
- =========================
- input_node_c --> node_a_copy
- / / /
- (input_node_c_2)? / /
- / /
- weight_copy ----/ /
- /
- bias_copy ------/
- """
- if isinstance(input_node_c, Node):
- graph_c = input_node_c.graph
- else:
- assert isinstance(input_node_c, list)
- graph_c = input_node_c[0].graph
- norm_args_kwargs = node_a.normalized_arguments(
- gm_a, normalize_to_only_use_kwargs=True)
- if norm_args_kwargs is not None:
- norm_args, norm_kwargs = norm_args_kwargs
- else:
- norm_args, norm_kwargs = node_a.args, node_a.kwargs
- new_args = []
- new_kwargs = {}
- def _copy_arg(arg):
- # copy the other inputs from the other graph
- if isinstance(arg, Node):
- arg = return_first_non_observer_node(arg, gm_a)
- arg = _copy_node_from_a_to_c(arg, gm_a, gm_b, graph_c)
- return arg
- elif isinstance(arg, (int, float, torch.dtype)):
- return arg
- elif isinstance(kwarg_val, (list, tuple)):
- for el in kwarg_val:
- assert not isinstance(el, Node), \
- "handling of Node inside list is not implemented"
- return arg
- else:
- raise AssertionError(
- f"handling for kwarg of type {type(kwarg_val)} is not implemented")
- cur_idx = 0
- while cur_idx < len(norm_args):
- if cur_idx == 0:
- new_arg = input_node_c
- elif cur_idx == 1 and input_node_c_2 is not None:
- new_arg = input_node_c_2
- else:
- new_arg = _copy_arg(norm_args[cur_idx])
- new_args.append(new_arg)
- cur_idx += 1
- for kwarg_name, kwarg_val in norm_kwargs.items():
- # stitch the inputs from base graph
- if cur_idx == 0:
- new_kwargs[kwarg_name] = input_node_c
- elif cur_idx == 1 and input_node_c_2 is not None:
- new_kwargs[kwarg_name] = input_node_c_2
- else:
- new_kwargs[kwarg_name] = _copy_arg(kwarg_val)
- cur_idx += 1
- new_args = tuple(new_args) # type: ignore[assignment]
- node_a_shadows_c_name = \
- get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
- if node_a.op == 'call_module':
- # if target is a module, we point to the module from gm_b
- new_mod_copy_name = \
- get_new_attr_name_with_prefix(node_name_prefix)(gm_b)
- # fetch the corresponding module from gm_a
- assert isinstance(node_a.target, str)
- mod_a = getattr_from_fqn(gm_a, node_a.target)
- setattr(gm_b, new_mod_copy_name, mod_a)
- node_a_shadows_c = graph_c.create_node(
- node_a.op, new_mod_copy_name, new_args,
- new_kwargs, node_a_shadows_c_name)
- return node_a_shadows_c
- else:
- assert node_a.op in ('call_function', 'call_method')
- node_a_shadows_c = graph_c.create_node(
- node_a.op, node_a.target, new_args,
- new_kwargs, node_a_shadows_c_name)
- return node_a_shadows_c
- def create_a_shadows_b(
- name_a: str,
- gm_a: GraphModule,
- name_b: str,
- gm_b: GraphModule,
- matched_subgraph_pairs: Dict[str, Tuple[NSSubgraph, NSSubgraph]],
- logger_cls: Callable,
- should_log_inputs: bool,
- node_type_to_io_type_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- ) -> GraphModule:
- """
- Creates a new GraphModule consisting of the graph of C, with the meaningful
- nodes of A shadowing the corresponding nodes of B. For example,
- Graph A:
- a0 -> op0_fp32 -> a1 -> op1_fp32 -> a2
- Graph B:
- b0 -> op0_int8 -> b1 -> op1_int8 -> b2
- matched_node_pairs: {'op0': (op0_fp32, op0_int8), 'op1': (op1_fp32, op1_int8)}
- Graph C (A shadows B):
- / dequant0 -> op0_fp32 -> logger_a_0 / dequant_1 -> op1_fp32 -> logger_a_1
- / /
- b0 -------------> op0_int8 -> logger_b_0 --------------> op1_int8 -> logger_b_1
- In a nutshell, this function does the following for each node pair:
- * copies the necessary attributes and modules from gm_a to gm_b,
- keeping names unique
- * adds a dtype cast op (dequant, quant, etc)
- * adds a copy of node_a in gm_b's graph
- * adds loggers to the outputs of node_a and node_b
- """
- if node_type_to_io_type_map is None:
- node_type_to_io_type_map = get_node_type_to_io_type_map()
- # graph_c is the graph created from copying the nodes of graph_b and inserting
- # the shadows with the nodes copied from graph_a
- graph_c = Graph()
- env_c: Dict[str, Any] = {}
- modules = dict(gm_b.named_modules())
- def load_arg(a):
- return map_arg(a, lambda node: env_c[node.name])
- start_node_b_to_matched_subgraph_a_and_name = {}
- end_node_b_to_matched_subgraph_a_and_name = {}
- for match_name, match in matched_subgraph_pairs.items():
- subgraph_a, subgraph_b = match
- ref_node_type_a = get_target_type_str(subgraph_a.base_op_node, gm_a)
- ref_node_type_b = get_target_type_str(subgraph_b.base_op_node, gm_b)
- start_node_b_to_matched_subgraph_a_and_name[subgraph_b.start_node] = \
- (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
- end_node_b_to_matched_subgraph_a_and_name[subgraph_b.end_node] = \
- (subgraph_a, match_name, ref_node_type_a, ref_node_type_b)
- for node_b in gm_b.graph.nodes:
- if node_b.op == 'output':
- graph_c.output(map_arg(node_b.args[0], load_arg))
- continue
- # calculate the flags to determine what to do with this node
- node_b_is_start_node = node_b in start_node_b_to_matched_subgraph_a_and_name
- node_b_is_end_node = node_b in end_node_b_to_matched_subgraph_a_and_name
- if (node_b_is_start_node or node_b_is_end_node):
- if node_b_is_start_node:
- subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
- start_node_b_to_matched_subgraph_a_and_name[node_b]
- else:
- assert node_b_is_end_node
- subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
- end_node_b_to_matched_subgraph_a_and_name[node_b]
- all_op_types_support_shadowing = (
- op_type_supports_shadowing(subgraph_a.start_node) and
- op_type_supports_shadowing(node_b)
- )
- if not all_op_types_support_shadowing:
- print(
- f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
- f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
- ', unsupported')
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- continue
- # For both start_node and end_node verify that we know how to do
- # the dtype cast. If we do not, skip.
- node_input_type_a, node_output_type_a = \
- get_node_first_input_and_output_type(
- subgraph_a.start_node, gm_a, logger_cls,
- node_type_to_io_type_map)
- node_input_type_b, node_output_type_b = \
- get_node_first_input_and_output_type(
- node_b, gm_b, logger_cls,
- node_type_to_io_type_map)
- node_io_types_known_a_and_b = (
- node_input_type_a != NodeInputOrOutputType.UNKNOWN and
- node_output_type_a != NodeInputOrOutputType.UNKNOWN and
- node_input_type_b != NodeInputOrOutputType.UNKNOWN and
- node_output_type_b != NodeInputOrOutputType.UNKNOWN
- )
- if not node_io_types_known_a_and_b:
- print(
- f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
- f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
- ', unknown dtype cast')
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- continue
- # If we are shadowing from fp32 to int8, we need to insert
- # quantize_per_tensor call with qparams from the previous node.
- # Only do this if we are able to infer these qparams from the graph.
- if (
- node_input_type_a == NodeInputOrOutputType.INT8 and
- node_input_type_b == NodeInputOrOutputType.FP32
- ):
- node_a_input_qparams = get_node_input_qparams(
- subgraph_a.start_node, gm_a, node_type_to_io_type_map)
- if not node_a_input_qparams:
- print(
- f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
- f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
- ', unknown input qparams')
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- continue
- num_non_param_args_node_a = \
- get_number_of_non_param_args(subgraph_a.start_node, gm_a)
- if not _can_insert_copy_of_subgraph_a(subgraph_a, gm_a, num_non_param_args_node_a):
- print(
- f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
- f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
- ', unhandled logic in subgraph copy')
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- continue
- fqn_base_a = _maybe_get_fqn(subgraph_a.base_op_node, gm_a)
- fqn_base_b = _maybe_get_fqn(subgraph_b.base_op_node, gm_b)
- if node_b_is_start_node:
- # if necessary, log the input of node_c
- if should_log_inputs:
- prev_node_b = get_normalized_nth_input(node_b, gm_b, 0)
- if isinstance(prev_node_b, Node):
- prev_node_c = env_c[prev_node_b.name]
- env_c[prev_node_c.name] = _insert_logger_after_node(
- prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
- node_b.name, name_b, ref_name, ref_node_type_b,
- NSSingleResultValuesType.NODE_INPUT.value,
- index_within_arg=0, index_of_arg=0,
- fqn=fqn_base_b)
- elif isinstance(prev_node_b, list):
- # first, save the prev_node instances, because they
- # will be overwritten in the env after the first logger
- # is added
- prev_node_c_list = [env_c[arg.name] for arg in prev_node_b]
- for arg_idx, arg in enumerate(prev_node_b):
- prev_node_c = prev_node_c_list[arg_idx]
- env_c[prev_node_c.name] = _insert_logger_after_node(
- prev_node_c, gm_b, logger_cls, '_ns_logger_b_inp_',
- node_b.name, name_b, ref_name, ref_node_type_b,
- NSSingleResultValuesType.NODE_INPUT.value,
- index_within_arg=arg_idx, index_of_arg=0,
- fqn=fqn_base_b)
- else:
- # logging of inputs which are not lists is not supported yet
- raise AssertionError(f"type {type(prev_node_b)} is not handled yet")
- # subgraph so far:
- #
- # (prev_node_c)+ -> (logger_c_input)?
- # Note: this if statement is always True, spelling it out to clarify code
- # intent.
- if node_b_is_start_node or node_b_is_end_node:
- # ensure env_c is populated with base node
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- node_c = env_c[node_b.name]
- # after this point,
- #
- # node_a is the original node from graph_a, with parent module gm_a
- # node_b is the original node from graph_b, with parent module gm_b
- # node_c is the copy of node_b in graph_c
- #
- # subgraph so far:
- #
- # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
- if node_b_is_start_node:
- # cast dtype from the dtype of node_c's input to the dtype of
- # node_a's input (dequant, etc)
- # prev_node_c = node_c.args[0]
- prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
- if should_log_inputs:
- # skip the input logger when inserting a dtype cast
- if isinstance(prev_node_c, Node):
- prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
- elif isinstance(prev_node_c, list):
- prev_node_c = [get_normalized_nth_input(arg, gm_b, 0) for arg in prev_node_c]
- dtype_cast_node = _insert_dtype_cast_after_node(
- subgraph_a.start_node, node_c, prev_node_c, gm_a, gm_b, graph_c,
- node_b.name + '_dtype_cast_', logger_cls,
- node_type_to_io_type_map)
- # note: not inserting to env_c because all nodes which use the dtype
- # casts are copied from graph_a
- #
- # subgraph so far:
- #
- # (dtype_cast_node)+
- # /
- # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
- # if input logging is enabled, log the input to the subgraph
- if should_log_inputs:
- # TODO: explain this
- ref_node_name = ''
- if isinstance(dtype_cast_node, Node):
- dtype_cast_node = _insert_logger_after_node(
- dtype_cast_node, gm_b, logger_cls, '_ns_logger_a_inp_',
- ref_node_name, name_a, ref_name, ref_node_type_a,
- NSSingleResultValuesType.NODE_INPUT.value,
- index_within_arg=0, index_of_arg=0,
- fqn=fqn_base_a)
- input_logger: Union[Node, List[Node]] = dtype_cast_node
- else:
- assert isinstance(dtype_cast_node, list)
- new_loggers = []
- for dtype_cast_idx, dtype_cast_node_inner in enumerate(dtype_cast_node):
- dtype_cast_logger = _insert_logger_after_node(
- dtype_cast_node_inner, gm_b, logger_cls, '_ns_logger_a_inp_',
- ref_node_name, name_a, ref_name, ref_node_type_a,
- NSSingleResultValuesType.NODE_INPUT.value,
- index_within_arg=dtype_cast_idx,
- index_of_arg=0,
- fqn=fqn_base_a)
- new_loggers.append(dtype_cast_logger)
- dtype_cast_node = new_loggers
- input_logger = dtype_cast_node
- # subgraph so far:
- #
- # (dtype_cast_node)+ -> (logger_a_input)?
- # /
- # prev_node_c -> (logger_c_input)? -> node_start_c
- # hook up the new mod_a copy to be in the graph, receiving the
- # same inputs as mod_b does, with dtype cast to match a
- # Some ops, such as LSTMs, have two non-param inputs. If we have
- # such an op, pass the second param as well. Note: dtype casting
- # for the second param is not implemented yet, it can be added
- # later if there is a use case.
- node_c_second_non_param_arg = None
- num_non_param_args_node_a = get_number_of_non_param_args(subgraph_a.start_node, gm_a)
- if num_non_param_args_node_a == 2:
- # node_c_second_non_param_arg = node_c.args[1]
- node_c_second_non_param_arg = get_normalized_nth_input(node_c, gm_b, 1)
- node_a_shadows_c = _insert_copy_of_subgraph_a_after_input_node_c(
- dtype_cast_node, node_c_second_non_param_arg,
- subgraph_a, gm_a, gm_b, node_c.name + '_shadow_copy_')
- env_c[node_a_shadows_c.name] = node_a_shadows_c
- # subgraph so far:
- #
- # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy(args/kwargs not shown)
- # /
- # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
- if should_log_inputs:
- # When we created the input logger, we left the ref_node_name
- # as an empty string, because the subgraph copy did not exist
- # yet. Now that the subgraph copy exists, we modify this name
- # to its true value.
- # Note: the alternative to this is to create the input logger
- # after creating the subgraph, which is slightly more
- # complicated. This is the lesser of two evils.
- # input_logger = env_c[dtype_cast_node.name]
- # Find the first node in the subgraph
- cur_node = node_a_shadows_c
- while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger:
- cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
- if isinstance(input_logger, Node):
- input_logger_mod = getattr(gm_b, input_logger.name)
- input_logger_mod.ref_node_name = cur_node.name
- else:
- assert isinstance(input_logger, list)
- for input_logger_inner in input_logger:
- input_logger_mod = getattr(gm_b, input_logger_inner.name)
- input_logger_mod.ref_node_name = cur_node.name
- # hook up a logger to the mod_a copy
- env_c[node_a_shadows_c.name] = _insert_logger_after_node(
- env_c[node_a_shadows_c.name], gm_b, logger_cls, '_ns_logger_a_',
- node_a_shadows_c.name, name_a, ref_name, ref_node_type_a,
- NSSingleResultValuesType.NODE_OUTPUT.value,
- index_within_arg=0, index_of_arg=0,
- fqn=fqn_base_a)
- # subgraph so far:
- #
- # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
- # /
- # (prev_node_c)+ -> (logger_c_input)? -> node_start_c
- if node_b_is_end_node:
- # hook up a logger to the mod_b copy
- env_c[node_b.name] = _insert_logger_after_node(
- env_c[node_b.name], gm_b, logger_cls, '_ns_logger_b_',
- node_b.name, name_b, ref_name, ref_node_type_b,
- NSSingleResultValuesType.NODE_OUTPUT.value,
- index_within_arg=0, index_of_arg=0,
- fqn=fqn_base_b)
- # subgraph so far:
- #
- # dtype_cast_node -> (logger_a_input)? -> subgraph_a_copy -> logger_a
- # /
- # (prev_node_c+) -> (logger_c_input)? -> node_start_c -> ... -> node_end_c -> logger_c
- #
- # Note: node_start_c may be the same node as node_end_c, or they
- # may have nodes inbetween.
- else:
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- gm_c = GraphModule(gm_b, graph_c)
- return gm_c
|