123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- """
- This module implements modules which are used to perform fake quantization
- during QAT.
- """
- import torch
- from torch.nn import Module
- from torch.ao.quantization.observer import (
- MovingAverageMinMaxObserver,
- HistogramObserver,
- MovingAveragePerChannelMinMaxObserver,
- FixedQParamsObserver,
- default_fixed_qparams_range_0to1_observer,
- default_fixed_qparams_range_neg1to1_observer,
- _with_args,
- )
- import re
- from abc import ABC, abstractmethod
- from typing import Any, Tuple
- __all__ = [
- "FakeQuantizeBase",
- "FakeQuantize",
- "FixedQParamsFakeQuantize",
- "FusedMovingAvgObsFakeQuantize",
- "disable_fake_quant",
- "disable_observer",
- "enable_fake_quant",
- "enable_observer",
- "default_fake_quant",
- "default_weight_fake_quant",
- "default_dynamic_fake_quant",
- "default_fixed_qparams_range_neg1to1_fake_quant",
- "default_fixed_qparams_range_0to1_fake_quant",
- "default_symmetric_fixed_qparams_fake_quant",
- "default_affine_fixed_qparams_fake_quant",
- "default_per_channel_weight_fake_quant",
- "default_embedding_fake_quant",
- "default_embedding_fake_quant_4bit",
- "default_histogram_fake_quant",
- "default_fused_act_fake_quant",
- "default_fused_wt_fake_quant",
- "default_fused_per_channel_wt_fake_quant",
- "fused_wt_fake_quant_range_neg_127_to_127",
- "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
- ]
- def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]
- def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]
- def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]
- def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
- return qscheme in [torch.per_channel_affine_float_qparams, ]
- class FakeQuantizeBase(ABC, Module):
- r""" Base fake quantize module
- Any fake quantize implementation should derive from this class.
- Concrete fake quantize module should follow the same API. In forward, they will update
- the statistics of the observed Tensor and fake quantize the input. They should also provide a
- `calculate_qparams` function that computes the quantization parameters given
- the collected statistics.
- """
- fake_quant_enabled: torch.Tensor
- observer_enabled: torch.Tensor
- def __init__(self):
- super().__init__()
- # fake_quant_enabled and observer_enabled are buffers to support their
- # replication in DDP. Data type is uint8 because NCCL does not support
- # bool tensors.
- self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
- self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
- @abstractmethod
- def forward(self, x):
- pass
- @abstractmethod
- def calculate_qparams(self, **kwargs):
- pass
- @torch.jit.export
- def enable_fake_quant(self, enabled: bool = True) -> None:
- self.fake_quant_enabled[0] = 1 if enabled else 0
- @torch.jit.export
- def disable_fake_quant(self):
- self.enable_fake_quant(False)
- @torch.jit.export
- def enable_observer(self, enabled: bool = True) -> None:
- self.observer_enabled[0] = 1 if enabled else 0
- @torch.jit.export
- def disable_observer(self):
- self.enable_observer(False)
- @classmethod
- def with_args(cls, **kwargs):
- fake_quant_constructor = _with_args(cls, **kwargs)
- # need to assign the correct module to fake_quantize
- # constructors to satisfy public v private requirements
- fake_quant_constructor.__module__ = "torch.ao.quantization.fake_quantize"
- return fake_quant_constructor
- class FakeQuantize(FakeQuantizeBase):
- r""" Simulate the quantize and dequantize operations in training time.
- The output of this module is given by::
- x_out = (
- clamp(round(x/scale + zero_point), quant_min, quant_max) - zero_point
- ) * scale
- * :attr:`scale` defines the scale factor used for quantization.
- * :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
- * :attr:`fake_quant_enabled` controls the application of fake quantization on tensors, note that
- statistics can still be updated.
- * :attr:`observer_enabled` controls statistics collection on tensors
- * :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
- allowable values are torch.qint8 and torch.quint8.
- Args:
- observer (module): Module for observing statistics on input tensors and calculating scale
- and zero-point.
- observer_kwargs (optional): Arguments for the observer module
- Attributes:
- activation_post_process (Module): User provided module that collects statistics on the input tensor and
- provides a method to calculate scale and zero-point.
- """
- scale: torch.Tensor
- zero_point: torch.Tensor
- def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=None, quant_max=None, **observer_kwargs):
- super().__init__()
- # Populate quant_min/quant_max to observer_kwargs if valid
- if quant_min is not None and quant_max is not None:
- assert quant_min <= quant_max, \
- 'quant_min must be less than or equal to quant_max'
- dtype = observer_kwargs.get("dtype", torch.quint8)
- if hasattr(observer, "p"):
- # In case observer is _PartialWrapper, dtype can be stored in
- # observer.p.keywords["dtype"]
- dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
- "dtype", dtype
- )
- assert torch.iinfo(dtype).min <= quant_min, 'quant_min out of bound'
- assert quant_max <= torch.iinfo(dtype).max, 'quant_max out of bound'
- observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
- self.activation_post_process = observer(**observer_kwargs)
- # TODO: keeping self.quant_min/max for BC; remove after a couple releases
- # Users should use self.activation_post_process.quant_min
- self.quant_min = self.activation_post_process.quant_min
- self.quant_max = self.activation_post_process.quant_max
- if _is_float_qparams(self.activation_post_process.qscheme):
- zero_point_dtype = torch.float
- else:
- zero_point_dtype = torch.int
- self.register_buffer('scale', torch.tensor([1.0], dtype=torch.float))
- self.register_buffer('zero_point', torch.tensor([0], dtype=zero_point_dtype))
- self.dtype = self.activation_post_process.dtype
- self.qscheme = self.activation_post_process.qscheme
- self.ch_axis = self.activation_post_process.ch_axis \
- if hasattr(self.activation_post_process, 'ch_axis') else -1
- assert _is_per_channel(self.qscheme) or \
- _is_per_tensor(self.qscheme), \
- 'Only per channel and per tensor quantization are supported in fake quantize' + \
- ' got qscheme: ' + str(self.qscheme)
- self.is_per_channel = _is_per_channel(self.qscheme)
- @torch.jit.export
- def calculate_qparams(self):
- return self.activation_post_process.calculate_qparams()
- def forward(self, X):
- if self.observer_enabled[0] == 1:
- self.activation_post_process(X.detach())
- _scale, _zero_point = self.calculate_qparams()
- _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
- if self.scale.shape != _scale.shape:
- self.scale.resize_(_scale.shape)
- self.zero_point.resize_(_zero_point.shape)
- self.scale.copy_(_scale)
- self.zero_point.copy_(_zero_point)
- if self.fake_quant_enabled[0] == 1:
- if self.is_per_channel:
- X = torch.fake_quantize_per_channel_affine(
- X, self.scale, self.zero_point,
- self.ch_axis, self.activation_post_process.quant_min, self.activation_post_process.quant_max)
- else:
- X = torch.fake_quantize_per_tensor_affine(
- X, self.scale, self.zero_point,
- self.activation_post_process.quant_min, self.activation_post_process.quant_max)
- return X
- @torch.jit.export
- def extra_repr(self):
- return 'fake_quant_enabled={}, observer_enabled={}, ' \
- 'quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, ' \
- 'scale={}, zero_point={}'.format(
- self.fake_quant_enabled, self.observer_enabled,
- self.activation_post_process.quant_min, self.activation_post_process.quant_max,
- self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
- def _save_to_state_dict(self, destination, prefix, keep_vars):
- # We cannot currently register scalar values as buffers, so need to manually
- # specify serialization here.
- super()._save_to_state_dict(destination, prefix, keep_vars)
- destination[prefix + 'scale'] = self.scale
- destination[prefix + 'zero_point'] = self.zero_point
- def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs):
- # Removing this function throws an error that the the size of the loaded tensor does not match the original size
- # i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
- local_state = ['scale', 'zero_point']
- for name in local_state:
- key = prefix + name
- if key in state_dict:
- val = state_dict[key]
- # Custom handling to allow loading scale and zero_point
- # of size N into uninitialized buffers of size 0. The
- # buffers are resized here, and the values are copied in
- # the default state_dict loading code of the parent.
- if name == 'scale':
- self.scale.resize_(val.shape)
- else:
- assert name == 'zero_point'
- self.zero_point.resize_(val.shape)
- # For torchscript module we need to update the attributes here since we do not
- # call the `_load_from_state_dict` function defined module.py
- if torch.jit.is_scripting():
- if name == 'scale':
- self.scale.copy_(val)
- else:
- assert name == 'zero_point'
- self.zero_point.copy_(val)
- elif strict:
- missing_keys.append(key)
- super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
- missing_keys, unexpected_keys, error_msgs)
- class FixedQParamsFakeQuantize(FakeQuantize):
- """ Simulate quantize and dequantize with fixed quantization
- parameters in training time. Only per tensor quantization
- is supported.
- """
- def __init__(self, observer):
- super().__init__(observer=observer)
- assert type(self.activation_post_process) == FixedQParamsObserver,\
- "%s's observer must be a %s" % (self.__class__.__name__, FixedQParamsObserver.__name__)
- self._observer_ctr = observer
- self.scale = self.activation_post_process.scale
- self.zero_point = self.activation_post_process.zero_point
- assert _is_per_tensor(self.qscheme), 'Only per tensor quantization is supported' + \
- ' FixedQParamsFakeQuantize module, got qscheme:' + str(self.qscheme)
- @torch.jit.export
- def calculate_qparams(self):
- return self.scale, self.zero_point
- @torch.jit.export
- def extra_repr(self):
- return 'fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, ' \
- 'dtype={}, quant_min={}, quant_max={}, qscheme={}'.format(
- self.fake_quant_enabled, self.observer_enabled,
- self.scale, self.zero_point, self.dtype,
- self.activation_post_process.quant_min, self.activation_post_process.quant_max, self.qscheme)
- class FusedMovingAvgObsFakeQuantize(FakeQuantize):
- r"""Fused module that is used to observe the input tensor (compute min/max), compute
- scale/zero_point and fake_quantize the tensor.
- This module uses calculation similar MovingAverageMinMaxObserver for the inputs,
- to compute the min/max values in order to compute the scale/zero_point.
- The qscheme input in the observer is used to differentiate between symmetric/affine
- quantization scheme.
- The output of this module is given by
- x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
- Similar to :class:`~torch.ao.quantization.FakeQuantize`, and accepts the same attributes as the
- base class.
- """
- def __init__(
- self,
- observer: Any = MovingAverageMinMaxObserver,
- quant_min: int = 0,
- quant_max: int = 255,
- **observer_kwargs: Any
- ) -> None:
- super().__init__(observer, quant_min, quant_max, **observer_kwargs)
- assert isinstance(self.activation_post_process, (MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver)),\
- "Fused observer+fake_quant module only works with MovingAverageMinMaxObserver"
- self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.long))
- self.register_buffer("observer_enabled", torch.tensor([1], dtype=torch.long))
- self.is_symmetric_quant = _is_symmetric_quant(self.activation_post_process.qscheme)
- @torch.jit.export
- def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
- return self.activation_post_process.calculate_qparams()
- @torch.jit.export
- def extra_repr(self) -> str:
- return (
- "fake_quant_enabled={}, observer_enabled={}, scale={}, zero_point={}, "
- "dtype={}, quant_min={}, quant_max={}, qscheme={}, reduce_range={}".format(
- self.fake_quant_enabled,
- self.observer_enabled,
- self.scale,
- self.zero_point,
- self.dtype,
- self.activation_post_process.quant_min,
- self.activation_post_process.quant_max,
- self.qscheme,
- self.activation_post_process.reduce_range,
- )
- )
- def forward(self, X: torch.Tensor) -> torch.Tensor:
- return torch.fused_moving_avg_obs_fake_quant(
- X,
- self.observer_enabled,
- self.fake_quant_enabled,
- self.activation_post_process.min_val,
- self.activation_post_process.max_val,
- self.scale,
- self.zero_point,
- self.activation_post_process.averaging_constant,
- self.activation_post_process.quant_min,
- self.activation_post_process.quant_max,
- self.ch_axis,
- self.is_per_channel,
- self.is_symmetric_quant,
- )
- default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
- dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
- """
- Default fake_quant for activations.
- """
- default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
- dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
- """
- Default fake_quant for weights.
- Observer is memoryless since averaging_constant is 1.
- """
- default_dynamic_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
- dtype=torch.quint8, averaging_constant=1)
- """
- Default dynamic fake_quant for activations.
- """
- default_fixed_qparams_range_neg1to1_fake_quant = (
- FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_neg1to1_observer)
- )
- default_fixed_qparams_range_0to1_fake_quant = (
- FixedQParamsFakeQuantize.with_args(observer=default_fixed_qparams_range_0to1_observer)
- )
- # TODO: the following 2 variables are kept for backwards compatibility; remove after a few releases
- default_symmetric_fixed_qparams_fake_quant = default_fixed_qparams_range_neg1to1_fake_quant
- default_affine_fixed_qparams_fake_quant = default_fixed_qparams_range_0to1_fake_quant
- default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- quant_min=-128,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_channel_symmetric,
- reduce_range=False,
- ch_axis=0)
- """
- Default fake_quant for per-channel weights.
- Observer is memoryless since averaging_constant is 1.
- """
- default_embedding_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- qscheme=torch.per_channel_affine_float_qparams,
- dtype=torch.quint8,
- quant_min=0,
- quant_max=255,
- ch_axis=0,
- averaging_constant=1)
- """
- Default fake_quant for embeddings.
- Observer is memoryless since averaging_constant is 1.
- """
- default_embedding_fake_quant_4bit = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- qscheme=torch.per_channel_affine_float_qparams,
- ch_axis=0,
- dtype=torch.quint4x2,
- averaging_constant=1)
- default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
- quant_min=0,
- quant_max=255,
- dtype=torch.quint8,
- qscheme=torch.per_tensor_affine,
- reduce_range=True)
- """
- Fake_quant for activations using a histogram..
- """
- default_fused_act_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
- quant_min=0,
- quant_max=255,
- dtype=torch.quint8,)
- """
- Fused version of `default_fake_quant`, with improved performance.
- """
- default_fused_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
- quant_min=-128,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_tensor_symmetric)
- """
- Fused version of `default_weight_fake_quant`, with improved performance.
- """
- default_fused_per_channel_wt_fake_quant = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- quant_min=-128,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_channel_symmetric)
- """
- Fused version of `default_per_channel_weight_fake_quant`, with improved performance.
- """
- fused_wt_fake_quant_range_neg_127_to_127 = FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
- quant_min=-127,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_tensor_symmetric,
- eps=2 ** -12)
- """
- Fused version of `default_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
- """
- fused_per_channel_wt_fake_quant_range_neg_127_to_127 = \
- FusedMovingAvgObsFakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
- quant_min=-127,
- quant_max=127,
- dtype=torch.qint8,
- qscheme=torch.per_channel_symmetric,
- eps=2 ** -12)
- """
- Fused version of `default_per_channel_weight_fake_quant`, with the 8-bit values restricted to [-127, +127], excluding -128.
- """
- def _is_fake_quant_script_module(mod):
- ''' Returns true if given mod is an instance of FakeQuantize script module.
- '''
- if isinstance(mod, torch.jit.RecursiveScriptModule):
- # qualified name looks like '__torch__.torch.ao.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
- suffix = mod._c.qualified_name.split('.', 1)[1]
- name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
- return name == 'torch.ao.quantization.fake_quantize.FakeQuantize' or \
- name == 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'
- return False
- def disable_fake_quant(mod):
- """
- Disable fake quantization for this module, if applicable. Example usage::
- # model is any PyTorch model
- model.apply(torch.ao.quantization.disable_fake_quant)
- """
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.disable_fake_quant()
- def enable_fake_quant(mod):
- """
- Enable fake quantization for this module, if applicable. Example usage::
- # model is any PyTorch model
- model.apply(torch.ao.quantization.enable_fake_quant)
- """
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.enable_fake_quant()
- def disable_observer(mod):
- """
- Disable observation for this module, if applicable. Example usage::
- # model is any PyTorch model
- model.apply(torch.ao.quantization.disable_observer)
- """
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.disable_observer()
- def enable_observer(mod):
- """
- Enable observation for this module, if applicable. Example usage::
- # model is any PyTorch model
- model.apply(torch.ao.quantization.enable_observer)
- """
- if isinstance(mod, FakeQuantizeBase) or _is_fake_quant_script_module(mod):
- mod.enable_observer()
|