123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663 |
- import copy
- import itertools
- import warnings
- import torch
- import torch.nn as nn
- import torch.ao.nn.quantized as nnq
- from torch.ao.nn.intrinsic import _FusedModule
- from torch.ao.quantization.quantization_mappings import (
- get_default_dynamic_quant_module_mappings,
- get_default_static_quant_module_mappings,
- get_default_static_quant_reference_module_mappings,
- get_default_qat_module_mappings,
- get_default_qconfig_propagation_list,
- no_observer_set,
- _has_special_act_post_process,
- _get_special_act_post_process,
- )
- from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
- from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
- from torch.ao.quantization.qconfig import (
- _add_module_to_qconfig_obs_ctr,
- default_dynamic_qconfig,
- float16_dynamic_qconfig,
- float_qparams_weight_only_qconfig,
- float_qparams_weight_only_qconfig_4bit,
- _activation_is_memoryless)
- from torch.nn.utils.parametrize import type_before_parametrizations
- from torch.ao.quantization.observer import _is_activation_post_process
- # TODO remove this once BC is no longer required to avoid a SEV
- from torch.ao.quantization.observer import ( # noqa: F401
- _is_activation_post_process as is_activation_post_process
- )
- __all__ = [
- "get_default_custom_config_dict",
- "propagate_qconfig_",
- "add_quant_dequant",
- "prepare",
- "quantize",
- "quantize_dynamic",
- "prepare_qat",
- "quantize_qat",
- "convert",
- "swap_module",
- ]
- _DEFAULT_CUSTOM_CONFIG_DICT = {
- 'float_to_observed_custom_module_class': {
- nn.LSTM: nn.quantizable.LSTM,
- nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
- },
- 'observed_to_quantized_custom_module_class': {
- nn.quantizable.LSTM: nn.quantized.LSTM,
- nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
- }
- }
- def get_default_custom_config_dict():
- r"""Defines the default custom config dict.
- """
- return _DEFAULT_CUSTOM_CONFIG_DICT
- def _propagate_qconfig_helper(module, qconfig_dict,
- qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
- r"""This is a helper function for `propagate_qconfig_`
- Args:
- module: input module
- qconfig_dict: dictionary that maps from name of submodule to quantization
- configuration
- qconfig_parent: quantization config of parent module, we will fallback to
- this config when there is no specified config for current
- module
- prefix: corresponding prefix of the current module, used as key in
- qconfig_dict
- prepare_custom_config_dict: dictionary for custom handling of modules
- see docs for :func:`~torch.ao.quantization.prepare_fx`
- Return:
- None, module is modified inplace with qconfig attached
- """
- module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
- module_qconfig = qconfig_dict.get(prefix, module_qconfig)
- module_qconfig = getattr(module, 'qconfig', module_qconfig)
- torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)
- qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
- module.qconfig = qconfig_with_device_check
- for name, child in module.named_children():
- module_prefix = prefix + '.' + name if prefix else name
- # do no not propagate qconfig to child if child is non traceable
- if prepare_custom_config_dict is None or not (
- name in prepare_custom_config_dict.get("non_traceable_module_name", [])
- or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
- ):
- _propagate_qconfig_helper(
- child, qconfig_dict, qconfig_with_device_check, module_prefix
- )
- def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
- r"""Propagate qconfig through the module hierarchy and assign `qconfig`
- attribute on each leaf module
- Args:
- module: input module
- qconfig_dict: dictionary that maps from name or type of submodule to
- quantization configuration, qconfig applies to all submodules of a
- given module unless qconfig for the submodules are specified (when
- the submodule already has qconfig attribute)
- prepare_custom_config_dict: dictionary for custom handling of modules
- see docs for :func:`~torch.ao.quantization.prepare_fx`
- Return:
- None, module is modified inplace with qconfig attached
- """
- if qconfig_dict is None:
- qconfig_dict = {}
- if prepare_custom_config_dict is None:
- prepare_custom_config_dict = {}
- _propagate_qconfig_helper(module, qconfig_dict, prepare_custom_config_dict=prepare_custom_config_dict)
- def _observer_forward_hook(self, input, output):
- r"""Forward hook that calls observer on the output
- """
- return self.activation_post_process(output)
- def _observer_forward_pre_hook(self, input):
- r"""Forward pre hook that calls observer on the output
- """
- return self.activation_post_process(input[0])
- def _register_activation_post_process_hook(module, pre_hook=False):
- assert hasattr(module, 'activation_post_process'), \
- 'Expect activation_post_process attribute already attached to the module'
- if pre_hook:
- handle = module.register_forward_pre_hook(
- _observer_forward_pre_hook, prepend=True
- )
- else:
- handle = module.register_forward_hook(
- _observer_forward_hook, prepend=True
- )
- def _add_observer_(module, qconfig_propagation_list=None, non_leaf_module_list=None, device=None, custom_module_class_mapping=None):
- r"""Add observer for the leaf child of the module.
- This function insert observer module to all leaf child module that
- has a valid qconfig attribute.
- Args:
- module: input module with qconfig attributes for all the leaf modules that we want to quantize
- qconfig_propagation_list: a list of quantizable modules that will have observers added to them
- if they are leaf nodes
- device: parent device, if any
- non_leaf_module_list: list of non-leaf modules we want to add observer
- Return:
- None, module is modified inplace with added observer modules and forward_hooks
- """
- if qconfig_propagation_list is None:
- qconfig_propagation_list = get_default_qconfig_propagation_list()
- if custom_module_class_mapping is None:
- custom_module_class_mapping = {}
- # respect device affinity when adding observers
- if device is None:
- devices = _get_unique_devices_(module)
- assert len(devices) <= 1, (
- "_add_observer_ only works with cpu or single-device CUDA modules, "
- "but got devices {}".format(devices)
- )
- device = next(iter(devices)) if len(devices) > 0 else None
- def get_activation_post_process(qconfig, device, special_act_post_process=None):
- activation = qconfig.activation() if special_act_post_process is None else special_act_post_process()
- if device is not None:
- activation.to(device)
- return activation
- def needs_observation(m):
- return hasattr(m, 'qconfig') and m.qconfig is not None
- def insert_activation_post_process(m, special_act_post_process=None):
- """ Adds an activation post process module and register
- a pre or post hook that calls the module
- """
- # We don't insert observer/fake_quantize for DeQuantStub
- if needs_observation(m) and not isinstance(m, DeQuantStub):
- # observer and hook will be gone after we swap the module
- m.add_module('activation_post_process', get_activation_post_process(
- m.qconfig, device, special_act_post_process))
- # Register observer as the first entry in the hook list
- # All post forward hooks are preserved and will be executed after the observer before convert
- _register_activation_post_process_hook(m, pre_hook=_activation_is_memoryless(m.qconfig))
- for name, child in module.named_children():
- # TODO remove Dropout special after codebase stable
- if type_before_parametrizations(child) in [nn.Dropout]:
- continue
- elif type_before_parametrizations(child) in [nnq.FloatFunctional, nnq.QFunctional]:
- if needs_observation(child):
- child.activation_post_process = get_activation_post_process(child.qconfig, device)
- elif isinstance(child, _FusedModule):
- # activation_post_process are now added directly to nn.Sequentail/_FusedModule
- if needs_observation(child):
- insert_activation_post_process(child)
- elif non_leaf_module_list is not None and type_before_parametrizations(child) in non_leaf_module_list:
- if needs_observation(child):
- insert_activation_post_process(child)
- elif _has_special_act_post_process(child):
- special_act_post_process = _get_special_act_post_process(child)
- insert_activation_post_process(child, special_act_post_process)
- elif needs_observation(child) and type_before_parametrizations(child) in custom_module_class_mapping:
- observed_child = custom_module_class_mapping[type_before_parametrizations(child)].from_float(child)
- setattr(module, name, observed_child)
- # TODO: These are the modules that cannot be observed
- # Once there are more, we should move them to a separate list
- if custom_module_class_mapping[type_before_parametrizations(child)] not in no_observer_set():
- insert_activation_post_process(observed_child)
- else:
- _add_observer_(child, qconfig_propagation_list, non_leaf_module_list, device, custom_module_class_mapping)
- # Insert observers only for leaf nodes, note that this observer is for
- # the output of the module, for input QuantStub will observe them
- if has_no_children_ignoring_parametrizations(module) and not isinstance(module, torch.nn.Sequential) \
- and type_before_parametrizations(module) in qconfig_propagation_list:
- insert_activation_post_process(module)
- def _get_unique_devices_(module):
- return {p.device for p in module.parameters()} | \
- {p.device for p in module.buffers()}
- def add_quant_dequant(module):
- r"""Wrap the leaf child module in QuantWrapper if it has a valid qconfig
- Note that this function will modify the children of module inplace and it
- can return a new module which wraps the input module as well.
- Args:
- module: input module with qconfig attributes for all the leaf modules
- that we want to quantize
- Return:
- Either the inplace modified module with submodules wrapped in
- `QuantWrapper` based on qconfig or a new `QuantWrapper` module which
- wraps the input module, the latter case only happens when the input
- module is a leaf module and we want to quantize it.
- """
- if has_no_children_ignoring_parametrizations(module) and hasattr(module, 'qconfig') and module.qconfig:
- return QuantWrapper(module)
- for name, child in module.named_children():
- module._modules[name] = add_quant_dequant(child)
- return module
- def prepare(model, inplace=False, allow_list=None,
- observer_non_leaf_module_list=None,
- prepare_custom_config_dict=None):
- r"""Prepares a copy of the model for quantization calibration or quantization-aware training.
- Quantization configuration should be assigned preemptively
- to individual submodules in `.qconfig` attribute.
- The model will be attached with observer or fake quant modules, and qconfig
- will be propagated.
- Args:
- `model`: input model to be modified in-place
- `inplace`: carry out model transformations in-place, the original module is mutated
- `allow_list`: list of quantizable modules
- `observer_non_leaf_module_list`: list of non-leaf modules we want to add observer
- `prepare_custom_config_dict`: customization configuration dictionary for prepare function
- .. code-block:: python
- # Example of prepare_custom_config_dict:
- prepare_custom_config_dict = {
- # user will manually define the corresponding observed
- # module class which has a from_float class method that converts
- # float custom module to observed custom module
- "float_to_observed_custom_module_class": {
- CustomModule: ObservedCustomModule
- }
- }
- """
- torch._C._log_api_usage_once("quantization_api.quantize.prepare")
- if prepare_custom_config_dict is None:
- prepare_custom_config_dict = get_default_custom_config_dict()
- custom_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {})
- if not inplace:
- model = copy.deepcopy(model)
- # TODO: remove allow_list
- qconfig_propagation_list = allow_list
- if allow_list is None:
- qconfig_propagation_list = get_default_qconfig_propagation_list()
- propagate_qconfig_(model, qconfig_dict=None)
- # sanity check common API misusage
- if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
- warnings.warn("None of the submodule got qconfig applied. Make sure you "
- "passed correct configuration through `qconfig_dict` or "
- "by assigning the `.qconfig` attribute directly on submodules")
- _add_observer_(
- model, qconfig_propagation_list, observer_non_leaf_module_list,
- custom_module_class_mapping=custom_module_class_mapping)
- return model
- def _remove_activation_post_process(module):
- # TODO: maybe we should change activation_post_process to _activation_post_process
- # to prevent it from being used by user
- if hasattr(module, 'activation_post_process') and \
- _is_activation_post_process(module.activation_post_process):
- delattr(module, 'activation_post_process')
- # remove activation_post_proceess pre and post hooks
- def remove_hooks(pre_hook=False):
- hook_map = module._forward_pre_hooks if pre_hook else module._forward_hooks
- observer_hook = _observer_forward_pre_hook if pre_hook else _observer_forward_hook
- handle_ids_to_remove = set()
- for handle_id, hook_fn in hook_map.items():
- if hook_fn is observer_hook:
- handle_ids_to_remove.add(handle_id)
- for handle_id in handle_ids_to_remove:
- hook_map.pop(handle_id)
- remove_hooks(pre_hook=True)
- remove_hooks(pre_hook=False)
- # TODO: rename to something more general
- def _remove_qconfig(module):
- r"""Clean up the qconfig left in the module so that new qconfig can be
- propagated.
- Args:
- module: module to be cleaned up
- """
- for child in module.children():
- _remove_qconfig(child)
- if hasattr(module, "qconfig"):
- del module.qconfig
- _remove_activation_post_process(module)
- def quantize(model, run_fn, run_args, mapping=None, inplace=False):
- r"""Quantize the input float model with post training static quantization.
- First it will prepare the model for calibration, then it calls
- `run_fn` which will run the calibration step, after that we will
- convert the model to a quantized model.
- Args:
- model: input float model
- run_fn: a calibration function for calibrating the prepared model
- run_args: positional arguments for `run_fn`
- inplace: carry out model transformations in-place, the original module is mutated
- mapping: correspondence between original module types and quantized counterparts
- Return:
- Quantized model.
- """
- torch._C._log_api_usage_once("quantization_api.quantize.quantize")
- if mapping is None:
- mapping = get_default_static_quant_module_mappings()
- if not inplace:
- model = copy.deepcopy(model)
- model.eval()
- prepare(model, inplace=True)
- run_fn(model, *run_args)
- convert(model, mapping, inplace=True)
- return model
- def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8,
- mapping=None, inplace=False):
- r"""Converts a float model to dynamic (i.e. weights-only) quantized model.
- Replaces specified modules with dynamic weight-only quantized versions and output the quantized model.
- For simplest usage provide `dtype` argument that can be float16 or qint8. Weight-only quantization
- by default is performed for layers with large weights size - i.e. Linear and RNN variants.
- Fine grained control is possible with `qconfig` and `mapping` that act similarly to `quantize()`.
- If `qconfig` is provided, the `dtype` argument is ignored.
- Args:
- model: input model
- qconfig_spec: Either:
- - A dictionary that maps from name or type of submodule to quantization
- configuration, qconfig applies to all submodules of a given
- module unless qconfig for the submodules are specified (when the
- submodule already has qconfig attribute). Entries in the dictionary
- need to be QConfig instances.
- - A set of types and/or submodule names to apply dynamic quantization to,
- in which case the `dtype` argument is used to specify the bit-width
- inplace: carry out model transformations in-place, the original module is mutated
- mapping: maps type of a submodule to a type of corresponding dynamically quantized version
- with which the submodule needs to be replaced
- """
- torch._C._log_api_usage_once("quantization_api.quantize.quantize_dynamic")
- if qconfig_spec is None:
- if dtype == torch.qint8:
- qconfig_spec = {
- nn.Linear : default_dynamic_qconfig,
- nn.LSTM : default_dynamic_qconfig,
- nn.GRU : default_dynamic_qconfig,
- nn.LSTMCell : default_dynamic_qconfig,
- nn.RNNCell : default_dynamic_qconfig,
- nn.GRUCell : default_dynamic_qconfig,
- }
- elif dtype == torch.float16:
- qconfig_spec = {
- nn.Linear : float16_dynamic_qconfig,
- nn.LSTM : float16_dynamic_qconfig,
- nn.GRU : float16_dynamic_qconfig,
- nn.LSTMCell : float16_dynamic_qconfig,
- nn.RNNCell : float16_dynamic_qconfig,
- nn.GRUCell : float16_dynamic_qconfig,
- }
- elif dtype == torch.quint8:
- qconfig_spec = {
- nn.EmbeddingBag : float_qparams_weight_only_qconfig,
- nn.Embedding : float_qparams_weight_only_qconfig,
- }
- elif dtype == torch.quint4x2:
- qconfig_spec = {
- nn.EmbeddingBag : float_qparams_weight_only_qconfig_4bit,
- }
- else:
- raise ValueError(
- "Don't know how to quantize with default settings for {}. Provide full qconfig please".format(dtype))
- elif isinstance(qconfig_spec, set):
- if dtype is torch.qint8:
- default_qconfig = default_dynamic_qconfig
- elif dtype is torch.float16:
- default_qconfig = float16_dynamic_qconfig
- elif dtype is torch.quint8:
- default_qconfig = float_qparams_weight_only_qconfig
- elif dtype is torch.quint4x2:
- default_qconfig = float_qparams_weight_only_qconfig_4bit
- else:
- raise RuntimeError('Unknown dtype specified for quantize_dynamic: ', str(dtype))
- qconfig_spec = dict(zip(qconfig_spec, itertools.repeat(default_qconfig)))
- if mapping is None:
- mapping = get_default_dynamic_quant_module_mappings()
- if not inplace:
- model = copy.deepcopy(model)
- model.eval()
- propagate_qconfig_(model, qconfig_spec)
- convert(model, mapping, inplace=True)
- return model
- def prepare_qat(model, mapping=None, inplace=False):
- r"""
- Prepares a copy of the model for quantization calibration or
- quantization-aware training and converts it to quantized version.
- Quantization configuration should be assigned preemptively
- to individual submodules in `.qconfig` attribute.
- Args:
- model: input model to be modified in-place
- mapping: dictionary that maps float modules to quantized modules to be
- replaced.
- inplace: carry out model transformations in-place, the original module
- is mutated
- """
- torch._C._log_api_usage_once("quantization_api.quantize.prepare_qat")
- assert model.training, "prepare_qat only works on models in training mode"
- if mapping is None:
- mapping = get_default_qat_module_mappings()
- if not inplace:
- model = copy.deepcopy(model)
- propagate_qconfig_(model, qconfig_dict=None)
- convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
- prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
- return model
- def quantize_qat(model, run_fn, run_args, inplace=False):
- r"""Do quantization aware training and output a quantized model
- Args:
- model: input model
- run_fn: a function for evaluating the prepared model, can be a
- function that simply runs the prepared model or a training
- loop
- run_args: positional arguments for `run_fn`
- Return:
- Quantized model.
- """
- torch._C._log_api_usage_once("quantization_api.quantize.quantize_qat")
- if not inplace:
- model = copy.deepcopy(model)
- model.train()
- prepare_qat(model, inplace=True)
- run_fn(model, *run_args)
- convert(model, inplace=True)
- return model
- def convert(
- module, mapping=None, inplace=False, remove_qconfig=True,
- is_reference=False, convert_custom_config_dict=None):
- r"""Converts submodules in input module to a different module according to `mapping`
- by calling `from_float` method on the target module class. And remove qconfig at the
- end if remove_qconfig is set to True.
- Args:
- `module`: prepared and calibrated module
- `mapping`: a dictionary that maps from source module type to target
- module type, can be overwritten to allow swapping user defined
- Modules
- `inplace`: carry out model transformations in-place, the original module
- is mutated
- `convert_custom_config_dict`: custom configuration dictionary for convert function
- .. code-block:: python
- # Example of convert_custom_config_dict:
- convert_custom_config_dict = {
- # user will manually define the corresponding quantized
- # module class which has a from_observed class method that converts
- # observed custom module to quantized custom module
- "observed_to_quantized_custom_module_class": {
- ObservedCustomModule: QuantizedCustomModule
- }
- }
- """
- torch._C._log_api_usage_once("quantization_api.quantize.convert")
- if not inplace:
- module = copy.deepcopy(module)
- _convert(
- module, mapping, inplace=True, is_reference=is_reference,
- convert_custom_config_dict=convert_custom_config_dict)
- if remove_qconfig:
- _remove_qconfig(module)
- return module
- def _convert(
- module, mapping=None, inplace=False,
- is_reference=False, convert_custom_config_dict=None):
- r"""Converts submodules in input module to a different module according to `mapping`
- by calling `from_float` method on the target module class
- Args:
- module: input module
- mapping: a dictionary that maps from source module type to target
- module type, can be overwritten to allow swapping user defined
- Modules
- inplace: carry out model transformations in-place, the original module
- is mutated
- is_reference: a flag to enable quantized reference module
- """
- if mapping is None:
- mapping = get_default_static_quant_reference_module_mappings() if is_reference \
- else get_default_static_quant_module_mappings()
- if convert_custom_config_dict is None:
- convert_custom_config_dict = get_default_custom_config_dict()
- custom_module_class_mapping = convert_custom_config_dict.get("observed_to_quantized_custom_module_class", {})
- if not inplace:
- module = copy.deepcopy(module)
- reassign = {}
- for name, mod in module.named_children():
- # both fused modules and observed custom modules are
- # swapped as one unit
- if not isinstance(mod, _FusedModule) and \
- type_before_parametrizations(mod) not in custom_module_class_mapping:
- _convert(mod, mapping, True, # inplace
- is_reference, convert_custom_config_dict)
- reassign[name] = swap_module(mod, mapping, custom_module_class_mapping)
- for key, value in reassign.items():
- module._modules[key] = value
- return module
- def swap_module(mod, mapping, custom_module_class_mapping):
- r"""Swaps the module if it has a quantized counterpart and it has an
- `observer` attached.
- Args:
- mod: input module
- mapping: a dictionary that maps from nn module to nnq module
- Return:
- The corresponding quantized module of `mod`
- """
- new_mod = mod
- if hasattr(mod, 'qconfig') and mod.qconfig is not None:
- swapped = False
- if type_before_parametrizations(mod) in custom_module_class_mapping:
- new_mod = custom_module_class_mapping[type_before_parametrizations(mod)].from_observed(mod)
- swapped = True
- elif type_before_parametrizations(mod) in mapping:
- qmod = mapping[type_before_parametrizations(mod)]
- if hasattr(qmod, '_IS_REFERENCE') and qmod._IS_REFERENCE:
- assert mod.qconfig is not None
- weight_post_process = mod.qconfig.weight()
- weight_post_process(mod.weight)
- weight_qparams = get_qparam_dict(weight_post_process)
- new_mod = qmod.from_float(mod, weight_qparams)
- else:
- new_mod = qmod.from_float(mod)
- swapped = True
- if swapped:
- # Preserve module's pre forward hooks. They'll be called on quantized input
- for pre_hook_fn in mod._forward_pre_hooks.values():
- new_mod.register_forward_pre_hook(pre_hook_fn)
- # Preserve module's post forward hooks except _observer_forward_hook
- # After convert they'll work with quantized output
- for hook_fn in mod._forward_hooks.values():
- if hook_fn is not _observer_forward_hook:
- new_mod.register_forward_hook(hook_fn)
- # respect device affinity when swapping modules
- devices = _get_unique_devices_(mod)
- assert len(devices) <= 1, (
- "swap_module only works with cpu or single-device CUDA modules, "
- "but got devices {}".format(devices)
- )
- device = next(iter(devices)) if len(devices) > 0 else None
- if device:
- new_mod.to(device)
- return new_mod
- def _get_observer_dict(mod, target_dict, prefix=""):
- r"""Traverse the modules and save all observers into dict.
- This is mainly used for quantization accuracy debug
- Args:
- mod: the top module we want to save all observers
- prefix: the prefix for the current module
- target_dict: the dictionary used to save all the observers
- """
- def get_prefix(prefix):
- return prefix if prefix == "" else prefix + '.'
- if hasattr(mod, 'activation_post_process'):
- target_dict[get_prefix(prefix) + 'activation_post_process'] = mod.activation_post_process
- for name, child in mod.named_children():
- module_prefix = get_prefix(prefix) + name if prefix else name
- _get_observer_dict(child, target_dict, module_prefix)
|