_decomposed.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import torch
  2. from torch.library import Library, impl
  3. from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
  4. from typing import Tuple
  5. # Note: decomposed means decomposed quantized tensor, using decomposed so that the
  6. # name is not too long
  7. quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
  8. _DTYPE_TO_QVALUE_BOUNDS = {
  9. torch.uint8: (0, 255),
  10. torch.int8: (-128, 127),
  11. torch.int32: (-(2**31), 2**31 - 1)
  12. }
  13. # Helper to check the passed in quant min and max are valid for the dtype
  14. def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
  15. if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
  16. raise ValueError(f"Unsupported dtype: {dtype}")
  17. quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
  18. assert quant_min >= quant_min_lower_bound, \
  19. "quant_min out of bound for dtype, " \
  20. f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
  21. assert quant_max <= quant_max_upper_bound, \
  22. "quant_max out of bound for dtype, " \
  23. f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
  24. quantized_decomposed_lib.define(
  25. "quantize_per_tensor(Tensor input, float scale, int zero_point, "
  26. "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
  27. @impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")
  28. def quantize_per_tensor(
  29. input: torch.Tensor,
  30. scale: float,
  31. zero_point: int,
  32. quant_min: int,
  33. quant_max: int,
  34. dtype: torch.dtype
  35. ) -> torch.Tensor:
  36. """ Affine quantization for the Tensor using the same quantization parameters to map
  37. from floating point to quantized values
  38. Args:
  39. input (torch.Tensor): original float32 Tensor
  40. scale (float): quantization parameter for affine quantization
  41. zero_point (int): quantization parameter for affine quantization
  42. quant_min (int): minimum quantized value for output Tensor
  43. quant_max (int): maximum quantized value for output Tensor
  44. dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
  45. Returns:
  46. Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
  47. are not stored in the Tensor, we are storing them in function arguments instead
  48. """
  49. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  50. _quant_min_max_bounds_check(quant_min, quant_max, dtype)
  51. inv_scale = 1.0 / scale
  52. return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype)
  53. quantized_decomposed_lib.define(
  54. "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
  55. "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
  56. @impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd")
  57. def quantize_per_tensor_tensor(
  58. input: torch.Tensor,
  59. scale: torch.Tensor,
  60. zero_point: torch.Tensor,
  61. quant_min: int,
  62. quant_max: int,
  63. dtype: torch.dtype
  64. ) -> torch.Tensor:
  65. """ Affine quantization for the Tensor using the same quantization parameters to map
  66. from floating point to quantized values
  67. Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
  68. scalar values
  69. """
  70. assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
  71. assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
  72. return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
  73. @impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
  74. def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
  75. assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
  76. assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
  77. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  78. _quant_min_max_bounds_check(quant_min, quant_max, dtype)
  79. return torch.empty_like(input, dtype=dtype)
  80. # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
  81. # the signature as metadata for the input Tensor, this might be useful for pattern
  82. # matching in the future
  83. # We will revisit this later if we found there are no use cases for it
  84. quantized_decomposed_lib.define(
  85. "dequantize_per_tensor(Tensor input, float scale, int zero_point, "
  86. "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
  87. @impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
  88. def dequantize_per_tensor(
  89. input: torch.Tensor,
  90. scale: float,
  91. zero_point: int,
  92. quant_min: int,
  93. quant_max: int,
  94. dtype: torch.dtype
  95. ) -> torch.Tensor:
  96. """ Affine dequantization for the Tensor using the same quantization parameters to map
  97. from quantized values to floating point values
  98. Args:
  99. input (torch.Tensor): Tensor with dtype matching `dtype` argument,
  100. e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with
  101. quantization parameters in the argument of this function (scale/zero_point)
  102. scale (float): quantization parameter for affine quantization
  103. zero_point (int): quantization parameter for affine quantization
  104. quant_min (int): minimum quantized value for input Tensor (not used in computation,
  105. reserved for pattern matching)
  106. quant_max (int): maximum quantized value for input Tensor (not used in computation,
  107. reserved for pattern matching)
  108. dtype (torch.dtype): dtype for input Tensor (not used in computation,
  109. reserved for pattern matching)
  110. Returns:
  111. dequantized float32 Tensor
  112. """
  113. assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
  114. if dtype in [torch.uint8, torch.int8, torch.int32]:
  115. # TODO: investigate why
  116. # (input - zero_point).to(torch.float32) * scale
  117. # failed the test
  118. return (input.to(torch.float32) - zero_point) * scale
  119. else:
  120. raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
  121. quantized_decomposed_lib.define(
  122. "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
  123. "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
  124. @impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "CompositeExplicitAutograd")
  125. def dequantize_per_tensor_tensor(
  126. input: torch.Tensor,
  127. scale: torch.Tensor,
  128. zero_point: torch.Tensor,
  129. quant_min: int,
  130. quant_max: int,
  131. dtype: torch.dtype
  132. ) -> torch.Tensor:
  133. """ Affine dequantization for the Tensor using the same quantization parameters to map
  134. from quantized values to floating point values
  135. Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
  136. scalar values
  137. """
  138. assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
  139. assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
  140. return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype)
  141. @impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
  142. def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype):
  143. assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}"
  144. assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}"
  145. assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
  146. if dtype in [torch.uint8, torch.int8, torch.int32]:
  147. return torch.empty_like(input, dtype=torch.float32)
  148. else:
  149. raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
  150. quantized_decomposed_lib.define(
  151. "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
  152. "ScalarType dtype) -> (Tensor, Tensor)")
  153. @impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
  154. def choose_qparams_tensor(
  155. input: torch.Tensor,
  156. qmin: int,
  157. qmax: int,
  158. dtype: torch.dtype
  159. ) -> Tuple[torch.Tensor, torch.Tensor]:
  160. """ Given an input Tensor, derive the per tensor affine quantization parameter
  161. (scale and zero_point) for target quantized Tensor from the Tensor
  162. Args:
  163. input (torch.Tensor): floating point input Tensor
  164. quant_min (int): minimum quantized value for target quantized Tensor
  165. quant_max (int): maximum quantized value for target quantized Tensor
  166. dtype (torch.dtype): dtype for target quantized Tensor
  167. Returns:
  168. scale (float): quantization parameter for the target quantized Tensor
  169. zero_point (int): quantization parameter for the target quantized Tensor
  170. """
  171. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  172. assert dtype == torch.int8 or dtype == torch.uint8 or dtype == torch.int32, \
  173. f"Expecting target dtype to be int8 uint8 or int32, but got: {dtype}"
  174. validate_qmin_qmax(qmin, qmax)
  175. min_val, max_val = torch.aminmax(input)
  176. return determine_qparams(
  177. min_val, max_val, qmin, qmax, dtype, torch.Tensor([torch.finfo(torch.float32).eps]), has_customized_qrange=False)
  178. quantized_decomposed_lib.define(
  179. "choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, "
  180. "ScalarType dtype) -> (Tensor, Tensor)")
  181. @impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "CompositeExplicitAutograd")
  182. def choose_qparams_symmetric_tensor(
  183. input: torch.Tensor,
  184. qmin: int,
  185. qmax: int,
  186. dtype: torch.dtype
  187. ) -> Tuple[torch.Tensor, torch.Tensor]:
  188. """ Given an input Tensor, derive the per tensor affine quantization parameter
  189. (scale and zero_point) for target quantized Tensor from the Tensor
  190. Args:
  191. input (torch.Tensor): floating point input Tensor
  192. quant_min (int): minimum quantized value for target quantized Tensor
  193. quant_max (int): maximum quantized value for target quantized Tensor
  194. dtype (torch.dtype): dtype for target quantized Tensor
  195. Returns:
  196. scale (float): quantization parameter for the target quantized Tensor
  197. zero_point (int): quantization parameter for the target quantized Tensor
  198. """
  199. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  200. assert dtype == torch.int8 or dtype == torch.uint8 or dtype == torch.int32, \
  201. f"Expecting target dtype to be int8 uint8 or int32, but got: {dtype}"
  202. validate_qmin_qmax(qmin, qmax)
  203. min_val, max_val = torch.aminmax(input)
  204. return determine_qparams(
  205. min_val,
  206. max_val,
  207. qmin,
  208. qmax,
  209. dtype,
  210. torch.Tensor([torch.finfo(torch.float32).eps]),
  211. has_customized_qrange=False,
  212. qscheme=torch.per_tensor_symmetric
  213. )
  214. @impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
  215. def choose_qparams_tensor_meta(
  216. input: torch.Tensor,
  217. quant_min: int,
  218. quant_max: int,
  219. dtype: torch.dtype
  220. ) -> Tuple[torch.Tensor, torch.Tensor]:
  221. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  222. assert quant_min < quant_max, f"Expecting quant_min to be smaller than quant_max but received min: \
  223. {quant_min} max: {quant_max}"
  224. return torch.empty(1, dtype=torch.float, device=input.device), torch.empty(1, dtype=torch.int32, device=input.device)
  225. @impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta")
  226. def choose_qparams_symmetric_tensor_meta(
  227. input: torch.Tensor,
  228. quant_min: int,
  229. quant_max: int,
  230. dtype: torch.dtype
  231. ) -> Tuple[torch.Tensor, torch.Tensor]:
  232. return torch.empty(1, dtype=torch.float, device=input.device), torch.empty(1, dtype=torch.int32, device=input.device)
  233. # Helper function used to implement per-channel quantization against any axis
  234. def _permute_to_axis_zero(x, axis):
  235. new_axis_list = list(range(x.dim()))
  236. new_axis_list[axis] = 0
  237. new_axis_list[0] = axis
  238. y = x.permute(tuple(new_axis_list))
  239. return y, new_axis_list
  240. quantized_decomposed_lib.define(
  241. "quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
  242. "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
  243. @impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
  244. def quantize_per_channel(
  245. input: torch.Tensor,
  246. scales: torch.Tensor,
  247. zero_points: torch.Tensor,
  248. axis: int,
  249. quant_min: int,
  250. quant_max: int,
  251. dtype: torch.dtype
  252. ) -> torch.Tensor:
  253. """ Affine per channel quantization for the Tensor using the same quantization
  254. parameters for each channel/axis to map from floating point to quantized values
  255. Args:
  256. input (torch.Tensor): original float32 Tensor
  257. scales (torch.Tensor): a list of scale quantization parameter for
  258. affine quantization, one per channel
  259. zero_point (torch.Tensor): a list of zero_point quantization parameter for
  260. affine quantization, one per channel
  261. quant_min (int): minimum quantized value for output Tensor
  262. quant_max (int): maximum quantized value for output Tensor
  263. dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
  264. Returns:
  265. Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
  266. are not stored in the Tensor, we are storing them in function arguments instead
  267. """
  268. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  269. assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
  270. _quant_min_max_bounds_check(quant_min, quant_max, dtype)
  271. input, permute_axis_list = _permute_to_axis_zero(input, axis)
  272. res = torch.zeros_like(input)
  273. for i in range(input.size(0)):
  274. res[i] = torch.clamp(
  275. torch.round(input[i] * (1.0 / scales[i])) + zero_points[i],
  276. quant_min,
  277. quant_max
  278. )
  279. out = res.permute(tuple(permute_axis_list))
  280. return out.to(dtype)
  281. @impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
  282. def quantize_per_channel_meta(
  283. input: torch.Tensor,
  284. scales: torch.Tensor,
  285. zero_points: torch.Tensor,
  286. axis: int,
  287. quant_min: int,
  288. quant_max: int,
  289. dtype: torch.dtype
  290. ) -> torch.Tensor:
  291. assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
  292. assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
  293. _quant_min_max_bounds_check(quant_min, quant_max, dtype)
  294. return torch.empty_like(input, dtype=dtype)
  295. # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
  296. # the signature as metadata for the input Tensor, this might be useful for pattern
  297. # matching in the future
  298. # We will revisit this later if we found there are no use cases for it
  299. quantized_decomposed_lib.define(
  300. "dequantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
  301. "int quant_min, int quant_max, ScalarType dtype) -> Tensor")
  302. @impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
  303. def dequantize_per_channel(
  304. input: torch.Tensor,
  305. scales: torch.Tensor,
  306. zero_points: torch.Tensor,
  307. axis: int,
  308. quant_min: int,
  309. quant_max: int,
  310. dtype: torch.dtype
  311. ) -> torch.Tensor:
  312. """ Affine per channel dequantization for the Tensor using the same quantization
  313. parameters for each channel/axis to map from quantized values to floating point values
  314. Args:
  315. input (torch.Tensor): Tensor with dtype matching `dtype` argument,
  316. e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
  317. quantization parameter in the argument of this function (scales/zero_points/axis)
  318. scales (torch.Tensor): a list of scale quantization parameter for
  319. affine quantization, one per channel
  320. zero_points (torch.Tensor): a list of zero_point quantization parameter for
  321. affine quantization, one per channel
  322. quant_min (int): minimum quantized value for output Tensor (not used in computation,
  323. reserved for pattern matching)
  324. quant_max (int): maximum quantized value for output Tensor (not used in computation,
  325. reserved for pattern matching)
  326. dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
  327. reserved for pattern matching)
  328. Returns:
  329. dequantized float32 Tensor
  330. """
  331. assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
  332. assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
  333. _quant_min_max_bounds_check(quant_min, quant_max, dtype)
  334. input, permute_axis_list = _permute_to_axis_zero(input, axis)
  335. res = torch.zeros_like(input, dtype=torch.float32)
  336. for i in range(input.size(0)):
  337. # TODO: investigate why
  338. # (input[i] - zero_points[i]).to(torch.float32) * scales[i]
  339. # failed the test
  340. res[i] = (input[i].to(torch.float32) - zero_points[i]) * scales[i]
  341. out = res.permute(tuple(permute_axis_list))
  342. return out
  343. @impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
  344. def dequantize_per_channel_meta(
  345. input: torch.Tensor,
  346. scales: torch.Tensor,
  347. zero_points: torch.Tensor,
  348. axis: int,
  349. quant_min: int,
  350. quant_max: int,
  351. dtype: torch.dtype
  352. ) -> torch.Tensor:
  353. assert input.dtype == dtype, f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
  354. assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
  355. _quant_min_max_bounds_check(quant_min, quant_max, dtype)
  356. return torch.empty_like(input, dtype=torch.float32)