AffineQuantizer.h 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/native/DispatchStub.h>
  5. #include <ATen/native/quantized/AffineQuantizerBase.h>
  6. namespace at {
  7. namespace native {
  8. Tensor& quantize_tensor_per_tensor_affine(
  9. const Tensor& rtensor,
  10. Tensor& qtensor,
  11. double scale,
  12. int64_t zero_point);
  13. Tensor& quantize_tensor_per_channel_affine(
  14. const Tensor& rtensor,
  15. Tensor& qtensor,
  16. Tensor scales,
  17. Tensor zero_points,
  18. int64_t axis);
  19. Tensor& quantize_tensor_per_channel_float_qparams(
  20. const Tensor& rtensor,
  21. Tensor& qtensor,
  22. Tensor scales,
  23. Tensor zero_points,
  24. int64_t axis);
  25. Tensor& dequantize_tensor_per_tensor_affine(
  26. const Tensor& qtensor,
  27. Tensor& rtensor,
  28. double scale,
  29. int64_t zero_point);
  30. Tensor& dequantize_tensor_per_channel_affine(
  31. const Tensor& qtensor,
  32. Tensor& rtensor,
  33. Tensor scales,
  34. Tensor zero_points,
  35. int64_t axis);
  36. Tensor& dequantize_tensor_per_channel_float_qparams(
  37. const Tensor& qtensor,
  38. Tensor& rtensor,
  39. Tensor scales,
  40. Tensor zero_points,
  41. int64_t axis);
  42. using quantize_tensor_per_tensor_affine_fn =
  43. void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
  44. using quantize_tensor_per_channel_affine_fn = void (*)(
  45. const Tensor& rtensor,
  46. Tensor& qtensor,
  47. const Tensor& scales,
  48. const Tensor& zero_points,
  49. int64_t axis);
  50. using quantize_tensor_per_channel_float_qparams_fn = void (*)(
  51. const Tensor& rtensor,
  52. Tensor& qtensor,
  53. const Tensor& scales,
  54. const Tensor& zero_points,
  55. int64_t axis);
  56. using dequantize_tensor_per_tensor_affine_fn =
  57. void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
  58. using dequantize_tensor_per_channel_affine_fn = void (*)(
  59. const Tensor& qtensor,
  60. Tensor& rtensor,
  61. const Tensor& scales,
  62. const Tensor& zero_points,
  63. int64_t axis);
  64. using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
  65. const Tensor& qtensor,
  66. Tensor& rtensor,
  67. const Tensor& scales,
  68. const Tensor& zero_points,
  69. int64_t axis);
  70. using quantize_tensor_per_tensor_affine_sub_byte_fn =
  71. void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
  72. using dequantize_tensor_per_tensor_affine_sub_byte_fn =
  73. void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
  74. DECLARE_DISPATCH(
  75. quantize_tensor_per_tensor_affine_fn,
  76. quantize_tensor_per_tensor_affine_stub);
  77. DECLARE_DISPATCH(
  78. quantize_tensor_per_channel_affine_fn,
  79. quantize_tensor_per_channel_affine_stub);
  80. DECLARE_DISPATCH(
  81. quantize_tensor_per_channel_float_qparams_fn,
  82. quantize_tensor_per_channel_float_qparams_stub);
  83. DECLARE_DISPATCH(
  84. dequantize_tensor_per_tensor_affine_fn,
  85. dequantize_tensor_per_tensor_affine_stub);
  86. DECLARE_DISPATCH(
  87. dequantize_tensor_per_channel_affine_fn,
  88. dequantize_tensor_per_channel_affine_stub);
  89. DECLARE_DISPATCH(
  90. dequantize_tensor_per_channel_float_qparams_fn,
  91. dequantize_tensor_per_channel_float_qparams_stub);
  92. DECLARE_DISPATCH(
  93. quantize_tensor_per_tensor_affine_sub_byte_fn,
  94. quantize_tensor_per_tensor_affine_sub_byte_stub);
  95. DECLARE_DISPATCH(
  96. dequantize_tensor_per_tensor_affine_sub_byte_fn,
  97. dequantize_tensor_per_tensor_affine_sub_byte_stub);
  98. template <typename T>
  99. TORCH_API Tensor quantize_tensor(
  100. Tensor rtensor,
  101. Tensor qtensor,
  102. double scale,
  103. int64_t zero_point);
  104. template <typename T>
  105. TORCH_API Tensor dequantize_tensor(
  106. Tensor qtensor,
  107. Tensor rtensor,
  108. double scale,
  109. int64_t zero_point);
  110. } // namespace native
  111. } // namespace at