12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- from torch import nn
- class QuantStub(nn.Module):
- r"""Quantize stub module, before calibration, this is same as an observer,
- it will be swapped as `nnq.Quantize` in `convert`.
- Args:
- qconfig: quantization configuration for the tensor,
- if qconfig is not provided, we will get qconfig from parent modules
- """
- def __init__(self, qconfig=None):
- super().__init__()
- if qconfig:
- self.qconfig = qconfig
- def forward(self, x):
- return x
- class DeQuantStub(nn.Module):
- r"""Dequantize stub module, before calibration, this is same as identity,
- this will be swapped as `nnq.DeQuantize` in `convert`.
- Args:
- qconfig: quantization configuration for the tensor,
- if qconfig is not provided, we will get qconfig from parent modules
- """
- def __init__(self, qconfig=None):
- super().__init__()
- if qconfig:
- self.qconfig = qconfig
- def forward(self, x):
- return x
- class QuantWrapper(nn.Module):
- r"""A wrapper class that wraps the input module, adds QuantStub and
- DeQuantStub and surround the call to module with call to quant and dequant
- modules.
- This is used by the `quantization` utility functions to add the quant and
- dequant modules, before `convert` function `QuantStub` will just be observer,
- it observes the input tensor, after `convert`, `QuantStub`
- will be swapped to `nnq.Quantize` which does actual quantization. Similarly
- for `DeQuantStub`.
- """
- quant: QuantStub
- dequant: DeQuantStub
- module: nn.Module
- def __init__(self, module):
- super().__init__()
- qconfig = module.qconfig if hasattr(module, 'qconfig') else None
- self.add_module('quant', QuantStub(qconfig))
- self.add_module('dequant', DeQuantStub(qconfig))
- self.add_module('module', module)
- self.train(module.training)
- def forward(self, X):
- X = self.quant(X)
- X = self.module(X)
- return self.dequant(X)
|