_learnable_fake_quantize.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import torch
  2. from torch.nn.parameter import Parameter
  3. from typing import List
  4. __all__: List[str] = []
  5. class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
  6. r""" This is an extension of the FakeQuantize module in fake_quantize.py, which
  7. supports more generalized lower-bit quantization and support learning of the scale
  8. and zero point parameters through backpropagation. For literature references,
  9. please see the class _LearnableFakeQuantizePerTensorOp.
  10. In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
  11. module also includes the following attributes to support quantization parameter learning.
  12. * :attr:`channel_len` defines the length of the channel when initializing scale and zero point
  13. for the per channel case.
  14. * :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
  15. normalized by the constant, which is proportional to the square root of the number of
  16. elements in the tensor. The related literature justifying the use of this particular constant
  17. can be found here: https://openreview.net/pdf?id=rkgO66VKDS.
  18. * :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.
  19. * :attr:`static_enabled` defines the flag for using observer's static estimation for
  20. scale and zero point.
  21. * :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
  22. """
  23. def __init__(self, observer, quant_min=0, quant_max=255, scale=1., zero_point=0., channel_len=-1,
  24. use_grad_scaling=False, **observer_kwargs):
  25. super().__init__()
  26. assert quant_min < quant_max, 'quant_min must be strictly less than quant_max.'
  27. self.quant_min = quant_min
  28. self.quant_max = quant_max
  29. # also pass quant_min and quant_max to observer
  30. observer_kwargs["quant_min"] = quant_min
  31. observer_kwargs["quant_max"] = quant_max
  32. self.use_grad_scaling = use_grad_scaling
  33. if channel_len == -1:
  34. self.scale = Parameter(torch.tensor([scale]))
  35. self.zero_point = Parameter(torch.tensor([zero_point]))
  36. else:
  37. assert isinstance(channel_len, int) and channel_len > 0, "Channel size must be a positive integer."
  38. self.scale = Parameter(torch.tensor([scale] * channel_len))
  39. self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
  40. self.activation_post_process = observer(**observer_kwargs)
  41. assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, \
  42. 'quant_min out of bound'
  43. assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, \
  44. 'quant_max out of bound'
  45. self.dtype = self.activation_post_process.dtype
  46. self.qscheme = self.activation_post_process.qscheme
  47. self.ch_axis = self.activation_post_process.ch_axis \
  48. if hasattr(self.activation_post_process, 'ch_axis') else -1
  49. self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
  50. self.register_buffer('static_enabled', torch.tensor([1], dtype=torch.uint8))
  51. self.register_buffer('learning_enabled', torch.tensor([0], dtype=torch.uint8))
  52. bitrange = torch.tensor(quant_max - quant_min + 1).double()
  53. self.bitwidth = int(torch.log2(bitrange).item())
  54. self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
  55. @torch.jit.export
  56. def enable_param_learning(self):
  57. r"""Enables learning of quantization parameters and
  58. disables static observer estimates. Forward path returns fake quantized X.
  59. """
  60. self.toggle_qparam_learning(enabled=True) \
  61. .toggle_fake_quant(enabled=True) \
  62. .toggle_observer_update(enabled=False)
  63. return self
  64. @torch.jit.export
  65. def enable_static_estimate(self):
  66. r"""Enables static observer estimates and disbales learning of
  67. quantization parameters. Forward path returns fake quantized X.
  68. """
  69. self.toggle_qparam_learning(enabled=False) \
  70. .toggle_fake_quant(enabled=True) \
  71. .toggle_observer_update(enabled=True)
  72. @torch.jit.export
  73. def enable_static_observation(self):
  74. r"""Enables static observer accumulating data from input but doesn't
  75. update the quantization parameters. Forward path returns the original X.
  76. """
  77. self.toggle_qparam_learning(enabled=False) \
  78. .toggle_fake_quant(enabled=False) \
  79. .toggle_observer_update(enabled=True)
  80. @torch.jit.export
  81. def toggle_observer_update(self, enabled=True):
  82. self.static_enabled[0] = int(enabled) # type: ignore[operator]
  83. return self
  84. @torch.jit.export
  85. def enable_observer(self, enabled=True):
  86. self.toggle_observer_update(enabled)
  87. @torch.jit.export
  88. def toggle_qparam_learning(self, enabled=True):
  89. self.learning_enabled[0] = int(enabled) # type: ignore[operator]
  90. self.scale.requires_grad = enabled
  91. self.zero_point.requires_grad = enabled
  92. return self
  93. @torch.jit.export
  94. def toggle_fake_quant(self, enabled=True):
  95. self.fake_quant_enabled[0] = int(enabled)
  96. return self
  97. @torch.jit.export
  98. def observe_quant_params(self):
  99. print('_LearnableFakeQuantize Scale: {}'.format(self.scale.detach()))
  100. print('_LearnableFakeQuantize Zero Point: {}'.format(self.zero_point.detach()))
  101. @torch.jit.export
  102. def calculate_qparams(self):
  103. self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
  104. scale = self.scale.detach()
  105. zero_point = self.zero_point.detach().round().clamp(self.quant_min, self.quant_max).long()
  106. return scale, zero_point
  107. def forward(self, X):
  108. if self.static_enabled[0] == 1: # type: ignore[index]
  109. self.activation_post_process(X.detach())
  110. _scale, _zero_point = self.activation_post_process.calculate_qparams()
  111. _scale = _scale.to(self.scale.device)
  112. _zero_point = _zero_point.to(self.zero_point.device)
  113. self.scale.data.copy_(_scale)
  114. self.zero_point.data.copy_(_zero_point)
  115. else:
  116. self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
  117. if self.fake_quant_enabled[0] == 1:
  118. if self.qscheme in (torch.per_channel_symmetric, torch.per_tensor_symmetric):
  119. self.zero_point.data.zero_()
  120. if self.use_grad_scaling:
  121. grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
  122. else:
  123. grad_factor = 1.0
  124. if self.qscheme in (
  125. torch.per_channel_symmetric, torch.per_channel_affine):
  126. X = torch._fake_quantize_learnable_per_channel_affine(
  127. X, self.scale, self.zero_point, self.ch_axis,
  128. self.quant_min, self.quant_max, grad_factor)
  129. else:
  130. X = torch._fake_quantize_learnable_per_tensor_affine(
  131. X, self.scale, self.zero_point,
  132. self.quant_min, self.quant_max, grad_factor)
  133. return X