from __future__ import annotations import copy from typing import Any, Callable, Dict, List, Union import torch from torch.ao.quantization import QConfigMapping from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER from torch.ao.quantization.qconfig import QConfigAny __all__ = ["QConfigMultiMapping"] _QCONFIG_STYLE_TO_METHOD: Dict[str, str] = { "global_qconfig": "set_global", "object_type_qconfigs": "set_object_type", "module_name_regex_qconfigs": "set_module_name_regex", "module_name_qconfigs": "set_module_name", "module_name_object_type_order_qconfigs": "set_module_name_object_type_order", } def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None: to_remove = [] for index, cur_qconfig in enumerate(qconfig_list): if cur_qconfig is None: to_remove.append(index) break for checked_qconfig in qconfig_list[:index]: if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig): to_remove.append(index) break for index in to_remove[::-1]: qconfig_list.pop(index) class QConfigMultiMapping: """ This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s so that multiple QConfigs can be specified for each QConfig matching style. The user can specify QConfigs using the following methods (in increasing match priority): ``set_global`` : sets the global (default) QConfigs ``set_object_type`` : sets the QConfigs for a given module type, function, or method name ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string ``set_module_name`` : sets the QConfigs for modules matching the given module name ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination of the given module name, object type, and the index at which the module appears Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a single QConfig. Example usage:: qconfig_mapping = QConfigMultiMapping() .set_global([qconfig1, qconfig2]) .set_object_type(torch.nn.Linear, [qconfig2, qconfig3]) .set_object_type(torch.nn.ReLU, [qconfig1]) .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2]) .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3]) .set_module_name("module1", [None]) .set_module_name("module2", [qconfig2]) .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3]) """ def __init__(self): # initialize this with 1 QConfigMapping to avoid corner cases self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()] def _handle_list_size_mismatch( self, qconfig_list: List[QConfigAny], style: str ) -> None: # this method handles cases where the size of qconfig_list does not match # the size of qconfig_mappings_list. # Issue: Consider a user inserting global_qconfig A and B first, then inserting # qconfig C as an object_type_qconfig for conv ops. If we internally store # 1 QConfigMapping with A and C and another with just B, then the # second QConfigMapping will match B to conv ops (which is not wanted), since B is global. # we avoid this by maintaining the invariant that if any QConfigMapping # has a qconfig style+key with a qconfig in it, all QConfigMappings must # have either a qconfig or None for that same style+key. In the above # example, a None qconfig would prevent the unwanted match in the # second QConfigMapping if len(qconfig_list) > len(self.qconfig_mappings_list): # Case: we have more qconfigs (in qconfig_list) than QConfigMappings # Add new QConfigMappings (initialized so we maintain the `invariant`) new_qconfig_mapping = QConfigMapping() # searches other QConfigMappings for qconfig style+keys # that need to be inserted as `None` into the new QConfigMapping for qconfig_mapping in self.qconfig_mappings_list: # global_qconfig has None by default for check_style in _QCONFIG_STYLE_ORDER[1:]: qconfigs_dict = getattr(qconfig_mapping, check_style) target_qconfigs_dict = getattr(new_qconfig_mapping, check_style) for key in qconfigs_dict: target_qconfigs_dict[key] = None break # insert copies of this new QConfigMapping until all entires # in qconfig_list can fit among the QConfigMappings while len(qconfig_list) > len(self.qconfig_mappings_list): self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping)) else: # Case: we have fewer qconfigs in qconfig_list than QConfigMappings # pad qconfig_list with `None` until length is same while len(qconfig_list) < len(self.qconfig_mappings_list): qconfig_list.append(None) # this function applies the insertion method across each QConfigMapping def _insert_qconfig_list( self, style: str, args: List[Union[str, int, Callable]], qconfig_list: List[QConfigAny], ) -> None: # we remove duplicates and None to make the ordering of qconfigs # deterministic upon insertion. _remove_duplicates_and_none(qconfig_list) self._handle_list_size_mismatch(qconfig_list, style) method_name = _QCONFIG_STYLE_TO_METHOD[style] for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list): # uses QConfigMapping set method to insert qconfig set_method = getattr(qconfig_mapping, method_name) set_method(*args, qconfig) def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping: """ Set global QConfigs see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info """ self._insert_qconfig_list("global_qconfig", [], global_qconfig_list) return self def set_object_type( self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny] ) -> QConfigMultiMapping: """ Set object type QConfigs see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info """ self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list) return self def set_module_name_regex( self, module_name_regex: str, qconfig_list: List[QConfigAny] ) -> QConfigMultiMapping: """ Set module_name_regex QConfigs see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info """ self._insert_qconfig_list( "module_name_regex_qconfigs", [module_name_regex], qconfig_list ) return self def set_module_name( self, module_name: str, qconfig_list: List[QConfigAny] ) -> QConfigMultiMapping: """ Set module_name QConfigs see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info """ self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list) return self def set_module_name_object_type_order( self, module_name: str, object_type: Callable, index: int, qconfig_list: List[QConfigAny], ) -> QConfigMultiMapping: """ Set module_name QConfigs see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info """ self._insert_qconfig_list( "module_name_object_type_order_qconfigs", [module_name, object_type, index], qconfig_list, ) return self def __repr__(self): return ( self.__class__.__name__ + " [" + "".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) + "\n]" ) @classmethod def from_list_qconfig_mapping( cls, qconfig_mapping_list: List[QConfigMapping] ) -> QConfigMultiMapping: """ Creates a QConfigMultiMapping from a list of QConfigMappings """ new_qconfig_multi_mapping = cls() new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy( qconfig_mapping_list ) # we need to avoid the issue described in _handle_list_size_mismatch, # so we reinsert all the qconfigs using the QConfigMultiMapping # set methods # go through all qconfig styles # note: global can be ignored since it is None by default for style in _QCONFIG_STYLE_ORDER[1:]: # gather all key+qconfigs for current style # into qconfig_dict_list qconfig_dict_list: Dict[Any, List[QConfigAny]] = {} for qconfig_mapping in qconfig_mapping_list: qconfig_dict = getattr(qconfig_mapping, style) for key, qconfig in qconfig_dict.items(): if key not in qconfig_dict_list: qconfig_dict_list[key] = [] qconfig_dict_list[key].append(qconfig) # reinsert all gathered key+qconfigs set_method_name = _QCONFIG_STYLE_TO_METHOD[style] set_method = getattr(new_qconfig_multi_mapping, set_method_name) for key, qconfig_list in qconfig_dict_list.items(): if isinstance(key, tuple): set_method(*key, qconfig_list) else: set_method(key, qconfig_list) return new_qconfig_multi_mapping