123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- import torch
- from torch.ao.quantization.qconfig import QConfig
- from torch.ao.quantization.quant_type import QuantType
- from torch.jit._recursive import wrap_cpp_module
- __all__ = [
- "script_qconfig",
- "script_qconfig_dict",
- "fuse_conv_bn_jit",
- "prepare_jit",
- "prepare_dynamic_jit",
- "convert_jit",
- "convert_dynamic_jit",
- "quantize_jit",
- "quantize_dynamic_jit",
- ]
- def _check_is_script_module(model):
- if not isinstance(model, torch.jit.ScriptModule):
- raise ValueError('input must be a script module, got: ' + str(type(model)))
- def _check_forward_method(model):
- if not model._c._has_method('forward'):
- raise ValueError('input script module does not have forward method')
- def script_qconfig(qconfig):
- r"""Instantiate the activation and weight observer modules and script
- them, these observer module instances will be deepcopied during
- prepare_jit step.
- """
- return QConfig(
- activation=torch.jit.script(qconfig.activation())._c,
- weight=torch.jit.script(qconfig.weight())._c)
- def script_qconfig_dict(qconfig_dict):
- r"""Helper function used by `prepare_jit`.
- Apply `script_qconfig` for all entries in `qconfig_dict` that is
- not None.
- """
- return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
- def fuse_conv_bn_jit(model, inplace=False):
- r""" Fuse conv - bn module
- Works for eval model only.
- Args:
- model: TorchScript model from scripting or tracing
- """
- torch._C._log_api_usage_once("quantization_api.quantize_jit.fuse_conv_bn_jit")
- model_c = model._c
- model_c = torch._C._jit_pass_fold_convbn(model_c)
- if inplace:
- model._reconstruct(model_c)
- else:
- model = wrap_cpp_module(model_c)
- return model
- def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
- _check_is_script_module(model)
- _check_forward_method(model)
- if not all(isinstance(x, str) for x in qconfig_dict.keys()):
- raise ValueError('qconfig_dict should only contain names(str) as keys.')
- scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
- model = fuse_conv_bn_jit(model, inplace)
- model_c = torch._C._jit_pass_insert_observers(model._c,
- 'forward',
- scripted_qconfig_dict,
- inplace,
- quant_type)
- if inplace:
- model._reconstruct(model_c)
- else:
- model = wrap_cpp_module(model_c)
- return model
- def _prepare_ondevice_jit(model, qconfig_dict, method_name='forward', inplace=False, quant_type=QuantType.STATIC):
- _check_is_script_module(model)
- if not all(isinstance(x, str) for x in qconfig_dict.keys()):
- raise ValueError('qconfig_dict should only contain names(str) as keys.')
- scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
- method_graph = model._c._get_method(method_name).graph
- torch._C._jit_pass_inline(method_graph)
- model = fuse_conv_bn_jit(model, inplace)
- model_c = torch._C._jit_pass_insert_observer_method_for_ondevice_ptq(model._c,
- method_name,
- scripted_qconfig_dict,
- inplace,
- quant_type)
- if inplace:
- model._reconstruct(model_c)
- else:
- model = wrap_cpp_module(model_c)
- return model
- def prepare_jit(model, qconfig_dict, inplace=False):
- torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_jit")
- return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
- def prepare_dynamic_jit(model, qconfig_dict, inplace=False):
- torch._C._log_api_usage_once("quantization_api.quantize_jit.prepare_dynamic_jit")
- return _prepare_jit(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
- def _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
- return _prepare_ondevice_jit(model, qconfig_dict, method_name, inplace, quant_type=QuantType.DYNAMIC)
- def _convert_jit(model, inplace=False, debug=False, quant_type=QuantType.STATIC,
- preserved_attrs=None):
- _check_is_script_module(model)
- model.eval()
- model_c = model._c
- model_c = torch._C._jit_pass_insert_quant_dequant(model_c, 'forward', inplace, debug, quant_type)
- if not debug:
- is_xpu = all(p.device.type == 'xpu' for p in model.parameters())
- if not is_xpu:
- # Moving model parameters to CPU since quantized operators
- # are only supported on CPU and XPU right now
- model.cpu()
- if preserved_attrs is None:
- preserved_attrs = []
- model_c = torch._C._jit_pass_quant_finalize(model_c, quant_type, preserved_attrs)
- if inplace:
- model._reconstruct(model_c)
- else:
- model = wrap_cpp_module(model_c)
- torch._C._jit_pass_constant_propagation(model.graph)
- torch._C._jit_pass_dce(model.graph)
- return model
- def _convert_ondevice_jit(model, method_name, inplace=False, debug=False, quant_type=QuantType.STATIC):
- _check_is_script_module(model)
- assert quant_type == QuantType.DYNAMIC, "This API, while should work for static quant, is only tested for dynamic quant."
- assert not method_name.startswith("observe_"), "Pass in valid method to be quantized, e.g. forward"
- observe_method_name = "observe_" + method_name
- quantize_method_name = "quantize_" + method_name
- model_c = model._c
- model_c = torch._C._jit_pass_insert_quant_dequant_for_ondevice_ptq(
- model._c, observe_method_name, inplace, debug, QuantType.DYNAMIC)
- model_c = torch._C._jit_pass_quant_finalize_for_ondevice_ptq(model_c, QuantType.DYNAMIC, quantize_method_name)
- if inplace:
- model._reconstruct(model_c)
- else:
- model = wrap_cpp_module(model_c)
- return model
- def convert_jit(model, inplace=False, debug=False, preserved_attrs=None):
- torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_jit")
- return _convert_jit(model, inplace, debug, quant_type=QuantType.STATIC, preserved_attrs=preserved_attrs)
- def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None):
- torch._C._log_api_usage_once("quantization_api.quantize_jit.convert_dynamic_jit")
- return _convert_jit(model, inplace, debug, quant_type=QuantType.DYNAMIC, preserved_attrs=preserved_attrs)
- def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False):
- return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC)
- def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False):
- model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace)
- model = _convert_ondevice_dynamic_jit(model, method_name, inplace)
- return model
- def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC):
- # Always do inplace convert because the Tensor is already
- # copied in prepare_jit when inplace is False
- if quant_type == QuantType.DYNAMIC:
- model = prepare_dynamic_jit(model, qconfig_dict, inplace)
- model = convert_dynamic_jit(model, True, debug)
- else:
- assert run_fn, "Must provide calibration function for post training static quantization"
- assert run_args, "Must provide calibration dataset for post training static quantization"
- model = prepare_jit(model, qconfig_dict, inplace)
- run_fn(model, *run_args)
- model = convert_jit(model, True, debug)
- torch._C._jit_pass_constant_propagation(model.graph)
- torch._C._jit_pass_dce(model.graph)
- return model
- def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
- r"""Quantize the input float TorchScript model with
- post training static quantization.
- First it will prepare the model for calibration, then it calls
- `run_fn` which will run the calibration step, after that we will
- convert the model to a quantized model.
- Args:
- `model`: input float TorchScript model
- `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
- qconfig for that module as value, empty key means the qconfig will be applied
- to whole model unless it's overwritten by more specific configurations, the
- qconfig for each module is either found in the dictionary or fallback to
- the qconfig of parent module.
- Right now qconfig_dict is the only way to configure how the model is quantized,
- and it is done in the granularity of module, that is, we only support one type
- of qconfig for each torch.nn.Module, and the qconfig for sub module will
- override the qconfig for parent module, empty string means global configuration.
- `run_fn`: a calibration function for calibrating the prepared model
- `run_args`: positional arguments for `run_fn`
- `inplace`: carry out model transformations in-place, the original module is
- mutated
- `debug`: flag for producing a debug friendly model (preserve weight attribute)
- Return:
- Quantized TorchSciprt model.
- Example:
- ```python
- import torch
- from torch.ao.quantization import get_default_qconfig
- from torch.ao.quantization import quantize_jit
- ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
- qconfig = get_default_qconfig('fbgemm')
- def calibrate(model, data_loader):
- model.eval()
- with torch.no_grad():
- for image, target in data_loader:
- model(image)
- quantized_model = quantize_jit(
- ts_model,
- {'': qconfig},
- calibrate,
- [data_loader_test])
- ```
- """
- torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_jit")
- return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC)
- def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
- r"""Quantize the input float TorchScript model with
- post training dynamic quantization.
- Currently only qint8 quantization of torch.nn.Linear is supported.
- Args:
- `model`: input float TorchScript model
- `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
- qconfig for that module as value, please see detailed
- descriptions in :func:`~torch.ao.quantization.quantize_jit`
- `inplace`: carry out model transformations in-place, the original module is
- mutated
- `debug`: flag for producing a debug friendly model (preserve weight attribute)
- Return:
- Quantized TorchSciprt model.
- Example:
- ```python
- import torch
- from torch.ao.quantization import per_channel_dynamic_qconfig
- from torch.ao.quantization import quantize_dynmiac_jit
- ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
- qconfig = get_default_qconfig('fbgemm')
- def calibrate(model, data_loader):
- model.eval()
- with torch.no_grad():
- for image, target in data_loader:
- model(image)
- quantized_model = quantize_dynamic_jit(
- ts_model,
- {'': qconfig},
- calibrate,
- [data_loader_test])
- ```
- """
- torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit")
- return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
- def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False):
- r"""Prepares the input float TorchScript model with
- *on-device* post training dynamic quantization.
- Currently only qint8 quantization of torch.nn.Linear is supported.
- Args:
- `model`: input float TorchScript model
- `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
- qconfig for that module as value, please see detailed
- `method_name`: Name of the method within the model, to be prepared for quantization
- descriptions in :func:`~torch.ao.quantization.quantize_jit`
- `inplace`: carry out model transformations in-place, the original module is
- mutated
- Return:
- TorchScript model that is ready for on device quantization.
- This means that the returned
- model has:
- - Method is inlined.
- - Model has observer modules inserted in the model.
- - Model has packed params inserted in the model. However they are empty as in they dont
- contain valid quantized weights.
- - observe_<method_name> is added that observe the values to be quantized.
- - reset_observers_<method_name> to reset observers.
- - quantize_<method_name> is added to the model.
- - This method extract scale, zero points.
- - Quantizes observed weights.
- - Creates packed params from it and update the attribute of the model with the new values
- for the packed params.
- - Reset the original fp32 weights with empty tensor using SetAttr.
- - quantized_<method_name> is added to the model.
- - This method uses quantized weights and quantized linear ops instead of fp32 op.
- - This method should be used for inference post PTQ.
- - Note that all method's signatures should be the same as method_name.
- Later on device:
- - Run reset_observers_<method_name>
- - Run observe_<method_name>
- - Run quantize_<method_name>
- - Now model can be saved and loaded later.
- - Run model with quantized_<method_name>
- Example:
- ```python
- import torch
- from torch.ao.quantization import per_channel_dynamic_qconfig
- from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit
- ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
- qconfig = get_default_qconfig('fbgemm')
- quant_ready_model = _quantize_ondevice_dynamic_jit(
- ts_model,
- {'': qconfig},
- 'forward',
- True)
- ```
- """
- return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace)
|