FakeQuantAffine.h 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/native/DispatchStub.h>
  5. namespace at {
  6. struct TensorIterator;
  7. namespace native {
  8. using fake_quant_tensor_cachemask_fn = void (*)(
  9. Tensor& output,
  10. Tensor& mask,
  11. const Tensor& input,
  12. float sc,
  13. int64_t z_point,
  14. int64_t quant_min,
  15. int64_t quant_max);
  16. using fake_quant_tensor_cachemask_tensor_qparams_fn = void (*)(
  17. Tensor& output,
  18. Tensor& mask,
  19. const Tensor& input,
  20. const Tensor& sc,
  21. const Tensor& z_point,
  22. const Tensor& fake_quant_enabled,
  23. int64_t quant_min,
  24. int64_t quant_max);
  25. using fake_quant_learnable_grad_tensor_fn = void (*)(
  26. TensorIterator& iter,
  27. float scale,
  28. float inv_scale,
  29. int64_t zero_point,
  30. int64_t quant_min,
  31. int64_t quant_max,
  32. float grad_factor);
  33. DECLARE_DISPATCH(fake_quant_tensor_cachemask_fn, fake_quant_tensor_cachemask_stub);
  34. DECLARE_DISPATCH(fake_quant_tensor_cachemask_tensor_qparams_fn, fake_quant_tensor_cachemask_tensor_qparams_stub);
  35. DECLARE_DISPATCH(fake_quant_learnable_grad_tensor_fn, fake_quant_grad_learnable_tensor_stub);
  36. using fake_quant_per_channel_fn = void (*)(
  37. TensorIterator &iter,
  38. int64_t quant_min,
  39. int64_t quant_max);
  40. using fake_quant_per_channel_cachemask_fn = void (*)(
  41. TensorIterator &iter,
  42. TensorIterator &iter_mask,
  43. int64_t quant_min,
  44. int64_t quant_max);
  45. DECLARE_DISPATCH(fake_quant_per_channel_cachemask_fn, fake_quant_per_channel_cachemask_stub);
  46. using fake_quant_learnable_per_channel_fn = void (*)(
  47. TensorIterator &iter,
  48. int64_t quant_min,
  49. int64_t quant_max,
  50. float grad_factor);
  51. DECLARE_DISPATCH(fake_quant_learnable_per_channel_fn, fake_quant_grad_learnable_channel_stub);
  52. } // namespace native
  53. } // namespace at