stubs.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from torch import nn
  2. class QuantStub(nn.Module):
  3. r"""Quantize stub module, before calibration, this is same as an observer,
  4. it will be swapped as `nnq.Quantize` in `convert`.
  5. Args:
  6. qconfig: quantization configuration for the tensor,
  7. if qconfig is not provided, we will get qconfig from parent modules
  8. """
  9. def __init__(self, qconfig=None):
  10. super().__init__()
  11. if qconfig:
  12. self.qconfig = qconfig
  13. def forward(self, x):
  14. return x
  15. class DeQuantStub(nn.Module):
  16. r"""Dequantize stub module, before calibration, this is same as identity,
  17. this will be swapped as `nnq.DeQuantize` in `convert`.
  18. Args:
  19. qconfig: quantization configuration for the tensor,
  20. if qconfig is not provided, we will get qconfig from parent modules
  21. """
  22. def __init__(self, qconfig=None):
  23. super().__init__()
  24. if qconfig:
  25. self.qconfig = qconfig
  26. def forward(self, x):
  27. return x
  28. class QuantWrapper(nn.Module):
  29. r"""A wrapper class that wraps the input module, adds QuantStub and
  30. DeQuantStub and surround the call to module with call to quant and dequant
  31. modules.
  32. This is used by the `quantization` utility functions to add the quant and
  33. dequant modules, before `convert` function `QuantStub` will just be observer,
  34. it observes the input tensor, after `convert`, `QuantStub`
  35. will be swapped to `nnq.Quantize` which does actual quantization. Similarly
  36. for `DeQuantStub`.
  37. """
  38. quant: QuantStub
  39. dequant: DeQuantStub
  40. module: nn.Module
  41. def __init__(self, module):
  42. super().__init__()
  43. qconfig = module.qconfig if hasattr(module, 'qconfig') else None
  44. self.add_module('quant', QuantStub(qconfig))
  45. self.add_module('dequant', DeQuantStub(qconfig))
  46. self.add_module('module', module)
  47. self.train(module.training)
  48. def forward(self, X):
  49. X = self.quant(X)
  50. X = self.module(X)
  51. return self.dequant(X)