123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- import collections
- import enum
- import torch
- toq = torch.ops.quantized
- from torch.fx import GraphModule
- from torch.fx.graph import Graph, Node
- from torch.ao.quantization.utils import getattr_from_fqn
- from .ns_types import NSSubgraph, NSNodeTargetType
- from .mappings import (
- get_base_name_to_sets_of_related_ops,
- get_unmatchable_types_map,
- )
- from .pattern_utils import (
- get_type_a_related_to_b,
- get_reversed_fusions,
- end_node_matches_reversed_fusion,
- )
- from torch.ao.quantization import (
- ObserverBase,
- FakeQuantizeBase,
- )
- from typing import Dict, Tuple, List, Optional, Set, Any
- def _get_output_nodes(g: Graph) -> List[Node]:
- return [n for n in g.nodes if n.op == 'output']
- class _NSGraphMatchableSubgraphsIterator:
- """
- Iterates through the graph of gm, starting with the output nodes
- and continuing backwards.
- 1. Returns matchable subgraphs, in order. A subgraph is defined by
- (start_node, end_node).
- 2. Skips over non-matchable subgraphs
- """
- def __init__(
- self,
- gm: GraphModule,
- non_matchable_functions: Set[NSNodeTargetType],
- non_matchable_modules: Set[NSNodeTargetType],
- non_matchable_methods: Set[NSNodeTargetType],
- ):
- self.gm: GraphModule = gm
- self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
- self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
- self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
- self.seen_nodes: Set[Node] = set()
- self.stack: List[Node] = []
- for start_node in _get_output_nodes(self.gm.graph):
- self.stack.append(start_node)
- def __iter__(self):
- return self
- def __next__(self) -> NSSubgraph:
- """
- Returns the next matchable subgraph.
- """
- while len(self.stack) > 0:
- cur_end_node = self.stack.pop()
- if cur_end_node in self.seen_nodes:
- continue
- # for subgraphs which are single nodes, start_node == end_node
- # for subgraphs with more than one node, start node != end_node
- cur_start_node = cur_end_node
- # Subgraphs like linear-relu have the base node as the start node.
- # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
- # base node as the second node.
- # The cur_base_op_node var will move to the actual node during
- # the fusion matching later in this code block.
- cur_base_op_node = cur_end_node
- # Check for potential fusions. For now, we are greedy
- # and always skip all non-base nodes of a fusion. For example,
- # if we match linear-relu backwards, we will always skip the
- # relu node and attempt to match the linear node. This can
- # be made configurable later if needed.
- for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
- is_match = end_node_matches_reversed_fusion(
- cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes)
- if is_match:
- # navigate to the base node
- for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
- self.seen_nodes.add(cur_start_node)
- # for now, assume that there are no other nodes
- # which need to be added to the stack
- cur_start_node = cur_start_node.args[0] # type: ignore[assignment]
- # if the base op index matches the current node, set it
- rev_base_op_idx = \
- len(_reverse_fusion_ops) - 2 - base_op_idx
- if rev_fusion_idx == rev_base_op_idx:
- cur_base_op_node = cur_start_node
- break
- self.seen_nodes.add(cur_start_node)
- # add args of previous nodes to stack
- for arg in cur_start_node.all_input_nodes:
- self._recursively_add_node_arg_to_stack(arg)
- # skip unmatchable nodes
- # note: this check is done on the start_node, i.e.
- # if we are matching linear-relu in reverse, this would do the matchable
- # check on the linear
- if not self._is_matchable(cur_base_op_node):
- continue
- # If an observer or a fake_quant was not matched as a part of
- # a pattern of multiple nodes, ignore it. One case where this is
- # relevant is an observer on a graph input, which was added because
- # it is necessary for the next node.
- if cur_end_node.op == 'call_module' and cur_start_node is cur_end_node:
- maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type]
- if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
- continue
- return NSSubgraph(
- start_node=cur_start_node, end_node=cur_end_node,
- base_op_node=cur_base_op_node)
- raise StopIteration
- def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
- """
- Adds all of the nodes in this arg to the stack, properly navigating
- through list, dicts and tuples.
- """
- if isinstance(arg, Node):
- self.stack.append(arg)
- elif isinstance(arg, torch.fx.immutable_collections.immutable_list) or type(arg) is tuple:
- for inner_arg in arg:
- self._recursively_add_node_arg_to_stack(inner_arg)
- elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
- for key, value in arg.items():
- self._recursively_add_node_arg_to_stack(value)
- def _is_matchable(self, node: Node) -> bool:
- if node.op == 'call_function':
- return not (node.target in self.non_matchable_functions)
- elif node.op == 'call_module':
- assert isinstance(node.target, str)
- target_mod = getattr_from_fqn(self.gm, node.target)
- return not \
- any(isinstance(target_mod, t) # type: ignore[arg-type]
- for t in self.non_matchable_modules)
- elif node.op == 'call_method':
- return not (node.target in self.non_matchable_methods)
- else:
- return False
- class GraphMatchingException(Exception):
- """
- Exception raised when two graphs cannot be matched.
- """
- pass
- class SubgraphTypeRelationship(enum.Enum):
- # same type, known
- # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
- EQUAL = enum.auto()
- # same type, but the type is not known to Numerical Suite
- # (user defined type, etc).
- EQUAL_BUT_UKNOWN = enum.auto()
- # known, same subgraph_relationship set, but not the same type
- # example: F.linear and toq.linear
- RELATED_BUT_NOT_EQUAL = enum.auto()
- # not related
- NOT_RELATED = enum.auto()
- def _get_subgraph_relationship_type(
- subgraph_a: NSSubgraph,
- subgraph_b: NSSubgraph,
- gm_a: GraphModule,
- gm_b: GraphModule,
- type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
- ) -> SubgraphTypeRelationship:
- node_a = subgraph_a.base_op_node
- node_b = subgraph_b.base_op_node
- # TODO(next): make this code handle matching by what is before the base op
- if node_a.op != node_b.op:
- if not (
- node_a.op in ('call_function', 'call_method') and
- node_b.op in ('call_function', 'call_method')
- ):
- return SubgraphTypeRelationship.NOT_RELATED
- if node_a.op in ('call_function', 'call_method'):
- key = (node_a.target, node_b.target)
- if key not in type_a_related_to_b:
- if node_a.target == node_b.target:
- return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
- else:
- return SubgraphTypeRelationship.NOT_RELATED
- # after this point, we are dealing with known types
- if node_a.target == node_b.target:
- node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
- node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
- if node_a_has_prev and (not node_b_has_prev):
- return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
- elif (not node_a_has_prev) and node_b_has_prev:
- return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
- elif (not node_a_has_prev) and (not node_b_has_prev):
- return SubgraphTypeRelationship.EQUAL
- else:
- # TODO(future PR): check for matches start_op_node and base_op_node
- return SubgraphTypeRelationship.EQUAL
- if key in type_a_related_to_b:
- return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
- else:
- return SubgraphTypeRelationship.NOT_RELATED
- elif node_a.op == 'call_module':
- assert (subgraph_a.base_op_node == subgraph_a.start_node and
- subgraph_b.base_op_node == subgraph_b.start_node), \
- "Matching call_module patterns where base_op_node != start_node is not supported yet"
- # for call_module, we need to look up the modules to do the type check
- assert isinstance(node_a.target, str)
- mod_a = getattr_from_fqn(gm_a, node_a.target)
- assert isinstance(node_b.target, str)
- mod_b = getattr_from_fqn(gm_b, node_b.target)
- key = (type(mod_a), type(mod_b))
- if key not in type_a_related_to_b:
- if type(mod_a) == type(mod_b):
- return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
- else:
- return SubgraphTypeRelationship.NOT_RELATED
- elif type(mod_a) == type(mod_b):
- return SubgraphTypeRelationship.EQUAL
- else:
- return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
- return SubgraphTypeRelationship.NOT_RELATED
- def _get_name_for_subgraph(
- subgraph_a: NSSubgraph,
- gm_a: GraphModule,
- base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
- existing_names: Set[str],
- ) -> str:
- """
- Returns a unique name for a subgraph. This name is based on two things:
- 1. the name of the set containing the underlying type of the base op in the
- subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
- 2. the number of previous subgraphs with related underlying type of the base op
- For example, in the graph
- linear0 -> relu0 -> linear1 -> relu1
- The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate
- from the output node backwards, the name given to (linear1, relu1) will be
- `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
- will be `base_op_torch.nn.functional.linear_1`.
- Why are we not just using the node name? Answer: because of two requirements:
- A. fusions must be supported
- B. some Numeric Suite APIs can be called without having all of the models in memory
- For example, let's say we need to match nodes of
- (1) ... -> linear0 -> relu0 -> ...
- And
- (2) ... -> linear_relu0 -> ...
- Without being able to inspect them together. With the current naming scheme, if
- we iterate through both of these graphs in the same order, and assuming the rest
- of the graphs match, both of these subgraphs will get the same name without
- (1) and (2) knowing anything about each other.
- """
- target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
- target_base_type = None
- for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
- if target_type in sets_of_related_ops:
- target_base_type = base_name
- target_base_name = 'base_op_' + str(target_base_type)
- counter = 0
- proposed_name = target_base_name + '_' + str(counter)
- while proposed_name in existing_names:
- counter += 1
- proposed_name = target_base_name + '_' + str(counter)
- existing_names.add(proposed_name)
- return proposed_name
- def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
- if node.op in ('call_function', 'call_method'):
- return node.target
- elif node.op == 'call_module':
- assert isinstance(node.target, str)
- mod = getattr_from_fqn(gm, node.target)
- return type(mod)
- return None
- def get_matching_subgraph_pairs(
- gm_a: GraphModule,
- gm_b: GraphModule,
- base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
- ) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
- """
- Matches matchable subgraphs of graph_a to graph_b.
- For a node, "matchable" is defined as a node which is not an observer,
- fake_quants, quant or dequant.
- A subgraph can contain one or more nodes. A subgraph is matchable if
- at least one node inside of it is matchable. Currently, all nodes in
- a subgraph must be matchable (because we assume no observers will be
- inserted in the middle of a fusion).
- A subgraph is defined by (start_node, end_node). We assume that only
- start_node and end_node are linked with the surrounding graph, all other
- nodes in a subgraph are self-contained.
- A pair of nodes is "related" if both nodes represent the same mathematical
- operation across different quantization flavors. For example,
- `F.linear` and `torch.ops.quantized.linear` are related, and
- `F.linear` and `torch.nn.Conv` are not related.
- For each matchable pair of nodes node_a and node_b, they will match
- if node_a and node_b are related.
- For graphs A and B, they will match iff:
- 1. the number of matchable subgraphs in A and B is equivalent
- 2. when iterating through the matchable subgraphs of A and B in the same order, each
- corresponding pair of base nodes is related.
- This enables us to find the corresponding subgraphs between
- graphs of related models. For example, if we had two graphs such as:
- graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
- w -/
- b -/
- graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
- packed_params_0 -/
- This function will return the following result:
- {
- 'conv_0': ( # the name of the node in graph_b
- (conv_0, conv_0), # (start_node_a, end_node_a)
- (qconv_0, qconv_0), # (start_node_b, end_node_b)
- ),
- }
- Or, if we have a fusion pattern,
- graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
- w -/
- b -/
- graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
- packed_params_0 -/
- This function will return the following result:
- {
- 'linear_relu_0': ( # the name of the node in graph_b
- (linear_0, relu_0), # (start_node_a, end_node_a)
- (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b)
- ),
- }
- """
- if unmatchable_types_map is None:
- unmatchable_types_map = get_unmatchable_types_map()
- non_matchable_functions = unmatchable_types_map['funs_unmatchable']
- non_matchable_modules = unmatchable_types_map['mods_unmatchable']
- non_matchable_methods = unmatchable_types_map['meths_unmatchable']
- graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
- gm_a, non_matchable_functions, non_matchable_modules,
- non_matchable_methods)
- graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
- gm_b, non_matchable_functions, non_matchable_modules,
- non_matchable_methods)
- results = collections.OrderedDict()
- if base_name_to_sets_of_related_ops is None:
- base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
- type_a_related_to_b = \
- get_type_a_related_to_b(base_name_to_sets_of_related_ops)
- existing_names_a: Set[str] = set()
- existing_names_b: Set[str] = set()
- while True:
- # fetch the next subgraphs from a and b
- cur_subgraph_a, cur_subgraph_b = None, None
- try:
- cur_subgraph_a = next(graph_a_iterator)
- except StopIteration:
- pass
- try:
- cur_subgraph_b = next(graph_b_iterator)
- except StopIteration:
- pass
- # look up types of a and b for useful error messages
- type_start_a, type_start_b = None, None
- if cur_subgraph_a is not None:
- type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
- if cur_subgraph_b is not None:
- type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
- # check for results and determine what to do next
- if cur_subgraph_a is not None and cur_subgraph_b is not None:
- # both nodes were fetched, check for subgraph_relationship
- # note: subgraph_relationship is checked on the start node, i.e.
- # if a linear-relu pattern is checked, we would check for subgraph_relationship
- # of the linear
- subgraph_relationship = _get_subgraph_relationship_type(
- cur_subgraph_a, cur_subgraph_b,
- gm_a, gm_b, type_a_related_to_b)
- if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
- msg = f"""
- The subgraphs
- ({cur_subgraph_a}, {type_start_a}) and
- ({cur_subgraph_b}, {type_start_b})
- are not related. Please ensure that the two models you pass in have the same number
- of subgraphs, and each pair of subgraphs is related to each other."""
- raise GraphMatchingException(msg)
- elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
- # skip matching but unknown types
- continue
- key_name_a = _get_name_for_subgraph(
- cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops,
- existing_names_a)
- key_name_b = _get_name_for_subgraph(
- cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops,
- existing_names_b)
- assert key_name_a == key_name_b, \
- f"Subgraph names {key_name_a} and {key_name_b} do not match"
- results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
- continue
- elif cur_subgraph_a is None and cur_subgraph_b is None:
- # we reached the end of both graphs
- break
- else:
- # only one node was fetched, no match possible, throw error
- msg = f"""
- Attempting to match
- ({cur_subgraph_a}, {type_start_a}) and
- ({cur_subgraph_b}, {type_start_b}),
- one of which is empty. Please ensure that the two models you pass in have the same number
- of subgraphs."""
- raise GraphMatchingException(msg)
- # The subgraph pairs are originally created by traversing the two graphs
- # from the outputs to the inputs. Reverse the results to return the
- # subgraphs in their order of execution.
- results = collections.OrderedDict(reversed(list(results.items())))
- return results
|