import sys import torch from torch.fx.graph import ( Graph, Node, ) from torch.ao.quantization.utils import Pattern from .quantize_handler import ( QuantizeHandler, ) from ..qconfig import ( QConfigAny, ) from ..utils import ( MatchAllNode ) from .graph_module import ( _is_observed_standalone_module, ) from torch.nn.utils.parametrize import type_before_parametrizations from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable __all__: List[str] = [] # TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type # would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` _MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler] _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler, QConfigAny] # Note: The order of patterns is important! match function will take whatever is matched first, so we'll # need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. # decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, # we'll start from the last node of the graph and traverse back. def _is_match(modules, node, pattern, max_uses=sys.maxsize): """ Matches a node in fx against a pattern """ if isinstance(pattern, tuple): self_match, *arg_matches = pattern if self_match is getattr: assert len(pattern) == 2, 'Expecting getattr pattern to have two elements' arg_matches = [] else: self_match = pattern arg_matches = [] if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): return True if node == pattern: return True if not isinstance(node, Node) or len(node.users) > max_uses: return False if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): if node.op != 'call_module': return False if not type_before_parametrizations(modules[node.target]) == self_match: return False elif callable(self_match): if node.op != 'call_function' or node.target is not self_match: return False elif node.target is getattr: if node.args[1] != pattern[1]: return False elif isinstance(self_match, str): if node.op != 'call_method' or node.target != self_match: return False elif node.target != self_match: return False if not arg_matches: return True if len(arg_matches) != len(node.args): return False return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches)) def _find_matches( graph: Graph, modules: Dict[str, torch.nn.Module], patterns: Dict[Pattern, QuantizeHandler], root_node_getter_mapping: Dict[Pattern, Callable], standalone_module_names: List[str] = None, standalone_module_classes: List[Type] = None, custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]: """ Matches the nodes in the input graph to quantization patterns, and outputs the information needed to quantize them in future steps. Inputs: - graph: an fx.Graph object - modules: a mapping of fully qualified module name to instance, for example, {'foo': ModuleFoo, ...} - patterns: a mapping from a tuple of nodes in reverse order to uninitialized QuantizeHandler subclass. Outputs a map of node_name -> (node, matched_values, matched_pattern, QuantizeHandler instance, qconfig) For example, { 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, , QConfig(...)), ... } """ if custom_module_classes is None: custom_module_classes = [] if standalone_module_classes is None: standalone_module_classes = [] if standalone_module_names is None: standalone_module_names = [] match_map: Dict[str, _MatchResult] = {} all_matched : Set[str] = set() def _recursive_record_node_in_match_map( last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value): if isinstance(node_pattern, Node): match_map[node_pattern.name] = ( last_node, matched_node_pattern, pattern, match_value) elif not isinstance(node_pattern, Iterable): return else: for n in node_pattern: _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value) # TODO: 1. merge with fuse matcher 2. document the code def record_match( pattern, node, last_node, matched_node_pattern, match_map): if isinstance(pattern, tuple): s, *args = pattern is_single_arg = len(args) == 1 current_node_pattern: List[Node] = [] record_match( s, node, last_node, matched_node_pattern, match_map) if pattern[0] is not getattr: for subpattern, arg in zip(args, node.args): record_match( subpattern, arg, node, current_node_pattern, match_map) if len(current_node_pattern) > 1: # current_node_pattern is the node pattern we get from matching # the subpattern with arguments of the node # we use is_single_arg to recover the original structure of the pattern # if the original pattern has a single argument, we will have # (original_op, (original_arg, ...)) # otherwise, we'll have a list of arguments # (original_op, arg0, arg1, arg2, ...) if is_single_arg: matched_node_pattern.append(tuple(current_node_pattern)) else: matched_node_pattern.extend(list(current_node_pattern)) else: matched_node_pattern.append(current_node_pattern[0]) else: matched_node_pattern.append(node) for node in reversed(graph.nodes): if node.name not in match_map and node.name not in all_matched: for pattern, quantize_handler_cls in patterns.items(): root_node_getter = root_node_getter_mapping.get(pattern, None) if _is_match(modules, node, pattern) and node.name not in match_map: matched_node_pattern: List[Node] = [] record_match( pattern, node, node, matched_node_pattern, match_map) quantize_handler = quantize_handler_cls( # type: ignore[operator] matched_node_pattern, modules, root_node_getter) last_node = node # record the match for all nodes in the pattern _recursive_record_node_in_match_map( last_node, match_map, # we need to record all nodes in the matched pattern in the match_map matched_node_pattern, # this is a part of the value corresponding to the node matched_node_pattern, pattern, quantize_handler) break # add custom module instances to the match result assert modules is not None for node in graph.nodes: if node.op == 'call_module' and \ type(modules[node.target]) in custom_module_classes: match_map[node.name] = ( node, node, None, QuantizeHandler(node, modules, is_custom_module=True)) def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]): assert modules is not None return ( node_target in standalone_module_names or # type: ignore[operator] type(modules[node_target]) in standalone_module_classes # type: ignore[operator] ) # add standalone modules to the match for node in graph.nodes: if node.op == 'call_module' and \ (is_standalone_module(node.target, modules) or _is_observed_standalone_module(modules[node.target])): # add node to matched nodes match_map[node.name] = ( node, node, None, QuantizeHandler(node, modules, is_standalone_module=True)) return match_map