123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- #pragma once
- #include <ATen/core/Tensor.h>
- #include <ATen/Dispatch.h>
- #include <ATen/native/DispatchStub.h>
- #include <ATen/native/quantized/AffineQuantizerBase.h>
- namespace at {
- namespace native {
- Tensor& quantize_tensor_per_tensor_affine(
- const Tensor& rtensor,
- Tensor& qtensor,
- double scale,
- int64_t zero_point);
- Tensor& quantize_tensor_per_channel_affine(
- const Tensor& rtensor,
- Tensor& qtensor,
- Tensor scales,
- Tensor zero_points,
- int64_t axis);
- Tensor& quantize_tensor_per_channel_float_qparams(
- const Tensor& rtensor,
- Tensor& qtensor,
- Tensor scales,
- Tensor zero_points,
- int64_t axis);
- Tensor& dequantize_tensor_per_tensor_affine(
- const Tensor& qtensor,
- Tensor& rtensor,
- double scale,
- int64_t zero_point);
- Tensor& dequantize_tensor_per_channel_affine(
- const Tensor& qtensor,
- Tensor& rtensor,
- Tensor scales,
- Tensor zero_points,
- int64_t axis);
- Tensor& dequantize_tensor_per_channel_float_qparams(
- const Tensor& qtensor,
- Tensor& rtensor,
- Tensor scales,
- Tensor zero_points,
- int64_t axis);
- using quantize_tensor_per_tensor_affine_fn =
- void (*)(const Tensor& rtensor, Tensor& qtensor, double scale, int64_t zero_point);
- using quantize_tensor_per_channel_affine_fn = void (*)(
- const Tensor& rtensor,
- Tensor& qtensor,
- const Tensor& scales,
- const Tensor& zero_points,
- int64_t axis);
- using quantize_tensor_per_channel_float_qparams_fn = void (*)(
- const Tensor& rtensor,
- Tensor& qtensor,
- const Tensor& scales,
- const Tensor& zero_points,
- int64_t axis);
- using dequantize_tensor_per_tensor_affine_fn =
- void (*)(const Tensor& qtensor, Tensor& rtensor, double scale, int64_t zero_point);
- using dequantize_tensor_per_channel_affine_fn = void (*)(
- const Tensor& qtensor,
- Tensor& rtensor,
- const Tensor& scales,
- const Tensor& zero_points,
- int64_t axis);
- using dequantize_tensor_per_channel_float_qparams_fn = void (*)(
- const Tensor& qtensor,
- Tensor& rtensor,
- const Tensor& scales,
- const Tensor& zero_points,
- int64_t axis);
- using quantize_tensor_per_tensor_affine_sub_byte_fn =
- void (*)(const Tensor& rtensor, Tensor& qtensor, float scale, float zero_point);
- using dequantize_tensor_per_tensor_affine_sub_byte_fn =
- void (*)(const Tensor& qtensor, Tensor& rtensor, float scale, float zero_point);
- DECLARE_DISPATCH(
- quantize_tensor_per_tensor_affine_fn,
- quantize_tensor_per_tensor_affine_stub);
- DECLARE_DISPATCH(
- quantize_tensor_per_channel_affine_fn,
- quantize_tensor_per_channel_affine_stub);
- DECLARE_DISPATCH(
- quantize_tensor_per_channel_float_qparams_fn,
- quantize_tensor_per_channel_float_qparams_stub);
- DECLARE_DISPATCH(
- dequantize_tensor_per_tensor_affine_fn,
- dequantize_tensor_per_tensor_affine_stub);
- DECLARE_DISPATCH(
- dequantize_tensor_per_channel_affine_fn,
- dequantize_tensor_per_channel_affine_stub);
- DECLARE_DISPATCH(
- dequantize_tensor_per_channel_float_qparams_fn,
- dequantize_tensor_per_channel_float_qparams_stub);
- DECLARE_DISPATCH(
- quantize_tensor_per_tensor_affine_sub_byte_fn,
- quantize_tensor_per_tensor_affine_sub_byte_stub);
- DECLARE_DISPATCH(
- dequantize_tensor_per_tensor_affine_sub_byte_fn,
- dequantize_tensor_per_tensor_affine_sub_byte_stub);
- template <typename T>
- TORCH_API Tensor quantize_tensor(
- Tensor rtensor,
- Tensor qtensor,
- double scale,
- int64_t zero_point);
- template <typename T>
- TORCH_API Tensor dequantize_tensor(
- Tensor qtensor,
- Tensor rtensor,
- double scale,
- int64_t zero_point);
- } // namespace native
- } // namespace at
|