123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766 |
- """
- Utils shared by different modes of quantization (eager/graph)
- """
- import functools
- import warnings
- from collections import OrderedDict
- from inspect import getfullargspec, signature
- from typing import Any, Callable, Dict, Optional, Tuple, Union
- import torch
- from torch.ao.quantization.quant_type import QuantType
- from torch.fx import Node
- from torch.nn.utils.parametrize import is_parametrized
- NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
- NodePattern.__module__ = "torch.ao.quantization.utils"
- # This is the Quantizer class instance from torch/quantization/fx/quantize.py.
- # Define separately to prevent circular imports.
- # TODO(future PR): improve this.
- # make this public once fixed (can't be public as is because setting the module directly
- # doesn't work)
- QuantizerCls = Any
- # Type for fusion patterns, it can be more complicated than the following actually,
- # see pattern.md for docs
- # TODO: not sure if typing supports recursive data types
- Pattern = Union[
- Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any
- ]
- Pattern.__module__ = "torch.ao.quantization.utils"
- # TODO: maybe rename this to MatchInputNode
- class MatchAllNode:
- """ A node pattern that matches all nodes, used in defining
- fusion patterns in FX Graph Mode Quantization
- """
- pass
- module_type_list = {
- torch.nn.ReLU,
- torch.nn.ReLU6,
- torch.nn.AdaptiveAvgPool1d,
- torch.nn.AdaptiveAvgPool2d,
- torch.nn.AdaptiveAvgPool3d,
- torch.nn.AvgPool1d,
- torch.nn.AvgPool2d,
- torch.nn.AvgPool3d,
- torch.nn.MaxPool1d,
- torch.nn.MaxPool2d,
- torch.nn.MaxPool3d,
- torch.nn.Identity,
- torch.nn.Hardsigmoid,
- torch.nn.Sigmoid,
- torch.nn.Tanh,
- }
- func_list = {
- torch.nn.functional.adaptive_avg_pool1d,
- torch.nn.functional.adaptive_avg_pool2d,
- torch.nn.functional.adaptive_avg_pool3d,
- torch.nn.functional.elu,
- torch.nn.functional.hardswish,
- torch.nn.functional.instance_norm,
- torch.nn.functional.layer_norm,
- torch.nn.functional.leaky_relu,
- torch.nn.functional.silu,
- torch.nn.functional.mish,
- torch.nn.functional.dropout,
- torch.nn.functional.max_pool1d,
- torch.nn.functional.max_pool2d,
- torch.nn.functional.max_pool3d,
- torch.nn.functional.relu,
- torch.nn.functional.hardtanh,
- torch.nn.functional.hardtanh_,
- torch.nn.functional.hardsigmoid,
- torch.nn.functional.sigmoid,
- torch.transpose,
- torch.repeat_interleave,
- torch.sigmoid,
- torch.squeeze,
- torch.stack,
- torch.sum,
- torch.tanh,
- torch.unsqueeze,
- torch.cat,
- }
- method_list = {
- torch.mean,
- 'relu',
- 'relu_',
- 'contiguous',
- 'detach',
- 'detach_',
- 'hardsigmoid',
- 'hardsigmoid_',
- 'permute',
- 'repeat',
- 'repeat_interleave',
- 'reshape',
- 'resize_',
- 'shape',
- 'sigmoid',
- 'sigmoid_',
- 'size',
- 'squeeze',
- 'squeeze_',
- 'tanh',
- 'tanh_',
- 'transpose',
- 'unsqueeze',
- 'unsqueeze_',
- 'view',
- }
- # TODO: not used now, remove
- def check_node(node, modules):
- # TODO: reuse is_fixed_qparam_node after we move this function to _lower_to_native_backend.py
- is_call_function = node.op == "call_function" and node.target in func_list
- is_call_method = node.op == "call_method" and node.target in method_list
- is_call_module = node.op == "call_module" and type(modules[str(node.target)]) in module_type_list
- return is_call_function, is_call_method, is_call_module
- def get_combined_dict(default_dict, additional_dict):
- d = default_dict.copy()
- d.update(additional_dict)
- return d
- def is_per_tensor(qscheme):
- return qscheme == torch.per_tensor_affine or \
- qscheme == torch.per_tensor_symmetric
- def is_per_channel(qscheme):
- return qscheme in [torch.per_channel_affine,
- torch.per_channel_affine_float_qparams,
- torch.per_channel_symmetric]
- def getattr_from_fqn(obj: Any, fqn: str) -> Any:
- """
- Given an obj and a fqn such as "foo.bar.baz", returns gm.foo.bar.baz.
- """
- return functools.reduce(getattr, fqn.split("."), obj)
- def to_underlying_dtype(qdtype):
- DTYPE_MAPPING = {
- torch.quint8: torch.uint8,
- torch.qint8: torch.int8,
- torch.qint32: torch.int32,
- torch.quint4x2: torch.uint8,
- torch.quint2x4: torch.uint8,
- }
- assert qdtype in DTYPE_MAPPING, "Unsupported dtype: " + qdtype
- return DTYPE_MAPPING[qdtype]
- def get_qparam_dict(observer_or_fake_quant):
- qscheme = observer_or_fake_quant.qscheme if hasattr(observer_or_fake_quant, "qscheme") else None
- dtype = observer_or_fake_quant.dtype
- qparams = {"qscheme": qscheme, "dtype": dtype}
- if not qscheme:
- return qparams
- if is_per_tensor(qscheme):
- qscheme = torch.per_tensor_affine
- elif is_per_channel(qscheme):
- # change symmetric to affine since we do not have symmetric
- # quantized Tensor
- if qscheme == torch.per_channel_symmetric:
- qscheme = torch.per_channel_affine
- qparams["axis"] = observer_or_fake_quant.ch_axis
- else:
- raise RuntimeError(f"Unrecognized qscheme: {qscheme}")
- # update qscheme, since we don't have symmetric quant qscheme
- # in quantized Tensor
- qparams["qscheme"] = qscheme
- scale, zero_point = observer_or_fake_quant.calculate_qparams()
- qparams["scale"] = scale
- qparams["zero_point"] = zero_point
- return qparams
- def get_swapped_custom_module_class(custom_module, custom_module_class_mapping, qconfig):
- """ Get the observed/quantized custom module class that we need
- to swap `custom_module` to
- Input:
- custom_module: input, can be an instance of either a float or observed custom module
- custom_module_class_mapping: the float to observed or observed to quantized custom module class mapping
- qconfig: qconfig configured for the custom module
- Output:
- corresponding observed/quantized custom module class for input custom module instance
- """
- quant_type = get_quant_type(qconfig)
- class_mapping = custom_module_class_mapping.get(quant_type, {})
- assert type(custom_module) in class_mapping, "did not find corresponding observed " \
- "module class for {} in mapping: {}".format(type(custom_module), class_mapping)
- return class_mapping[type(custom_module)]
- def activation_dtype(qconfig):
- assert qconfig is not None
- activation = qconfig.activation()
- return activation.dtype
- def weight_dtype(qconfig):
- assert qconfig is not None
- weight = qconfig.weight()
- return weight.dtype
- def activation_is_statically_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- quantized or not, this includes quantizing to quint8, qint8 and qint32 and float16
- """
- return (
- activation_dtype(qconfig) in [torch.quint8, torch.qint8, torch.qint32, torch.float16]
- and (not activation_is_dynamically_quantized(qconfig))
- )
- def activation_is_dynamically_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- dynamically quantized or not, this includes dynamically quantizing to
- quint8, qint8 and float16
- """
- activation_dtype, _, activation_is_dynamic = \
- get_qconfig_dtypes(qconfig)
- return activation_is_dynamic
- def activation_is_int8_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- quantized to int8 or not, this includes quantizing to quint8, qint8
- """
- return activation_dtype(qconfig) in [torch.quint8, torch.qint8]
- def activation_is_int32_quantized(qconfig):
- """ Given a qconfig, decide if the activation needs to be
- quantized to int32 or not
- """
- return activation_dtype(qconfig) == torch.qint32
- def weight_is_quantized(qconfig):
- """ Given a qconfig, decide if the weight needs to be
- quantized or not
- """
- return weight_dtype(qconfig) in [torch.quint8, torch.qint8, torch.float16, torch.quint4x2]
- def weight_is_statically_quantized(qconfig):
- """ Given a qconfig, decide if the weight needs to be statically
- quantized or not
- """
- return weight_dtype(qconfig) in [torch.quint8, torch.qint8]
- def op_is_int8_dynamically_quantized(qconfig) -> bool:
- """ Given a qconfig, returns True if this op is using int8 dynamic
- quantization
- """
- activation_dtype, weight_dtype, activation_is_dynamic = \
- get_qconfig_dtypes(qconfig)
- return (
- activation_dtype is torch.quint8 and
- # for now, the lines below assume fbgemm or qnnpack
- weight_dtype is torch.qint8 and
- activation_is_dynamic
- )
- def get_qconfig_dtypes(qconfig):
- r""" returns the qconfig tuple for qconfig:
- (activation_dtype, weight_dtype, activation_is_dynamic)
- """
- assert qconfig is not None
- activation = qconfig.activation()
- weight = qconfig.weight()
- act_is_dynamic = activation.is_dynamic if hasattr(activation, 'is_dynamic') else False
- return (activation.dtype, weight.dtype, act_is_dynamic)
- def get_quant_type(qconfig):
- assert qconfig is not None
- activation = qconfig.activation()
- weight = qconfig.weight()
- static_dtypes = [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]
- if weight.dtype in static_dtypes:
- if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
- return QuantType.DYNAMIC
- elif activation.dtype in static_dtypes:
- return QuantType.STATIC
- else:
- return QuantType.WEIGHT_ONLY
- if weight.dtype == torch.float16:
- if hasattr(activation, 'is_dynamic') and activation.is_dynamic:
- return QuantType.DYNAMIC
- elif activation.dtype == torch.float16:
- return QuantType.STATIC
- raise Exception("Unrecognized dtype combination in get_quant_type: activation({}),"
- "weight({})".format(activation.dtype, weight.dtype))
- def check_min_max_valid(min_val: torch.Tensor, max_val: torch.Tensor) -> bool:
- """ Checks if the given minimum and maximum values are valid, meaning that
- they exist and the min value is less than the max value.
- """
- if min_val.numel() == 0 or max_val.numel() == 0:
- warnings.warn(
- "must run observer before calling calculate_qparams. " +
- "Returning default values."
- )
- return False
- if min_val.dim() == 0 or max_val.dim() == 0:
- if min_val == float("inf") and max_val == float("-inf"):
- warnings.warn(
- "must run observer before calling calculate_qparams. " +
- "Returning default values."
- )
- return False
- assert min_val <= max_val, "min {} should be less than max {}".format(
- min_val, max_val
- )
- else:
- assert torch.all(
- min_val <= max_val
- ), "min {} should be less than max {}".format(min_val, max_val)
- return True
- def calculate_qmin_qmax(quant_min: int, quant_max: int, has_customized_qrange: bool, dtype: torch.dtype,
- reduce_range: bool) -> Tuple[int, int]:
- r"""Calculates actual qmin and qmax based on the quantization range,
- observer datatype and if range is reduced.
- """
- # TODO(jerryzh): Figure out why custom quant_min/quant_max are still adjusted.
- if has_customized_qrange:
- # This initialization here is to be resolve TorchScript compilation issues and allow
- # using of refinement to decouple initial_qmin and initial_qmax from quantization range.
- # The actual values of initial_qmin and initial_qmax will be reset below.
- if dtype == torch.qint32:
- initial_quant_min, initial_quant_max = 0, 2**31 - 1
- else:
- initial_quant_min, initial_quant_max = 0, 255
- # The following assignment of self.qmin and self.qmax to the local variables and the if check refine the
- # attribute from Optional valid integers for use, based on TorchScript's requirements.
- custom_quant_min, custom_quant_max = quant_min, quant_max
- if custom_quant_min is not None and custom_quant_max is not None:
- initial_quant_min, initial_quant_max = (
- custom_quant_min,
- custom_quant_max,
- )
- qrange_len = initial_quant_max - initial_quant_min + 1
- if dtype == torch.qint8:
- assert (
- 0 < qrange_len <= 256
- ), "quantization range should be positive and not exceed the maximum bit range (=256)."
- elif dtype == torch.qint32:
- assert (
- 0 < qrange_len <= 2**31
- ), "quantization range should be positive and not exceed the maximum bit range (=4294967296)."
- if reduce_range:
- quant_min, quant_max = quant_min // 2, quant_max // 2
- else:
- # Fallback onto default 8-bit qmin and qmax calculation if dynamic range is not used.
- if dtype == torch.qint8:
- if reduce_range:
- quant_min, quant_max = -64, 63
- else:
- quant_min, quant_max = -128, 127
- elif dtype == torch.quint8:
- if reduce_range:
- quant_min, quant_max = 0, 127
- else:
- quant_min, quant_max = 0, 255
- elif dtype == torch.qint32:
- quant_min, quant_max = -1 * (2 ** 31), (2 ** 31) - 1
- else:
- quant_min, quant_max = 0, 15
- return quant_min, quant_max
- def _parent_name(target):
- """
- Turn 'foo.bar' into ['foo', 'bar']
- """
- r = target.rsplit('.', 1)
- if len(r) == 1:
- return '', r[0]
- else:
- return r[0], r[1]
- def has_no_children_ignoring_parametrizations(module):
- """
- Checks if module._modules is empty or
- if module is a parametrization, checks that module._modules only has
- the 'parametrizations' module
- """
- if len(module._modules) == 0:
- return True
- elif is_parametrized(module):
- return len(module._modules) == 1 and 'parametrizations' in module._modules
- else:
- return False
- def _get_path_of_module(root: torch.nn.Module, submodule: torch.nn.Module) -> Optional[str]:
- """ Get the path (fully qualified name) of a submodule
- Example::
- >> class M(torch.nn.Module):
- def __init__(self):
- self.linear = torch.nn.Linear(5, 5)
- def forward(self, x):
- return self.linear(x)
- >> m = M()
- >> l = m.linear
- >> _get_path_of_module(m, l)
- "linear"
- """
- for n, p in root.named_modules():
- if submodule is p:
- return n
- return None
- def _get_signature_locals(f: Callable, loc: Dict[str, Any]) -> Dict[str, Any]:
- """ Get local keyword arguments
- Example::
- >> def f(self, a, b=9):
- pass
- >> loc = {"a": 6, "c": 7}
- >> _get_signature_locals(f, loc)
- {"a": 6}
- """
- return {k: v for k, v in loc.items() if k in signature(f).parameters}
- def _get_default_kwargs(f: Callable) -> "OrderedDict[str, Any]":
- """ Get all default keyword arguments from function signature
- Example::
- >> def f(self, a, b=9):
- pass
- >> _get_default_kwargs(f)
- {"b": 9}
- """
- kwargs = {}
- for name, param in signature(f).parameters.items():
- if param.default is not param.empty:
- kwargs[name] = param.default
- elif param.kind is param.VAR_POSITIONAL:
- kwargs[name] = ()
- elif param.kind is param.VAR_KEYWORD:
- kwargs[name] = {}
- return OrderedDict(kwargs)
- def _normalize_kwargs(func: Callable, loc: Dict[str, Any]) -> "OrderedDict[str, Any]":
- """ Given a function and local function arguments, normalize the keyword
- arguments by filling in default arguments from function signature
- Example::
- >> def f(self, key1=3, key2=3):
- pass
- >> loc = {"key2": 6}
- >> _normalize_kwargs(f, loc)
- {"key1": 3, "key2": 6}
- """
- default_kwargs = _get_default_kwargs(func)
- local_kwargs = _get_signature_locals(func, loc)
- normalized_kwargs = default_kwargs.copy()
- for attr, val in local_kwargs.items():
- if attr in normalized_kwargs:
- # override the default keyword arguments
- normalized_kwargs[attr] = val
- return normalized_kwargs
- def validate_qmin_qmax(quant_min: int, quant_max: int) -> None:
- r"""Validates that the user-specified quantization range is properly initialized
- and within the given bound supported by the observer dtype.
- To accommodate lower-bit quantization with respect to the existing torch.qint8 and
- torch.quint8 datatypes, the user can choose to use dynamic quantization range by passing
- in a tuple of initial qmin and qmax values. One use case is these customized qmin and qmax
- values are used to calculate static estimates of the scale and zero point for aggressive lower-bit
- fake quantization. These estimates are compared against parameters learned through backpropagation.
- The related literatures for scale and zero point via backpropagation are as follows:
- Learned Step Size Quantization: https://openreview.net/pdf?id=rkgO66VKDS
- Trained Quantization Thresholds: https://arxiv.org/pdf/1903.08066.pdf
- """
- # The variable names are prefixed with "initial" because their values (qmin and qmax) might be adjusted
- # based on whether quantization range is reduced and the datatype (signed/unsigned) used by the observer.
- assert (
- quant_min <= 0 <= quant_max
- ), "Used-specified quantization range must include 0."
- assert (
- quant_min < quant_max
- ), "qmin must be strictly less than qmax for user-specified quantization range."
- # Functionally equivalent to '_calculate_qparams' in observer.py. Observers must be torchscriptable however and qscheme
- # as far as I can tell is not allowed to passed as a parameter in torchscript functions. This makes refactoring observer
- # to use this utility a massive pain and very gross. For now Im opting just to duplicate as this code seems unlikey to change
- # (last update over 1 year ago) and when torchscript is fully deprecated we can refactor. TODO(jakeszwe, jerryzh168)
- def determine_qparams(
- min_val: torch.Tensor, max_val: torch.Tensor, quant_min: int, quant_max: int,
- dtype: torch.dtype, eps: torch.Tensor, has_customized_qrange: bool,
- qscheme: torch.qscheme = torch.per_tensor_affine) -> Tuple[torch.Tensor, torch.Tensor]:
- r"""Calculates the quantization parameters, given min and max
- value tensors. Works for both per tensor and per channel cases
- Args:
- min_val: Minimum values per channel
- max_val: Maximum values per channel
- Returns:
- scales: Scales tensor of shape (#channels,)
- zero_points: Zero points tensor of shape (#channels,)
- """
- if not check_min_max_valid(min_val, max_val):
- return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)
- min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
- max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
- device = min_val_neg.device
- scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
- zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
- if (
- qscheme == torch.per_tensor_symmetric
- or qscheme == torch.per_channel_symmetric
- ):
- max_val_pos = torch.max(-min_val_neg, max_val_pos)
- scale = max_val_pos / (float(quant_max - quant_min) / 2)
- scale = torch.max(scale, eps)
- if dtype == torch.uint8 or dtype == torch.quint8:
- if has_customized_qrange:
- # When customized quantization range is used, down-rounded midpoint of the range is chosen.
- zero_point = zero_point.new_full(
- zero_point.size(), (quant_min + quant_max) // 2
- )
- else:
- zero_point = zero_point.new_full(zero_point.size(), 128)
- elif qscheme == torch.per_channel_affine_float_qparams:
- scale = (max_val - min_val) / float(quant_max - quant_min)
- scale = torch.where(scale > eps, scale, torch.ones_like(scale))
- # We use the quantize function
- # xq = Round(Xf * inv_scale + zero_point),
- # setting zero_point to (-1 * min *inv_scale) we get
- # Xq = Round((Xf - min) * inv_scale)
- zero_point = -1 * min_val / scale
- else:
- scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
- scale = torch.max(scale, eps)
- zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
- zero_point = torch.clamp(zero_point, quant_min, quant_max)
- # For scalar values, cast them to Tensors of size 1 to keep the shape
- # consistent with default values in FakeQuantize.
- if len(scale.shape) == 0:
- # TODO: switch to scale.item() after adding JIT support
- scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
- if len(zero_point.shape) == 0:
- # TODO: switch to zero_point.item() after adding JIT support
- zero_point = torch.tensor(
- [int(zero_point)], dtype=zero_point.dtype, device=device
- )
- if qscheme == torch.per_channel_affine_float_qparams:
- zero_point = torch.tensor(
- [float(zero_point)], dtype=zero_point.dtype, device=device
- )
- return scale, zero_point
- def _get_num_pos_args(f: Callable) -> int:
- """ Get number of positional args for a function
- Example::
- >> def f(self, key1=3, key2=3):
- pass
- >> _get_num_pos_args(f)
- 3
- """
- return len(getfullargspec(f).args)
- def get_fqn_to_example_inputs(
- model: torch.nn.Module,
- example_inputs: Tuple[Any, ...]
- ) -> Dict[str, Tuple[Any, ...]]:
- """ Given a model and its example inputs, return a dictionary from
- fully qualified name of submodules to example_inputs for that submodule,
- e.g. {"linear1": (tensor1,), "linear2": (tensor2,), "sub": (tensor3,),
- "sub.linear1": (tensor4,), ...}
- Used to make quantizing submodules easier now that FX Graph Mode Quantization requries
- example inputs.
- Also works for keyword arguments with default values, we would flatten keyword
- arguments as positional arguments and fill in the missing keyword args with default
- values, e.g. if we have a forward function:
- def forward(self, x, key1=3, key2=3):
- ...
- and we call it with self.submodule(x, key2=6)
- we'll get example_inputs: (x, 3, 6)
- user can also override `key1` with positional arguments as well:
- for self.submodule(x, 5, key2=6)
- we'll get: (x, 5, 6)
- variable positional arguments and variable positional keyword arguments in forward
- function are not supported currently, so please make sure no submodules is using
- them.
- """
- root = model
- fqn_to_example_inputs = {}
- def _patched_module_call(self, *args, **kwargs):
- submodule_example_inputs = list(args).copy()
- normalized_kwargs = _normalize_kwargs(self.forward, kwargs)
- # minus 1 to skipping counting `self`
- num_args = _get_num_pos_args(self.forward) - 1
- num_to_pop = num_args - len(submodule_example_inputs)
- while num_to_pop and normalized_kwargs:
- normalized_kwargs.popitem(last=False)
- num_to_pop -= 1
- submodule_example_inputs.extend(normalized_kwargs.values())
- submodule_example_inputs_tuple = tuple(submodule_example_inputs)
- fqn = _get_path_of_module(root, self)
- if fqn is not None:
- fqn_to_example_inputs[fqn] = submodule_example_inputs_tuple
- return orig_module_call(self, *args, **kwargs)
- orig_module_call = torch.nn.Module.__call__
- torch.nn.Module.__call__ = _patched_module_call
- try:
- model(*example_inputs)
- finally:
- # restore the module call even if there is an exception
- torch.nn.Module.__call__ = orig_module_call
- return fqn_to_example_inputs
- def _get_lstm_with_individually_observed_parts(
- float_lstm: torch.nn.LSTM,
- # Use Callable instead of _PartialWrapper here to avoid circular dependencies
- linear_output_obs_ctr: Optional[Callable] = None,
- sigmoid_obs_ctr: Optional[Callable] = None,
- tanh_obs_ctr: Optional[Callable] = None,
- cell_state_obs_ctr: Optional[Callable] = None,
- hidden_state_obs_ctr: Optional[Callable] = None,
- ) -> torch.ao.nn.quantizable.LSTM:
- """
- Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
- with specific observers or fake quantizes assigned to the inner ops or submodules.
- In both eager and FX graph mode quantization, `torch.ao.nn.quantizable.LSTM` is
- used as an observed custom module, which is responsible for inserting its own
- observers. By default, all inner ops inherit the parent custom module's QConfig.
- Users who wish to override this behavior may extend `torch.ao.nn.quantizable.LSTM`
- and use this helper function to customize the observer insertion logic.
- Args:
- `float_lstm`: The float LSTM module
- `linear_output_obs_ctr`: observer or fake quantize for linear outputs Wx + b,
- where W is the weight matrix, b is the bias, and x is either the inputs
- or the hidden state from the previous layer (if any)
- `sigmoid_obs_ctr`: observer or fake quantize for sigmoid activations
- `tanh_obs_ctr`: observer or fake quantize for tanh activations
- `cell_state_obs_ctr`: observer or fake quantize for the cell state
- `hidden_state_obs_ctr`: observer or fake quantize for the hidden state and
- the output
- Return:
- A `torch.ao.nn.quantizable.LSTM` with the specified observers or fake quantizes
- attached to the inner submodules.
- """
- def make_qconfig(obs_ctr: Callable) -> torch.ao.quantization.QConfig:
- """
- Make a QConfig with fixed qparams observers or fake quantizes.
- """
- if isinstance(obs_ctr(), torch.ao.quantization.FakeQuantizeBase):
- weight = torch.ao.quantization.default_weight_fake_quant
- else:
- weight = torch.ao.quantization.default_weight_observer
- return torch.ao.quantization.QConfig(activation=obs_ctr, weight=weight)
- observed_lstm = torch.ao.nn.quantizable.LSTM(
- float_lstm.input_size, float_lstm.hidden_size, float_lstm.num_layers, float_lstm.bias,
- float_lstm.batch_first, float_lstm.dropout, float_lstm.bidirectional)
- # Assign QConfigs with fixed qparams to all inner submodules
- # Module hierarchy: LSTM > _LSTMLayer > _LSTMSingleLayer (forward or backward) > LSTMCell
- for layer in observed_lstm.layers:
- inner_layers = [layer.layer_fw]
- if float_lstm.bidirectional:
- inner_layers.append(layer.layer_bw)
- for inner_layer in inner_layers:
- cell = inner_layer.cell
- if linear_output_obs_ctr is not None:
- qconfig = make_qconfig(linear_output_obs_ctr)
- cell.igates.qconfig = qconfig
- cell.hgates.qconfig = qconfig
- if sigmoid_obs_ctr is not None:
- qconfig = make_qconfig(sigmoid_obs_ctr)
- cell.input_gate.qconfig = qconfig
- cell.forget_gate.qconfig = qconfig
- cell.output_gate.qconfig = qconfig
- if tanh_obs_ctr is not None:
- cell.cell_gate.qconfig = make_qconfig(tanh_obs_ctr)
- if cell_state_obs_ctr is not None:
- cell.fgate_cx_igate_cgate.qconfig = make_qconfig(cell_state_obs_ctr)
- obs = cell_state_obs_ctr()
- if hasattr(obs, "scale") and hasattr(obs, "zero_point"):
- cell.initial_cell_state_qparams = (obs.scale, obs.zero_point)
- cell.cell_state_dtype = obs.dtype
- if hidden_state_obs_ctr is not None:
- cell.ogate_cy.qconfig = make_qconfig(hidden_state_obs_ctr)
- obs = hidden_state_obs_ctr()
- if hasattr(obs, "scale") and hasattr(obs, "zero_point"):
- cell.initial_hidden_state_qparams = (obs.scale, obs.zero_point)
- cell.hidden_state_dtype = obs.dtype
- # need to do this here to avoid circular dependency
- from torch.ao.quantization.quantize import _add_observer_
- # Insert the observers based on the previously attached QConfigs
- # Pass in non_leaf_module_list to prevent the observers for sigmoid/tanh from being overridden
- _add_observer_( # type: ignore[attr-defined]
- observed_lstm,
- non_leaf_module_list=[torch.nn.Sigmoid, torch.nn.Tanh]
- )
- return observed_lstm
- __all__ = [
- "NodePattern",
- "Pattern",
- "MatchAllNode",
- "check_node",
- "get_combined_dict",
- "is_per_tensor",
- "is_per_channel",
- "getattr_from_fqn",
- "get_qparam_dict",
- "get_swapped_custom_module_class",
- "activation_dtype",
- "weight_dtype",
- "activation_is_statically_quantized",
- "activation_is_dynamically_quantized",
- "activation_is_int8_quantized",
- "activation_is_int32_quantized",
- "weight_is_quantized",
- "weight_is_statically_quantized",
- "op_is_int8_dynamically_quantized",
- "get_qconfig_dtypes",
- "get_quant_type",
- "check_min_max_valid",
- "calculate_qmin_qmax",
- "has_no_children_ignoring_parametrizations",
- "get_fqn_to_example_inputs",
- "to_underlying_dtype",
- "determine_qparams",
- "validate_qmin_qmax",
- ]
|