import torch import torch.fx from torch.fx import ( Node, GraphModule, Graph, ) from torch.ao.ns.fx.utils import ( # TODO(future PR): make this work correctly for methods get_target_type_str, get_normalized_nth_input, ) from torch.ao.ns.fx.ns_types import ( NSSingleResultValuesType, NSResultsType, ) from torch.ao.ns.fx.graph_passes import _maybe_get_fqn from torch.ao.quantization import QConfigMapping from torch.ao.quantization.qconfig import QConfigAny from torch.ao.quantization.utils import getattr_from_fqn from torch.ao.quantization.fx.match_utils import _MatchResult from torch.utils._pytree import tree_map import collections import copy from typing import List, Dict, Set, Tuple, Callable, Any, Optional import operator SHADOW_NODE_NAME_PREFIX = 'shadow' SHADOW_WRAPPER_NODE_NAME_PREFIX = 'shadow_wrapper' # TODO(future PR): reuse existing mapping instead of creating a new one BINARY_FUNCTIONS = { torch.add, torch.Tensor.add, operator.add, torch.mul, torch.Tensor.mul, operator.mul, } def _get_attr_name(subgraph_idx, subgraph_candidate_idx): return f"{SHADOW_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" def _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx): return f"{SHADOW_WRAPPER_NODE_NAME_PREFIX}_{subgraph_idx}_{subgraph_candidate_idx}" class OutputProp: """ Output propagation (modeled from shape propagation). Given a GraphModule and an example input, saves the output flowing through each node on `node.traced_result`. Code based on the example from https://pytorch.org/docs/stable/fx.html#the-interpreter-pattern """ def __init__(self, mod): self.mod = mod self.graph = mod.graph self.modules = dict(self.mod.named_modules()) def propagate(self, *args): args_iter = iter(args) env : Dict[str, Node] = {} def load_arg(a): return torch.fx.graph.map_arg(a, lambda n: env[n.name]) def fetch_attr(target : str): target_atoms = target.split('.') attr_itr = self.mod for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") attr_itr = getattr(attr_itr, atom) return attr_itr for node in self.graph.nodes: if node.op == 'placeholder': result = next(args_iter) elif node.op == 'get_attr': result = fetch_attr(node.target) elif node.op == 'call_function': result = node.target(*load_arg(node.args), **load_arg(node.kwargs)) elif node.op == 'call_method': self_obj, *args = load_arg(node.args) kwargs = load_arg(node.kwargs) result = getattr(self_obj, node.target)(*args, **kwargs) elif node.op == 'call_module': result = self.modules[node.target](*load_arg(node.args), **load_arg(node.kwargs)) if isinstance(result, torch.Tensor): node.traced_result = result env[node.name] = result return None def _get_dedup_subgraphs( matches: Dict[str, _MatchResult] ) -> Dict[str, List[Node]]: # the original matches variable is unique by node, make it unique by subgraph # instead seen_nodes = set() subgraphs_dedup = {} # Dict items are not reversible until Python 3.8, so we hack it # to be compatible with previous Python versions # TODO(future PR): try reversed(list(matches.items())) matches_items_reversed: List[Tuple[str, _MatchResult]] = [] for name, cur_match in matches.items(): matches_items_reversed.insert(0, (name, cur_match)) # Note: the order is important. `matches` currently provides the matches # in reverse order. We would like to process the matches in non-reverse # order, so that we can create an intuitive naming scheme, such as # naming the first op's submodules `shadow_0_0` through `shadow_0_(n-1)` for name, cur_match in matches_items_reversed: # type: ignore[call-overload] was_seen = False for node_or_tuple in cur_match[1]: # Cur_match[1] has an unusual type. It says that it's a `List[Node]`, # but it is really not. Furthermore, the contents of this field # can change from match results of multiple nodes of the same pattern # # For example, for conv -> bn -> relu, we see # match_results = { # 'conv': (relu, [(bn, conv), relu], ...), # 'bn': (relu, [(bn, conv), relu], ...), # 'relu': (relu, [(bn, conv), relu], ...), # } # # Ideally we should clean up the `find_matches` function to make # this more intuitive. For the purposes of this prototype, we hack # around it. if isinstance(node_or_tuple, Node): if node_or_tuple in seen_nodes: was_seen = True seen_nodes.add(node_or_tuple) else: assert isinstance(node_or_tuple, tuple) for node in node_or_tuple: assert isinstance(node, Node) if node in seen_nodes: was_seen = True seen_nodes.add(node) if was_seen: continue # Start with the unusual type, convert it to [op_0, ..., op_n] list_of_nodes = [] if len(cur_match[1]) == 1: list_of_nodes = cur_match[1] else: assert len(cur_match[1]) == 2 # either (a, b), or ((a, b), c) or (c, (a, b)) # cannot make any assumptions on order, not clear what the # _find_matches function is doing to populate this # TODO(future PR): make this code less confusing, see discussion # in https://github.com/pytorch/pytorch/pull/80521/files#r975918836 def _order_nodes(node_a, node_b, node_c) -> List[Node]: nodes = [node_a, node_b, node_c] first_node = None mid_node = None last_node = None for n in nodes: prev_n = n.args[0] next_n = list(n.users)[0] if prev_n not in nodes: first_node = n elif next_n not in nodes: last_node = n else: mid_node = n assert first_node is not None and mid_node is not None and \ last_node is not None assert mid_node.args[0] is first_node assert last_node.args[0] is mid_node return [last_node, mid_node, first_node] if isinstance(cur_match[1][0], Node) and isinstance(cur_match[1][1], Node): # (a, b) list_of_nodes = cur_match[1] elif isinstance(cur_match[1][0], tuple): # ((a, b), c) node_a, node_b = cur_match[1][0] node_c = cur_match[1][1] list_of_nodes = _order_nodes(node_a, node_b, node_c) elif isinstance(cur_match[1][1], tuple): # (a, (b, c)) node_a, node_b = cur_match[1][1] node_c = cur_match[1][0] list_of_nodes = _order_nodes(node_a, node_b, node_c) # [node_n, ..., node_0], note that the order is reversed # to make it chronological for simple subgraphs list_of_nodes.reverse() subgraphs_dedup[name] = list_of_nodes return subgraphs_dedup def _get_logger_for_subgraph( model: GraphModule, first_node: Node, last_node: Node, subgraph_idx: int, subgraph_candidate_idx: int, qconfig_str: str, logger_cls: Callable, fqn: Optional[str], ) -> torch.nn.Module: """ Given a model and a linear subgraph starting from `first_node` and ending with `last_node`, creates a logger for the end of this subgraph. """ if fqn is None: fqn = '' logger_mod_orig = logger_cls( first_node.name, # ref_node_name last_node.name, # prev_node_name f'subgraph_{subgraph_idx}_{subgraph_candidate_idx}', # model_name 'model', # ref_name get_target_type_str(last_node, model), # prev_node_target_type get_target_type_str(first_node, model), # ref_node_target_type NSSingleResultValuesType.NODE_OUTPUT.value, # results_type 0, # index_within_arg 0, # index_of_arg fqn, # fqn qconfig_str, ) # Usually we expect the user to add loggers, then calibrate, then convert, # and then populate loggers. This is why the loggers start disabled. # TODO(future PR): reconsider the design to make this more intuitive. logger_mod_orig.enabled = False return logger_mod_orig def create_submodule_from_subgraph( model: torch.nn.Module, first_node: Node, last_node: Node, ) -> GraphModule: """ Input: a model, and a linear subgraph within the model from first_node to last_node. Output: a new submodule containing a copy of the subgraph, with the inputs to the first node becoming the inputs to the submodule, and all other nodes in the subgraph being copied. Example inputs: `model`: a module with graph x0 -> op1 -> x1 -> op2 -> x2 | arg1 `first_node`: op1 `last_node`: op2 Example output: a new module with graph input1 -> op1_copy -> x1 -> op2_copy -> output1 | arg1 """ # # create a blank GraphModule with an empty graph # class M(torch.nn.Module): def forward(self, x): pass m = M() gm = torch.fx.symbolic_trace(m) g = gm.graph for node in reversed(gm.graph.nodes): g.erase_node(node) # # modify the graph to have a copy of our subgraph # cur_node_orig = first_node cur_args_orig = cur_node_orig.args cur_kwargs_orig = cur_node_orig.kwargs cur_name_idx = 0 iteration_limit = 100 cur_iteration = 0 while True: if cur_node_orig is first_node: # we are at the first node, we need to set up graph inputs # TODO(future): some graphs could have placeholders which are unrelated # to the first node, need to handle this cur_args_copy = [] cur_kwargs_copy = {} seen_names: Set[str] = set() old_name_to_new_node: Dict[str, Node] = {} def _add_placeholder( g: Graph, node: Node, seen_names, old_name_to_new_node ): # note: for graphs starting with patterns such as `y = x + x`, we # need to ensure we do not add multiple placeholders with the # same name counter = 0 while node.name + '_' + str(counter) in seen_names: counter += 1 cur_name = node.name + '_' + str(counter) seen_names.add(cur_name) placeholder = g.placeholder(cur_name) old_name_to_new_node[node.name] = placeholder return placeholder for arg in cur_node_orig.args: if isinstance(arg, Node): p = _add_placeholder( g, arg, seen_names, old_name_to_new_node) cur_args_copy.append(p) elif isinstance(arg, (list, tuple)): new_arg = [] for inner_arg in arg: if isinstance(inner_arg, Node): new_arg.append(_add_placeholder( g, inner_arg, seen_names, old_name_to_new_node)) else: new_arg.append(inner_arg) cur_args_copy.append(new_arg) else: cur_args_copy.append(arg) # TODO(future PR): handle non-normalized kwargs for kwarg_name, kwarg in cur_node_orig.kwargs.items(): if isinstance(kwarg, Node): cur_kwargs_copy[kwarg_name] = _add_placeholder( g, kwarg, seen_names, old_name_to_new_node) elif isinstance(kwarg, (list, tuple)): new_kwarg = [] for inner_kwarg in kwarg: p = _add_placeholder( g, inner_kwarg, seen_names, old_name_to_new_node) new_kwarg.append(p) cur_kwargs_copy[kwarg_name] = new_kwarg else: cur_kwargs_copy[kwarg_name] = kwarg cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] else: # we are not at first node, first arg is from the previous node, # and all other args are copied # the current implementation is simplistic and cannot handle # ops with two or more arguments which need to be passed from # the previous op, so we assert them out assert cur_node_orig.target not in BINARY_FUNCTIONS # at this point in the code, cur_node_copy is pointing to the copy # of the previous node # TODO(future PR): this is not handling complicated graphs correctly, need to # look at actual relationships instead of assuming sequential graph # TODO(future PR): this is ignoring kwargs, will need to support kwargs # for any fusion pattern which has them for a node that is not the # first node. cur_args_copy = [cur_node_copy] # type: ignore[has-type] if len(cur_node_orig.args) > 1: for arg in cur_node_orig.args[1:]: if isinstance(arg, torch.nn.Parameter): new_arg = arg.clone().detach() # type: ignore[assignment] mod_name = f"mod_{cur_name_idx}" cur_name_idx += 1 setattr(gm, mod_name, new_arg) new_arg_placeholder = gm.placeholder(mod_name) cur_args_copy.append(new_arg_placeholder) elif isinstance(arg, (float, int, torch.dtype)): cur_args_copy.append(arg) else: raise AssertionError(f'arg of type {type(arg)} not handled yet') cur_args_copy = tuple(cur_args_copy) # type: ignore[assignment] # copy the node if cur_node_orig.op == 'call_module': orig_mod = getattr_from_fqn(model, cur_node_orig.target) # type: ignore[arg-type] orig_mod_copy = copy.deepcopy(orig_mod) mod_name = f"mod_{cur_name_idx}" setattr(gm, mod_name, orig_mod_copy) cur_name_idx += 1 cur_node_copy = g.call_module(mod_name, cur_args_copy, cur_kwargs_copy) elif cur_node_orig.op == 'call_function': cur_node_copy = g.call_function( cur_node_orig.target, cur_args_copy, cur_kwargs_copy) elif cur_node_orig.op == 'call_method': cur_node_copy = g.call_method( cur_node_orig.target, cur_args_copy, cur_kwargs_copy) else: raise AssertionError(f'{cur_node_orig.op} not supported yet') if cur_node_orig is last_node: break # go to next node assert len(cur_node_orig.users.keys()) == 1, \ f'{cur_node_orig} has more than 1 users, not supported yet' cur_node_orig = list(cur_node_orig.users.keys())[0] cur_args_orig = cur_node_orig.args cur_kwargs_orig = cur_node_orig.kwargs cur_iteration += 1 if cur_iteration > iteration_limit: raise AssertionError('iteration limit exceeded') # set up outputs g.output(cur_node_copy) gm.recompile() return gm def create_one_transformed_and_logged_copy_of_subgraph( mt: GraphModule, subgraph_idx: int, subgraph_candidate_idx: int, first_node: Node, last_node: Node, fqn: Optional[str], list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], example_inputs: Any, last_added_shadow_node_list: List[Optional[Node]], custom_prepare_fn: Optional[Callable] = None, custom_prepare_kwargs: Dict[str, Any] = None, ) -> None: """ Given a subgraph in `mt` and a subgraph candidate idx, inserts the subgraph candidate copy and instruments it with loggers. If subgraph_candidate_idx is 0, this is the baseline fp32 subgraph and we just add a logger to the end. If subgraph_candidate_idx is not 0, we create a copy of the subgraph and prepare it with `prepare_fx`. """ # TODO(future PR): move logger classes to utils to remove circular dependency from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger if subgraph_candidate_idx == 0: # idx = 0 is the floating point (original) version of the subgraph # We keep the subgraph as is, and add a logger at the end qconfig_str = '' logger_mod_orig = _get_logger_for_subgraph( mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx, qconfig_str, OutputLogger, fqn) attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) assert not hasattr(mt, attr_name) setattr(mt, attr_name, logger_mod_orig) with mt.graph.inserting_after(last_node): new_node = mt.graph.call_module(attr_name, args=(last_node,), kwargs={}) last_added_shadow_node_list[0] = new_node else: # idx > 0 means we have a candidate qconfig to try, so we need # to make a copy of the subgraph, feed it with the right inputs, # and add a logger at the end # get the qconfig # subtract one because the first candidate is the floating point # version of the subgraph node_name_to_qconfig = \ list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] qconfig = node_name_to_qconfig[first_node.name] # if no quantization is requested, skip # TODO(future PR): deduplicate equivalent qconfigs that come from # different qconfig mapping objects if qconfig is None: return qconfig_mapping = QConfigMapping().set_global(qconfig) # create a copy of the submodule, wrapped in a separate module orig_mod_copy_wrapped = create_submodule_from_subgraph( mt, first_node, last_node) # add a call to prepare_fx on the wrapper module if custom_prepare_fn is None: orig_mod_copy_wrapped = torch.ao.quantization.quantize_fx.prepare_fx( orig_mod_copy_wrapped, qconfig_mapping, example_inputs=example_inputs) else: if custom_prepare_kwargs is None: custom_prepare_kwargs = {} for kwarg_name in ["example_inputs", "prepare_custom_config", "qconfig_mapping"]: assert kwarg_name not in custom_prepare_kwargs, f"cannot specify {kwarg_name} in custom_prepare_kwargs" prepare_kwargs: Dict[str, Any] = { "example_inputs": example_inputs, "qconfig_mapping": qconfig_mapping } prepare_kwargs.update(custom_prepare_kwargs) orig_mod_copy_wrapped = custom_prepare_fn( orig_mod_copy_wrapped, **prepare_kwargs) # attach the wrapper to the model attr_name = _get_attr_wrapper_name(subgraph_idx, subgraph_candidate_idx) assert not hasattr(mt, attr_name) setattr(mt, attr_name, orig_mod_copy_wrapped) # add a call to the wrapper module from the parent graph insert_after_node = last_added_shadow_node_list[0] with mt.graph.inserting_after(insert_after_node): # TODO(future PR): handle fusion patterns where non-first nodes # need inputs # pass in all node args and kwargs new_args = [] for arg in first_node.args: if isinstance(arg, Node): new_args.append(arg) elif isinstance(arg, (list, tuple)) and len(arg) and isinstance(arg[0], Node): for inner_arg in arg: if isinstance(inner_arg, Node): new_args.append(inner_arg) new_kwargs = {} for name, old_kwarg in first_node.kwargs.items(): if isinstance(old_kwarg, Node): new_kwargs[name] = old_kwarg elif isinstance(old_kwarg, (list, tuple)) and len(old_kwarg): for inner_old_kwarg in old_kwarg: # TODO(future PR): clarify why we are adding kwargs to args new_args.append(inner_old_kwarg) new_args = tuple(new_args) # type: ignore[assignment] new_node = mt.graph.call_module( attr_name, args=new_args, kwargs=new_kwargs) # add a logger to parent graph to observe the shadow wrapper logger_mod_orig = _get_logger_for_subgraph( mt, first_node, last_node, subgraph_idx, subgraph_candidate_idx, str(qconfig), OutputComparisonLogger, fqn) attr_name = _get_attr_name(subgraph_idx, subgraph_candidate_idx) assert not hasattr(mt, attr_name) setattr(mt, attr_name, logger_mod_orig) with mt.graph.inserting_after(new_node): logger = mt.graph.call_module(attr_name, args=(new_node, last_node), kwargs={}) last_added_shadow_node_list[0] = logger mt.recompile() def create_n_transformed_and_logged_copies_of_subgraph( mt: GraphModule, subgraph_idx: int, match_name: str, nodes_in_this_subgraph: List[Any], qconfig_mappings: List[QConfigMapping], list_of_node_name_to_qconfig: List[Dict[str, QConfigAny]], custom_prepare_fn: Optional[Callable] = None, custom_prepare_kwargs: Dict[str, Any] = None, ) -> None: """ Given a model `mt` and a subgraph_idx, creates the needed copies of the subgraph for all qconfigs, and instruments them with loggers. """ # for now, assume that # 1. the first node has one input # 2. the last node has one output # for now, ignore all subgraphs that contain non-nodes (tuples, etc) # TODO(future PR): implement this if any( not isinstance(node, Node) for node in nodes_in_this_subgraph ): return first_node = nodes_in_this_subgraph[0] last_node = nodes_in_this_subgraph[-1] # We used output propagation to populate example values on each # node. Use the example values from the previous node as the input # to the current node. prev_node = get_normalized_nth_input(first_node, mt, 0) if isinstance(prev_node, list): example_inputs = [x.traced_result for x in prev_node] elif isinstance(prev_node, tuple): example_inputs = (x.traced_result for x in prev_node) # type: ignore[assignment] else: # currently some customer models do not have a traced_result in # every node, so we have to guard for this case since we cannot # quantize without an example input # TODO(future PR): add a test case for this once we have an easy # repro, see https://github.com/pytorch/pytorch/pull/80521/files#r975940489 # for additional context if hasattr(prev_node, 'traced_result'): example_inputs = (prev_node.traced_result,) # type: ignore[attr-defined, assignment] else: print( 'unable to get example input for node ' + f'{first_node.format_node()}, skipping') return # If there are no quantization configs for this subgraph, skip adding # loggers. This reduces memory usage for models where not all layers are # quantized. # TODO(future): consider making this configurable found_at_least_one_qconfig = False for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): if subgraph_candidate_idx == 0: # fp32 baseline does not need a qconfig continue # a. we have N shadows, so len(qconfig_mappings) is N # b. we will have the fp32 layer + N shadows, so overall number of # (original_op) + (*shadows) will be N+1 # c. since `subgraph_candidate_idx` represents (b), we need # to subtract 1 to query from (a) node_name_to_qconfig = \ list_of_node_name_to_qconfig[subgraph_candidate_idx - 1] qconfig = node_name_to_qconfig[first_node.name] if qconfig is not None: found_at_least_one_qconfig = True break if not found_at_least_one_qconfig: print('unable to find at least one qconfig for node ' + f'{first_node.format_node()}, skipping') return fqn = _maybe_get_fqn(first_node, mt) # We want the results to contain the subgraphs in natural order, # and the graph to also contain shadow wrappers and shadow loggers # in natural order. # If we just iterate in reverse, the graph will be in natural # order but the eventual results will be in reverse order. # So, we keep track of the last shadow logger we added and # always insert after it. last_added_shadow_node_list: List[Optional[Node]] = [None] for subgraph_candidate_idx in range(len(qconfig_mappings) + 1): create_one_transformed_and_logged_copy_of_subgraph( mt, subgraph_idx, subgraph_candidate_idx, first_node, last_node, fqn, list_of_node_name_to_qconfig, example_inputs, last_added_shadow_node_list, custom_prepare_fn, custom_prepare_kwargs) def create_add_loggers_graph( model: GraphModule, subgraphs_dedup: Dict[str, List[Node]], qconfig_mapping: QConfigMapping, node_name_to_qconfig: Dict[str, QConfigAny], ) -> None: """ Given a model, a model graph partition (currently a set of matched subgraphs) and instructions how to transform each subgraph (currently quantizing it according to qconfig_mapping), modifies the model graph to create an alternate path through the original graph, with each of the subgraphs quantized. This is useful to compare propagation error of a transformation such as quantization. For example, given layer op0 and op1, there are four cases when handling op1: 1. op0 and op1 quantized 2. op0 and op1 unquantized 3. op0 quantized, op1 unquantized 4. op0 unquantized, op1 quantized Example input, case 1: .. code:: x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log \ \ \ \ # noqa: W605 ---> op0_1 -> x1_1 ----> clog op1_1 -> x2_1 ----> clog Example output, case 1: .. code:: x0_0 -> op0_0 -> x1_0 -> log -----> op1_0 -> x2_0 -> log \ \ \ # noqa: W605 ---> op0_1 -> x1_1 ----> clog -> op1_1 -> x2_1 ----> clog """ # TODO(future PR): move logger classes to utils to remove circular dependency from torch.ao.ns._numeric_suite_fx import OutputLogger, OutputComparisonLogger def _get_subgraph_containing_node(node, subgraphs_dedup): for name, subgraph in subgraphs_dedup.items(): if node in subgraph: return subgraph return None # First, we need to create shadow branches, going from # # x0 -> op0 -> x1 -> ... # # # to # # x0 -> op0_0 -> x1_0 -> log -> ... # \ \ # -> op0_1 -> x1_1 -> clog # # Later, the outputs of each shadow will be rerouted to calculate # propagation error. # Note: we cannot iterate over matched subgraphs because some nodes # may not be matched. So, we iterate over nodes in the graph, and # associate them to matched subgraphs if possible. nodes_to_skip = set() # for each subgraph, save a mapping from first node of subgraph # to first and last node of the shadow of this subgraph orig_first_node_to_shadow_in_node = {} orig_first_node_to_shadow_out_node = {} # need to record original list because we will mutate the graph as we go orig_nodes = list(model.graph.nodes) # type: ignore[union-attr, arg-type] cur_subgraph_idx = 0 for n in orig_nodes: if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip: continue maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) insert_submodule_copy = False if maybe_subgraph is not None: first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] for node_to_skip in maybe_subgraph: nodes_to_skip.add(node_to_skip) qconfig = node_name_to_qconfig[first_node.name] if qconfig is not None: insert_submodule_copy = True else: first_node, last_node = n, n if insert_submodule_copy: match_name = first_node.name create_n_transformed_and_logged_copies_of_subgraph( model, cur_subgraph_idx, match_name, maybe_subgraph, [qconfig_mapping], [node_name_to_qconfig], None, None ) # find the created shadow module and record it so we # can find it easily in step 2 expected_shadow_target = f"shadow_wrapper_{cur_subgraph_idx}_1" new_shadow_mod = None for maybe_shadow_mod in model.graph.nodes: if maybe_shadow_mod.op == 'call_module' and \ maybe_shadow_mod.target == expected_shadow_target: new_shadow_mod = maybe_shadow_mod break assert new_shadow_mod is not None orig_first_node_to_shadow_in_node[first_node] = new_shadow_mod orig_first_node_to_shadow_out_node[first_node] = new_shadow_mod else: # create a copy of the subgraph by only copying FX nodes # but not copying any parameters, to minimize memory usage subgraph_to_use = maybe_subgraph if maybe_subgraph is not None \ else [first_node] # add a regular logger after last_node qconfig_str = '' subgraph_candidate_idx = 0 fqn = _maybe_get_fqn(first_node, model) logger_mod_orig = _get_logger_for_subgraph( model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx, qconfig_str, OutputLogger, fqn) attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) assert not hasattr(model, attr_name) setattr(model, attr_name, logger_mod_orig) insertion_point = last_node with model.graph.inserting_after(insertion_point): logger = model.graph.call_module( attr_name, args=(last_node,), kwargs={}) insertion_point = logger # create a copy of the subgraph cur_node_orig = first_node cur_node_copy = None first_node_copy = None while cur_node_orig in subgraph_to_use: # TODO(future PR): make this support all possible args/kwargs if cur_node_orig is first_node: new_args = cur_node_orig.args new_kwargs = cur_node_orig.kwargs else: first_arg_for_copy = cur_node_copy new_args = tuple([first_arg_for_copy, *cur_node_orig.args[1:]]) # noqa: C409 new_kwargs = cur_node_orig.kwargs # make a copy of cur_node_orig with model.graph.inserting_after(insertion_point): cur_node_copy = model.graph.create_node( cur_node_orig.op, cur_node_orig.target, new_args, new_kwargs, # cur_node_orig.name, # TODO(future PR): set name explicitly ) if first_node_copy is None: first_node_copy = cur_node_copy # since now only linear subgraphs are supported, all nodes # except the last one must have only one user if cur_node_orig != last_node: assert len(cur_node_orig.users.keys()) == 1 cur_node_orig = list(cur_node_orig.users.keys())[0] assert not cur_node_orig.name.startswith(SHADOW_NODE_NAME_PREFIX) insertion_point = cur_node_copy # add a comparison logger after last_node's copy subgraph_candidate_idx = 1 logger_mod_orig = _get_logger_for_subgraph( model, first_node, last_node, cur_subgraph_idx, subgraph_candidate_idx, qconfig_str, OutputComparisonLogger, fqn) attr_name = _get_attr_name(cur_subgraph_idx, subgraph_candidate_idx) assert not hasattr(model, attr_name) setattr(model, attr_name, logger_mod_orig) with model.graph.inserting_after(insertion_point): logger = model.graph.call_module( attr_name, args=(cur_node_copy, last_node), kwargs={}) # save the final node so we can use it in step 2 orig_first_node_to_shadow_in_node[first_node] = first_node_copy orig_first_node_to_shadow_out_node[first_node] = cur_node_copy cur_subgraph_idx += 1 model.recompile() # Now, we go from # # x0 -> op0_0 -> x1_0 -> log -> x1 -> op1_0 -> ... # \ \ \ # -> op0_1 -> x1_1 -> clog -> op1_1 -> ... # # to # # x0 -> op0_0 -> x1_0 -> log --> x1_0 -> op1_0 -> ... # \ \ # -> op0_1 -> x1_1 -> clog -> x1_1 -> op1_1 -> ... # # sample values of key internal variables for the example above: # # orig_first_node_to_shadow_in_node = {op0_0: op0_1, op1_0: op1_1} # orig_first_node_to_shadow_out_node = {op0_0: op0_1, op1_0: op1_1} # # note: for subgraphs with more than one node, in_node will be different # compared to out_node nodes_to_skip = set() for n in orig_nodes: if n.op in ('placeholder', 'get_attr', 'output') or n in nodes_to_skip: continue maybe_subgraph = _get_subgraph_containing_node(n, subgraphs_dedup) if maybe_subgraph is not None: first_node, last_node = maybe_subgraph[0], maybe_subgraph[-1] for node_to_skip in maybe_subgraph: nodes_to_skip.add(node_to_skip) else: first_node, last_node = n, n def maybe_remap_node_to_shadow(node): """ If unshadowed `node` has a shadow version, return that. If not, return `node`. """ if not isinstance(node, Node): # handle scalars return node if node.op in ('placeholder', 'get_attr'): return node # Find the shadowed version of this arg from the previous # subgraph. For this, we need to: # 1. navigate to the first node of the previous subgraph # 2. get the output of the shadow wrapper which has (1) as an input # For now, assume the arg is in matched subgraphs. In the # future we may have to handle the case where this is not true. prev_subgraph = _get_subgraph_containing_node( node, subgraphs_dedup) if prev_subgraph is None: prev_subgraph = [node] prev_first_node = prev_subgraph[0] prev_shadow_output = \ orig_first_node_to_shadow_out_node[prev_first_node] return prev_shadow_output cur_shadow_input = \ orig_first_node_to_shadow_in_node[first_node] assert cur_shadow_input is not None cur_shadow_input.args = tree_map( maybe_remap_node_to_shadow, cur_shadow_input.args) cur_shadow_input.kwargs = tree_map( maybe_remap_node_to_shadow, cur_shadow_input.kwargs) model.recompile() def _get_weight_info_from_shadow_wrapper(shadow_wrapper: torch.nn.Module): # input: shadow wrapper module # output if shadow wrapper module has a weighted op: # (quantize_fn, (quantize_fn_args)) # output if shadow wrapper module doesn't have a weighted op: # None # For now, assume that the weight is the second input # to the shadow module. If that changes, we can fix it later. placeholders_seen = 0 for shadow_n in shadow_wrapper.graph.nodes: # type: ignore[union-attr] if shadow_n.op != 'placeholder': continue placeholders_seen += 1 if placeholders_seen != 2: continue # the subgraph looks like # # _input_scale_1 = self._input_scale_1 # _input_zero_point_1 = self._input_zero_point_1 # quantize_per_channel = torch.quantize_per_channel( # w2_0, _input_scale_1, _input_zero_point_1, # 0, torch.qint8) # # we have `w2_0`, and are navigating this subgraph # to get `_input_scale_1` and `_input_zero_point_1` assert len(shadow_n.users) == 1 quant_node = list(shadow_n.users.keys())[0] new_args: Any = None if quant_node.target == torch.quantize_per_channel: _weight, scale_node, zp_node, axis, dtype = quant_node.args scale_val = getattr_from_fqn( shadow_wrapper, scale_node.target) zp_val = getattr_from_fqn( shadow_wrapper, zp_node.target) new_args = (scale_val, zp_val, axis, dtype) else: assert quant_node.target == torch.quantize_per_tensor _weight, scale_node, zp_node, dtype = quant_node.args scale_val = getattr_from_fqn( shadow_wrapper, scale_node.target) zp_val = getattr_from_fqn( shadow_wrapper, zp_node.target) new_args = (scale_val, zp_val, dtype) return (quant_node.target, new_args) return None def extract_weight_comparison(m: GraphModule) -> NSResultsType: # example graph: # # w1 = self.w1 # b1 = self.b1 # linear = torch._C._nn.linear(x, w1, b1) # shadow_0_0 = self.shadow_0_0(linear) # shadow_wrapper_0_1 = self.shadow_wrapper_0_1(x, w1, b1) # shadow_0_1 = self.shadow_0_1(shadow_wrapper_0_1, linear) # # algorithm: # 1. for each call_function node matching our allowlist: # 2. if corresponding shadow wrapper exists, extract the weight pair # # Note: this is not super robust, but that's ok because this is # just for legacy customers who depend on the previous two-model version # of this API. TBD if we need to make this robust. # Note: modules are not supported, since existing customers only # use functions. # TODO(future PR): move this to config weighted_ops = { torch.nn.functional.linear, } results: NSResultsType = { 'model': {NSSingleResultValuesType.WEIGHT.value: {}} } for n in m.graph.nodes: # type: ignore[union-attr] if not (n.op == 'call_function' and n.target in weighted_ops): continue # Check if we have a corresponding shadow wrapper # TODO(future PR, if needed): support kwargs # TODO(future PR, if needed): support multiple shadow users first_arg = n.args[0] shadow_wrapper_node = None for user in first_arg.users: # TODO(before land): fix string match if user.op == 'call_module' and \ user.target.startswith('shadow_wrapper'): shadow_wrapper_node = user break if shadow_wrapper_node is None: continue shadow_wrapper = getattr_from_fqn( m, shadow_wrapper_node.target) # type: ignore[arg-type] weight_info = _get_weight_info_from_shadow_wrapper( shadow_wrapper) if weight_info is None: continue # get weight w_node = n.args[1] w_obj = getattr_from_fqn(m, w_node.target).detach() # get a quantized version of weight quant_fn, quant_fn_args_except_first = weight_info new_args = (w_obj, *quant_fn_args_except_first) w_obj_q = quant_fn(*new_args) # add a comparison ref_node_name = n.name prev_node_name = n.name ref_node_type = get_target_type_str(n, m) prev_node_type = ref_node_type fqn = None if hasattr(m, '_node_name_to_scope'): fqn = m._node_name_to_scope[n.name][0] # type: ignore[index] comparison = torch.ao.ns.fx.utils.compute_sqnr(w_obj, w_obj_q) result_fp32 = { 'res_type': NSSingleResultValuesType.WEIGHT.value, 'values': [w_obj], 'prev_node_name': prev_node_name, 'prev_node_target_type': prev_node_type, 'ref_node_name': ref_node_name, 'ref_node_target_type': ref_node_type, 'index_within_arg': 0, 'index_of_arg': 0, 'fqn': fqn, 'qconfig_str': '', 'comparisons': [comparison], 'comparison_fn_name': 'sqnr', } result_q = { 'res_type': NSSingleResultValuesType.WEIGHT.value, 'values': [w_obj_q], 'prev_node_name': prev_node_name, 'prev_node_target_type': prev_node_type, 'ref_node_name': ref_node_name, 'ref_node_target_type': ref_node_type, 'index_within_arg': 0, 'index_of_arg': 0, 'fqn': fqn, 'qconfig_str': '', 'comparisons': [comparison], 'comparison_fn_name': 'sqnr', } # go from subgraph_n_1 to subgraph_n_0 _1, _2, node_idx, _3 = shadow_wrapper_node.target.split('_') name_fp32 = f"subgraph_{node_idx}_0" name_q = f"subgraph_{node_idx}_1" results['model'][NSSingleResultValuesType.WEIGHT.value][name_fp32] = \ [result_fp32] results['model'][NSSingleResultValuesType.WEIGHT.value][name_q] = \ [result_q] return results # TODO(future PR): redesign this to make it easier to consume outputs def group_results_by_subgraph(results: NSResultsType) -> Any: """ Creates a comparison of results Input: { 'model': { 'node_output': { 'subgraph_0_0': [ 'values': [torch.tensor(...), ...], ... 'ref_node_name': ..., 'ref_node_target_type': ..., 'qconfig_str': ..., 'comparisons': [], ... 'comparison_fn_name': '', 'fqn': '...', ], 'subgraph_0_1': [ 'values': [torch.tensor(...), ...], ... 'ref_node_name': ..., 'ref_node_target_type': ..., 'qconfig_str': ..., 'comparisons': [torch.tensor(...), ...], ... 'comparison_fn_name': '...', 'fqn': '...', ], ... }, }, } Output: { 'subgraph_0': { '0': { 'ref_node_name': '...', 'ref_node_target_type': ..., 'values': [torch.tensor(...), ...], 'qconfig_str': None, 'comparisons': [torch.tensor(...), ...], ... 'comparison_fn_name': '...', 'fqn': '...', }, '1': { 'ref_node_name': '...', 'ref_node_target_type': ..., 'values': [torch.tensor(...), ...], 'qconfig_str': '...', 'comparisons': [torch.tensor(...), ...], ... 'comparison_fn_name': '...', 'fqn': '...', }, }, } """ subgraph_name_to_subgraph_results: Any = collections.defaultdict(dict) # node_output or weight key_to_use = list(results['model'].keys())[0] for subgraph_name_with_idx, subgraph_candidate_results in \ results['model'][key_to_use].items(): # convert from `subgraph_m_n` to `subgraph_m` and `n` subgraph_str, subgraph_idx, subgraph_candidate_idx = \ subgraph_name_with_idx.split('_') subgraph_name = f'{subgraph_str}_{subgraph_idx}' subgraph_results = { 'ref_node_name': subgraph_candidate_results[0]['ref_node_name'], 'ref_node_target_type': subgraph_candidate_results[0]['ref_node_target_type'], 'fqn': subgraph_candidate_results[0]['fqn'], 'values': subgraph_candidate_results[0]['values'], 'qconfig_str': subgraph_candidate_results[0]['qconfig_str'], 'comparisons': subgraph_candidate_results[0]['comparisons'], 'comparison_fn_name': subgraph_candidate_results[0]['comparison_fn_name'], } subgraph_name_to_subgraph_results[subgraph_name][subgraph_candidate_idx] = \ subgraph_results return dict(subgraph_name_to_subgraph_results) # TODO(future PR): redesign this to make it easier to consume outputs def create_results_comparison( results_grouped, ) -> Any: """ Input: { 'subgraph_0': { '0': { 'ref_node_name': '...', 'ref_node_target_type': ..., 'values': [torch.tensor(...), ...], 'qconfig_str': '', 'comparisons': [], 'comparison_fn_name': '', 'fqn': '...', }, '1': { 'ref_node_name': '...', 'ref_node_target_type': ..., 'values': [torch.tensor(...), ...], 'qconfig_str': '...', 'comparisons': [torch.tensor(...), ...], 'comparison_fn_name': 'sqnr', 'fqn': '...', }, }, } Output: { 'subgraph_0': { 'ref_node_name': '...', 'ref_node_target_type': '...', 'fqn': '...', 'candidates': { '1': { 'qconfig_str': ..., 'comparison_fn_name': 'sqnr', 'cmp_raw': [..., ...], 'cmp_mean': ..., }, ..., }, }, } """ results_comparison = {} for subgraph_name, subgraph_results in results_grouped.items(): candidates = {} for subgraph_inner_name, subgraph_inner_result in subgraph_results.items(): # skip comparing baseline to baseline if subgraph_inner_name == '0': continue # we expect the comparisons to be precalculated from # calibration, so we just fetch them here cmp_raw = subgraph_inner_result['comparisons'] cmp_raw_tensor = torch.stack(cmp_raw) candidates[subgraph_inner_name] = { 'qconfig_str': subgraph_inner_result['qconfig_str'], 'comparison_fn_name': subgraph_inner_result['comparison_fn_name'], 'cmp_raw': cmp_raw_tensor, 'cmp_mean': torch.mean(cmp_raw_tensor), } results_comparison[subgraph_name] = { 'ref_node_name': subgraph_results['0']['ref_node_name'], 'ref_node_target_type': subgraph_results['0']['ref_node_target_type'], 'fqn': subgraph_results['0']['fqn'], 'candidates': candidates, } return results_comparison # TODO(future PR): redesign this to make it easier to consume outputs def print_n_shadows_summary( results_comparison, ) -> None: """ Input: { 'subgraph_0': { 'ref_node_name': 'linear1', 'ref_node_target_type': '...', 'fqn': '...', 'candidates': { '1': { 'qconfig_str': ..., 'comparison_fn_name': ..., 'cmp_raw': [45.0, 55.0], 'cmp_mean': 50.0, }, ..., }, }, } Prints: node_name | node_type | fqn | 0 | 1 | ... linear1 | ... | ... | 45.0 | 50.0 | ... """ try: from tabulate import tabulate except ImportError: print("`print_tabular` relies on the library `tabulate`, " "which could not be found on this machine. Run `pip " "install tabulate` to install the library.") return results = [] for subgraph_name, subgraph_data in results_comparison.items(): mean_all_candidates = [ candidate['cmp_mean'] for candidate_name, candidate in subgraph_data['candidates'].items() ] data_row = [ subgraph_data['ref_node_name'], subgraph_data['ref_node_target_type'], subgraph_data['fqn'], *mean_all_candidates, ] results.append(data_row) max_candidate_idx_len = -1 for data_row in results: max_candidate_idx_len = max(max_candidate_idx_len, len(data_row[1])) candidate_idx_headers = [str(x) for x in range(max_candidate_idx_len)] headers = ['node_name', 'node_type', 'fqn', *candidate_idx_headers] print(tabulate(results, headers=headers))