utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import torch
  2. import typing
  3. __all__ = [
  4. "ReferenceQuantizedModule",
  5. ]
  6. class ReferenceQuantizedModule(torch.nn.Module):
  7. def _init_weight_qparams(self, weight_qparams, device):
  8. if weight_qparams is None:
  9. weight_qparams = {
  10. "qscheme": torch.per_tensor_affine,
  11. "dtype": torch.quint8,
  12. "scale": 1.0,
  13. "zero_point": 0
  14. }
  15. self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
  16. self.weight_dtype = weight_qparams["dtype"]
  17. assert self.weight_qscheme in [
  18. None, torch.per_tensor_affine, torch.per_channel_affine,
  19. torch.per_channel_affine_float_qparams], \
  20. Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
  21. if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
  22. zero_point_dtype = weight_qparams["zero_point"].dtype if \
  23. isinstance(weight_qparams["zero_point"], torch.Tensor) else \
  24. torch.int
  25. w_scale = weight_qparams["scale"]
  26. w_scale_tensor = w_scale.clone().detach() \
  27. if isinstance(w_scale, torch.Tensor) \
  28. else torch.tensor(w_scale, dtype=torch.float, device=device)
  29. self.register_buffer("weight_scale", w_scale_tensor)
  30. w_zp = weight_qparams["zero_point"]
  31. w_zp_tensor = w_zp.clone().detach() \
  32. if isinstance(w_zp, torch.Tensor) \
  33. else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
  34. self.register_buffer("weight_zero_point", w_zp_tensor)
  35. if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
  36. w_axis = weight_qparams["axis"]
  37. w_axis_tensor = w_axis.clone().detach() \
  38. if isinstance(w_axis, torch.Tensor) \
  39. else torch.tensor(w_axis, dtype=torch.int, device=device)
  40. self.register_buffer("weight_axis", w_axis_tensor)
  41. else:
  42. # added for TorchScriptability, not used
  43. self.register_buffer(
  44. "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
  45. else:
  46. # added for TorchScriptability, and for torch.float
  47. self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
  48. self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
  49. self.register_buffer(
  50. "weight_axis", torch.tensor(0, dtype=torch.int, device=device))
  51. self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
  52. # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
  53. # for capturing `.item` operations
  54. self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
  55. def get_weight(self):
  56. """
  57. Fake quantize (quantize and dequantize) the weight with
  58. the quantization parameters for weight, this is used to
  59. simulate the numerics for the quantized weight in a quantized
  60. model
  61. """
  62. # suppress mypy warning
  63. assert isinstance(self.weight_scale, torch.Tensor)
  64. assert isinstance(self.weight_zero_point, torch.Tensor)
  65. if self.is_decomposed:
  66. return _quantize_and_dequantize_weight_decomposed(
  67. self.weight, # type: ignore[arg-type]
  68. self.weight_qscheme,
  69. self.weight_dtype,
  70. self.weight_scale,
  71. self.weight_zero_point,
  72. self.weight_axis_int)
  73. else:
  74. return _quantize_and_dequantize_weight(
  75. self.weight, # type: ignore[arg-type]
  76. self.weight_qscheme,
  77. self.weight_dtype,
  78. self.weight_scale,
  79. self.weight_zero_point,
  80. self.weight_axis_int)
  81. def get_quantized_weight(self):
  82. # suppress mypy warning
  83. assert isinstance(self.weight_scale, torch.Tensor)
  84. assert isinstance(self.weight_zero_point, torch.Tensor)
  85. # assert isinstance(self.weight_axis, torch.Tensor)
  86. if self.is_decomposed:
  87. return _quantize_weight_decomposed(
  88. self.weight, # type: ignore[arg-type]
  89. self.weight_qscheme,
  90. self.weight_dtype,
  91. self.weight_scale,
  92. self.weight_zero_point,
  93. self.weight_axis_int)
  94. else:
  95. return _quantize_weight(
  96. self.weight, # type: ignore[arg-type]
  97. self.weight_qscheme,
  98. self.weight_dtype,
  99. self.weight_scale,
  100. self.weight_zero_point,
  101. self.weight_axis_int)
  102. def _save_to_state_dict(self, destination, prefix, keep_vars):
  103. super()._save_to_state_dict(destination, prefix, keep_vars)
  104. _save_weight_qparams(
  105. destination, prefix, self.weight_qscheme, self.weight_dtype,
  106. self.weight_scale, self.weight_zero_point, self.weight_axis)
  107. def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
  108. missing_keys, unexpected_keys, error_msgs):
  109. for key in _get_weight_qparam_keys(state_dict, prefix):
  110. setattr(self, key, state_dict[prefix + key])
  111. state_dict.pop(prefix + key)
  112. super()._load_from_state_dict(
  113. state_dict, prefix, local_metadata, False,
  114. missing_keys, unexpected_keys, error_msgs)
  115. def _quantize_weight_decomposed(
  116. weight: torch.Tensor,
  117. weight_qscheme: torch.qscheme,
  118. weight_dtype: torch.dtype,
  119. weight_scale: torch.Tensor,
  120. weight_zero_point: torch.Tensor,
  121. weight_axis: int
  122. ) -> torch.Tensor:
  123. # TODO: get the quant_min and quant_max from activation_post_process
  124. _DTYPE_TO_QVALUE_BOUNDS = {
  125. torch.uint8: (0, 255),
  126. torch.int8: (-128, 127),
  127. torch.int32: (-(2**31), 2**31 - 1),
  128. }
  129. # TODO: add an util function for converting qdtype to dtype
  130. _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
  131. torch.quint8: torch.uint8,
  132. torch.qint8: torch.int8,
  133. torch.qint32: torch.int32,
  134. }
  135. if weight_qscheme == torch.per_tensor_affine:
  136. if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
  137. weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
  138. weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
  139. weight = torch.ops.quantized_decomposed.quantize_per_tensor(
  140. weight,
  141. weight_scale,
  142. weight_zero_point,
  143. weight_quant_min,
  144. weight_quant_max,
  145. weight_dtype_
  146. )
  147. return weight
  148. elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
  149. # TODO: torch.quint4x2 is not supported
  150. if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
  151. weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
  152. weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
  153. weight = torch.ops.quantized_decomposed.quantize_per_channel(
  154. weight,
  155. weight_scale,
  156. weight_zero_point,
  157. weight_axis,
  158. weight_quant_min,
  159. weight_quant_max,
  160. weight_dtype_) # type: ignore[arg-type]
  161. return weight
  162. raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
  163. def _dequantize_weight_decomposed(
  164. weight: torch.Tensor,
  165. weight_qscheme: torch.qscheme,
  166. weight_dtype: torch.dtype,
  167. weight_scale: torch.Tensor,
  168. weight_zero_point: torch.Tensor,
  169. weight_axis: int
  170. ) -> torch.Tensor:
  171. # TODO: get the quant_min and quant_max from activation_post_process
  172. _DTYPE_TO_QVALUE_BOUNDS = {
  173. torch.uint8: (0, 255),
  174. torch.int8: (-128, 127),
  175. torch.int32: (-(2**31), 2**31 - 1),
  176. }
  177. # TODO: add an util function for converting qdtype to dtype
  178. _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE = {
  179. torch.quint8: torch.uint8,
  180. torch.qint8: torch.int8,
  181. torch.qint32: torch.int32,
  182. }
  183. if weight_qscheme == torch.per_tensor_affine:
  184. if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
  185. weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
  186. weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
  187. weight = torch.ops.quantized_decomposed.dequantize_per_tensor(
  188. weight,
  189. weight_scale,
  190. weight_zero_point,
  191. weight_quant_min,
  192. weight_quant_max,
  193. weight_dtype_
  194. )
  195. return weight
  196. elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
  197. # TODO: torch.quint4x2 is not supported
  198. if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
  199. weight_dtype_ = _QDTYPE_TO_UNDERLYING_INT_REPR_DTYPE[weight_dtype]
  200. weight_quant_min, weight_quant_max = _DTYPE_TO_QVALUE_BOUNDS[weight_dtype_]
  201. weight = torch.ops.quantized_decomposed.dequantize_per_channel(
  202. weight,
  203. weight_scale,
  204. weight_zero_point,
  205. weight_axis,
  206. weight_quant_min,
  207. weight_quant_max,
  208. weight_dtype_) # type: ignore[arg-type]
  209. return weight
  210. raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
  211. def _quantize_weight(
  212. weight: torch.Tensor,
  213. weight_qscheme: torch.qscheme,
  214. weight_dtype: torch.dtype,
  215. weight_scale: torch.Tensor,
  216. weight_zero_point: torch.Tensor,
  217. weight_axis_int: int
  218. ) -> torch.Tensor:
  219. if weight_dtype == torch.float16:
  220. weight = weight.to(weight_dtype)
  221. return weight
  222. if weight_qscheme == torch.per_tensor_affine:
  223. if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
  224. weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
  225. return weight
  226. elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
  227. if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
  228. weight = torch.quantize_per_channel(
  229. weight, weight_scale,
  230. weight_zero_point, weight_axis_int, weight_dtype) # type: ignore[arg-type]
  231. return weight
  232. raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
  233. def _quantize_and_dequantize_weight_decomposed(
  234. weight: torch.Tensor,
  235. weight_qscheme: torch.qscheme,
  236. weight_dtype: torch.dtype,
  237. weight_scale: torch.Tensor,
  238. weight_zero_point: torch.Tensor,
  239. weight_axis_int: int
  240. ) -> torch.Tensor:
  241. """ Quantize and then dequantize the weight based on
  242. the quantization parameters
  243. """
  244. if weight_qscheme in [
  245. torch.per_tensor_affine,
  246. torch.per_channel_affine,
  247. torch.per_channel_affine_float_qparams]:
  248. weight_quant = _quantize_weight_decomposed(
  249. weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int)
  250. weight_dequant = _dequantize_weight_decomposed(
  251. weight_quant, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int)
  252. else:
  253. weight_dequant = weight
  254. return weight_dequant
  255. def _quantize_and_dequantize_weight(
  256. weight: torch.Tensor,
  257. weight_qscheme: torch.qscheme,
  258. weight_dtype: torch.dtype,
  259. weight_scale: torch.Tensor,
  260. weight_zero_point: torch.Tensor,
  261. weight_axis_int: int
  262. ) -> torch.Tensor:
  263. """ Quantize and then dequantize the weight based on
  264. the quantization parameters
  265. """
  266. if weight_qscheme in [
  267. torch.per_tensor_affine,
  268. torch.per_channel_affine,
  269. torch.per_channel_affine_float_qparams]:
  270. weight_quant = _quantize_weight(
  271. weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis_int)
  272. weight_dequant = weight_quant.dequantize()
  273. else:
  274. weight_dequant = weight
  275. return weight_dequant
  276. def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
  277. destination[prefix + "weight_qscheme"] = weight_qscheme
  278. destination[prefix + "weight_dtype"] = weight_dtype
  279. if weight_qscheme is not None:
  280. destination[prefix + "weight_scale"] = weight_scale
  281. destination[prefix + "weight_zero_point"] = weight_zero_point
  282. if weight_qscheme == torch.per_channel_affine:
  283. destination[prefix + "weight_axis"] = weight_axis
  284. def _get_weight_qparam_keys(
  285. state_dict: typing.Dict[str, typing.Any],
  286. prefix: str):
  287. keys = ["weight_qscheme", "weight_dtype"]
  288. weight_qscheme = state_dict[prefix + "weight_qscheme"]
  289. if weight_qscheme is not None:
  290. keys.append("weight_scale")
  291. keys.append("weight_zero_point")
  292. if weight_qscheme == torch.quantize_per_channel:
  293. keys.append("weight_axis")
  294. return keys