qconfig_multi_mapping.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. from __future__ import annotations
  2. import copy
  3. from typing import Any, Callable, Dict, List, Union
  4. import torch
  5. from torch.ao.quantization import QConfigMapping
  6. from torch.ao.quantization.qconfig_mapping import _QCONFIG_STYLE_ORDER
  7. from torch.ao.quantization.qconfig import QConfigAny
  8. __all__ = ["QConfigMultiMapping"]
  9. _QCONFIG_STYLE_TO_METHOD: Dict[str, str] = {
  10. "global_qconfig": "set_global",
  11. "object_type_qconfigs": "set_object_type",
  12. "module_name_regex_qconfigs": "set_module_name_regex",
  13. "module_name_qconfigs": "set_module_name",
  14. "module_name_object_type_order_qconfigs": "set_module_name_object_type_order",
  15. }
  16. def _remove_duplicates_and_none(qconfig_list: List[QConfigAny]) -> None:
  17. to_remove = []
  18. for index, cur_qconfig in enumerate(qconfig_list):
  19. if cur_qconfig is None:
  20. to_remove.append(index)
  21. break
  22. for checked_qconfig in qconfig_list[:index]:
  23. if torch.ao.quantization.qconfig_equals(cur_qconfig, checked_qconfig):
  24. to_remove.append(index)
  25. break
  26. for index in to_remove[::-1]:
  27. qconfig_list.pop(index)
  28. class QConfigMultiMapping:
  29. """
  30. This class, used with the prepare_n_shadows_model API, stores a list of :class:`torch.ao.quantization.QConfigMapping`s
  31. so that multiple QConfigs can be specified for each QConfig matching style.
  32. The user can specify QConfigs using the following methods (in increasing match priority):
  33. ``set_global`` : sets the global (default) QConfigs
  34. ``set_object_type`` : sets the QConfigs for a given module type, function, or method name
  35. ``set_module_name_regex`` : sets the QConfigs for modules matching the given regex string
  36. ``set_module_name`` : sets the QConfigs for modules matching the given module name
  37. ``set_module_name_object_type_order`` : sets the QConfigs for modules matching a combination
  38. of the given module name, object type, and the index at which the module appears
  39. Note: Usage of set methods is the same as in QConfigMapping except with a passed in list of QConfigs rather than a
  40. single QConfig.
  41. Example usage::
  42. qconfig_mapping = QConfigMultiMapping()
  43. .set_global([qconfig1, qconfig2])
  44. .set_object_type(torch.nn.Linear, [qconfig2, qconfig3])
  45. .set_object_type(torch.nn.ReLU, [qconfig1])
  46. .set_module_name_regex("foo.*bar.*conv[0-9]+", [qconfig2])
  47. .set_module_name_regex("foo.*", [qconfig1, qconfig2, qconfig3])
  48. .set_module_name("module1", [None])
  49. .set_module_name("module2", [qconfig2])
  50. .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, [qconfig3])
  51. """
  52. def __init__(self):
  53. # initialize this with 1 QConfigMapping to avoid corner cases
  54. self.qconfig_mappings_list: List[QConfigMapping] = [QConfigMapping()]
  55. def _handle_list_size_mismatch(
  56. self, qconfig_list: List[QConfigAny], style: str
  57. ) -> None:
  58. # this method handles cases where the size of qconfig_list does not match
  59. # the size of qconfig_mappings_list.
  60. # Issue: Consider a user inserting global_qconfig A and B first, then inserting
  61. # qconfig C as an object_type_qconfig for conv ops. If we internally store
  62. # 1 QConfigMapping with A and C and another with just B, then the
  63. # second QConfigMapping will match B to conv ops (which is not wanted), since B is global.
  64. # we avoid this by maintaining the invariant that if any QConfigMapping
  65. # has a qconfig style+key with a qconfig in it, all QConfigMappings must
  66. # have either a qconfig or None for that same style+key. In the above
  67. # example, a None qconfig would prevent the unwanted match in the
  68. # second QConfigMapping
  69. if len(qconfig_list) > len(self.qconfig_mappings_list):
  70. # Case: we have more qconfigs (in qconfig_list) than QConfigMappings
  71. # Add new QConfigMappings (initialized so we maintain the `invariant`)
  72. new_qconfig_mapping = QConfigMapping()
  73. # searches other QConfigMappings for qconfig style+keys
  74. # that need to be inserted as `None` into the new QConfigMapping
  75. for qconfig_mapping in self.qconfig_mappings_list:
  76. # global_qconfig has None by default
  77. for check_style in _QCONFIG_STYLE_ORDER[1:]:
  78. qconfigs_dict = getattr(qconfig_mapping, check_style)
  79. target_qconfigs_dict = getattr(new_qconfig_mapping, check_style)
  80. for key in qconfigs_dict:
  81. target_qconfigs_dict[key] = None
  82. break
  83. # insert copies of this new QConfigMapping until all entires
  84. # in qconfig_list can fit among the QConfigMappings
  85. while len(qconfig_list) > len(self.qconfig_mappings_list):
  86. self.qconfig_mappings_list.append(copy.deepcopy(new_qconfig_mapping))
  87. else:
  88. # Case: we have fewer qconfigs in qconfig_list than QConfigMappings
  89. # pad qconfig_list with `None` until length is same
  90. while len(qconfig_list) < len(self.qconfig_mappings_list):
  91. qconfig_list.append(None)
  92. # this function applies the insertion method across each QConfigMapping
  93. def _insert_qconfig_list(
  94. self,
  95. style: str,
  96. args: List[Union[str, int, Callable]],
  97. qconfig_list: List[QConfigAny],
  98. ) -> None:
  99. # we remove duplicates and None to make the ordering of qconfigs
  100. # deterministic upon insertion.
  101. _remove_duplicates_and_none(qconfig_list)
  102. self._handle_list_size_mismatch(qconfig_list, style)
  103. method_name = _QCONFIG_STYLE_TO_METHOD[style]
  104. for qconfig_mapping, qconfig in zip(self.qconfig_mappings_list, qconfig_list):
  105. # uses QConfigMapping set method to insert qconfig
  106. set_method = getattr(qconfig_mapping, method_name)
  107. set_method(*args, qconfig)
  108. def set_global(self, global_qconfig_list: List[QConfigAny]) -> QConfigMultiMapping:
  109. """
  110. Set global QConfigs
  111. see :func:`~torch.ao.quantization.QConfigMapping.set_global()` for more info
  112. """
  113. self._insert_qconfig_list("global_qconfig", [], global_qconfig_list)
  114. return self
  115. def set_object_type(
  116. self, object_type: Union[Callable, str], qconfig_list: List[QConfigAny]
  117. ) -> QConfigMultiMapping:
  118. """
  119. Set object type QConfigs
  120. see :func:`~torch.ao.quantization.QConfigMapping.set_object_type()` for more info
  121. """
  122. self._insert_qconfig_list("object_type_qconfigs", [object_type], qconfig_list)
  123. return self
  124. def set_module_name_regex(
  125. self, module_name_regex: str, qconfig_list: List[QConfigAny]
  126. ) -> QConfigMultiMapping:
  127. """
  128. Set module_name_regex QConfigs
  129. see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_regex()` for more info
  130. """
  131. self._insert_qconfig_list(
  132. "module_name_regex_qconfigs", [module_name_regex], qconfig_list
  133. )
  134. return self
  135. def set_module_name(
  136. self, module_name: str, qconfig_list: List[QConfigAny]
  137. ) -> QConfigMultiMapping:
  138. """
  139. Set module_name QConfigs
  140. see :func:`~torch.ao.quantization.QConfigMapping.set_module_name()` for more info
  141. """
  142. self._insert_qconfig_list("module_name_qconfigs", [module_name], qconfig_list)
  143. return self
  144. def set_module_name_object_type_order(
  145. self,
  146. module_name: str,
  147. object_type: Callable,
  148. index: int,
  149. qconfig_list: List[QConfigAny],
  150. ) -> QConfigMultiMapping:
  151. """
  152. Set module_name QConfigs
  153. see :func:`~torch.ao.quantization.QConfigMapping.set_module_name_object_type_order()` for more info
  154. """
  155. self._insert_qconfig_list(
  156. "module_name_object_type_order_qconfigs",
  157. [module_name, object_type, index],
  158. qconfig_list,
  159. )
  160. return self
  161. def __repr__(self):
  162. return (
  163. self.__class__.__name__ +
  164. " [" +
  165. "".join(f"\n{qconfig_mapping.__repr__()}," for qconfig_mapping in self.qconfig_mappings_list) +
  166. "\n]"
  167. )
  168. @classmethod
  169. def from_list_qconfig_mapping(
  170. cls, qconfig_mapping_list: List[QConfigMapping]
  171. ) -> QConfigMultiMapping:
  172. """
  173. Creates a QConfigMultiMapping from a list of QConfigMappings
  174. """
  175. new_qconfig_multi_mapping = cls()
  176. new_qconfig_multi_mapping.qconfig_mappings_list = copy.deepcopy(
  177. qconfig_mapping_list
  178. )
  179. # we need to avoid the issue described in _handle_list_size_mismatch,
  180. # so we reinsert all the qconfigs using the QConfigMultiMapping
  181. # set methods
  182. # go through all qconfig styles
  183. # note: global can be ignored since it is None by default
  184. for style in _QCONFIG_STYLE_ORDER[1:]:
  185. # gather all key+qconfigs for current style
  186. # into qconfig_dict_list
  187. qconfig_dict_list: Dict[Any, List[QConfigAny]] = {}
  188. for qconfig_mapping in qconfig_mapping_list:
  189. qconfig_dict = getattr(qconfig_mapping, style)
  190. for key, qconfig in qconfig_dict.items():
  191. if key not in qconfig_dict_list:
  192. qconfig_dict_list[key] = []
  193. qconfig_dict_list[key].append(qconfig)
  194. # reinsert all gathered key+qconfigs
  195. set_method_name = _QCONFIG_STYLE_TO_METHOD[style]
  196. set_method = getattr(new_qconfig_multi_mapping, set_method_name)
  197. for key, qconfig_list in qconfig_dict_list.items():
  198. if isinstance(key, tuple):
  199. set_method(*key, qconfig_list)
  200. else:
  201. set_method(key, qconfig_list)
  202. return new_qconfig_multi_mapping