123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237 |
- import sys
- import torch
- from torch.fx.graph import (
- Graph,
- Node,
- )
- from torch.ao.quantization.utils import Pattern
- from .quantize_handler import (
- QuantizeHandler,
- )
- from ..qconfig import (
- QConfigAny,
- )
- from ..utils import (
- MatchAllNode
- )
- from .graph_module import (
- _is_observed_standalone_module,
- )
- from torch.nn.utils.parametrize import type_before_parametrizations
- from typing import Any, Dict, List, Callable, Optional, Tuple, Type, Set, Iterable
- __all__: List[str] = []
- # TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
- # would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
- _MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
- _MatchResultWithQConfig = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
- QConfigAny]
- # Note: The order of patterns is important! match function will take whatever is matched first, so we'll
- # need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
- # decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
- # we'll start from the last node of the graph and traverse back.
- def _is_match(modules, node, pattern, max_uses=sys.maxsize):
- """ Matches a node in fx against a pattern
- """
- if isinstance(pattern, tuple):
- self_match, *arg_matches = pattern
- if self_match is getattr:
- assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
- arg_matches = []
- else:
- self_match = pattern
- arg_matches = []
- if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
- return True
- if node == pattern:
- return True
- if not isinstance(node, Node) or len(node.users) > max_uses:
- return False
- if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
- if node.op != 'call_module':
- return False
- if not type_before_parametrizations(modules[node.target]) == self_match:
- return False
- elif callable(self_match):
- if node.op != 'call_function' or node.target is not self_match:
- return False
- elif node.target is getattr:
- if node.args[1] != pattern[1]:
- return False
- elif isinstance(self_match, str):
- if node.op != 'call_method' or node.target != self_match:
- return False
- elif node.target != self_match:
- return False
- if not arg_matches:
- return True
- if len(arg_matches) != len(node.args):
- return False
- return all(_is_match(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
- def _find_matches(
- graph: Graph,
- modules: Dict[str, torch.nn.Module],
- patterns: Dict[Pattern, QuantizeHandler],
- root_node_getter_mapping: Dict[Pattern, Callable],
- standalone_module_names: List[str] = None,
- standalone_module_classes: List[Type] = None,
- custom_module_classes: List[Any] = None) -> Dict[str, _MatchResult]:
- """
- Matches the nodes in the input graph to quantization patterns, and
- outputs the information needed to quantize them in future steps.
- Inputs:
- - graph: an fx.Graph object
- - modules: a mapping of fully qualified module name to instance,
- for example, {'foo': ModuleFoo, ...}
- - patterns: a mapping from a tuple of nodes in reverse order to
- uninitialized QuantizeHandler subclass.
- Outputs a map of
- node_name ->
- (node, matched_values, matched_pattern, QuantizeHandler instance,
- qconfig)
- For example, {
- 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
- <CopyNodeQuantizeHandler instance>, QConfig(...)),
- ...
- }
- """
- if custom_module_classes is None:
- custom_module_classes = []
- if standalone_module_classes is None:
- standalone_module_classes = []
- if standalone_module_names is None:
- standalone_module_names = []
- match_map: Dict[str, _MatchResult] = {}
- all_matched : Set[str] = set()
- def _recursive_record_node_in_match_map(
- last_node,
- match_map,
- node_pattern,
- matched_node_pattern,
- pattern,
- match_value):
- if isinstance(node_pattern, Node):
- match_map[node_pattern.name] = (
- last_node, matched_node_pattern, pattern, match_value)
- elif not isinstance(node_pattern, Iterable):
- return
- else:
- for n in node_pattern:
- _recursive_record_node_in_match_map(last_node, match_map, n, matched_node_pattern, pattern, match_value)
- # TODO: 1. merge with fuse matcher 2. document the code
- def record_match(
- pattern,
- node,
- last_node,
- matched_node_pattern,
- match_map):
- if isinstance(pattern, tuple):
- s, *args = pattern
- is_single_arg = len(args) == 1
- current_node_pattern: List[Node] = []
- record_match(
- s,
- node,
- last_node,
- matched_node_pattern,
- match_map)
- if pattern[0] is not getattr:
- for subpattern, arg in zip(args, node.args):
- record_match(
- subpattern,
- arg,
- node,
- current_node_pattern,
- match_map)
- if len(current_node_pattern) > 1:
- # current_node_pattern is the node pattern we get from matching
- # the subpattern with arguments of the node
- # we use is_single_arg to recover the original structure of the pattern
- # if the original pattern has a single argument, we will have
- # (original_op, (original_arg, ...))
- # otherwise, we'll have a list of arguments
- # (original_op, arg0, arg1, arg2, ...)
- if is_single_arg:
- matched_node_pattern.append(tuple(current_node_pattern))
- else:
- matched_node_pattern.extend(list(current_node_pattern))
- else:
- matched_node_pattern.append(current_node_pattern[0])
- else:
- matched_node_pattern.append(node)
- for node in reversed(graph.nodes):
- if node.name not in match_map and node.name not in all_matched:
- for pattern, quantize_handler_cls in patterns.items():
- root_node_getter = root_node_getter_mapping.get(pattern, None)
- if _is_match(modules, node, pattern) and node.name not in match_map:
- matched_node_pattern: List[Node] = []
- record_match(
- pattern,
- node,
- node,
- matched_node_pattern,
- match_map)
- quantize_handler = quantize_handler_cls( # type: ignore[operator]
- matched_node_pattern,
- modules,
- root_node_getter)
- last_node = node
- # record the match for all nodes in the pattern
- _recursive_record_node_in_match_map(
- last_node,
- match_map,
- # we need to record all nodes in the matched pattern in the match_map
- matched_node_pattern,
- # this is a part of the value corresponding to the node
- matched_node_pattern,
- pattern,
- quantize_handler)
- break
- # add custom module instances to the match result
- assert modules is not None
- for node in graph.nodes:
- if node.op == 'call_module' and \
- type(modules[node.target]) in custom_module_classes:
- match_map[node.name] = (
- node, node, None, QuantizeHandler(node, modules, is_custom_module=True))
- def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
- assert modules is not None
- return (
- node_target in standalone_module_names or # type: ignore[operator]
- type(modules[node_target]) in standalone_module_classes # type: ignore[operator]
- )
- # add standalone modules to the match
- for node in graph.nodes:
- if node.op == 'call_module' and \
- (is_standalone_module(node.target, modules) or
- _is_observed_standalone_module(modules[node.target])):
- # add node to matched nodes
- match_map[node.name] = (
- node, node, None,
- QuantizeHandler(node, modules, is_standalone_module=True))
- return match_map