123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343 |
- import torch
- import re
- from collections import defaultdict, OrderedDict
- from typing import Callable, Any, Dict, Tuple, Set, List, Union
- from torch.ao.quantization import QConfig
- from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
- from torch.ao.quantization.observer import (
- _is_activation_post_process,
- )
- from torch.ao.quantization.backend_config import (
- BackendConfig,
- DTypeConfig,
- )
- from torch.ao.quantization.backend_config.utils import (
- get_module_to_qat_module,
- )
- from torch.fx import (
- GraphModule,
- )
- from torch.fx.graph import (
- Graph,
- )
- from torch.ao.nn.intrinsic import _FusedModule
- from ..utils import (
- _parent_name,
- get_qconfig_dtypes,
- )
- from ..qconfig_mapping import (
- _OBJECT_TYPE_DICT_KEY,
- _MODULE_NAME_DICT_KEY,
- _MODULE_NAME_REGEX_DICT_KEY,
- QConfigMapping,
- )
- __all__: List[str] = []
- def _maybe_adjust_qconfig_for_module_name_object_type_order(
- qconfig_mapping: QConfigMapping,
- cur_module_path: str,
- cur_object_type: Callable,
- cur_object_type_idx: int,
- fallback_qconfig: QConfigAny,
- ) -> QConfigAny:
- for (module_name, object_type, index), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items():
- if (
- (module_name == cur_module_path) and
- (object_type == cur_object_type) and
- (index == cur_object_type_idx)
- ):
- return qconfig
- return fallback_qconfig
- def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
- """
- Update the QConfigMapping to account for fused modules such as LinearReLU.
- This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
- """
- object_type_dict = qconfig_mapping.object_type_qconfigs
- if len(object_type_dict) == 0:
- return qconfig_mapping
- modules = dict(model.named_modules())
- for node in model.graph.nodes:
- if node.op == 'call_module' and node.target in modules:
- maybe_fused_module = modules[str(node.target)]
- if not isinstance(maybe_fused_module, _FusedModule):
- continue
- ops = list(maybe_fused_module._modules.values())
- fused_qconfig = object_type_dict.get(type(ops[0]), None)
- # Raise an error if the modules in the fused module have
- # different qconfigs specified in the qconfig_dict
- # TODO: currently it only works for modules,
- # need to make this work for torch.nn.functional.relu
- # TODO: currently it only works for object_type configurations,
- # ideally it should work for different types of configurations,
- # maybe we want to redesign this part
- for op in ops[1:]:
- if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
- raise LookupError(
- "During fusion, we need to specify the same " +
- f"qconfigs for all module types in {type(maybe_fused_module)} " +
- f"offending type: {type(op)}")
- if fused_qconfig is not None:
- object_type_dict[type(maybe_fused_module)] = fused_qconfig
- def _generate_node_name_to_qconfig(
- root: torch.nn.Module,
- modules: Dict[str, torch.nn.Module],
- input_graph: Graph,
- qconfig_mapping: QConfigMapping,
- node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]:
- global_qconfig = qconfig_mapping.global_qconfig
- node_name_to_qconfig = {}
- # example:
- #
- # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
- #
- # meaning in submodule 'foo.bar', we have seen 0 F.linear and
- # 1 F.conv2d invocations so far.
- submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \
- defaultdict(lambda: defaultdict(int))
- for node in input_graph.nodes:
- qconfig = None
- if node.op == "get_attr":
- module_name, _ = _parent_name(node.target)
- qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
- qconfig_mapping, type(modules[module_name]), module_name, global_qconfig)
- qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
- elif node.op == "call_function":
- # precedence: module_name_qconfig
- # > function_qconfig > global_qconfig
- # module_name takes precedence over function qconfig
- function_qconfig = _get_object_type_qconfig(
- qconfig_mapping, node.target, global_qconfig)
- module_path, module_type = node_name_to_scope[node.name]
- qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
- qconfig_mapping, module_type, module_path, function_qconfig)
- cur_object_type_idx = \
- submodule_to_object_type_to_cur_idx[module_path][node.target]
- submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
- qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
- qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig)
- qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
- elif node.op == "call_method":
- module_path, module_type = node_name_to_scope[node.name]
- # first use node.target (string) to get the qconfig
- # this is to support configs like
- # "object_type": [("reshpe", qconfig)]
- qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
- qconfig_mapping, node.target, module_path, global_qconfig)
- # if there is no special config for the method, we'll fall back to the
- # config for the module that contains the call_method node
- qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
- qconfig_mapping, module_type, module_path, qconfig)
- # currently call_method does not support modifying qconfig
- # by order, we can add this later if it is needed.
- qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
- elif node.op == 'call_module':
- # if the node is an observer, just continue - don't add it to the qconfig_map
- if _is_activation_post_process(modules[node.target]):
- continue
- qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
- qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
- module_path, module_type = node_name_to_scope[node.name]
- # Note: for call_module, the module_path is the current module's name.
- # to meaningfully count invocations, we need to count them in the parent
- # module.
- parent_name, _ = _parent_name(module_path)
- cur_object_type_idx = \
- submodule_to_object_type_to_cur_idx[parent_name][module_type]
- submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
- qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
- qconfig_mapping, parent_name, module_type, cur_object_type_idx,
- qconfig)
- qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
- # regex is not supported eager mode propagate_qconfig_, we'll
- # need to set the qconfig explicitly here in case regex
- # is used
- modules[node.target].qconfig = qconfig_with_device_check
- else:
- qconfig_with_device_check = None
- node_name_to_qconfig[node.name] = qconfig_with_device_check
- return node_name_to_qconfig
- def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
- r""" Checks if the given config_dict has the correct keys
- Args:
- `config_dict`: dictionary whose keys we want to check
- """
- for k in config_dict.keys():
- if k not in allowed_keys:
- raise ValueError(
- 'Expected ' + dict_name + ' to have the following keys: ' +
- str(allowed_keys) + '. But found \'' + k +
- '\' instead.')
- def _compare_prepare_convert_qconfig_mappings(
- prepare_qconfig_mapping: QConfigMapping,
- convert_qconfig_mapping: QConfigMapping):
- r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values
- Args:
- `prepare_qconfig_mapping`: configuration for prepare quantization step
- `convert_qconfig_mapping`: configuration for convert quantization step
- """
- assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \
- "Expected global qconfigs to be the same in the prepare and convert quantization configs"
- prepare_dicts: List[OrderedDict] = [
- prepare_qconfig_mapping.object_type_qconfigs,
- prepare_qconfig_mapping.module_name_qconfigs,
- prepare_qconfig_mapping.module_name_regex_qconfigs,
- ]
- convert_dicts: List[OrderedDict] = [
- convert_qconfig_mapping.object_type_qconfigs,
- convert_qconfig_mapping.module_name_qconfigs,
- convert_qconfig_mapping.module_name_regex_qconfigs,
- ]
- dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY]
- for i in range(len(prepare_dicts)):
- for name, qconfig in prepare_dicts[i].items():
- assert name in convert_dicts[i], "Missing key {} {} in convert QConfigMapping \
- when it was present in prepare".format(dict_names[i], name)
- assert convert_dicts[i][name] is None \
- or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \
- "Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \
- prepare: {}; convert: {}".format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name])
- def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
- for dtype_config in dtype_configs:
- is_dynamic = dtype_config.is_dynamic
- if is_dynamic is None:
- is_dynamic = False
- input_dtype = dtype_config.input_dtype or torch.float
- weight_dtype = dtype_config.weight_dtype or torch.float
- bias_dtype = dtype_config.bias_dtype or torch.float
- output_dtype = dtype_config.output_dtype or torch.float
- qconfig_activation_dtype, qconfig_weight_dtype, qconfig_input_act_is_dynamic = \
- get_qconfig_dtypes(qconfig)
- qconfig_bias_dtype = torch.float16 \
- if (
- qconfig_activation_dtype == torch.float16
- and qconfig_weight_dtype == torch.float16
- and not is_dynamic
- ) else torch.float
- if is_dynamic:
- is_match = qconfig_input_act_is_dynamic and \
- input_dtype == qconfig_activation_dtype and \
- output_dtype == torch.float and \
- weight_dtype == qconfig_weight_dtype
- else:
- is_match = input_dtype == qconfig_activation_dtype and \
- output_dtype == qconfig_activation_dtype and \
- weight_dtype == qconfig_weight_dtype and \
- bias_dtype == qconfig_bias_dtype
- if is_match:
- return True
- return False
- def _get_object_type_qconfig(
- qconfig_mapping: QConfigMapping,
- object_type: Union[Callable, str],
- fallback_qconfig: QConfigAny) -> QConfigAny:
- return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
- def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
- for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
- if re.match(regex_pattern, module_name):
- # first match wins
- return qconfig
- return fallback_qconfig
- def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
- if module_name == '':
- # module name qconfig not found
- return fallback_qconfig
- if module_name in qconfig_mapping.module_name_qconfigs:
- return qconfig_mapping.module_name_qconfigs[module_name]
- else:
- parent, _ = _parent_name(module_name)
- return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
- def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
- # get qconfig for module_name,
- # fallback to module_name_regex_qconfig, module_type_qconfig,
- # global_qconfig if necessary
- module_type_qconfig = _get_object_type_qconfig(
- qconfig_mapping, module_type, global_qconfig)
- module_name_regex_qconfig = _get_module_name_regex_qconfig(
- qconfig_mapping, module_name, module_type_qconfig)
- module_name_qconfig = _get_module_name_qconfig(
- qconfig_mapping, module_name, module_name_regex_qconfig)
- return module_name_qconfig
- def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
- """ flatten the global, object_type and module_name qconfig
- to the same qconfig_dict so that it can be used by
- propagate_qconfig_ function.
- "module_name_regex" is ignored for now since it's not supported
- in propagate_qconfig_, but it can be fixed later.
- For example:
- Input: {
- "": qconfig,
- "object_type": [
- (torch.add, qconfig)
- ],
- "module_name": [
- ("conv", qconfig)
- ]
- }
- Output: {
- "": qconfig,
- torch.add: qconfig,
- "conv": qconfig
- }
- """
- flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig}
- for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
- flattened[obj] = qconfig
- for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
- flattened[obj] = qconfig
- return flattened
- def _update_qconfig_for_qat(
- qconfig_mapping: QConfigMapping,
- backend_config: BackendConfig):
- """
- Update the qconfig_mapping to account for module swaps during QAT.
- During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
- """
- module_to_qat_module_class = get_module_to_qat_module(backend_config)
- object_type_dict = qconfig_mapping.object_type_qconfigs
- new_object_type_dict = object_type_dict.copy()
- for k, v in new_object_type_dict.items():
- if k in module_to_qat_module_class:
- object_type_dict[module_to_qat_module_class[k]] = v
|