quantize.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. import copy
  2. import itertools
  3. import warnings
  4. import torch
  5. import torch.nn as nn
  6. import torch.ao.nn.quantized as nnq
  7. from torch.ao.nn.intrinsic import _FusedModule
  8. from torch.ao.quantization.quantization_mappings import (
  9. get_default_dynamic_quant_module_mappings,
  10. get_default_static_quant_module_mappings,
  11. get_default_static_quant_reference_module_mappings,
  12. get_default_qat_module_mappings,
  13. get_default_qconfig_propagation_list,
  14. no_observer_set,
  15. _has_special_act_post_process,
  16. _get_special_act_post_process,
  17. )
  18. from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
  19. from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
  20. from torch.ao.quantization.qconfig import (
  21. _add_module_to_qconfig_obs_ctr,
  22. default_dynamic_qconfig,
  23. float16_dynamic_qconfig,
  24. float_qparams_weight_only_qconfig,
  25. float_qparams_weight_only_qconfig_4bit,
  26. _activation_is_memoryless)
  27. from torch.nn.utils.parametrize import type_before_parametrizations
  28. from torch.ao.quantization.observer import _is_activation_post_process
  29. # TODO remove this once BC is no longer required to avoid a SEV
  30. from torch.ao.quantization.observer import ( # noqa: F401
  31. _is_activation_post_process as is_activation_post_process
  32. )
  33. __all__ = [
  34. "get_default_custom_config_dict",
  35. "propagate_qconfig_",
  36. "add_quant_dequant",
  37. "prepare",
  38. "quantize",
  39. "quantize_dynamic",
  40. "prepare_qat",
  41. "quantize_qat",
  42. "convert",
  43. "swap_module",
  44. ]
  45. _DEFAULT_CUSTOM_CONFIG_DICT = {
  46. 'float_to_observed_custom_module_class': {
  47. nn.LSTM: nn.quantizable.LSTM,
  48. nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
  49. },
  50. 'observed_to_quantized_custom_module_class': {
  51. nn.quantizable.LSTM: nn.quantized.LSTM,
  52. nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
  53. }
  54. }
  55. def get_default_custom_config_dict():
  56. r"""Defines the default custom config dict.
  57. """
  58. return _DEFAULT_CUSTOM_CONFIG_DICT
  59. def _propagate_qconfig_helper(module, qconfig_dict,
  60. qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
  61. r"""This is a helper function for `propagate_qconfig_`
  62. Args:
  63. module: input module
  64. qconfig_dict: dictionary that maps from name of submodule to quantization
  65. configuration
  66. qconfig_parent: quantization config of parent module, we will fallback to
  67. this config when there is no specified config for current
  68. module
  69. prefix: corresponding prefix of the current module, used as key in
  70. qconfig_dict
  71. prepare_custom_config_dict: dictionary for custom handling of modules
  72. see docs for :func:`~torch.ao.quantization.prepare_fx`
  73. Return:
  74. None, module is modified inplace with qconfig attached
  75. """
  76. module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
  77. module_qconfig = qconfig_dict.get(prefix, module_qconfig)
  78. module_qconfig = getattr(module, 'qconfig', module_qconfig)
  79. torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)
  80. qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
  81. module.qconfig = qconfig_with_device_check
  82. for name, child in module.named_children():
  83. module_prefix = prefix + '.' + name if prefix else name
  84. # do no not propagate qconfig to child if child is non traceable
  85. if prepare_custom_config_dict is None or not (
  86. name in prepare_custom_config_dict.get("non_traceable_module_name", [])
  87. or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
  88. ):
  89. _propagate_qconfig_helper(
  90. child, qconfig_dict, qconfig_with_device_check, module_prefix
  91. )
  92. def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
  93. r"""Propagate qconfig through the module hierarchy and assign `qconfig`
  94. attribute on each leaf module
  95. Args:
  96. module: input module
  97. qconfig_dict: dictionary that maps from name or type of submodule to
  98. quantization configuration, qconfig applies to all submodules of a
  99. given module unless qconfig for the submodules are specified (when
  100. the submodule already has qconfig attribute)
  101. prepare_custom_config_dict: dictionary for custom handling of modules
  102. see docs for :func:`~torch.ao.quantization.prepare_fx`
  103. Return:
  104. None, module is modified inplace with qconfig attached
  105. """
  106. if qconfig_dict is None:
  107. qconfig_dict = {}
  108. if prepare_custom_config_dict is None:
  109. prepare_custom_config_dict = {}
  110. _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)
  111. def _observer_forward_hook(self, input, output):
  112. r"""Forward hook that calls observer on the output
  113. """
  114. return self.activation_post_process(output)
  115. def _observer_forward_pre_hook(self, input):
  116. r"""Forward pre hook that calls observer on the output
  117. """
  118. return self.activation_post_process(input[0])
  119. def _register_activation_post_process_hook(module, pre_hook=False):
  120. assert hasattr(module, 'activation_post_process'), \
  121. 'Expect activation_post_process attribute already attached to the module'
  122. if pre_hook:
  123. handle = module.register_forward_pre_hook(
  124. _observer_forward_pre_hook, prepend=True
  125. )
  126. else:
  127. handle = module.register_forward_hook(
  128. _observer_forward_hook, prepend=True
  129. )
  130. def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
  131. r"""Add observer for the leaf child of the module.
  132. This function insert observer module to all leaf child module that
  133. has a valid qconfig attribute.
  134. Args:
  135. module: input module with qconfig attributes for all the leaf modules that we want to quantize
  136. qconfig_propagation_list: a list of quantizable modules that will have observers added to them
  137. if they are leaf nodes
  138. device: parent device, if any
  139. non_leaf_module_list: list of non-leaf modules we want to add observer
  140. Return:
  141. None, module is modified inplace with added observer modules and forward_hooks
  142. """
  143. if qconfig_propagation_list is None:
  144. qconfig_propagation_list = get_default_qconfig_propagation_list()
  145. if custom_module_class_mapping is None:
  146. custom_module_class_mapping = {}
  147. # respect device affinity when adding observers
  148. if device is None:
  149. devices = _get_unique_devices_(module)
  150. assert len(devices) <= 1, (
  151. "_add_observer_ only works with cpu or single-device CUDA modules, "
  152. "but got devices {}".format(devices)
  153. )
  154. device = next(iter(devices)) if len(devices) > 0 else None
  155. def get_activation_post_process(qconfig, device, special_act_post_process=None):
  156. activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
  157. if device is not None:
  158. activation.to(device)
  159. return activation
  160. def needs_observation(m):
  161. return hasattr(m, 'qconfig') and m.qconfig is not None
  162. def insert_activation_post_process(m, special_act_post_process=None):
  163. """ Adds an activation post process module and register
  164. a pre or post hook that calls the module
  165. """
  166. # We don't insert observer/fake_quantize for DeQuantStub
  167. if needs_observation(m) and not isinstance(m, DeQuantStub):
  168. # observer and hook will be gone after we swap the module
  169. m.add_module('activation_post_process', get_activation_post_process(
  170. m.qconfig, device, special_act_post_process))
  171. # Register observer as the first entry in the hook list
  172. # All post forward hooks are preserved and will be executed after the observer before convert
  173. _register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
  174. for name, child in module.named_children():
  175. # TODO remove Dropout special after codebase stable
  176. if type_before_parametrizations(child) in [nn.Dropout]:
  177. continue
  178. elif type_before_parametrizations(child) in [nnq.FloatFunctional, nnq.QFunctional]:
  179. if needs_observation(child):
  180. child.activation_post_process = get_activation_post_process(child.qconfig, device)
  181. elif isinstance(child, _FusedModule):
  182. # activation_post_process are now added directly to nn.Sequentail/_FusedModule
  183. if needs_observation(child):
  184. insert_activation_post_process(child)
  185. elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list:
  186. if needs_observation(child):
  187. insert_activation_post_process(child)
  188. elif _has_special_act_post_process(child):
  189. special_act_post_process = _get_special_act_post_process(child)
  190. insert_activation_post_process(child, special_act_post_process)
  191. elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping:
  192. observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)
  193. setattr(module, name, observed_child)
  194. # TODO: These are the modules that cannot be observed
  195. # Once there are more, we should move them to a separate list
  196. if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
  197. insert_activation_post_process(observed_child)
  198. else:
  199. _add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
  200. # Insert observers only for leaf nodes, note that this observer is for
  201. # the output of the module, for input QuantStub will observe them
  202. if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
  203. and type_before_parametrizations(module) in qconfig_propagation_list:
  204. insert_activation_post_process(module)
  205. def _get_unique_devices_(module):
  206. return {p.device for p in module.parameters()} | \
  207. {p.device for p in module.buffers()}
  208. def add_quant_dequant(module):
  209. r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
  210. Note that this function will modify the children of module inplace and it
  211. can return a new module which wraps the input module as well.
  212. Args:
  213. module: input module with qconfig attributes for all the leaf modules
  214. that we want to quantize
  215. Return:
  216. Either the inplace modified module with submodules wrapped in
  217. `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
  218. wraps the input module, the latter case only happens when the input
  219. module is a leaf module and we want to quantize it.
  220. """
  221. if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig:
  222. return QuantWrapper(module)
  223. for name, child in module.named_children():
  224. module._modules[name] = add_quant_dequant(child)
  225. return module
  226. def prepare(model, inplace=False, allow_list=None,
  227. observer_non_leaf_module_list=None,
  228. prepare_custom_config_dict=None):
  229. r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
  230. Quantization configuration should be assigned preemptively
  231. to individual submodules in `.qconfig` attribute.
  232. The model will be attached with observer or fake quant modules, and qconfig
  233. will be propagated.
  234. Args:
  235. `model`: input model to be modified in-place
  236. `inplace`: carry out model transformations in-place, the original module is mutated
  237. `allow_list`: list of quantizable modules
  238. `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
  239. `prepare_custom_config_dict`: customization configuration dictionary for prepare function
  240. .. code-block:: python
  241. # Example of prepare_custom_config_dict:
  242. prepare_custom_config_dict = {
  243. # user will manually define the corresponding observed
  244. # module class which has a from_float class method that converts
  245. # float custom module to observed custom module
  246. "float_to_observed_custom_module_class": {
  247. CustomModule: ObservedCustomModule
  248. }
  249. }
  250. """
  251. torch._C._log_api_usage_once("quantization_api.quantize.prepare")
  252. if prepare_custom_config_dict is None:
  253. prepare_custom_config_dict = get_default_custom_config_dict()
  254. custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
  255. if not inplace:
  256. model = copy.deepcopy(model)
  257. # TODO: remove allow_list
  258. qconfig_propagation_list = allow_list
  259. if allow_list is None:
  260. qconfig_propagation_list = get_default_qconfig_propagation_list()
  261. propagate_qconfig_(model, qconfig_dict=None)
  262. # sanity check common API misusage
  263. if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
  264. warnings.warn("None of the submodule got qconfig applied. Make sure you "
  265. "passed correct configuration through `qconfig_dict` or "
  266. "by assigning the `.qconfig` attribute directly on submodules")
  267. _add_observer_(
  268. model, qconfig_propagation_list, observer_non_leaf_module_list,
  269. custom_module_class_mapping=custom_module_class_mapping)
  270. return model
  271. def _remove_activation_post_process(module):
  272. # TODO: maybe we should change activation_post_process to _activation_post_process
  273. # to prevent it from being used by user
  274. if hasattr(module, 'activation_post_process') and \
  275. _is_activation_post_process(module.activation_post_process):
  276. delattr(module, 'activation_post_process')
  277. # remove activation_post_proceess pre and post hooks
  278. def remove_hooks(pre_hook=False):
  279. hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
  280. observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook
  281. handle_ids_to_remove = set()
  282. for handle_id, hook_fn in hook_map.items():
  283. if hook_fn is observer_hook:
  284. handle_ids_to_remove.add(handle_id)
  285. for handle_id in handle_ids_to_remove:
  286. hook_map.pop(handle_id)
  287. remove_hooks(pre_hook=True)
  288. remove_hooks(pre_hook=False)
  289. # TODO: rename to something more general
  290. def _remove_qconfig(module):
  291. r"""Clean up the qconfig left in the module so that new qconfig can be
  292. propagated.
  293. Args:
  294. module: module to be cleaned up
  295. """
  296. for child in module.children():
  297. _remove_qconfig(child)
  298. if hasattr(module, "qconfig"):
  299. del module.qconfig
  300. _remove_activation_post_process(module)
  301. def quantize(model, run_fn, run_args, mapping=None, inplace=False):
  302. r"""Quantize the input float model with post training static quantization.
  303. First it will prepare the model for calibration, then it calls
  304. `run_fn` which will run the calibration step, after that we will
  305. convert the model to a quantized model.
  306. Args:
  307. model: input float model
  308. run_fn: a calibration function for calibrating the prepared model
  309. run_args: positional arguments for `run_fn`
  310. inplace: carry out model transformations in-place, the original module is mutated
  311. mapping: correspondence between original module types and quantized counterparts
  312. Return:
  313. Quantized model.
  314. """
  315. torch._C._log_api_usage_once("quantization_api.quantize.quantize")
  316. if mapping is None:
  317. mapping = get_default_static_quant_module_mappings()
  318. if not inplace:
  319. model = copy.deepcopy(model)
  320. model.eval()
  321. prepare(model, inplace=True)
  322. run_fn(model, *run_args)
  323. convert(model, mapping, inplace=True)
  324. return model
  325. def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
  326. mapping=None, inplace=False):
  327. r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
  328. Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
  329. For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
  330. by default is performed for layers with large weights size - i.e. Linear and RNN variants.
  331. Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
  332. If `qconfig` is provided, the `dtype` argument is ignored.
  333. Args:
  334. model: input model
  335. qconfig_spec: Either:
  336. - A dictionary that maps from name or type of submodule to quantization
  337. configuration, qconfig applies to all submodules of a given
  338. module unless qconfig for the submodules are specified (when the
  339. submodule already has qconfig attribute). Entries in the dictionary
  340. need to be QConfig instances.
  341. - A set of types and/or submodule names to apply dynamic quantization to,
  342. in which case the `dtype` argument is used to specify the bit-width
  343. inplace: carry out model transformations in-place, the original module is mutated
  344. mapping: maps type of a submodule to a type of corresponding dynamically quantized version
  345. with which the submodule needs to be replaced
  346. """
  347. torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
  348. if qconfig_spec is None:
  349. if dtype == torch.qint8:
  350. qconfig_spec = {
  351. nn.Linear : default_dynamic_qconfig,
  352. nn.LSTM : default_dynamic_qconfig,
  353. nn.GRU : default_dynamic_qconfig,
  354. nn.LSTMCell : default_dynamic_qconfig,
  355. nn.RNNCell : default_dynamic_qconfig,
  356. nn.GRUCell : default_dynamic_qconfig,
  357. }
  358. elif dtype == torch.float16:
  359. qconfig_spec = {
  360. nn.Linear : float16_dynamic_qconfig,
  361. nn.LSTM : float16_dynamic_qconfig,
  362. nn.GRU : float16_dynamic_qconfig,
  363. nn.LSTMCell : float16_dynamic_qconfig,
  364. nn.RNNCell : float16_dynamic_qconfig,
  365. nn.GRUCell : float16_dynamic_qconfig,
  366. }
  367. elif dtype == torch.quint8:
  368. qconfig_spec = {
  369. nn.EmbeddingBag : float_qparams_weight_only_qconfig,
  370. nn.Embedding : float_qparams_weight_only_qconfig,
  371. }
  372. elif dtype == torch.quint4x2:
  373. qconfig_spec = {
  374. nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit,
  375. }
  376. else:
  377. raise ValueError(
  378. "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
  379. elif isinstance(qconfig_spec, set):
  380. if dtype is torch.qint8:
  381. default_qconfig = default_dynamic_qconfig
  382. elif dtype is torch.float16:
  383. default_qconfig = float16_dynamic_qconfig
  384. elif dtype is torch.quint8:
  385. default_qconfig = float_qparams_weight_only_qconfig
  386. elif dtype is torch.quint4x2:
  387. default_qconfig = float_qparams_weight_only_qconfig_4bit
  388. else:
  389. raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
  390. qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
  391. if mapping is None:
  392. mapping = get_default_dynamic_quant_module_mappings()
  393. if not inplace:
  394. model = copy.deepcopy(model)
  395. model.eval()
  396. propagate_qconfig_(model, qconfig_spec)
  397. convert(model, mapping, inplace=True)
  398. return model
  399. def prepare_qat(model, mapping=None, inplace=False):
  400. r"""
  401. Prepares a copy of the model for quantization calibration or
  402. quantization-aware training and converts it to quantized version.
  403. Quantization configuration should be assigned preemptively
  404. to individual submodules in `.qconfig` attribute.
  405. Args:
  406. model: input model to be modified in-place
  407. mapping: dictionary that maps float modules to quantized modules to be
  408. replaced.
  409. inplace: carry out model transformations in-place, the original module
  410. is mutated
  411. """
  412. torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
  413. assert model.training, "prepare_qat only works on models in training mode"
  414. if mapping is None:
  415. mapping = get_default_qat_module_mappings()
  416. if not inplace:
  417. model = copy.deepcopy(model)
  418. propagate_qconfig_(model, qconfig_dict=None)
  419. convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
  420. prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
  421. return model
  422. def quantize_qat(model, run_fn, run_args, inplace=False):
  423. r"""Do quantization aware training and output a quantized model
  424. Args:
  425. model: input model
  426. run_fn: a function for evaluating the prepared model, can be a
  427. function that simply runs the prepared model or a training
  428. loop
  429. run_args: positional arguments for `run_fn`
  430. Return:
  431. Quantized model.
  432. """
  433. torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
  434. if not inplace:
  435. model = copy.deepcopy(model)
  436. model.train()
  437. prepare_qat(model, inplace=True)
  438. run_fn(model, *run_args)
  439. convert(model, inplace=True)
  440. return model
  441. def convert(
  442. module, mapping=None, inplace=False, remove_qconfig=True,
  443. is_reference=False, convert_custom_config_dict=None):
  444. r"""Converts submodules in input module to a different module according to `mapping`
  445. by calling `from_float` method on the target module class. And remove qconfig at the
  446. end if remove_qconfig is set to True.
  447. Args:
  448. `module`: prepared and calibrated module
  449. `mapping`: a dictionary that maps from source module type to target
  450. module type, can be overwritten to allow swapping user defined
  451. Modules
  452. `inplace`: carry out model transformations in-place, the original module
  453. is mutated
  454. `convert_custom_config_dict`: custom configuration dictionary for convert function
  455. .. code-block:: python
  456. # Example of convert_custom_config_dict:
  457. convert_custom_config_dict = {
  458. # user will manually define the corresponding quantized
  459. # module class which has a from_observed class method that converts
  460. # observed custom module to quantized custom module
  461. "observed_to_quantized_custom_module_class": {
  462. ObservedCustomModule: QuantizedCustomModule
  463. }
  464. }
  465. """
  466. torch._C._log_api_usage_once("quantization_api.quantize.convert")
  467. if not inplace:
  468. module = copy.deepcopy(module)
  469. _convert(
  470. module, mapping, inplace=True, is_reference=is_reference,
  471. convert_custom_config_dict=convert_custom_config_dict)
  472. if remove_qconfig:
  473. _remove_qconfig(module)
  474. return module
  475. def _convert(
  476. module, mapping=None, inplace=False,
  477. is_reference=False, convert_custom_config_dict=None):
  478. r"""Converts submodules in input module to a different module according to `mapping`
  479. by calling `from_float` method on the target module class
  480. Args:
  481. module: input module
  482. mapping: a dictionary that maps from source module type to target
  483. module type, can be overwritten to allow swapping user defined
  484. Modules
  485. inplace: carry out model transformations in-place, the original module
  486. is mutated
  487. is_reference: a flag to enable quantized reference module
  488. """
  489. if mapping is None:
  490. mapping = get_default_static_quant_reference_module_mappings() if is_reference \
  491. else get_default_static_quant_module_mappings()
  492. if convert_custom_config_dict is None:
  493. convert_custom_config_dict = get_default_custom_config_dict()
  494. custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
  495. if not inplace:
  496. module = copy.deepcopy(module)
  497. reassign = {}
  498. for name, mod in module.named_children():
  499. # both fused modules and observed custom modules are
  500. # swapped as one unit
  501. if not isinstance(mod, _FusedModule) and \
  502. type_before_parametrizations(mod) not in custom_module_class_mapping:
  503. _convert(mod, mapping, True, # inplace
  504. is_reference, convert_custom_config_dict)
  505. reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
  506. for key, value in reassign.items():
  507. module._modules[key] = value
  508. return module
  509. def swap_module(mod, mapping, custom_module_class_mapping):
  510. r"""Swaps the module if it has a quantized counterpart and it has an
  511. `observer` attached.
  512. Args:
  513. mod: input module
  514. mapping: a dictionary that maps from nn module to nnq module
  515. Return:
  516. The corresponding quantized module of `mod`
  517. """
  518. new_mod = mod
  519. if hasattr(mod, 'qconfig') and mod.qconfig is not None:
  520. swapped = False
  521. if type_before_parametrizations(mod) in custom_module_class_mapping:
  522. new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod)
  523. swapped = True
  524. elif type_before_parametrizations(mod) in mapping:
  525. qmod = mapping[type_before_parametrizations(mod)]
  526. if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE:
  527. assert mod.qconfig is not None
  528. weight_post_process = mod.qconfig.weight()
  529. weight_post_process(mod.weight)
  530. weight_qparams = get_qparam_dict(weight_post_process)
  531. new_mod = qmod.from_float(mod, weight_qparams)
  532. else:
  533. new_mod = qmod.from_float(mod)
  534. swapped = True
  535. if swapped:
  536. # Preserve module's pre forward hooks. They'll be called on quantized input
  537. for pre_hook_fn in mod._forward_pre_hooks.values():
  538. new_mod.register_forward_pre_hook(pre_hook_fn)
  539. # Preserve module's post forward hooks except _observer_forward_hook
  540. # After convert they'll work with quantized output
  541. for hook_fn in mod._forward_hooks.values():
  542. if hook_fn is not _observer_forward_hook:
  543. new_mod.register_forward_hook(hook_fn)
  544. # respect device affinity when swapping modules
  545. devices = _get_unique_devices_(mod)
  546. assert len(devices) <= 1, (
  547. "swap_module only works with cpu or single-device CUDA modules, "
  548. "but got devices {}".format(devices)
  549. )
  550. device = next(iter(devices)) if len(devices) > 0 else None
  551. if device:
  552. new_mod.to(device)
  553. return new_mod
  554. def _get_observer_dict(mod, target_dict, prefix=""):
  555. r"""Traverse the modules and save all observers into dict.
  556. This is mainly used for quantization accuracy debug
  557. Args:
  558. mod: the top module we want to save all observers
  559. prefix: the prefix for the current module
  560. target_dict: the dictionary used to save all the observers
  561. """
  562. def get_prefix(prefix):
  563. return prefix if prefix == "" else prefix + '.'
  564. if hasattr(mod, 'activation_post_process'):
  565. target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
  566. for name, child in mod.named_children():
  567. module_prefix = get_prefix(prefix) + name if prefix else name
  568. _get_observer_dict(child, target_dict, module_prefix)