common_quantized.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. r"""Importing this file includes common utility methods for checking quantized
  2. tensors and modules.
  3. """
  4. import numpy as np
  5. import torch
  6. from contextlib import contextmanager
  7. from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS
  8. supported_qengines = torch.backends.quantized.supported_engines
  9. supported_qengines.remove('none')
  10. # Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
  11. # QNNPACK is not supported on PPC
  12. # QNNPACK throws ASAN heap-buffer-overflow error.
  13. if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]):
  14. supported_qengines.remove('qnnpack')
  15. def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
  16. output_padding=0):
  17. """Computes the output shape given convolution parameters."""
  18. return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
  19. * (dilation - 1)) / stride) + 2 * output_padding + 1
  20. # Quantization references
  21. def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
  22. """Quantizes a numpy array."""
  23. if qmin is None:
  24. qmin = np.iinfo(dtype).min
  25. if qmax is None:
  26. qmax = np.iinfo(dtype).max
  27. qx = np.round(x / scale + zero_point).astype(np.int64)
  28. qx = np.clip(qx, qmin, qmax)
  29. qx = qx.astype(dtype)
  30. return qx
  31. def _dequantize(qx, scale, zero_point):
  32. """Dequantizes a numpy array."""
  33. x = (qx.astype(float) - zero_point) * scale
  34. return x
  35. def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
  36. """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
  37. converted back to given type"""
  38. qx = (x * multiplier).round() + zero_point
  39. qx = np.clip(qx, qmin, qmax).astype(qtype)
  40. return qx
  41. def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
  42. """Calculate the dynamic quantization parameters (scale, zero_point)
  43. according to the min and max element of the tensor"""
  44. assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
  45. if qscheme == torch.per_tensor_symmetric:
  46. assert dtype == torch.qint8
  47. if isinstance(X, torch.Tensor):
  48. X = X.numpy()
  49. if dtype == torch.qint8:
  50. if reduce_range:
  51. qmin, qmax = -64, 63
  52. else:
  53. qmin, qmax = -128, 127
  54. else: # dtype == torch.quint8
  55. if reduce_range:
  56. qmin, qmax = 0, 127
  57. else:
  58. qmin, qmax = 0, 255
  59. min_val = X.min()
  60. max_val = X.max()
  61. is_symmetric = (qscheme == torch.per_tensor_symmetric)
  62. if min_val == max_val:
  63. scale = 1.0
  64. zero_point = 0
  65. else:
  66. if is_symmetric:
  67. max_val = max(max_val, -min_val)
  68. min_val = -max_val
  69. scale = (max_val - min_val) / (qmax - qmin)
  70. scale = max(scale, np.finfo(np.float32).eps)
  71. zero_point = 0
  72. else:
  73. max_val = max(max_val, 0.0)
  74. min_val = min(min_val, 0.0)
  75. scale = (max_val - min_val) / (qmax - qmin)
  76. scale = max(scale, np.finfo(np.float32).eps)
  77. zero_point = qmin - round(min_val / scale)
  78. zero_point = max(qmin, zero_point)
  79. zero_point = min(qmax, zero_point)
  80. return [float(scale), int(zero_point)]
  81. def _calculate_dynamic_per_channel_qparams(X, dtype):
  82. """Calculate the dynamic quantization parameters (scale, zero_point)
  83. according to the min and max element of the tensor"""
  84. if isinstance(X, torch.Tensor):
  85. X = X.numpy()
  86. qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
  87. n_levels = qmax - qmin
  88. scale = np.zeros(X.shape[0], dtype=np.float64)
  89. zero_point = np.zeros(X.shape[0], dtype=np.int64)
  90. for i in range(zero_point.shape[0]):
  91. min_val = X.min()
  92. max_val = X.max()
  93. if min_val == max_val:
  94. scale[i] = 1.0
  95. zero_point[i] = 0
  96. else:
  97. max_val = max(max_val, 0.0)
  98. min_val = min(min_val, 0.0)
  99. scale[i] = (max_val - min_val) / n_levels
  100. scale[i] = max(scale[i], np.finfo(np.float32).eps)
  101. zero_point[i] = qmin - round(min_val / scale[i])
  102. zero_point[i] = max(qmin, zero_point[i])
  103. zero_point[i] = min(qmax, zero_point[i])
  104. return scale, zero_point
  105. def _snr(x, x_hat):
  106. """Calculates the signal to noise ratio and returns the signal and noise
  107. power, as well as the SNR in dB.
  108. If the input is a list/tuple this function is called recursively on each
  109. element. The result will have the same nested structure as the inputs.
  110. Args:
  111. x, x_hat: Either a tensor or a nested list/tuple of tensors.
  112. Returns:
  113. signal, noise, SNR(in dB): Either floats or a nested list of floats
  114. """
  115. if isinstance(x, (list, tuple)):
  116. assert(len(x) == len(x_hat))
  117. res = []
  118. for idx in range(len(x)):
  119. res.append(_snr(x[idx], x_hat[idx]))
  120. return res
  121. if x_hat.is_quantized:
  122. x_hat = x_hat.dequantize()
  123. if x.is_quantized:
  124. x = x.dequantize()
  125. noise = (x - x_hat).norm()
  126. if noise == 0:
  127. return 0.0, float('inf'), float('inf')
  128. signal = x.norm()
  129. snr = signal / noise
  130. snr_db = 20 * snr.log10()
  131. return signal, noise, snr_db
  132. @contextmanager
  133. def override_quantized_engine(qengine):
  134. previous = torch.backends.quantized.engine
  135. torch.backends.quantized.engine = qengine
  136. try:
  137. yield
  138. finally:
  139. torch.backends.quantized.engine = previous
  140. @contextmanager
  141. def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
  142. try:
  143. if qengine_is_qnnpack:
  144. torch._C._set_default_mobile_cpu_allocator()
  145. yield
  146. finally:
  147. if qengine_is_qnnpack:
  148. torch._C._unset_default_mobile_cpu_allocator()
  149. # TODO: Update all quantization tests to use this decorator.
  150. # Currently for some of the tests it seems to have inconsistent params
  151. # for fbgemm vs qnnpack.
  152. def override_qengines(qfunction):
  153. def test_fn(*args, **kwargs):
  154. for qengine in supported_qengines:
  155. with override_quantized_engine(qengine):
  156. # qfunction should not return anything.
  157. qfunction(*args, **kwargs)
  158. return test_fn
  159. def qengine_is_fbgemm():
  160. return torch.backends.quantized.engine == 'fbgemm'
  161. def qengine_is_qnnpack():
  162. return torch.backends.quantized.engine == 'qnnpack'
  163. def qengine_is_onednn():
  164. return torch.backends.quantized.engine == 'onednn'
  165. def qengine_is_x86():
  166. return torch.backends.quantized.engine == 'x86'
  167. # Helper function used to simulate per-channel fake-quant against any axis
  168. def _permute_to_axis_zero(X, axis):
  169. new_axis_list = list(range(X.dim()))
  170. new_axis_list[axis] = 0
  171. new_axis_list[0] = axis
  172. y = X.permute(tuple(new_axis_list))
  173. return y, new_axis_list
  174. # Reference method for fake quantize
  175. # Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
  176. def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
  177. dtype = X.dtype
  178. X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
  179. res = torch.zeros_like(X)
  180. for i in range(X.size()[0]):
  181. res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
  182. per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
  183. out = res.permute(tuple(permute_axis_list))
  184. return out.to(dtype)
  185. # Reference method for the gradient of the fake quantize operator
  186. # Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
  187. def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
  188. dtype = X.dtype
  189. X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
  190. Xq = torch.zeros_like(X)
  191. for i in range(X.size()[0]):
  192. Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
  193. Xq = Xq.permute(tuple(permute_axis_list))
  194. mask = (Xq >= quant_min) * (Xq <= quant_max)
  195. res = torch.zeros_like(dY)
  196. res[mask] = dY[mask]
  197. return res.to(dtype)
  198. def to_tensor(X, device):
  199. if not isinstance(X, torch.Tensor):
  200. X = torch.tensor(X)
  201. else:
  202. X = X.clone().detach()
  203. return X.to(device=torch.device(device), dtype=torch.float32)