QuantUtils.h 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/core/List.h>
  4. #include <ATen/TensorOperators.h>
  5. #include <c10/util/irange.h>
  6. #include <algorithm>
  7. #include <cmath>
  8. #ifndef AT_PER_OPERATOR_HEADERS
  9. #include <ATen/Functions.h>
  10. #include <ATen/NativeFunctions.h>
  11. #else
  12. #include <ATen/ops/quantize_per_tensor_native.h>
  13. #include <ATen/ops/quantize_per_channel_native.h>
  14. #include <ATen/ops/zeros.h>
  15. #endif
  16. namespace quant_utils {
  17. namespace {
  18. float RawUint16ToFp16(unsigned short value) {
  19. // Convert raw 16 bits half precision floating point number
  20. // to single precision floating point number.
  21. const unsigned short sign_bits = value >> 15;
  22. const unsigned short exponent_bits = value >> 10 & 0x1f;
  23. const unsigned short significand_bits = value & 0x3ff;
  24. const float sign = sign_bits ? -1 : 1;
  25. const float significand =
  26. 1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10;
  27. const float exponent = exponent_bits - 0xf;
  28. return sign * std::ldexp(significand, exponent);
  29. }
  30. template <typename T>
  31. bool CheckAndSaturate(T max_val, T* element) {
  32. if (*element > max_val) {
  33. *element = max_val;
  34. return true;
  35. }
  36. if (*element < -max_val) {
  37. *element = -max_val;
  38. return true;
  39. }
  40. return false;
  41. }
  42. }
  43. using namespace std;
  44. // A structure to hold quantization parameters 'scale' and 'zero_point'.
  45. // The meaning of these values is as the constants in the quantization equation
  46. //
  47. // real_value = scale * (quantized_value - zero_point)
  48. //
  49. // In other words, 'zero_point' is the quantized value that corresponds
  50. // to the real value 0, and 'scale' is the difference of real values
  51. // corresponding to consecutive quantized values.
  52. struct TensorQuantizationParams {
  53. double scale;
  54. std::int32_t zero_point;
  55. int precision;
  56. };
  57. // Use fp16_min as the small scale cutoff because we don't want to use scales in
  58. // fp16 subnormal range. This is to be consistent with Glow and FakeLowP
  59. // implementation for NNPI.
  60. constexpr float SMALL_SCALE_THRESHOLD = 6.1e-5f;
  61. // Following implementation should be identical to fbgemm::ChooseQuantizationParams
  62. inline TensorQuantizationParams ChooseQuantizationParams(
  63. float min,
  64. float max,
  65. int32_t qmin,
  66. int32_t qmax,
  67. bool preserve_sparsity = false,
  68. bool force_scale_power_of_two = false,
  69. bool reduce_range = false) {
  70. TORCH_CHECK(
  71. min <= max,
  72. "In ChooseQuantizationParams, min should be less than or equal to max");
  73. if (reduce_range) {
  74. qmin = qmin/2;
  75. qmax = qmax/2;
  76. }
  77. if (min < 0 && max > 0 && preserve_sparsity) {
  78. int symmetric_qmin = -((qmax - qmin) / 2 + 1);
  79. int symmetric_qmax = (qmax - qmin) / 2;
  80. double max_scale =
  81. std::max(fabs(min / symmetric_qmin), fabs(max / symmetric_qmax));
  82. min = max_scale * symmetric_qmin;
  83. max = max_scale * symmetric_qmax;
  84. }
  85. // We extend the [min, max] interval to ensure that it contains 0.
  86. // Otherwise, we would not meet the requirement that 0 be an exactly
  87. // representable value.
  88. min = std::min(min, 0.f);
  89. max = std::max(max, 0.f);
  90. TORCH_CHECK(
  91. qmin < qmax,
  92. "In ChooseQuantizationParams, qmin should be less than qmax");
  93. // Use double precision for intermediate computation but use single precision
  94. // in final number to reflect the actual number used during quantization.
  95. double scale = (static_cast<double>(max) - min) / (qmax - qmin);
  96. // If scale is 0 or too small so its reciprocal is infinity, we arbitrary
  97. // adjust the scale to 0.1 . We want to avoid scale's reciprocal being
  98. // infinity because some of fbgemm code pre-computes scale's reciprocal to do
  99. // multiplication instead of division in the time critical part of code.
  100. if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) {
  101. scale = 0.1;
  102. }
  103. TORCH_CHECK(scale > 0, "quantization scale should be > 0");
  104. if (force_scale_power_of_two) {
  105. if (scale < 1) {
  106. scale = 1.0 / (1 << static_cast<int>(floor(log(1.0 / scale) / log(2))));
  107. } else {
  108. scale = 1 << static_cast<int>(ceil(log(scale) / log(2)));
  109. }
  110. }
  111. // Cut off small scale
  112. if (scale < SMALL_SCALE_THRESHOLD) {
  113. float org_scale = scale;
  114. scale = SMALL_SCALE_THRESHOLD;
  115. // Adjust the min and max based on the new scale
  116. if (min == 0.0f) {
  117. max = SMALL_SCALE_THRESHOLD * (qmax - qmin);
  118. } else if (max == 0.0f) {
  119. min = -SMALL_SCALE_THRESHOLD * (qmax - qmin);
  120. } else {
  121. float amplifier = SMALL_SCALE_THRESHOLD / org_scale;
  122. min *= amplifier;
  123. max *= amplifier;
  124. }
  125. }
  126. // Zero-point computation.
  127. // First the initial floating-point computation. The zero-point can be
  128. // determined from solving an affine equation for any known pair
  129. // (real value, corresponding quantized value).
  130. // We know two such pairs: (rmin, qmin) and (rmax, qmax).
  131. // The arithmetic error on the zero point computed from either pair
  132. // will be roughly machine_epsilon * (sum of absolute values of terms)
  133. // so we want to use the variant that adds the smaller terms.
  134. double zero_point_from_min = qmin - min / static_cast<double>(scale);
  135. double zero_point_from_max = qmax - max / static_cast<double>(scale);
  136. double zero_point_from_min_error =
  137. std::abs(qmin) - std::abs(min / static_cast<double>(scale));
  138. double zero_point_from_max_error =
  139. std::abs(qmax) - std::abs(max / static_cast<double>(scale));
  140. double initial_zero_point =
  141. zero_point_from_min_error < zero_point_from_max_error
  142. ? zero_point_from_min
  143. : zero_point_from_max;
  144. // for symmetric quantization (preserve_sparsity == true), we force zero_point
  145. // to be a middle value between qmin and qmax.
  146. // If either min or max is 0, then we just use 0 as zero_point.
  147. if (min < 0 && max > 0 && preserve_sparsity) {
  148. initial_zero_point = static_cast<double>(qmin + qmax) / 2;
  149. }
  150. // Now we need to nudge the zero point to be an integer
  151. // (our zero points are integer, and this is motivated by the requirement
  152. // to be able to represent the real value "0" exactly as a quantized value,
  153. // which is required in multiple places, for example in Im2col with zero
  154. // padding).
  155. int32_t nudged_zero_point = 0;
  156. if (initial_zero_point < qmin) {
  157. nudged_zero_point = qmin;
  158. } else if (initial_zero_point > qmax) {
  159. nudged_zero_point = qmax;
  160. } else {
  161. nudged_zero_point = nearbyint(initial_zero_point);
  162. }
  163. TensorQuantizationParams result;
  164. result.scale = scale;
  165. result.zero_point = nudged_zero_point;
  166. return result;
  167. }
  168. // This function helps to convert the Conv1D dimensions usable by the Conv2d op.
  169. constexpr int64_t kConv1dSqueezeDim = 0;
  170. static C10_UNUSED torch::List<int64_t> MakeArgForConv1d(const torch::List<int64_t>& arg,
  171. int64_t base_value) {
  172. TORCH_CHECK(!arg.empty(), "Argument must have elements.");
  173. torch::List<int64_t> result({arg.get(0), base_value});
  174. if (arg.size() == 1) {
  175. result[1] = arg.get(0);
  176. } else {
  177. result[1] = arg.get(1);
  178. }
  179. result[kConv1dSqueezeDim] = base_value;
  180. return result;
  181. }
  182. // The range for using FP16 quantization of weights requires that the elements
  183. // should be in the range of [5.96e-8, 65504]. If it is out of range, then the
  184. // number will be saturated to max or min representable values by FP16.
  185. inline void HandleWeightsSaturation(int64_t N, float* weight) {
  186. const float kFp16Max = RawUint16ToFp16(0x7BFF);
  187. bool found_out_of_range = false;
  188. for (const auto i : c10::irange(N)) {
  189. bool saturate = CheckAndSaturate<float>(kFp16Max, weight + i);
  190. if (saturate) {
  191. found_out_of_range = true;
  192. }
  193. }
  194. if (found_out_of_range) {
  195. TORCH_WARN("FOUND weight out of range ");
  196. }
  197. }
  198. // Util function for quantizing bias.
  199. inline at::Tensor QuantizeBias(
  200. bool is_per_channel,
  201. const at::Tensor& bias,
  202. const at::Tensor& weight_contig,
  203. double input_scale) {
  204. at::Tensor qbias;
  205. if (is_per_channel) {
  206. auto bias_quant_scales =
  207. weight_contig.q_per_channel_scales() * input_scale;
  208. auto bias_zp = at::zeros(bias_quant_scales.sizes(), c10::kInt);
  209. qbias = at::native::quantize_per_channel(
  210. bias, bias_quant_scales, bias_zp, 0, c10::kQInt32);
  211. } else {
  212. qbias = at::native::quantize_per_tensor(
  213. bias, weight_contig.q_scale() * input_scale, 0, c10::kQInt32);
  214. }
  215. return qbias;
  216. }
  217. } // namespace quant_utils