PackedParams.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #pragma once
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/core/ivalue.h>
  4. struct LinearPackedParamsBase : public torch::jit::CustomClassHolder {
  5. virtual at::Tensor apply(
  6. at::Tensor input,
  7. double output_scale,
  8. int64_t output_zero_point) = 0;
  9. virtual at::Tensor apply_relu(
  10. at::Tensor input,
  11. double output_scale,
  12. int64_t output_zero_point) = 0;
  13. // out variant of LinearPackedParamsBase::apply
  14. virtual at::Tensor& apply_out(
  15. const at::Tensor& /*input*/,
  16. double /*output_scale*/,
  17. int64_t /*output_zero_point*/,
  18. at::Tensor& output) {
  19. throw std::runtime_error(
  20. "apply_out is not implemented for this packed "
  21. "parameter type");
  22. return output;
  23. }
  24. virtual at::Tensor& apply_relu_out(
  25. const at::Tensor& /*input*/,
  26. double /*output_scale*/,
  27. int64_t /*output_zero_point*/,
  28. at::Tensor& output) {
  29. throw std::runtime_error(
  30. "apply_relu_out is not implemented for this packed "
  31. "parameter type");
  32. return output;
  33. }
  34. // Corresponding pattern (the ops with `*` are part of the pattern that
  35. // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_output_fp32):
  36. // input -> q* -> dq* -> linear* ->
  37. // qweight -> dq* /
  38. //
  39. // After fusion:
  40. // input -> quantized::linear_with_input_q_dq_qweight_dq_output_fp32* ->
  41. // qweight /
  42. //
  43. // Additional Note: the weight is packed as well
  44. // Params:
  45. // X: float32 Tensor, will be quantized to quint8 in the op
  46. // W_prepack: packed qint8 quantized weight and bias
  47. // Returns:
  48. // Y: float32 Tensor
  49. virtual at::Tensor apply_with_input_q_dq_qweight_dq_output_fp32(
  50. at::Tensor input,
  51. double input_scale,
  52. int64_t input_zero_point) {
  53. throw std::runtime_error(
  54. "apply_with_input_q_dq_qweight_dq_output_fp32 is not implemented for this packed "
  55. "parameter type");
  56. return {};
  57. }
  58. // Corresponding pattern (the ops with `*` are part of the pattern that
  59. // represents the computation of quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32):
  60. // input -> q* -> dq* -> linear* -> relu* ->
  61. // qweight -> dq* /
  62. //
  63. // After fusion:
  64. // input -> quantized::linear_with_input_q_dq_qweight_dq_relu_output_fp32* ->
  65. // qweight /
  66. //
  67. // Additional Note: the weight is packed as well
  68. // Params:
  69. // input: float32 Tensor, will be quantized to quint8 in the op
  70. // Returns:
  71. // float32 Tensor
  72. virtual at::Tensor apply_with_input_q_dq_qweight_dq_relu_output_fp32(
  73. at::Tensor input,
  74. double input_scale,
  75. int64_t input_zero_point) {
  76. throw std::runtime_error(
  77. "apply_with_input_q_dq_qweight_dq_relu_output_fp32 is not implemented for this packed "
  78. "parameter type");
  79. return {};
  80. }
  81. virtual at::Tensor apply_dynamic(
  82. at::Tensor input,
  83. bool reduce_range = false) = 0;
  84. virtual at::Tensor apply_dynamic_relu(
  85. at::Tensor input,
  86. bool reduce_range = false) = 0;
  87. virtual at::Tensor& apply_dynamic_out(
  88. const at::Tensor& /* input */,
  89. at::Tensor& output,
  90. bool /* reduce_range */) {
  91. throw std::runtime_error(
  92. "apply_dynamic_out is not implemented for this packed "
  93. "parameter type");
  94. return output;
  95. }
  96. virtual at::Tensor& apply_dynamic_relu_out(
  97. const at::Tensor& /* input */,
  98. at::Tensor& output,
  99. bool /* reduce_range */) {
  100. throw std::runtime_error(
  101. "apply_dynamic_relu_out is not implemented for this packed "
  102. "parameter type");
  103. return output;
  104. }
  105. virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
  106. virtual c10::optional<at::Tensor> bias() = 0;
  107. virtual void set_bias(c10::optional<at::Tensor> /*bias*/) {
  108. throw std::runtime_error(
  109. "set_bias is not implemented for this packed "
  110. "parameter type");
  111. }
  112. };
  113. template <int kSpatialDim = 2>
  114. struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
  115. virtual at::Tensor apply(
  116. const at::Tensor& input,
  117. double output_scale,
  118. int64_t output_zero_point) = 0;
  119. virtual at::Tensor apply_relu(
  120. const at::Tensor& input,
  121. double output_scale,
  122. int64_t output_zero_point) = 0;
  123. virtual at::Tensor apply_dynamic(
  124. const at::Tensor& input,
  125. bool reduce_range) = 0;
  126. virtual std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() = 0;
  127. virtual torch::List<int64_t> stride() const = 0;
  128. virtual torch::List<int64_t> padding() const = 0;
  129. virtual torch::List<int64_t> output_padding() const = 0;
  130. virtual torch::List<int64_t> dilation() const = 0;
  131. virtual int64_t groups() const = 0;
  132. virtual bool transpose() const = 0;
  133. };