import torch import re from collections import defaultdict, OrderedDict from typing import Callable, Any, Dict, Tuple, Set, List, Union from torch.ao.quantization import QConfig from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals from torch.ao.quantization.observer import ( _is_activation_post_process, ) from torch.ao.quantization.backend_config import ( BackendConfig, DTypeConfig, ) from torch.ao.quantization.backend_config.utils import ( get_module_to_qat_module, ) from torch.fx import ( GraphModule, ) from torch.fx.graph import ( Graph, ) from torch.ao.nn.intrinsic import _FusedModule from ..utils import ( _parent_name, get_qconfig_dtypes, ) from ..qconfig_mapping import ( _OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY, QConfigMapping, ) __all__: List[str] = [] def _maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping: QConfigMapping, cur_module_path: str, cur_object_type: Callable, cur_object_type_idx: int, fallback_qconfig: QConfigAny, ) -> QConfigAny: for (module_name, object_type, index), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items(): if ( (module_name == cur_module_path) and (object_type == cur_object_type) and (index == cur_object_type_idx) ): return qconfig return fallback_qconfig def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping): """ Update the QConfigMapping to account for fused modules such as LinearReLU. This assumes the QConfigMapping's attributes have already been converted to OrderedDicts. """ object_type_dict = qconfig_mapping.object_type_qconfigs if len(object_type_dict) == 0: return qconfig_mapping modules = dict(model.named_modules()) for node in model.graph.nodes: if node.op == 'call_module' and node.target in modules: maybe_fused_module = modules[str(node.target)] if not isinstance(maybe_fused_module, _FusedModule): continue ops = list(maybe_fused_module._modules.values()) fused_qconfig = object_type_dict.get(type(ops[0]), None) # Raise an error if the modules in the fused module have # different qconfigs specified in the qconfig_dict # TODO: currently it only works for modules, # need to make this work for torch.nn.functional.relu # TODO: currently it only works for object_type configurations, # ideally it should work for different types of configurations, # maybe we want to redesign this part for op in ops[1:]: if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig): raise LookupError( "During fusion, we need to specify the same " + f"qconfigs for all module types in {type(maybe_fused_module)} " + f"offending type: {type(op)}") if fused_qconfig is not None: object_type_dict[type(maybe_fused_module)] = fused_qconfig def _generate_node_name_to_qconfig( root: torch.nn.Module, modules: Dict[str, torch.nn.Module], input_graph: Graph, qconfig_mapping: QConfigMapping, node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]: global_qconfig = qconfig_mapping.global_qconfig node_name_to_qconfig = {} # example: # # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...} # # meaning in submodule 'foo.bar', we have seen 0 F.linear and # 1 F.conv2d invocations so far. submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \ defaultdict(lambda: defaultdict(int)) for node in input_graph.nodes: qconfig = None if node.op == "get_attr": module_name, _ = _parent_name(node.target) qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[module_name]), module_name, global_qconfig) qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == "call_function": # precedence: module_name_qconfig # > function_qconfig > global_qconfig # module_name takes precedence over function qconfig function_qconfig = _get_object_type_qconfig( qconfig_mapping, node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, module_type, module_path, function_qconfig) cur_object_type_idx = \ submodule_to_object_type_to_cur_idx[module_path][node.target] submodule_to_object_type_to_cur_idx[module_path][node.target] += 1 qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig) qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == "call_method": module_path, module_type = node_name_to_scope[node.name] # first use node.target (string) to get the qconfig # this is to support configs like # "object_type": [("reshpe", qconfig)] qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, node.target, module_path, global_qconfig) # if there is no special config for the method, we'll fall back to the # config for the module that contains the call_method node qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, module_type, module_path, qconfig) # currently call_method does not support modifying qconfig # by order, we can add this later if it is needed. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) elif node.op == 'call_module': # if the node is an observer, just continue - don't add it to the qconfig_map if _is_activation_post_process(modules[node.target]): continue qconfig = _maybe_adjust_qconfig_for_module_type_or_name( qconfig_mapping, type(modules[node.target]), node.target, global_qconfig) module_path, module_type = node_name_to_scope[node.name] # Note: for call_module, the module_path is the current module's name. # to meaningfully count invocations, we need to count them in the parent # module. parent_name, _ = _parent_name(module_path) cur_object_type_idx = \ submodule_to_object_type_to_cur_idx[parent_name][module_type] submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1 qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order( qconfig_mapping, parent_name, module_type, cur_object_type_idx, qconfig) qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None)) # regex is not supported eager mode propagate_qconfig_, we'll # need to set the qconfig explicitly here in case regex # is used modules[node.target].qconfig = qconfig_with_device_check else: qconfig_with_device_check = None node_name_to_qconfig[node.name] = qconfig_with_device_check return node_name_to_qconfig def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None: r""" Checks if the given config_dict has the correct keys Args: `config_dict`: dictionary whose keys we want to check """ for k in config_dict.keys(): if k not in allowed_keys: raise ValueError( 'Expected ' + dict_name + ' to have the following keys: ' + str(allowed_keys) + '. But found \'' + k + '\' instead.') def _compare_prepare_convert_qconfig_mappings( prepare_qconfig_mapping: QConfigMapping, convert_qconfig_mapping: QConfigMapping): r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values Args: `prepare_qconfig_mapping`: configuration for prepare quantization step `convert_qconfig_mapping`: configuration for convert quantization step """ assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \ "Expected global qconfigs to be the same in the prepare and convert quantization configs" prepare_dicts: List[OrderedDict] = [ prepare_qconfig_mapping.object_type_qconfigs, prepare_qconfig_mapping.module_name_qconfigs, prepare_qconfig_mapping.module_name_regex_qconfigs, ] convert_dicts: List[OrderedDict] = [ convert_qconfig_mapping.object_type_qconfigs, convert_qconfig_mapping.module_name_qconfigs, convert_qconfig_mapping.module_name_regex_qconfigs, ] dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY] for i in range(len(prepare_dicts)): for name, qconfig in prepare_dicts[i].items(): assert name in convert_dicts[i], "Missing key {} {} in convert QConfigMapping \ when it was present in prepare".format(dict_names[i], name) assert convert_dicts[i][name] is None \ or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \ "Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \ prepare: {}; convert: {}".format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name]) def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]): for dtype_config in dtype_configs: is_dynamic = dtype_config.is_dynamic if is_dynamic is None: is_dynamic = False input_dtype = dtype_config.input_dtype or torch.float weight_dtype = dtype_config.weight_dtype or torch.float bias_dtype = dtype_config.bias_dtype or torch.float output_dtype = dtype_config.output_dtype or torch.float qconfig_activation_dtype, qconfig_weight_dtype, qconfig_input_act_is_dynamic = \ get_qconfig_dtypes(qconfig) qconfig_bias_dtype = torch.float16 \ if ( qconfig_activation_dtype == torch.float16 and qconfig_weight_dtype == torch.float16 and not is_dynamic ) else torch.float if is_dynamic: is_match = qconfig_input_act_is_dynamic and \ input_dtype == qconfig_activation_dtype and \ output_dtype == torch.float and \ weight_dtype == qconfig_weight_dtype else: is_match = input_dtype == qconfig_activation_dtype and \ output_dtype == qconfig_activation_dtype and \ weight_dtype == qconfig_weight_dtype and \ bias_dtype == qconfig_bias_dtype if is_match: return True return False def _get_object_type_qconfig( qconfig_mapping: QConfigMapping, object_type: Union[Callable, str], fallback_qconfig: QConfigAny) -> QConfigAny: return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig) def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig): for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items(): if re.match(regex_pattern, module_name): # first match wins return qconfig return fallback_qconfig def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig): if module_name == '': # module name qconfig not found return fallback_qconfig if module_name in qconfig_mapping.module_name_qconfigs: return qconfig_mapping.module_name_qconfigs[module_name] else: parent, _ = _parent_name(module_name) return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig) def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig): # get qconfig for module_name, # fallback to module_name_regex_qconfig, module_type_qconfig, # global_qconfig if necessary module_type_qconfig = _get_object_type_qconfig( qconfig_mapping, module_type, global_qconfig) module_name_regex_qconfig = _get_module_name_regex_qconfig( qconfig_mapping, module_name, module_type_qconfig) module_name_qconfig = _get_module_name_qconfig( qconfig_mapping, module_name, module_name_regex_qconfig) return module_name_qconfig def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]: """ flatten the global, object_type and module_name qconfig to the same qconfig_dict so that it can be used by propagate_qconfig_ function. "module_name_regex" is ignored for now since it's not supported in propagate_qconfig_, but it can be fixed later. For example: Input: { "": qconfig, "object_type": [ (torch.add, qconfig) ], "module_name": [ ("conv", qconfig) ] } Output: { "": qconfig, torch.add: qconfig, "conv": qconfig } """ flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig} for obj, qconfig in qconfig_mapping.object_type_qconfigs.items(): flattened[obj] = qconfig for obj, qconfig in qconfig_mapping.module_name_qconfigs.items(): flattened[obj] = qconfig return flattened def _update_qconfig_for_qat( qconfig_mapping: QConfigMapping, backend_config: BackendConfig): """ Update the qconfig_mapping to account for module swaps during QAT. During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types. """ module_to_qat_module_class = get_module_to_qat_module(backend_config) object_type_dict = qconfig_mapping.object_type_qconfigs new_object_type_dict = object_type_dict.copy() for k, v in new_object_type_dict.items(): if k in module_to_qat_module_class: object_type_dict[module_to_qat_module_class[k]] = v