activation.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import torch
  2. __all__ = [
  3. "ReLU6",
  4. "Hardswish",
  5. "ELU",
  6. "LeakyReLU",
  7. "Sigmoid",
  8. "Softmax",
  9. "MultiheadAttention",
  10. "PReLU"
  11. ]
  12. class ReLU6(torch.nn.ReLU):
  13. r"""Applies the element-wise function:
  14. :math:`\text{ReLU6}(x) = \min(\max(x_0, x), q(6))`, where :math:`x_0` is the
  15. zero_point, and :math:`q(6)` is the quantized representation of number 6.
  16. Args:
  17. inplace: can optionally do the operation in-place. Default: ``False``
  18. Shape:
  19. - Input: :math:`(N, *)` where `*` means, any number of additional
  20. dimensions
  21. - Output: :math:`(N, *)`, same shape as the input
  22. .. image:: ../scripts/activation_images/ReLU6.png
  23. Examples::
  24. >>> m = nn.quantized.ReLU6()
  25. >>> input = torch.randn(2)
  26. >>> # xdoctest: +SKIP
  27. >>> input = torch.quantize_per_tensor(input, 1.0, 0, dtype=torch.qint32)
  28. >>> output = m(input)
  29. """
  30. def __init__(self, inplace=False):
  31. super().__init__(inplace)
  32. self.inplace = inplace
  33. def forward(self, input):
  34. return torch.ops.quantized.relu6(input, self.inplace)
  35. def _get_name(self):
  36. return 'QuantizedReLU6'
  37. @staticmethod
  38. def from_float(mod):
  39. return ReLU6(mod.inplace)
  40. class Hardswish(torch.nn.Hardswish):
  41. r"""This is the quantized version of :class:`~torch.nn.Hardswish`.
  42. Args:
  43. scale: quantization scale of the output tensor
  44. zero_point: quantization zero point of the output tensor
  45. """
  46. def __init__(self, scale, zero_point):
  47. super().__init__()
  48. self.scale = scale
  49. self.zero_point = zero_point
  50. def forward(self, input):
  51. return torch.ao.nn.quantized.functional.hardswish(
  52. input, scale=self.scale, zero_point=self.zero_point)
  53. def _get_name(self):
  54. return 'QuantizedHardswish'
  55. @staticmethod
  56. def from_float(mod):
  57. scale, zero_point = mod.activation_post_process.calculate_qparams()
  58. return Hardswish(float(scale), int(zero_point))
  59. @classmethod
  60. def from_reference(cls, mod, scale, zero_point):
  61. return cls(float(scale), int(zero_point))
  62. class ELU(torch.nn.ELU):
  63. r"""This is the quantized equivalent of :class:`~torch.nn.ELU`.
  64. Args:
  65. scale: quantization scale of the output tensor
  66. zero_point: quantization zero point of the output tensor
  67. alpha: the alpha constant
  68. """
  69. def __init__(self, scale, zero_point, alpha=1.):
  70. super().__init__(alpha)
  71. self.scale = scale
  72. self.zero_point = zero_point
  73. def forward(self, input):
  74. return torch.ao.nn.quantized.functional.elu(
  75. input, self.scale, self.zero_point, self.alpha)
  76. def _get_name(self):
  77. return 'QuantizedELU'
  78. @staticmethod
  79. def from_float(mod):
  80. scale, zero_point = mod.activation_post_process.calculate_qparams()
  81. return ELU(float(scale), int(zero_point), mod.alpha)
  82. @classmethod
  83. def from_reference(cls, mod, scale, zero_point):
  84. return cls(float(scale), int(zero_point), mod.alpha)
  85. class LeakyReLU(torch.nn.LeakyReLU):
  86. r"""This is the quantized equivalent of :class:`~torch.nn.LeakyReLU`.
  87. Args:
  88. scale: quantization scale of the output tensor
  89. zero_point: quantization zero point of the output tensor
  90. negative_slope: Controls the angle of the negative slope. Default: 1e-2
  91. """
  92. def __init__(self, scale: float, zero_point: int, negative_slope: float = 1e-2,
  93. inplace: bool = False, device=None, dtype=None) -> None:
  94. factory_kwargs = {'device': device, 'dtype': dtype}
  95. super().__init__(negative_slope, inplace)
  96. self.register_buffer('scale', torch.tensor(scale, **factory_kwargs))
  97. self.register_buffer('zero_point', torch.tensor(zero_point, **factory_kwargs))
  98. def forward(self, input):
  99. return torch.ops.quantized.leaky_relu(
  100. input, self.negative_slope, self.inplace, self.scale, self.zero_point)
  101. def _get_name(self):
  102. return 'QuantizedLeakyReLU'
  103. @classmethod
  104. def from_float(cls, mod):
  105. scale, zero_point = mod.activation_post_process.calculate_qparams()
  106. return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
  107. @classmethod
  108. def from_reference(cls, mod, scale, zero_point):
  109. return cls(float(scale), int(zero_point), mod.negative_slope, mod.inplace)
  110. class Sigmoid(torch.nn.Sigmoid):
  111. r"""This is the quantized equivalent of :class:`~torch.nn.Sigmoid`.
  112. Args:
  113. scale: quantization scale of the output tensor
  114. zero_point: quantization zero point of the output tensor
  115. """
  116. def __init__(self, output_scale: float, output_zero_point: int):
  117. super().__init__()
  118. self.output_scale = output_scale
  119. self.output_zero_point = output_zero_point
  120. def forward(self, input):
  121. return torch.ops.quantized.sigmoid(input, self.output_scale, self.output_zero_point)
  122. @classmethod
  123. def from_float(cls, mod):
  124. output_scale, output_zero_point = mod.activation_post_process.calculate_qparams()
  125. return cls(float(output_scale), int(output_zero_point))
  126. class Softmax(torch.nn.Softmax):
  127. r"""This is the quantized version of :class:`~torch.nn.Softmax`.
  128. Args:
  129. dim: A dimension along which Softmax will be computed (so every slice along dim will sum to 1).
  130. scale: quantization scale of the output tensor
  131. zero_point: quantization zero point of the output tensor
  132. """
  133. def __init__(self, dim=None, scale=1.0, zero_point=0):
  134. super().__init__()
  135. self.dim = dim
  136. self.scale = scale
  137. self.zero_point = zero_point
  138. def forward(self, input):
  139. dim = self.dim
  140. if dim is None:
  141. stacklevel = 3
  142. # Note: adding the mypy ignore on _get_softmax_dim seems less bad
  143. # than making `_get_softmax_dim` an official API.
  144. dim = torch.nn.functional._get_softmax_dim( # type: ignore[attr-defined]
  145. "softmax", input.dim(), stacklevel)
  146. return torch.ops.quantized.softmax(
  147. input, dim, self.scale, self.zero_point)
  148. def _get_name(self):
  149. return 'QuantizedSoftmax'
  150. @staticmethod
  151. def from_float(mod):
  152. scale, zero_point = mod.activation_post_process.calculate_qparams()
  153. return Softmax(mod.dim, float(scale), int(zero_point))
  154. @classmethod
  155. def from_reference(cls, mod, scale, zero_point):
  156. return cls(mod.dim, float(scale), int(zero_point))
  157. class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
  158. _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
  159. def _get_name(self):
  160. return "QuantizedMultiheadAttention"
  161. @classmethod
  162. def from_float(cls, other):
  163. # The whole flow is float -> observed -> quantized
  164. # This class does observed -> quantized only
  165. raise NotImplementedError("It looks like you are trying to convert a "
  166. "non-observed MHA module. Please, see "
  167. "the examples on quantizable MHAs.")
  168. @classmethod
  169. def from_observed(cls, other):
  170. converted = torch.ao.quantization.convert(other, mapping=None,
  171. inplace=False,
  172. remove_qconfig=True,
  173. convert_custom_config_dict=None)
  174. converted.__class__ = cls
  175. # Remove the parameters for the bias_k and bias_v to quantize them
  176. # TODO: This is a potential source of accuracy drop.
  177. # quantized cat takes the scale and zp of the first
  178. # element, which might lose the precision in the bias_k
  179. # and the bias_v (which are cat'ed with k/v being first).
  180. if converted.bias_k is not None:
  181. bias_k = converted._parameters.pop('bias_k')
  182. sc, zp = torch._choose_qparams_per_tensor(bias_k,
  183. reduce_range=False)
  184. bias_k = torch.quantize_per_tensor(bias_k, sc, zp, torch.quint8)
  185. setattr(converted, 'bias_k', bias_k) # noqa: B010
  186. if converted.bias_v is not None:
  187. bias_v = converted._parameters.pop('bias_v')
  188. sc, zp = torch._choose_qparams_per_tensor(bias_k,
  189. reduce_range=False)
  190. bias_v = torch.quantize_per_tensor(bias_v, sc, zp, torch.quint8)
  191. setattr(converted, 'bias_v', bias_v) # noqa: B010
  192. return converted
  193. class PReLU(torch.nn.Module):
  194. r"""This is the quantized equivalent of :class:`~torch.nn.PReLU`.
  195. Args:
  196. scale: quantization scale of the output tensor
  197. zero_point: quantization zero point of the output tensor
  198. num_parameters: number of parameters: 1, or the number of channels at input. Default: 1
  199. """
  200. def __init__(self, output_scale: float, output_zero_point: int,
  201. num_parameters: int = 1) -> None:
  202. super().__init__()
  203. self.num_parameters = num_parameters
  204. self.scale = output_scale
  205. self.zero_point = output_zero_point
  206. w = torch.randn(num_parameters, dtype=torch.float)
  207. qw = torch.quantize_per_tensor(w, scale=1.0, zero_point=0, dtype=torch.quint8)
  208. self.set_weight(qw)
  209. def set_weight(self, w: torch.Tensor) -> None:
  210. self.weight = w
  211. def forward(self, input: torch.Tensor) -> torch.Tensor:
  212. return torch.ops.quantized.prelu(input, self.weight, self.scale, self.zero_point)
  213. def _get_name(self):
  214. return 'QuantizedPReLU'
  215. @classmethod
  216. def from_float(cls, mod):
  217. scale, zero_point = mod.activation_post_process.calculate_qparams()
  218. qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
  219. float_wt = mod.weight.float()
  220. observer = mod.qconfig.weight()
  221. wt_scale, wt_zp = observer.calculate_qparams()
  222. qweight = torch.quantize_per_tensor(
  223. float_wt, float(wt_scale), int(wt_zp), torch.quint8)
  224. qprelu.set_weight(qweight)
  225. return qprelu
  226. @classmethod
  227. def from_reference(cls, mod, scale, zero_point):
  228. qprelu = cls(float(scale), int(zero_point), mod.num_parameters)
  229. float_wt = mod.weight.float()
  230. observer = mod.qconfig.weight()
  231. wt_scale, wt_zp = observer.calculate_qparams()
  232. qweight = torch.quantize_per_tensor(
  233. float_wt, float(wt_scale), int(wt_zp), torch.quint8)
  234. qprelu.set_weight(qweight)
  235. return qprelu