123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665 |
- import copy
- import torch
- import warnings
- from torch.fx import (
- GraphModule,
- )
- from torch.fx.graph import (
- Graph,
- Node,
- )
- from torch.fx.node import Argument
- from ..quantize import (
- propagate_qconfig_,
- )
- from ..observer import (
- ObserverBase,
- _is_activation_post_process
- )
- from ..qconfig import (
- _is_reuse_input_qconfig,
- QConfigAny,
- )
- from ..qconfig_mapping import (
- QConfigMapping,
- )
- from .qconfig_mapping_utils import (
- _generate_node_name_to_qconfig,
- _update_qconfig_for_fusion,
- _get_flattened_qconfig_dict,
- _update_qconfig_for_qat,
- )
- from .quantize_handler import (
- _default_root_node_getter,
- _get_pattern_to_quantize_handlers,
- QuantizeHandler,
- )
- from torch.ao.quantization.utils import (
- Pattern,
- NodePattern,
- )
- from ._equalize import (
- is_equalization_observer,
- node_supports_equalization,
- )
- from .pattern_utils import (
- _sorted_patterns_dict,
- )
- from .match_utils import (
- _MatchResultWithQConfig,
- _find_matches,
- )
- from .utils import (
- _insert_dequant_stubs_for_custom_module_lstm_output,
- _is_custom_module_lstm,
- _maybe_get_custom_module_lstm_from_node_arg,
- _qconfig_satisfies_dtype_config_constraints,
- get_custom_module_class_keys,
- all_node_args_have_no_tensors,
- assert_and_get_unique_device,
- get_non_observable_arg_indexes_and_types,
- get_new_attr_name_with_prefix,
- node_arg_is_weight,
- node_arg_is_bias,
- NON_QUANTIZABLE_WEIGHT_OPS,
- ObservedGraphModuleAttrs,
- )
- from torch.ao.quantization import (
- PlaceholderObserver
- )
- from torch.ao.quantization.quantize import (
- convert
- )
- from ..utils import (
- _parent_name,
- get_qconfig_dtypes,
- get_swapped_custom_module_class,
- activation_is_statically_quantized,
- )
- from ..backend_config.utils import (
- get_pattern_to_dtype_configs,
- get_module_to_qat_module,
- get_fusion_pattern_to_root_node_getter,
- )
- from ..backend_config import (
- BackendConfig,
- DTypeConfig,
- get_native_backend_config,
- )
- from .custom_config import (
- PrepareCustomConfig,
- StandaloneModuleConfigEntry,
- )
- from torch._subclasses import FakeTensor
- from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union, Callable
- __all__ = [
- "insert_observers_for_model",
- "prepare",
- "propagate_dtypes_for_known_nodes",
- ]
- # list of dtypes to not add observers to
- _DO_NOT_OBS_DTYPE_LIST = [int, float, torch.bool, None]
- # note: the following default target dtype info dicts are temporary,
- # should be moved to the new programmable API class soon
- _DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO = {
- "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation,
- "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig.activation
- }
- _DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO = {
- "input_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation,
- "output_act_obs_or_fq_ctr": torch.ao.quantization.qconfig._default_quint8_placeholder_qconfig.activation
- }
- def _is_activation_post_process_node(node: Node, named_modules: Dict[str, torch.nn.Module]) -> bool:
- return isinstance(node, torch.fx.Node) and node.op == "call_module" and \
- _is_activation_post_process(named_modules[str(node.target)])
- def _get_dtype_and_is_dynamic(obs_or_fq_ctr: Optional[Callable]) -> Tuple[Optional[torch.dtype], bool]:
- """ Given a constructor for observer or fake quant module, returns
- a Tuple of dtype and is_dynamic
- """
- # TODO: instead of instantiating the instance, we can use inspect to get the default args
- if obs_or_fq_ctr is None:
- return None, False
- else:
- obs_or_fq = obs_or_fq_ctr()
- return obs_or_fq.dtype, getattr(obs_or_fq, "is_dynamic", False)
- def _is_input_arg_dtype_supported_by_backend(
- arg: Argument,
- node: Node,
- qconfig: QConfigAny,
- dtype_config: DTypeConfig,
- backend_config: BackendConfig,
- ) -> bool:
- """ Check if the configured qconfig for the argument
- is supported by the backend or not
- """
- if isinstance(arg, (list, tuple)):
- return all(_is_input_arg_dtype_supported_by_backend(
- a, node, qconfig,
- dtype_config, backend_config) for a in arg)
- if not isinstance(arg, Node):
- return True
- # TODO: support check for standalone module
- is_weight = node_arg_is_weight(node, arg, backend_config)
- is_bias = node_arg_is_bias(node, arg, backend_config)
- is_activation = not is_weight and not is_bias
- if is_activation:
- input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
- qconfig_dtype, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
- # TODO(future PR): remove the cast to bool below after figuring
- # out why backend_config has is_dynamic set to None in some cases.
- return (dtype_config.input_dtype is None) or (
- dtype_config.input_dtype == qconfig_dtype and
- bool(dtype_config.is_dynamic) == bool(qconfig_is_dynamic) and
- _qconfig_satisfies_dtype_config_constraints(qconfig, dtype_config.input_dtype_with_constraints)
- )
- elif is_weight:
- # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
- weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
- qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq_ctr)
- backend_config_weight_dtype = dtype_config.weight_dtype
- dtype_matches = qconfig_weight_dtype == backend_config_weight_dtype
- qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
- qconfig, dtype_config.weight_dtype_with_constraints, is_activation=False)
- return backend_config_weight_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
- else: # bias
- # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
- bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
- qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq_ctr)
- backend_config_bias_dtype = dtype_config.bias_dtype
- return backend_config_bias_dtype is None or qconfig_bias_dtype == backend_config_bias_dtype
- def _is_output_dtype_supported_by_backend(
- node: Node,
- qconfig: QConfigAny,
- dtype_config: DTypeConfig,
- ) -> bool:
- """ Check if the configured qconfig for the output
- is supported by the backend or not
- """
- # TODO: move dtype check into `_qconfig_satisfies_dtype_config_constraints` as well
- backend_config_output_dtype = dtype_config.output_dtype
- # TODO: we should check is_dynamic here as well, the code from _is_input_arg_dtype_supported_by_backend
- # from input activation check can be reused here
- qconfig_output_dtype = None
- output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
- qconfig_output_dtype, qconfig_output_is_dynamic = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
- # TODO: this is a hack because we can only specify one activation_obs_or_fq for
- # qconfig (qconfig.activation), and we are only supporting dynamically quantized
- # linear op which has fp32 output dtype, this should be removed if we generalize
- # the structure of qconfig in the future
- if qconfig_output_is_dynamic:
- qconfig_output_dtype = torch.float32
- dtype_matches = qconfig_output_dtype == backend_config_output_dtype
- qconfig_satisfies_constraints = _qconfig_satisfies_dtype_config_constraints(
- qconfig, dtype_config.output_dtype_with_constraints)
- return backend_config_output_dtype is None or (dtype_matches and qconfig_satisfies_constraints)
- def _is_observer_in_same_graph(node: Node, named_modules: Dict[str, torch.nn.Module]):
- """ Check if observer in same graph
- when the node output is not fp32 and input is 'placeholder'
- the input is assumed to be quantized, so it is observed
- in a different place rather than not observed.
- """
- node_output_dtype = _get_arg_target_dtype_as_output(node, named_modules)
- if len(node.args) > 0 and isinstance(node.args[0], Node):
- if node_output_dtype == torch.quint8 and node.args[0].op == 'placeholder':
- return False
- return True
- def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
- pattern: Optional[Pattern],
- matched_node_pattern: Optional[List[Node]],
- qconfig: QConfigAny,
- backend_config: BackendConfig,
- ) -> bool:
- """ Check if the dtype configuration of a pattern is supported by
- the backend or not, and whether the qconfig satisfies constraints
- specified in the corresponding dtype config.
- """
- if backend_config is None or pattern is None:
- return True
- assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
- pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
- dtype_configs: List[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
- pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
- root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
- root_node = root_node_getter(matched_node_pattern)
- input_node = root_node
- output_node = matched_node_pattern[0]
- for dtype_config in dtype_configs:
- # check if arg dtype are supported
- supported = True
- for arg in list(input_node.args) + list(input_node.kwargs.values()):
- supported = supported and _is_input_arg_dtype_supported_by_backend(
- arg, input_node, qconfig, dtype_config, backend_config)
- # check if output dtype is supported
- supported = supported and _is_output_dtype_supported_by_backend(
- output_node, qconfig, dtype_config)
- if supported:
- return True
- return False
- def _get_standalone_module_configs(
- node: Node,
- named_modules: Dict[str, torch.nn.Module],
- prepare_custom_config: PrepareCustomConfig,
- parent_qconfig: QConfigAny,
- parent_backend_config: Optional[BackendConfig],
- ) -> Tuple[QConfigMapping, Tuple[Any, ...], PrepareCustomConfig, Optional[BackendConfig]]:
- """
- Returns the standalone module QConfigMapping and PrepareCustomConfig
- for `node`, assuming that the module pointed to by `node` is
- a standalone modules.
- """
- module_name = str(node.target)
- module_type = type(named_modules[module_name]) # type: ignore[index]
- # name config has precedence over type config
- config_entry = StandaloneModuleConfigEntry(None, (), None, None)
- config_entry = prepare_custom_config.standalone_module_classes.get(module_type, config_entry)
- config_entry = prepare_custom_config.standalone_module_names.get(module_name, config_entry)
- # fallback to use parent module's qconfig if user didn't specify qconfig dict
- qconfig_mapping = config_entry.qconfig_mapping or QConfigMapping().set_global(parent_qconfig)
- example_inputs = config_entry.example_inputs
- prepare_custom_config = config_entry.prepare_custom_config or PrepareCustomConfig()
- backend_config = config_entry.backend_config or parent_backend_config
- return (qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
- def _qat_swap_modules(
- root: torch.nn.Module,
- module_to_qat_module: Dict[Pattern, Type[torch.nn.Module]]) -> None:
- convert(root, mapping=module_to_qat_module, inplace=True, remove_qconfig=False)
- def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: Set[str]):
- if isinstance(matched_node_pattern, Node):
- s.add(matched_node_pattern.name)
- elif isinstance(matched_node_pattern, (list, tuple)):
- for maybe_node in matched_node_pattern:
- _add_matched_node_name_to_set(maybe_node, s)
- def _insert_observer(
- node: Node,
- observer: ObserverBase,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- ) -> Node:
- """
- Attaches `observer` to `model`, and creates a node which calls
- `observer` on the output of `node`.
- """
- model_device = assert_and_get_unique_device(model)
- if model_device:
- observer.to(model_device)
- # add observer module as attribute
- if is_equalization_observer(observer):
- prefix = node.name + '_equalization_process_'
- else:
- prefix = 'activation_post_process_'
- get_new_observer_name = get_new_attr_name_with_prefix(prefix)
- observer_name = get_new_observer_name(model)
- setattr(model, observer_name, observer)
- named_modules[observer_name] = observer
- with graph.inserting_after(node):
- new_obs = graph.create_node(
- 'call_module', observer_name, (node,), {})
- return new_obs
- def _set_target_dtype_info_for_matched_node_pattern(
- matched_node_pattern: NodePattern,
- last_node: Node,
- qconfig: QConfigAny,
- backend_config: BackendConfig,
- named_modules: Dict[str, torch.nn.Module],
- cache_for_no_tensor_check: Dict[Node, bool],
- processed_nodes: Set[Node],
- ) -> None:
- """ Sets the target_dtype_info for each node in matched_node_pattern
- Note: processed_nodes is used to ensure we only process each node once
- """
- if isinstance(matched_node_pattern, (list, tuple)):
- for node_pattern in matched_node_pattern:
- _set_target_dtype_info_for_matched_node_pattern(
- node_pattern,
- last_node,
- qconfig,
- backend_config,
- named_modules,
- cache_for_no_tensor_check,
- processed_nodes
- )
- # set target_dtype_info if matched_node_pattern is a Node
- # other types of matched object, e.g. int, float literals, are ignored
- elif isinstance(matched_node_pattern, Node):
- # for pyre
- assert isinstance(matched_node_pattern, Node)
- node = matched_node_pattern
- if node in processed_nodes:
- return
- processed_nodes.add(node)
- if qconfig is None:
- return
- # TODO: refactor the following code in terms of apply a qconfig to a pattern
- # e.g. for a pattern with op1 -> op2 -> op3, and qconfig = QConfig(input_act=obs0, output_act=obs1)
- # we set the input_obs_or_fq_ctr for the arguments of op1 to based on qconfig.input_act,
- # and set output_obs_or_fq_ctr based on qconfig.output_act
- # this also requires we extend the structure of QConfig to support more fine
- # grained configurations
- target_dtype_info: Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]] = (
- _get_target_activation_dtype_for_node(
- node,
- qconfig,
- named_modules,
- cache_for_no_tensor_check,
- )
- )
- node.meta["target_dtype_info"] = target_dtype_info
- def _get_target_activation_dtype_for_node(
- node: Node,
- qconfig: QConfigAny,
- named_modules: Dict[str, torch.nn.Module],
- cache_for_no_tensor_check: Dict[Node, bool],
- ) -> Dict[str, Optional[Tuple[Union[torch.dtype, type], bool]]]:
- """
- For each op attribute in the op's input activation, output activation,
- weight, bias - returns the settings of dtype and is_dynamic we expect
- for the `quantize` call in the reference model representation, or None
- if there is no `quantize` call needed.
- For example, if we have a node corresponding to `op0` in
- x0 -> op0 -> x1
- And we want a reference quantized representation to be
- x0 -> quant_static -> dequant -> op0 -> quant_dynamic -> dequant -> x1
- Then this function will return
- {
- "input_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
- "output_act_obs_or_fq_ctr": MinMaxObserver.with_args(dtype=torch.quint8, is_dynamic=False),
- }
- TODO(future PR, if needed): explicitly spell out the non-Tensor
- dtypes.
- """
- args_have_no_tensors = \
- all_node_args_have_no_tensors(
- node, named_modules, cache_for_no_tensor_check)
- if args_have_no_tensors:
- return {
- "input_act_obs_or_fq_ctr": None,
- "output_act_obs_or_fq_ctr": None,
- }
- # get qconfig to determine the eventual dtype of this node
- if qconfig is not None:
- act_dtype, weight_dtype, input_act_is_dynamic = \
- get_qconfig_dtypes(qconfig)
- # Currently `QConfig` only has one `activation` field.
- # For static quantization, it is reused for both input
- # and output activation. For dynamic quantization, this
- # field is currently only used for the input activation,
- # with the output activation being in fp32.
- # In the future this may change as we add more fields
- # to the `QConfig` object.
- output_act_dtype = act_dtype \
- if (not input_act_is_dynamic) else torch.float
- bias_dtype = torch.float16 \
- if (
- act_dtype == torch.float16
- and weight_dtype == torch.float16
- and (not input_act_is_dynamic)
- ) else torch.float
- return {
- "input_act_obs_or_fq_ctr": qconfig.activation,
- "weight_obs_or_fq_ctr": qconfig.weight,
- "bias_obs_or_fq_ctr": PlaceholderObserver.with_args(dtype=bias_dtype),
- "output_act_obs_or_fq_ctr": qconfig.activation,
- }
- return copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
- def _get_arg_target_dtype_as_output(
- arg: Node,
- named_modules: Dict[str, torch.nn.Module],
- ) -> Optional[Union[torch.dtype, type]]:
- """ Get the target output activation dtype for
- the argument in the original graph, skipping inserted observers
- We are assuming that the observers are inserted correctly, and the dtype for
- argument in quantized graph will match what is specified by the qconfig
- """
- assert isinstance(arg, Node)
- # Custom module LSTM output is a tuple that we broke down into the internal nodes in order
- # to insert DeQuantStubs (see `_insert_dequant_stubs_for_custom_module_lstm_output`).
- # Since we modified the graph in this case, we must trace back from the args through
- # the specific nodes we added in order to reach the original LSTM node. Otherwise, we would
- # not be able to accurately detect whether this node is a consumer of custom module LSTM.
- custom_module_lstm_node = _maybe_get_custom_module_lstm_from_node_arg(arg, named_modules)
- output_act_obs_or_fq_ctr = None
- if custom_module_lstm_node is not None:
- output_act_obs_or_fq_ctr = custom_module_lstm_node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
- elif _is_activation_post_process_node(arg, named_modules):
- observed_arg = arg.args[0]
- assert isinstance(observed_arg, Node), "Currently we only support observing Node"
- output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
- else:
- output_act_obs_or_fq_ctr = \
- arg.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"]
- output_act_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
- # TODO: should support is_dynamic here as well
- return output_act_dtype
- def _get_arg_target_dtype_as_input_to_node(
- arg: Node,
- node: Node,
- named_modules: Dict[str, torch.nn.Module],
- backend_config: BackendConfig,
- ) -> Optional[Union[torch.dtype, type]]:
- """ Get the target argument dtype for the argument `arg`, as input
- to node `node`
- """
- assert isinstance(arg, Node)
- is_weight = node_arg_is_weight(node, arg, backend_config)
- is_bias = node_arg_is_bias(node, arg, backend_config)
- is_activation = not is_weight and not is_bias
- if is_activation:
- input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
- qconfig_dtype, _ = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
- return qconfig_dtype
- elif is_weight:
- if node.target in NON_QUANTIZABLE_WEIGHT_OPS:
- return None
- else:
- weight_obs_or_fq_ctr = node.meta["target_dtype_info"].get("weight_obs_or_fq_ctr", None)
- qconfig_weight_dtype, _ = _get_dtype_and_is_dynamic(weight_obs_or_fq_ctr)
- return qconfig_weight_dtype
- else:
- bias_obs_or_fq_ctr = node.meta["target_dtype_info"].get("bias_obs_or_fq_ctr", None)
- qconfig_bias_dtype, _ = _get_dtype_and_is_dynamic(bias_obs_or_fq_ctr)
- return qconfig_bias_dtype
- def _get_arg_target_is_dynamic_as_input_to_node(
- arg: Node,
- node: Node,
- named_modules: Dict[str, torch.nn.Module],
- backend_config: BackendConfig,
- ) -> bool:
- """ Get the target argument dtype for the argument `arg`, as input
- to node `node`
- """
- assert isinstance(arg, Node)
- is_weight = node_arg_is_weight(node, arg, backend_config)
- is_bias = node_arg_is_bias(node, arg, backend_config)
- is_activation = not is_weight and not is_bias
- if is_activation and "input_act_obs_or_fq_ctr" in node.meta["target_dtype_info"]:
- input_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("input_act_obs_or_fq_ctr")
- _, qconfig_is_dynamic = _get_dtype_and_is_dynamic(input_act_obs_or_fq_ctr)
- return qconfig_is_dynamic
- else:
- return False
- def _maybe_insert_input_observer_for_arg_or_kwarg(
- node: Union[Node, Any],
- arg: Argument,
- qconfig: QConfigAny,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- qhandler: Optional[QuantizeHandler],
- prepare_custom_config: PrepareCustomConfig,
- backend_config: BackendConfig,
- ) -> Argument:
- """
- Given a `node` and an `arg`, inserts an input observer between
- `node` and `arg` if necessary.
- """
- # for ops such as torch.cat([x0, x1]),
- # traverse through the list
- if isinstance(arg, (list, tuple)):
- new_arg_to_return = []
- for inner_arg in arg:
- new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
- node, inner_arg, qconfig, model, named_modules,
- graph,
- qhandler,
- prepare_custom_config,
- backend_config)
- new_arg_to_return.append(new_inner_arg)
- return type(arg)(new_arg_to_return)
- if not isinstance(arg, Node):
- return arg
- assert isinstance(arg, Node)
- # default (no observer)
- new_arg = arg
- is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
- assert qconfig is not None
- if not is_standalone_module:
- # regular flow for most nodes, except standalone modules
- is_weight = node_arg_is_weight(node, arg, backend_config)
- _is_reuse_input_qconfig_ = _is_reuse_input_qconfig(qconfig)
- act_post_process_ctr = qconfig.weight if is_weight else \
- qconfig.activation
- arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules)
- arg_as_input_target_dtype = _get_arg_target_dtype_as_input_to_node(
- arg, node, named_modules, backend_config)
- arg_as_input_target_is_dynamic = \
- _get_arg_target_is_dynamic_as_input_to_node(
- arg, node, named_modules, backend_config) # type: ignore[arg-type]
- needs_obs = \
- (
- # the following code block is for static quantization
- (not arg_as_input_target_is_dynamic) and
- # if the dtypes are different, we need an observer
- (arg_as_output_target_dtype != arg_as_input_target_dtype) and
- # except if the second dtype is float, a dequant will be inserted
- # without an observer in convert
- # TODO(future PR): change this so a placeholder is inserted for
- # future dequants, to make the logic easier to understand
- (arg_as_input_target_dtype != torch.float) and
- # if arg output dtype is in _DO_NOT_OBS_DTYPE_LIST do not insert observer
- (arg_as_output_target_dtype not in _DO_NOT_OBS_DTYPE_LIST) and
- # if qconfig is reuse_input qconfig, we won't insert extra observer for input
- not _is_reuse_input_qconfig_
- ) or (
- # need to add input observer for dynamic quantization
- # only add observer for first input for now, we may need to extend
- # qconfig_dict and backend_config to support more general configurations
- # of dynamic quantization, e.g. dynamically quantizing second input, third
- # input etc.
- arg_as_input_target_is_dynamic and arg is node.args[0]
- )
- else:
- # custom flow for standalone modules
- _, _, sm_prepare_custom_config, _ = \
- _get_standalone_module_configs(
- node, named_modules, prepare_custom_config, qconfig, backend_config)
- sm_input_quantized_idxs = sm_prepare_custom_config.input_quantized_indexes
- # for args, this is set to the index of the current arg
- # for kwargs, this is left at None
- cur_input_idx = None
- for arg_idx, arg_to_check in enumerate(node.args):
- if arg_to_check is arg:
- cur_input_idx = arg_idx
- break
- if cur_input_idx is None:
- needs_obs = False
- else:
- arg_as_output_target_dtype = _get_arg_target_dtype_as_output(arg, named_modules)
- arg_as_input_target_dtype = torch.quint8 if cur_input_idx in sm_input_quantized_idxs \
- else torch.float
- needs_obs = (
- (arg_as_output_target_dtype != arg_as_input_target_dtype) and
- (arg_as_input_target_dtype != torch.float)
- )
- act_post_process_ctr = qconfig.activation
- if needs_obs:
- new_obs_mod = act_post_process_ctr()
- existing_obs_node = None
- # Before using the new observer, check if an observer
- # of the correct type already exists. If it does, use it.
- # This prevents duplicate observer insertions if a node is
- # used by multiple nodes.
- # TODO: this is looking into how the value is used in the future
- # we should remove this
- # removing this means we insert one observer for each use, even if they
- # have the same dtype, we can have an extra pass that removes the extra observers
- for maybe_obs_node, _ in arg.users.items():
- if maybe_obs_node.op == 'call_module':
- maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index]
- if (
- type(maybe_obs_mod) == type(new_obs_mod) and
- maybe_obs_mod.dtype == arg_as_input_target_dtype
- ):
- existing_obs_node = maybe_obs_node
- break
- if existing_obs_node is None:
- new_obs_node = _insert_observer(
- arg, new_obs_mod, model, named_modules, graph)
- # override this arg to be the observed arg
- new_arg = new_obs_node
- else:
- new_arg = existing_obs_node
- return new_arg
- def _maybe_insert_input_observers_for_node(
- node: Node,
- qconfig: QConfigAny,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- qhandler: Optional[QuantizeHandler],
- prepare_custom_config: PrepareCustomConfig,
- backend_config: BackendConfig,
- ) -> None:
- """
- If needed, inserts observers to the input args and kwargs of `node`.
- Note: modifies `node` inplace.
- For example, if cur_node needs an observer after prev_node, we change from
- prev_node -> cur_node
- To
- prev_node -> obs -> cur_node
- """
- if qconfig is None:
- # if quantization is turned off for this node, we do not need
- # to insert input observers
- return
- assert qconfig is not None
- # Look through every input arg. If that arg's target dtype does not
- # match the current node's target dtype, insert an observer.
- new_args = []
- for arg in node.args:
- new_arg = _maybe_insert_input_observer_for_arg_or_kwarg(
- node, arg, qconfig, model, named_modules, graph,
- qhandler,
- prepare_custom_config,
- backend_config)
- new_args.append(new_arg)
- new_kwargs = {}
- for k, kwarg in node.kwargs.items():
- new_kwarg = _maybe_insert_input_observer_for_arg_or_kwarg(
- node, kwarg, qconfig, model, named_modules, graph,
- qhandler,
- prepare_custom_config,
- backend_config)
- new_kwargs[k] = new_kwarg
- # assign the new args and kwargs to the node, inplace
- node.args = tuple(new_args)
- node.kwargs = new_kwargs
- def _maybe_insert_input_equalization_observers_for_node(
- node: Node,
- equalization_qconfig: Any,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- is_branch: bool,
- backend_config: BackendConfig,
- ) -> None:
- """
- If `node` needs to be equalized, find the input/weight observers it needs in
- `equalization_qconfig`, creates them, and inserts it into `graph`.
- If `node` does not need an equalization observer, returns None.
- """
- if equalization_qconfig is None or not node_supports_equalization(node, named_modules):
- return
- if is_branch:
- warnings.warn(
- f"Cannot equalize {node} because it is part of a branch."
- )
- return
- new_args = []
- for arg in node.args:
- if not isinstance(arg, Node) or node_arg_is_bias(node, arg, backend_config):
- new_args.append(arg)
- continue
- is_weight = node_arg_is_weight(node, arg, backend_config)
- act_eq_process_ctr = equalization_qconfig.weight if is_weight else \
- equalization_qconfig.input_activation
- new_eq_obs_mod = act_eq_process_ctr()
- new_eq_obs_node = _insert_observer(
- arg, new_eq_obs_mod, model, named_modules, graph)
- new_args.append(new_eq_obs_node)
- # assign the new args and kwargs to the node, inplace
- node.args = tuple(new_args)
- def _maybe_insert_output_observer_for_node(
- node: Node,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
- matched_pattern: Any,
- qhandler: Optional[QuantizeHandler],
- is_qat: bool,
- ) -> Optional[Node]:
- """
- If `node` needs an output observer, creates it, inserts it into `graph`
- and returns it.
- If `node` does not need an output observer, returns None.
- """
- root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get(
- node.name, (None, None, None, None, None))
- if qhandler is None:
- return None
- assert qconfig is not None
- assert node.op != 'output', 'observer insertion for outputs is handled elsewhere'
- is_standalone_module = qhandler is not None and qhandler.is_standalone_module()
- output_act_obs_or_fq_ctr = node.meta["target_dtype_info"].get("output_act_obs_or_fq_ctr")
- qconfig_dtype, _ = _get_dtype_and_is_dynamic(output_act_obs_or_fq_ctr)
- should_insert_observer = qconfig_dtype not in _DO_NOT_OBS_DTYPE_LIST + [torch.float]
- # TODO(future PR): move the following logic to
- # should_insert_observer_for_output
- should_insert_observer = should_insert_observer and \
- activation_is_statically_quantized(qconfig)
- # we never insert observers to output of standalone module, we assume
- # if needed, they are inserted inside the standalone module
- should_insert_observer = should_insert_observer and \
- (not is_standalone_module)
- if should_insert_observer:
- observer = qconfig.activation()
- return _insert_observer(node, observer, model, named_modules, graph)
- else:
- return None
- def _maybe_insert_observers_before_graph_output(
- graph_output_node: Node,
- output_quantized_idxs: List[int],
- node_name_to_qconfig: Dict[str, QConfigAny],
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- ) -> None:
- """
- If the output needs to be quantized and there are any nodes
- in the output which are not already observed, inserts observers
- for those nodes.
- """
- # TODO(future PR): update the output_quantized_idxs API to match
- # arbitrary data structures. There is always a single output, and
- # that output can have arbitrary nesting of values. List[int] is
- # not the right data type for this.
- assert output_quantized_idxs == [0] or output_quantized_idxs == [], \
- 'unrecognized format of output_quantized_idxs'
- # Currently dequants are inserted in the convert step. So, we only
- # have to do anything if the output is hardcoded to be quantized
- if output_quantized_idxs == []:
- return
- # TODO(future PR): support more dtypes in model outputs, if necessary
- output_target_dtype = torch.quint8
- def _recursive_maybe_replace_node_with_obs(
- maybe_node: Argument,
- target_dtype: torch.dtype,
- node_name_to_qconfig: Dict[str, QConfigAny],
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- graph: Graph,
- ) -> Argument:
- """
- Navigate an arbitrary data structure of lists, tuples, dicts.
- For each container type, recurse on all inputs. Once any Node
- is found, insert an observer if needed and do not recurse further.
- For example, given a structure of
- {'foo1': [[bar1]], 'foo2': {'foo3': [[[bar3]]]}}
- we recurse down to bar1 and bar3, observe them if necessary,
- and if we inserted an observer then replace the original node
- with its observer.
- Returns the data structure with all nodes needing observation being
- replaced by their observers.
- """
- if isinstance(maybe_node, Node):
- # check dtype of this node
- this_node_dtype = _get_arg_target_dtype_as_output(
- maybe_node, named_modules)
- if this_node_dtype != target_dtype:
- # insert observer
- qconfig = node_name_to_qconfig.get(maybe_node.name)
- # TODO(future PR): see if we need to allow specifying qconfig
- # on output nodes, to remove the restriction below.
- assert qconfig is not None, \
- 'Quantizing the output node without a qconfig is not supported'
- observer_mod = qconfig.activation()
- observer_node = _insert_observer(
- maybe_node, observer_mod, model, named_modules, graph)
- return observer_node
- else:
- return maybe_node
- elif isinstance(maybe_node, (list, tuple)):
- results = []
- for inner_node in maybe_node:
- results.append(_recursive_maybe_replace_node_with_obs(
- inner_node, target_dtype, node_name_to_qconfig, model, named_modules, graph))
- if isinstance(maybe_node, list):
- return results
- else:
- return tuple(results)
- elif isinstance(maybe_node, dict):
- results_dict = {}
- for k, inner_v in maybe_node.items():
- results_dict[k] = _recursive_maybe_replace_node_with_obs(
- inner_v, target_dtype, node_name_to_qconfig, model, named_modules, graph)
- return results_dict
- else:
- return results
- new_args = []
- for old_arg in graph_output_node.args:
- new_args.append(
- _recursive_maybe_replace_node_with_obs(
- old_arg, output_target_dtype, node_name_to_qconfig, model, named_modules, graph))
- graph_output_node.args = tuple(new_args) # type: ignore[assignment]
- def _maybe_propagate_dtype_for_node(
- node: Node,
- target_dtype: Union[torch.dtype, type],
- node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
- ) -> None:
- """
- Assigns `target_dtype` to `node`, setting `is_dynamic` to False. If `node`
- is a general tensor shape op, also call this function recursively on
- the first argument, to propagate the dtype to the caller.
- """
- node.meta["target_dtype_info"]["input_act_obs_or_fq_ctr"] = None
- node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"] = None
- # if this is a copy node, propagate to first arg
- root_node, _, pattern, qhandler, qconfig = node_name_to_match_result_with_qconfig.get(
- node.name, (None, None, None, None, None))
- # TODO: probably need to remove `is_general_tensor_value_op`
- if qhandler is not None and qhandler.is_general_tensor_value_op():
- prev_node = node.args[0]
- if isinstance(prev_node, Node):
- _maybe_propagate_dtype_for_node(
- prev_node, target_dtype, node_name_to_match_result_with_qconfig)
- def propagate_dtypes_for_known_nodes(
- graph: Graph,
- node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
- ) -> None:
- """
- Currently we assume that inputs to the graph are either `torch.float` or
- `torch.quint8`, which is not always correct. For ops such as
- `x.masked_fill(mask, value)`, we know that the dtype of `mask` is a
- `BoolTensor`. Propagate this information throughout the graph.
- Note: not all dtypes in the graph will be correct after this pass, but a
- higher percentage of them will be correct. Hopefully in the future we can
- replace this with a better way to reason about dtypes of tensors.
- """
- for node in graph.nodes:
- non_observable_arg_dict = get_non_observable_arg_indexes_and_types(node)
- for arg_type in non_observable_arg_dict:
- non_observable_indices = non_observable_arg_dict[arg_type](node)
- for index in non_observable_indices:
- arg = node.args[index]
- # when an argument is a tuple, it does not show up as another node so we need to go through
- # all elements of the tuple manually
- if isinstance(arg, (tuple, list)):
- arg_list = list(arg)
- else:
- arg_list = [arg]
- for cur_arg in arg_list:
- # hard coded arguments show up but aren't `Node` typed and do not need dtype propgated
- if isinstance(cur_arg, torch.fx.node.Node):
- _maybe_propagate_dtype_for_node(
- cur_arg, arg_type, node_name_to_match_result_with_qconfig)
- def _maybe_make_input_output_share_observers(
- node: Node,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module],
- ) -> bool:
- """
- Ensures that we share an observer
- for all input arguments as well as the output argument. In detail, given
- a graph of
- x0 -> obs0 -> op -> x2
- /
- x1 -> obs1 /
- where node obs0 points to observer instance observer0,
- obs1 points to observer1 and obs2 points to observer2, we make nodes obs1
- and ob2 point to observer0.
- Returns: whether the operation succeeded or not
- """
- first_arg = None
- # find the first non-Tensor arg
- for i in range(len(node.args)):
- if isinstance(node.args[i], (Node, list, tuple)):
- first_arg = node.args[i]
- break
- # if there is no non-Tensor arg, return directly
- if first_arg is None:
- return False
- if isinstance(first_arg, (list, tuple)):
- first_arg_arg = first_arg[0]
- elif isinstance(first_arg, Node):
- first_arg_arg = first_arg
- else:
- return False
- # if we have a graph such as
- # observed_node -> non_observed_node -> cat
- # we need to navigate up to the first observer
- iteration_guard = 0
- while not _is_activation_post_process_node(first_arg_arg, named_modules):
- if not isinstance(first_arg_arg, Node):
- return False
- # did not find an activation_post_process for the op
- if first_arg_arg.op == "placeholder":
- return False
- # trace back the args until we found the first Tensor/Node
- trace_back_node = None
- for i in range(len(first_arg_arg.args)):
- trace_back_node = first_arg_arg.args[i]
- if isinstance(trace_back_node, Node):
- break
- if trace_back_node is None:
- return False
- first_arg_arg = trace_back_node
- iteration_guard += 1
- if iteration_guard > 10000:
- raise AssertionError('Unable to find observer of previous node')
- assert isinstance(first_arg_arg, Node)
- target_to_use = first_arg_arg.target
- assert isinstance(target_to_use, str)
- obs_mod_to_use = named_modules[target_to_use]
- if isinstance(first_arg, (list, tuple)):
- # set all other input observer nodes to use that module
- for input_idx, input_arg in enumerate(first_arg):
- if input_idx == 0:
- continue
- iteration_guard = 0
- while not _is_activation_post_process_node(input_arg, named_modules):
- # failed to trace back since no input arg for the current node
- if len(input_arg.args) < 1:
- return False
- input_arg = input_arg.args[0]
- iteration_guard += 1
- if iteration_guard > 10000:
- raise AssertionError('Unable to find observer of previous node')
- parent_name, name = _parent_name(input_arg.target)
- setattr(named_modules[parent_name], name, obs_mod_to_use)
- # set the output observer node to use that module
- for output_obs_node, _ in node.users.items():
- assert _is_activation_post_process_node(output_obs_node, named_modules)
- parent_name, name = _parent_name(output_obs_node.target)
- setattr(named_modules[parent_name], name, obs_mod_to_use)
- # TODO(future PR): delete the orphaned observer modules
- return True
- def _remove_output_observer(
- node: Node,
- model: torch.nn.Module,
- named_modules: Dict[str, torch.nn.Module]):
- items = list(node.users.items())
- for output_obs_node, _ in items:
- assert _is_activation_post_process_node(output_obs_node, named_modules)
- output_obs_node.replace_all_uses_with(node)
- model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
- def _swap_custom_module_to_observed(
- node: Node,
- qconfig: QConfigAny,
- named_modules: Dict[str, torch.nn.Module],
- prepare_custom_config: PrepareCustomConfig):
- custom_module = named_modules[node.target] # type: ignore[index]
- custom_module_class_mapping = prepare_custom_config.float_to_observed_mapping
- observed_custom_module_class = \
- get_swapped_custom_module_class(
- custom_module, custom_module_class_mapping, qconfig)
- observed_custom_module = \
- observed_custom_module_class.from_float(custom_module)
- parent_name, name = _parent_name(node.target)
- setattr(named_modules[parent_name], name, observed_custom_module)
- def insert_observers_for_model(
- model: GraphModule,
- node_name_to_match_result_with_qconfig: Dict[str, _MatchResultWithQConfig],
- node_name_to_qconfig: Dict[str, QConfigAny],
- prepare_custom_config: PrepareCustomConfig,
- equalization_config_map: Dict[str, Any],
- backend_config: BackendConfig,
- observed_node_names: Set[str],
- is_qat: bool,
- ) -> Optional[Node]:
- """
- Inserts observers, using the following high level algorithm:
- For each node in the graph:
- 1. determine the target dtype of this node in the quantized graph, and save
- it for future steps
- 2. determine the target dtype or all args and kwargs of this node
- 3. if any arg or kwarg's target dtype does not match the current node's
- dtype, insert an observer
- 4. if the current node needs an output observer, insert it
- For example:
- - starting graph:
- x0 -> linear -> x1
- - observed graph after processing x0:
- x0(fp32)
- - observed graph after processing linear:
- x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8)
- - observed graph after processing x1:
- x0(fp32) -> x0_obs0(int8) -> linear(int8) -> linear_obs0(int8) -> x1
- After a node is processed, the naive observer placement is guaranteed to be
- complete for that node and all of its predecessors. There can be future
- passes which optimize the graph by deduplicating observers, etc.
- """
- # node.meta["target_dtype_info"] stores the target dtype information
- # that's derived from qconfig for the Node, for example, if we have
- # a conv2d node that has a qconfig
- # qconfig = QConfig(activation=..., weight=...)
- # # information for input and bias node omitted
- # # for getattr node
- # # weight = getattr(self, 'weight')
- # weight.meta["target_dtype_info"] = {
- # 'output_act_obs_or_fq_ctr': qconfig.weight,
- # }
- # # for conv2d node
- # # conv2d = call_function[target=torch.nn.functional.conv2d](
- # # args=(input, weight, bias))
- # conv2d.meta["target_dtype_info"] = {
- # 'input_act_obs_or_fq_ctr': qconfig.activation
- # 'weight_obs_or_fq_ctr': qconfig.weight,
- # 'bias_obs_or_fq_ctr': PlaceholderObserver.with_args(dtype=torch.float32),
- # 'output_act_obs_or_fq_ctr': qconfig.activation,
- # }
- #
- cache_for_no_tensor_check: Dict[Node, bool] = {}
- # first, populate the dtype map based only on qconfig and qhandler
- # this assumes:
- # graph inputs are fp32 by default, and int8 where overriden
- # other nodes output dtype is specified by the qconfig
- named_modules = dict(model.named_modules(remove_duplicate=False))
- input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
- output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
- processed_nodes: Set[Node] = set()
- # initalize target_dtype_info
- for node in model.graph.nodes:
- node.meta["target_dtype_info"] = copy.copy(_DEFAULT_FP32_QCONFIG_FOR_TARGET_DTYPE_INFO)
- inputs_seen_counter = 0
- outputs_seen_counter = 0
- placeholder_node_to_input_index: Dict[Node, int] = {}
- # TODO: we probably don't need this counter since each graph will only have
- # one output node?
- output_node_to_output_index: Dict[Node, int] = {}
- for node in model.graph.nodes:
- if node.op == "placeholder":
- placeholder_node_to_input_index[node] = inputs_seen_counter
- inputs_seen_counter += 1
- if node.op == "output":
- output_node_to_output_index[node] = outputs_seen_counter
- outputs_seen_counter += 1
- # Step 1, set the observer or fake quantize module constructor for each node in the
- # matched_node_pattern
- for node_name, match_res_with_qconfig in node_name_to_match_result_with_qconfig.items():
- last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
- assert qhandler is not None
- _set_target_dtype_info_for_matched_node_pattern(
- matched_node_pattern,
- last_node,
- qconfig,
- backend_config,
- named_modules,
- cache_for_no_tensor_check,
- processed_nodes
- )
- # Step 2. Special cases for some operators, we might be able to remove them
- # in the future if we know dtype information of each node better
- # Step 2.1. some settings are not based on patterns, we need to process each node
- # instead
- for node in model.graph.nodes:
- if node.op == "placeholder" and placeholder_node_to_input_index[node] in input_quantized_idxs:
- # users are not supposed to call calculate_qparams on PlaceholderObserver, and
- # this is OK because we are using this as a way to encode the dtypes of input
- # tensor, we won't actually insert these observers in the graph and won't
- # actually call calculate_qparams
- node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
- elif node.op in ("call_module", "call_method", "call_function"):
- args_have_no_tensors = \
- all_node_args_have_no_tensors(
- node, named_modules, cache_for_no_tensor_check)
- if args_have_no_tensors:
- node.meta["target_dtype_info"] = {
- "input_act_obs_or_fq_ctr": None,
- "output_act_obs_or_fq_ctr": None,
- }
- elif node.op == "output" and output_node_to_output_index[node] in output_quantized_idxs:
- node.meta["target_dtype_info"] = copy.copy(_DEFAULT_QUINT8_QCONFIG_FOR_TARGET_DTYPE_INFO)
- # Step 2.2, for nodes with known input dtypes, propagate them throughout the
- # graph. For example, if there is a call such as
- # x1 = x0.masked_fill(mask, 1)
- # we propagate the type of mask to be torch.bool
- propagate_dtypes_for_known_nodes(model.graph, node_name_to_match_result_with_qconfig)
- # Step 3, check if the requested target_dtype_info is supported by backend or not
- # if not, we'll reset the target_dtye_info to use the default (float Tensor)
- # reset the counters and set of processed_nodes
- processed_nodes = set()
- for node_name, match_res_with_qconfig in node_name_to_match_result_with_qconfig.items():
- last_node, matched_node_pattern, pattern, qhandler, qconfig = match_res_with_qconfig
- is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
- pattern, matched_node_pattern, qconfig, backend_config)
- assert qhandler is not None
- # get output_act_dtype so that we don't also reset the special typed nodes
- # TODO: we might want to handle these more uniformly with the default path
- # this can be improved if we can use node.meta["val"]
- output_act_dtype, _ = _get_dtype_and_is_dynamic(node.meta["target_dtype_info"]["output_act_obs_or_fq_ctr"])
- if not is_supported_by_backend and output_act_dtype not in [None, int, float, torch.bool]:
- # restore target_dtype_info to default if it is not supported by backend
- _set_target_dtype_info_for_matched_node_pattern(
- matched_node_pattern,
- last_node,
- torch.ao.quantization.qconfig._default_fp32_placeholder_qconfig,
- backend_config,
- named_modules,
- cache_for_no_tensor_check,
- processed_nodes
- )
- # After this point, the current node and all of its arguments
- # have a target_dtype_info assigned. Now, we insert observers for inputs
- # of this node (if needed for this node), and the output of this node
- # (if needed for this node).
- # Since we are mutating the graph as we go, we iterate over the original
- # nodes before observer insertion, instead of model.graph.nodes.
- nodes_before_observation = list(model.graph.nodes)
- # Avoid duplicates custom module swaps for multiple nodes with same target.
- custom_module_names_already_swapped: Set[str] = set()
- # TODO: reuse placeholder_node_to_input_index and output_node_to_output_index
- # reset inputs/outputs counters
- inputs_seen_counter = 0
- outputs_seen_counter = 0
- results_node = None
- # TODO: change this to insert obs/fq by pattern instead of by node
- for node in nodes_before_observation:
- if node.op == 'placeholder':
- # if a graph input is in fp32, it does not need observation
- # if a graph input is in int8, we assume the observation happens
- # outside of the graph, and no additional observation is needed
- pass
- elif node.op in ('call_module', 'call_method', 'call_function', 'output'):
- # check for matches
- last_node, matched_node_pattern, pattern, qhandler, qconfig = (
- node_name_to_match_result_with_qconfig.get(node.name, (None, None, None, None, None)) # type: ignore[assignment]
- )
- equalization_qconfig = equalization_config_map.get(node.name, None)
- this_node_dtype_info = node.meta["target_dtype_info"]
- if "val" in node.meta:
- output_is_a_tensor = (
- this_node_dtype_info is not None and
- isinstance(node.meta["val"], FakeTensor)
- )
- else:
- output_is_a_tensor = this_node_dtype_info is not None
- skip_inserting_observers = (
- (qconfig is None) or
- not output_is_a_tensor
- ) and (
- not node.op == 'output'
- )
- # TODO: take a closer look to see if we can remove this check
- # right now it is here because of `observed_node_names`, we are using
- # it as an indicator for swapping the modules to reference modules in
- # convert
- is_supported_by_backend = _is_pattern_dtype_config_and_qconfig_supported_by_backend(
- pattern, matched_node_pattern, qconfig, backend_config)
- if not skip_inserting_observers and is_supported_by_backend:
- named_modules = dict(model.named_modules(remove_duplicate=False))
- if node.op != 'output':
- assert matched_node_pattern is not None
- # add matched nodes to the observed node name set
- _add_matched_node_name_to_set(matched_node_pattern, observed_node_names)
- # This is currently only used for equalization.
- # Checks if the current node is in a branch in which the two
- # first layers are both being quantized.
- #
- # ex. conv2
- # /
- # x -> conv1
- #
- # If this is the case, we will not apply equalization to the
- # initial two layers.
- is_quantized_branch = False
- if (
- len(node.args) > 0 and
- isinstance(node.args[0], Node) and
- len(node.args[0].users) > 1
- ):
- for user in node.args[0].users:
- # Checks if there exists another user being quantized
- is_user_quantized = (
- node_name_to_qconfig.get(user.name, None) is not None or
- (user.op == 'call_module' and isinstance(named_modules[str(user.target)], ObserverBase))
- )
- if user != node and is_user_quantized:
- is_quantized_branch = True
- pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
- root_node_getter = pattern_to_root_node_getter.get(pattern, _default_root_node_getter)
- root_node = root_node_getter(matched_node_pattern)
- is_input_node_of_the_pattern = node is root_node
- if is_input_node_of_the_pattern:
- # this modifies node inplace
- _maybe_insert_input_observers_for_node(
- node, qconfig, model, named_modules, model.graph,
- qhandler,
- prepare_custom_config,
- backend_config)
- # insert equalization input observers if needed
- _maybe_insert_input_equalization_observers_for_node(
- node, equalization_qconfig, model, named_modules, model.graph,
- is_quantized_branch, backend_config)
- is_last_node_of_pattern = node is last_node
- is_general_tensor_value_op = \
- (qhandler is not None and qhandler.is_general_tensor_value_op())
- _is_reuse_input_qconfig_ = _is_reuse_input_qconfig(qconfig)
- if is_last_node_of_pattern:
- if _is_custom_module_lstm(node, named_modules, qconfig, qhandler):
- # Currently custom module outputs are assumed to be already quantized,
- # so we need to insert a DeQuantStub after the output. For custom module
- # LSTM specifically, the outputs are also a nested tuple, so we must first
- # break down the tuple to insert DeQuantStubs after the internal nodes.
- # TODO: This currently diverges from how custom modules are handled today,
- # where we insert observers after the output instead of DeQuantStubs, and
- # replace these observers with "dequantize" nodes during convert. Conceptually,
- # these output observers are the same as DeQuantStubs. In the future, we
- # should resolve this inconsistency by inserting DeQuantStubs for all custom
- # modules, not just for LSTM.
- _insert_dequant_stubs_for_custom_module_lstm_output(node, model, named_modules, model.graph)
- if(node.target not in custom_module_names_already_swapped):
- custom_module_names_already_swapped.add(node.target)
- _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
- else:
- # this returns the new observer node if it was needed
- maybe_output_obs_node = _maybe_insert_output_observer_for_node(
- node, model, named_modules, model.graph, node_name_to_match_result_with_qconfig,
- pattern, qhandler, is_qat)
- if maybe_output_obs_node is not None:
- # Update users of original node to use the output observer
- # instead. For example, change
- #
- # next_node
- # /
- # cur_node -> obs
- #
- # to
- #
- # next_node
- # /
- # cur_node -> obs
- #
- # We need to save orig users before updating uses because
- # the list of users will change as we update uses
- orig_users = list(node.users.keys())
- for user_node in orig_users:
- if user_node is maybe_output_obs_node:
- continue
- user_node.replace_input_with(node, maybe_output_obs_node)
- _is_observer_in_same_graph_ = _is_observer_in_same_graph(
- node, named_modules)
- # for general tensor value ops, we modify the graph
- # to make all inputs and outputs use the first input's
- # observer
- if (is_general_tensor_value_op and _is_observer_in_same_graph_) or \
- _is_reuse_input_qconfig_:
- if not _maybe_make_input_output_share_observers(node, model, named_modules):
- _remove_output_observer(node, model, named_modules)
- if qhandler is not None and qhandler.is_custom_module():
- if(node.target not in custom_module_names_already_swapped):
- custom_module_names_already_swapped.add(node.target)
- _swap_custom_module_to_observed(node, qconfig, named_modules, prepare_custom_config)
- else: # output
- _maybe_insert_observers_before_graph_output(
- node, output_quantized_idxs,
- node_name_to_qconfig,
- model, named_modules, model.graph)
- #
- # After this point, the current node has input and output observers
- # that it needs for itself inserted.
- #
- # increment the counters, so future inputs and outputs are assigned
- # correct dtypes
- if node.op == 'placeholder':
- inputs_seen_counter += 1
- elif node.op == 'output':
- outputs_seen_counter += 1
- results_node = node
- return results_node
- def _run_prepare_fx_on_standalone_modules(
- model: torch.nn.Module,
- is_qat: bool,
- named_modules: Dict[str, torch.nn.Module],
- node_name_to_match_result_with_qconfig: Any,
- prepare_custom_config: PrepareCustomConfig,
- backend_config: BackendConfig,
- ) -> None:
- """
- Runs prepare_fx on each standalone module. Note: this does
- not modify the graph, it just replaces the unobserved modules with
- their observed versions.
- """
- for (
- node_name,
- (root_node, _, pattern, qhandler, qconfig),
- ) in node_name_to_match_result_with_qconfig.items():
- if qhandler is None:
- continue
- elif not qhandler.is_standalone_module():
- continue
- sm_qconfig_mapping, sm_example_inputs, sm_prepare_custom_config, \
- sm_backend_config = _get_standalone_module_configs(
- root_node, named_modules, prepare_custom_config, qconfig, backend_config)
- standalone_module = named_modules[root_node.target]
- prepare = \
- torch.ao.quantization.quantize_fx._prepare_standalone_module_fx # type: ignore[attr-defined]
- observed_standalone_module = \
- prepare(
- standalone_module,
- sm_qconfig_mapping,
- is_qat,
- example_inputs=sm_example_inputs,
- prepare_custom_config=sm_prepare_custom_config,
- backend_config=sm_backend_config)
- parent_name, name = _parent_name(root_node.target)
- setattr(named_modules[parent_name], name, observed_standalone_module)
- named_modules[root_node.target] = observed_standalone_module
- def _save_state(
- observed: GraphModule,
- node_name_to_qconfig: Dict[str, QConfigAny],
- node_name_to_scope: Dict[str, Tuple[str, type]],
- prepare_custom_config: PrepareCustomConfig,
- equalization_node_name_to_qconfig: Dict[str, Any],
- qconfig_mapping: QConfigMapping,
- is_qat: bool,
- observed_node_names: Set[str],
- ) -> None:
- observed.meta["_observed_graph_module_attrs"] = (
- ObservedGraphModuleAttrs(
- node_name_to_qconfig=node_name_to_qconfig,
- node_name_to_scope=node_name_to_scope,
- prepare_custom_config=prepare_custom_config,
- equalization_node_name_to_qconfig=equalization_node_name_to_qconfig,
- qconfig_mapping=qconfig_mapping,
- is_qat=is_qat,
- observed_node_names=observed_node_names,
- )
- )
- def prepare(
- model: GraphModule,
- qconfig_mapping: Union[QConfigMapping, Dict[str, Any]],
- is_qat: bool,
- node_name_to_scope: Dict[str, Tuple[str, type]],
- example_inputs: Tuple[Any, ...],
- prepare_custom_config: Union[PrepareCustomConfig, Dict[str, Any], None] = None,
- _equalization_config: Union[QConfigMapping, Dict[str, Any], None] = None,
- backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
- is_standalone_module: bool = False) -> GraphModule:
- """ standalone_module means it a submodule that is not inlined in
- parent module, and will be quantized separately as one unit.
- How the standalone module is observed is specified by `input_quantized_idxs` and
- `output_quantized_idxs` in the prepare_custom_config for the standalone module
- Args:
- node_name_to_scope: mapping from node name to the scope of the module which contains the node.
- The scope is a tuple of fully qualified path of the module and the type of the module
- Returns:
- model(GraphModule): prepared standalone module
- attributes related to standalone module
- in model.meta["_observed_graph_module_attrs"]:
- is_observed_standalone_module (bool): boolean value that shows whether the
- current model is a observed standalone module or not
- standalone_module_input_quantized_idxs(List[Int]): a list of
- indexes for the graph input that is expected to be quantized,
- same as input_quantized_idxs configuration provided
- for the standalone module
- standalone_module_output_quantized_idxs(List[Int]): a list of
- indexs for the graph output that is quantized
- same as input_quantized_idxs configuration provided
- for the standalone module
- """
- if prepare_custom_config is None:
- prepare_custom_config = PrepareCustomConfig()
- if _equalization_config is None:
- _equalization_config = QConfigMapping()
- if isinstance(qconfig_mapping, Dict):
- warnings.warn(
- "Passing a QConfig dictionary to prepare is deprecated and will not be supported "
- "in a future version. Please pass in a QConfigMapping instead.")
- qconfig_mapping = QConfigMapping.from_dict(qconfig_mapping)
- if isinstance(_equalization_config, Dict):
- warnings.warn(
- "Passing a QConfig dictionary to prepare for equalization is deprecated and will not "
- "be supported in a future version. Please pass in a QConfigMapping instead.")
- _equalization_config = QConfigMapping.from_dict(_equalization_config)
- if isinstance(prepare_custom_config, Dict):
- warnings.warn(
- "Passing a prepare_custom_config_dict to prepare is deprecated and will not be supported "
- "in a future version. Please pass in a PrepareCustomConfig instead.")
- prepare_custom_config = PrepareCustomConfig.from_dict(prepare_custom_config)
- if isinstance(backend_config, Dict):
- warnings.warn(
- "Passing a backend_config_dict to prepare is deprecated and will not be supported "
- "in a future version. Please pass in a BackendConfig instead.")
- backend_config = BackendConfig.from_dict(backend_config)
- assert(isinstance(qconfig_mapping, QConfigMapping))
- assert(isinstance(_equalization_config, QConfigMapping))
- qconfig_mapping = copy.deepcopy(qconfig_mapping)
- _equalization_config = copy.deepcopy(_equalization_config)
- # mapping from a tuple of nodes in reverse order to uninitialized
- # QuantizeHandler subclass. For example,
- # {
- # # match a single node
- # (<class 'torch.nn.modules.conv.Conv3d'>:
- # <class 'torch.ao.quantization.fx.quantize.ConvRelu'>),
- # # match multiple nodes in reverse order
- # ((<function relu at 0x7f766a7360d0>, <built-in function add>):
- # <class 'torch.ao.quantization.fx.quantize.Add'>),
- # }
- pattern_to_quantize_handler: Dict[Pattern, QuantizeHandler] = {}
- if backend_config is None:
- backend_config = get_native_backend_config()
- pattern_to_quantize_handler = _get_pattern_to_quantize_handlers(backend_config)
- pattern_to_quantize_handler = _sorted_patterns_dict(pattern_to_quantize_handler)
- root_node_getter_mapping = \
- get_fusion_pattern_to_root_node_getter(backend_config)
- _update_qconfig_for_fusion(model, qconfig_mapping)
- _update_qconfig_for_fusion(model, _equalization_config)
- flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
- # TODO: support regex as well
- propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
- if is_qat:
- module_to_qat_module = get_module_to_qat_module(backend_config)
- _qat_swap_modules(model, module_to_qat_module)
- _update_qconfig_for_qat(qconfig_mapping, backend_config)
- # mapping from fully qualified module name to module instance
- # for example,
- # {
- # '': Model(...),
- # 'linear': Linear(...),
- # 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
- # }
- named_modules = dict(model.named_modules(remove_duplicate=False))
- # fill node_name_to_qconfig, a map from node name to qconfig, used in _find_matches
- equalization_node_name_to_qconfig = _generate_node_name_to_qconfig(
- model, named_modules, model.graph, _equalization_config, node_name_to_scope)
- node_name_to_qconfig = _generate_node_name_to_qconfig(model, named_modules, model.graph, qconfig_mapping, node_name_to_scope)
- # match the patterns that will get quantized
- standalone_module_names = list(prepare_custom_config.standalone_module_names.keys())
- standalone_module_classes = list(prepare_custom_config.standalone_module_classes.keys())
- custom_module_classes = get_custom_module_class_keys(prepare_custom_config.float_to_observed_mapping)
- matches_without_qconfig = _find_matches(
- model.graph, named_modules, pattern_to_quantize_handler, root_node_getter_mapping,
- standalone_module_names, standalone_module_classes, custom_module_classes)
- # map qconfig instances to matches
- node_name_to_match_result_with_qconfig = {}
- for node_name, match_without_qconfig in matches_without_qconfig.items():
- match_with_qconfig = (*match_without_qconfig, node_name_to_qconfig[node_name])
- node_name_to_match_result_with_qconfig[node_name] = match_with_qconfig
- _run_prepare_fx_on_standalone_modules(
- model, is_qat, named_modules, node_name_to_match_result_with_qconfig, prepare_custom_config, backend_config)
- # record names for the set of observed node, so that in convert step
- # we know whether we need to convert a floating point module to reference
- # quantized module or not
- observed_node_names: Set[str] = set()
- result_node = insert_observers_for_model(
- model,
- node_name_to_match_result_with_qconfig,
- node_name_to_qconfig,
- prepare_custom_config,
- equalization_node_name_to_qconfig,
- backend_config,
- observed_node_names,
- is_qat
- )
- model = GraphModule(model, model.graph)
- _save_state(model, node_name_to_qconfig, node_name_to_scope,
- prepare_custom_config, equalization_node_name_to_qconfig,
- qconfig_mapping, is_qat, observed_node_names)
- if is_standalone_module:
- assert result_node is not None
- assert isinstance(result_node.args[0], Node), \
- "standalone module only supports returning simple value currently"\
- "(not tuple, dict etc.)"
- # these inputs are observed in parent
- # converting List[int] to Tensor since module attribute is
- # Union[Tensor, Module]
- input_quantized_idxs: List[int] = prepare_custom_config.input_quantized_indexes
- output_quantized_idxs: List[int] = prepare_custom_config.output_quantized_indexes
- observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
- # inplace modification
- observed_graph_module_attrs.is_observed_standalone_module = True
- observed_graph_module_attrs.standalone_module_input_quantized_idxs = \
- input_quantized_idxs
- observed_graph_module_attrs.standalone_module_output_quantized_idxs = \
- output_quantized_idxs
- return model
|