qconfig_mapping_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import torch
  2. import re
  3. from collections import defaultdict, OrderedDict
  4. from typing import Callable, Any, Dict, Tuple, Set, List, Union
  5. from torch.ao.quantization import QConfig
  6. from torch.ao.quantization.qconfig import _add_module_to_qconfig_obs_ctr, QConfigAny, qconfig_equals
  7. from torch.ao.quantization.observer import (
  8. _is_activation_post_process,
  9. )
  10. from torch.ao.quantization.backend_config import (
  11. BackendConfig,
  12. DTypeConfig,
  13. )
  14. from torch.ao.quantization.backend_config.utils import (
  15. get_module_to_qat_module,
  16. )
  17. from torch.fx import (
  18. GraphModule,
  19. )
  20. from torch.fx.graph import (
  21. Graph,
  22. )
  23. from torch.ao.nn.intrinsic import _FusedModule
  24. from ..utils import (
  25. _parent_name,
  26. get_qconfig_dtypes,
  27. )
  28. from ..qconfig_mapping import (
  29. _OBJECT_TYPE_DICT_KEY,
  30. _MODULE_NAME_DICT_KEY,
  31. _MODULE_NAME_REGEX_DICT_KEY,
  32. QConfigMapping,
  33. )
  34. __all__: List[str] = []
  35. def _maybe_adjust_qconfig_for_module_name_object_type_order(
  36. qconfig_mapping: QConfigMapping,
  37. cur_module_path: str,
  38. cur_object_type: Callable,
  39. cur_object_type_idx: int,
  40. fallback_qconfig: QConfigAny,
  41. ) -> QConfigAny:
  42. for (module_name, object_type, index), qconfig in qconfig_mapping.module_name_object_type_order_qconfigs.items():
  43. if (
  44. (module_name == cur_module_path) and
  45. (object_type == cur_object_type) and
  46. (index == cur_object_type_idx)
  47. ):
  48. return qconfig
  49. return fallback_qconfig
  50. def _update_qconfig_for_fusion(model: GraphModule, qconfig_mapping: QConfigMapping):
  51. """
  52. Update the QConfigMapping to account for fused modules such as LinearReLU.
  53. This assumes the QConfigMapping's attributes have already been converted to OrderedDicts.
  54. """
  55. object_type_dict = qconfig_mapping.object_type_qconfigs
  56. if len(object_type_dict) == 0:
  57. return qconfig_mapping
  58. modules = dict(model.named_modules())
  59. for node in model.graph.nodes:
  60. if node.op == 'call_module' and node.target in modules:
  61. maybe_fused_module = modules[str(node.target)]
  62. if not isinstance(maybe_fused_module, _FusedModule):
  63. continue
  64. ops = list(maybe_fused_module._modules.values())
  65. fused_qconfig = object_type_dict.get(type(ops[0]), None)
  66. # Raise an error if the modules in the fused module have
  67. # different qconfigs specified in the qconfig_dict
  68. # TODO: currently it only works for modules,
  69. # need to make this work for torch.nn.functional.relu
  70. # TODO: currently it only works for object_type configurations,
  71. # ideally it should work for different types of configurations,
  72. # maybe we want to redesign this part
  73. for op in ops[1:]:
  74. if not qconfig_equals(object_type_dict.get(type(op), None), fused_qconfig):
  75. raise LookupError(
  76. "During fusion, we need to specify the same " +
  77. f"qconfigs for all module types in {type(maybe_fused_module)} " +
  78. f"offending type: {type(op)}")
  79. if fused_qconfig is not None:
  80. object_type_dict[type(maybe_fused_module)] = fused_qconfig
  81. def _generate_node_name_to_qconfig(
  82. root: torch.nn.Module,
  83. modules: Dict[str, torch.nn.Module],
  84. input_graph: Graph,
  85. qconfig_mapping: QConfigMapping,
  86. node_name_to_scope: Dict[str, Tuple[str, type]]) -> Dict[str, QConfigAny]:
  87. global_qconfig = qconfig_mapping.global_qconfig
  88. node_name_to_qconfig = {}
  89. # example:
  90. #
  91. # {'foo.bar': {F.linear: 0, F.conv2d: 1, ...}, ...}
  92. #
  93. # meaning in submodule 'foo.bar', we have seen 0 F.linear and
  94. # 1 F.conv2d invocations so far.
  95. submodule_to_object_type_to_cur_idx: Dict[str, Dict[Callable, int]] = \
  96. defaultdict(lambda: defaultdict(int))
  97. for node in input_graph.nodes:
  98. qconfig = None
  99. if node.op == "get_attr":
  100. module_name, _ = _parent_name(node.target)
  101. qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
  102. qconfig_mapping, type(modules[module_name]), module_name, global_qconfig)
  103. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
  104. elif node.op == "call_function":
  105. # precedence: module_name_qconfig
  106. # > function_qconfig > global_qconfig
  107. # module_name takes precedence over function qconfig
  108. function_qconfig = _get_object_type_qconfig(
  109. qconfig_mapping, node.target, global_qconfig)
  110. module_path, module_type = node_name_to_scope[node.name]
  111. qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
  112. qconfig_mapping, module_type, module_path, function_qconfig)
  113. cur_object_type_idx = \
  114. submodule_to_object_type_to_cur_idx[module_path][node.target]
  115. submodule_to_object_type_to_cur_idx[module_path][node.target] += 1
  116. qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
  117. qconfig_mapping, module_path, node.target, cur_object_type_idx, qconfig)
  118. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
  119. elif node.op == "call_method":
  120. module_path, module_type = node_name_to_scope[node.name]
  121. # first use node.target (string) to get the qconfig
  122. # this is to support configs like
  123. # "object_type": [("reshpe", qconfig)]
  124. qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
  125. qconfig_mapping, node.target, module_path, global_qconfig)
  126. # if there is no special config for the method, we'll fall back to the
  127. # config for the module that contains the call_method node
  128. qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
  129. qconfig_mapping, module_type, module_path, qconfig)
  130. # currently call_method does not support modifying qconfig
  131. # by order, we can add this later if it is needed.
  132. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
  133. elif node.op == 'call_module':
  134. # if the node is an observer, just continue - don't add it to the qconfig_map
  135. if _is_activation_post_process(modules[node.target]):
  136. continue
  137. qconfig = _maybe_adjust_qconfig_for_module_type_or_name(
  138. qconfig_mapping, type(modules[node.target]), node.target, global_qconfig)
  139. module_path, module_type = node_name_to_scope[node.name]
  140. # Note: for call_module, the module_path is the current module's name.
  141. # to meaningfully count invocations, we need to count them in the parent
  142. # module.
  143. parent_name, _ = _parent_name(module_path)
  144. cur_object_type_idx = \
  145. submodule_to_object_type_to_cur_idx[parent_name][module_type]
  146. submodule_to_object_type_to_cur_idx[parent_name][module_type] += 1
  147. qconfig = _maybe_adjust_qconfig_for_module_name_object_type_order(
  148. qconfig_mapping, parent_name, module_type, cur_object_type_idx,
  149. qconfig)
  150. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(qconfig, modules.get(node.target, None))
  151. # regex is not supported eager mode propagate_qconfig_, we'll
  152. # need to set the qconfig explicitly here in case regex
  153. # is used
  154. modules[node.target].qconfig = qconfig_with_device_check
  155. else:
  156. qconfig_with_device_check = None
  157. node_name_to_qconfig[node.name] = qconfig_with_device_check
  158. return node_name_to_qconfig
  159. def _check_is_valid_config_dict(config_dict: Any, allowed_keys: Set[str], dict_name: str) -> None:
  160. r""" Checks if the given config_dict has the correct keys
  161. Args:
  162. `config_dict`: dictionary whose keys we want to check
  163. """
  164. for k in config_dict.keys():
  165. if k not in allowed_keys:
  166. raise ValueError(
  167. 'Expected ' + dict_name + ' to have the following keys: ' +
  168. str(allowed_keys) + '. But found \'' + k +
  169. '\' instead.')
  170. def _compare_prepare_convert_qconfig_mappings(
  171. prepare_qconfig_mapping: QConfigMapping,
  172. convert_qconfig_mapping: QConfigMapping):
  173. r""" Compare the qconfig_mapping passed in convert to the one from prepare and check the values
  174. Args:
  175. `prepare_qconfig_mapping`: configuration for prepare quantization step
  176. `convert_qconfig_mapping`: configuration for convert quantization step
  177. """
  178. assert qconfig_equals(prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig), \
  179. "Expected global qconfigs to be the same in the prepare and convert quantization configs"
  180. prepare_dicts: List[OrderedDict] = [
  181. prepare_qconfig_mapping.object_type_qconfigs,
  182. prepare_qconfig_mapping.module_name_qconfigs,
  183. prepare_qconfig_mapping.module_name_regex_qconfigs,
  184. ]
  185. convert_dicts: List[OrderedDict] = [
  186. convert_qconfig_mapping.object_type_qconfigs,
  187. convert_qconfig_mapping.module_name_qconfigs,
  188. convert_qconfig_mapping.module_name_regex_qconfigs,
  189. ]
  190. dict_names = [_OBJECT_TYPE_DICT_KEY, _MODULE_NAME_DICT_KEY, _MODULE_NAME_REGEX_DICT_KEY]
  191. for i in range(len(prepare_dicts)):
  192. for name, qconfig in prepare_dicts[i].items():
  193. assert name in convert_dicts[i], "Missing key {} {} in convert QConfigMapping \
  194. when it was present in prepare".format(dict_names[i], name)
  195. assert convert_dicts[i][name] is None \
  196. or qconfig_equals(prepare_dicts[i][name], convert_dicts[i][name]), \
  197. "Expected convert QConfigMapping to have the same qconfig as prepare for key {} {}; \
  198. prepare: {}; convert: {}".format(dict_names[i], name, prepare_dicts[i][name], convert_dicts[i][name])
  199. def _is_qconfig_supported_by_dtype_configs(qconfig: QConfig, dtype_configs: List[DTypeConfig]):
  200. for dtype_config in dtype_configs:
  201. is_dynamic = dtype_config.is_dynamic
  202. if is_dynamic is None:
  203. is_dynamic = False
  204. input_dtype = dtype_config.input_dtype or torch.float
  205. weight_dtype = dtype_config.weight_dtype or torch.float
  206. bias_dtype = dtype_config.bias_dtype or torch.float
  207. output_dtype = dtype_config.output_dtype or torch.float
  208. qconfig_activation_dtype, qconfig_weight_dtype, qconfig_input_act_is_dynamic = \
  209. get_qconfig_dtypes(qconfig)
  210. qconfig_bias_dtype = torch.float16 \
  211. if (
  212. qconfig_activation_dtype == torch.float16
  213. and qconfig_weight_dtype == torch.float16
  214. and not is_dynamic
  215. ) else torch.float
  216. if is_dynamic:
  217. is_match = qconfig_input_act_is_dynamic and \
  218. input_dtype == qconfig_activation_dtype and \
  219. output_dtype == torch.float and \
  220. weight_dtype == qconfig_weight_dtype
  221. else:
  222. is_match = input_dtype == qconfig_activation_dtype and \
  223. output_dtype == qconfig_activation_dtype and \
  224. weight_dtype == qconfig_weight_dtype and \
  225. bias_dtype == qconfig_bias_dtype
  226. if is_match:
  227. return True
  228. return False
  229. def _get_object_type_qconfig(
  230. qconfig_mapping: QConfigMapping,
  231. object_type: Union[Callable, str],
  232. fallback_qconfig: QConfigAny) -> QConfigAny:
  233. return qconfig_mapping.object_type_qconfigs.get(object_type, fallback_qconfig)
  234. def _get_module_name_regex_qconfig(qconfig_mapping, module_name, fallback_qconfig):
  235. for regex_pattern, qconfig in qconfig_mapping.module_name_regex_qconfigs.items():
  236. if re.match(regex_pattern, module_name):
  237. # first match wins
  238. return qconfig
  239. return fallback_qconfig
  240. def _get_module_name_qconfig(qconfig_mapping, module_name, fallback_qconfig):
  241. if module_name == '':
  242. # module name qconfig not found
  243. return fallback_qconfig
  244. if module_name in qconfig_mapping.module_name_qconfigs:
  245. return qconfig_mapping.module_name_qconfigs[module_name]
  246. else:
  247. parent, _ = _parent_name(module_name)
  248. return _get_module_name_qconfig(qconfig_mapping, parent, fallback_qconfig)
  249. def _maybe_adjust_qconfig_for_module_type_or_name(qconfig_mapping, module_type, module_name, global_qconfig):
  250. # get qconfig for module_name,
  251. # fallback to module_name_regex_qconfig, module_type_qconfig,
  252. # global_qconfig if necessary
  253. module_type_qconfig = _get_object_type_qconfig(
  254. qconfig_mapping, module_type, global_qconfig)
  255. module_name_regex_qconfig = _get_module_name_regex_qconfig(
  256. qconfig_mapping, module_name, module_type_qconfig)
  257. module_name_qconfig = _get_module_name_qconfig(
  258. qconfig_mapping, module_name, module_name_regex_qconfig)
  259. return module_name_qconfig
  260. def _get_flattened_qconfig_dict(qconfig_mapping: QConfigMapping) -> Dict[Union[Callable, str], QConfigAny]:
  261. """ flatten the global, object_type and module_name qconfig
  262. to the same qconfig_dict so that it can be used by
  263. propagate_qconfig_ function.
  264. "module_name_regex" is ignored for now since it's not supported
  265. in propagate_qconfig_, but it can be fixed later.
  266. For example:
  267. Input: {
  268. "": qconfig,
  269. "object_type": [
  270. (torch.add, qconfig)
  271. ],
  272. "module_name": [
  273. ("conv", qconfig)
  274. ]
  275. }
  276. Output: {
  277. "": qconfig,
  278. torch.add: qconfig,
  279. "conv": qconfig
  280. }
  281. """
  282. flattened: Dict[Union[Callable, str], QConfigAny] = {"": qconfig_mapping.global_qconfig}
  283. for obj, qconfig in qconfig_mapping.object_type_qconfigs.items():
  284. flattened[obj] = qconfig
  285. for obj, qconfig in qconfig_mapping.module_name_qconfigs.items():
  286. flattened[obj] = qconfig
  287. return flattened
  288. def _update_qconfig_for_qat(
  289. qconfig_mapping: QConfigMapping,
  290. backend_config: BackendConfig):
  291. """
  292. Update the qconfig_mapping to account for module swaps during QAT.
  293. During QAT we perform a module swap on the nn.Module types to the corresponding nn.qat.modules types.
  294. """
  295. module_to_qat_module_class = get_module_to_qat_module(backend_config)
  296. object_type_dict = qconfig_mapping.object_type_qconfigs
  297. new_object_type_dict = object_type_dict.copy()
  298. for k, v in new_object_type_dict.items():
  299. if k in module_to_qat_module_class:
  300. object_type_dict[module_to_qat_module_class[k]] = v