XnnpackUtils.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. #pragma once
  2. #ifdef USE_XNNPACK
  3. #include <cstdint>
  4. #include <ATen/core/Tensor.h>
  5. #include <ATen/native/xnnpack/Common.h>
  6. using xnnpack_operator = at::native::xnnpack::Operator;
  7. namespace at {
  8. namespace native {
  9. namespace xnnp_utils {
  10. /*
  11. * Return shape in the same order as the memory format
  12. * e.g. channels_last will return NHWC instead of NCHW
  13. */
  14. std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
  15. /*
  16. * Input is always int8_t, output can be [int8_t, uint8_t].
  17. * input + offset = output
  18. * int8_t + 128 = uint8_t
  19. * int8_t + 0 = int8_t
  20. */
  21. template <typename PT>
  22. void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
  23. template <int kSpatialDim>
  24. Tensor convert_conv_weights_to_channel_last_tensor(
  25. const at::Tensor& src,
  26. int groups,
  27. bool transpose);
  28. /*
  29. * Series of create wrapper functions to call xnn_create_[de]conv* functions.
  30. */
  31. C10_ALWAYS_INLINE
  32. enum xnn_status xnnp_create_convolution2d_nhwc(
  33. uint32_t pad_top,
  34. uint32_t pad_right,
  35. uint32_t pad_bottom,
  36. uint32_t pad_left,
  37. uint32_t kernel_h,
  38. uint32_t kernel_w,
  39. uint32_t stride_h,
  40. uint32_t stride_w,
  41. uint32_t dilation_h,
  42. uint32_t dilation_w,
  43. uint32_t groups,
  44. size_t group_input_channels,
  45. size_t group_output_channels,
  46. size_t ip_chan_stride,
  47. size_t op_chan_stride,
  48. int8_t izp,
  49. float ip_scale,
  50. int8_t kzp,
  51. const float* k_scales,
  52. const int8_t* kernel,
  53. const int32_t* bias,
  54. int8_t ozp,
  55. float op_scale,
  56. int8_t op_min,
  57. int8_t op_max,
  58. uint32_t flags,
  59. xnn_operator_t* op,
  60. bool per_channel,
  61. bool transpose) {
  62. /* Symmetric quantization forces kzp = 0 */
  63. TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
  64. "But got: ", kzp);
  65. if (transpose) {
  66. TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
  67. return xnn_create_deconvolution2d_nhwc_qs8(
  68. pad_top, /* uint32_t output_padding_top */
  69. pad_right, /* uint32_t output_padding_right */
  70. pad_bottom, /* uint32_t output_padding_bottom */
  71. pad_left, /* uint32_t output_padding_left */
  72. kernel_h, /* uint32_t kernel_height */
  73. kernel_w, /* uint32_t kernel_width */
  74. stride_h, /* uint32_t stride_height */
  75. stride_w, /* uint32_t stride_width */
  76. dilation_h, /* uint32_t dilation_height */
  77. dilation_w, /* uint32_t dilation_width */
  78. groups, /* uint32_t groups */
  79. group_input_channels, /* size_t group_input_channels */
  80. group_output_channels, /* size_t group_output_channels */
  81. ip_chan_stride, /* size_t input_pixel_stride */
  82. op_chan_stride, /* size_t output_pixel_stride */
  83. izp, /* int8_t input_zero_point */
  84. ip_scale, /* float input_scale */
  85. k_scales[0], /* float kernel_scale */
  86. kernel, /* const int8_t* kernel */
  87. bias, /* const int32_t* bias */
  88. ozp, /* int8_t output_zero_point */
  89. op_scale, /* float output_scale */
  90. op_min, /* int8_t output_min */
  91. op_max, /* int8_t output_max */
  92. flags, /* uint32_t flags */
  93. nullptr, /* xnn_caches_t caches */
  94. op); /* xnn_operator_t* deconvolution_op_out */
  95. }
  96. if (!per_channel) {
  97. return xnn_create_convolution2d_nhwc_qs8(
  98. pad_top, /* uint32_t input_padding_top */
  99. pad_right, /* uint32_t input_padding_right */
  100. pad_bottom, /* uint32_t input_padding_bottom */
  101. pad_left, /* uint32_t input_padding_left */
  102. kernel_h, /* uint32_t kernel_height */
  103. kernel_w, /* uint32_t kernel_width */
  104. stride_h, /* uint32_t subsampling_height */
  105. stride_w, /* uint32_t subsampling_width */
  106. dilation_h, /* uint32_t dilation_height */
  107. dilation_w, /* uint32_t dilation_width */
  108. groups, /* uint32_t groups */
  109. group_input_channels, /* size_t group_input_channels */
  110. group_output_channels, /* size_t group_output_channels*/
  111. ip_chan_stride, /* size_t input_channel_stride */
  112. op_chan_stride, /* size_t output_channel_stride */
  113. izp, /* int8_t input_zero_point */
  114. ip_scale, /* float input_scale */
  115. k_scales[0], /* float kernel_scale */
  116. kernel, /* const int8_t* kernel */
  117. bias, /* const int32_t* bias */
  118. ozp, /* int8_t output_zero_point */
  119. op_scale, /* float output_scale */
  120. op_min, /* int8_t output_min */
  121. op_max, /* int8_t output_max */
  122. flags, /* uint32_t flags */
  123. nullptr, /* xnn_caches_t caches */
  124. op); /* xnn_operator_t* convolution_op_out */
  125. } else { /* per_channel */
  126. return xnn_create_convolution2d_nhwc_qc8(
  127. pad_top, /* uint32_t input_padding_top */
  128. pad_right, /* uint32_t input_padding_right */
  129. pad_bottom, /* uint32_t input_padding_bottom */
  130. pad_left, /* uint32_t input_padding_left */
  131. kernel_h, /* uint32_t kernel_height */
  132. kernel_w, /* uint32_t kernel_width */
  133. stride_h, /* uint32_t subsampling_height */
  134. stride_w, /* uint32_t subsampling_width */
  135. dilation_h, /* uint32_t dilation_height */
  136. dilation_w, /* uint32_t dilation_width */
  137. groups, /* uint32_t groups */
  138. group_input_channels, /* size_t group_input_channels */
  139. group_output_channels, /* size_t group_output_channels*/
  140. ip_chan_stride, /* size_t input_channel_stride */
  141. op_chan_stride, /* size_t output_channel_stride */
  142. izp, /* int8_t input_zero_point */
  143. ip_scale, /* float input_scale */
  144. k_scales, /* const float* kernel_scale */
  145. kernel, /* const int8_t* kernel */
  146. bias, /* const int32_t* bias */
  147. ozp, /* int8_t output_zero_point */
  148. op_scale, /* float output_scale */
  149. op_min, /* int8_t output_min */
  150. op_max, /* int8_t output_max */
  151. flags, /* uint32_t flags */
  152. nullptr, /* xnn_caches_t caches */
  153. op); /* xnn_operator_t* convolution_op_out */
  154. }
  155. }
  156. /*
  157. * Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
  158. */
  159. C10_ALWAYS_INLINE
  160. enum xnn_status xnnp_setup_convolution2d_nhwc(
  161. xnn_operator_t op,
  162. size_t batch,
  163. size_t in_h,
  164. size_t in_w,
  165. const int8_t* inp,
  166. int8_t* outp,
  167. pthreadpool_t pt_pool,
  168. bool per_channel = false,
  169. bool transpose = false,
  170. uint32_t adj_h = 0,
  171. uint32_t adj_w = 0) {
  172. if(transpose) {
  173. TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
  174. return xnn_setup_deconvolution2d_nhwc_qs8(
  175. op, /* xnn_operator_t deconvolution_op */
  176. batch, /* size_t batch_size */
  177. in_h, /* size_t input_height */
  178. in_w, /* size_t input_width */
  179. adj_h, /* uint32_t adjustment_height */
  180. adj_w, /* uint32_t adjustment_width */
  181. inp, /* const int8_t* input */
  182. outp, /* int8_t* output */
  183. pt_pool); /* pthreadpool_t threadpool */
  184. }
  185. if (!per_channel) {
  186. return xnn_setup_convolution2d_nhwc_qs8(
  187. op, /* xnn_operator_t convolution_op */
  188. batch, /* size_t batch_size */
  189. in_h, /* size_t input_height */
  190. in_w, /* size_t input_width */
  191. inp, /* const int8_t* input */
  192. outp, /* int8_t* output */
  193. pt_pool); /* pthreadpool_t threadpool */
  194. } else { /* per_channel */
  195. return xnn_setup_convolution2d_nhwc_qc8(
  196. op, /* xnn_operator_t convolution_op */
  197. batch, /* size_t batch_size */
  198. in_h, /* size_t input_height */
  199. in_w, /* size_t input_width */
  200. inp, /* const int8_t* input */
  201. outp, /* int8_t* output */
  202. pt_pool); /* pthreadpool_t threadpool */
  203. }
  204. }
  205. /*
  206. * Series of wrapper functions to call xnn_create* and xnn_setup*
  207. * functions for linear
  208. */
  209. C10_ALWAYS_INLINE
  210. enum xnn_status xnnp_create_fully_connected_nc(
  211. size_t input_channels,
  212. size_t output_channels,
  213. size_t input_stride,
  214. size_t output_stride,
  215. int8_t input_zero_point,
  216. float input_scale,
  217. int8_t kernel_zero_point,
  218. float kernel_scale,
  219. const int8_t* kernel,
  220. const int32_t* bias,
  221. int8_t output_zero_point,
  222. float output_scale,
  223. int8_t output_min,
  224. int8_t output_max,
  225. uint32_t flags,
  226. xnn_operator_t* fully_connected_op_out) {
  227. /* Symmetric quantization forces kzp = 0 */
  228. TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
  229. "But got: ", kernel_zero_point);
  230. return xnn_create_fully_connected_nc_qs8(
  231. input_channels, /* size_t input_channels */
  232. output_channels, /* size_t output_channels */
  233. input_stride, /* size_t input_stride */
  234. output_stride, /* size_t output_stride */
  235. input_zero_point, /* int8_t input_zero_point */
  236. input_scale, /* float input_scale */
  237. kernel_scale, /* float kernel_scale */
  238. kernel, /* const int8_t* kernel */
  239. bias, /* const int32_t* bias */
  240. output_zero_point, /* int8_t output_zero_point */
  241. output_scale, /* float output_scale */
  242. output_min, /* int8_t output_min */
  243. output_max, /* int8_t output_max */
  244. flags, /* uint32_t flags */
  245. nullptr, /* xnn_caches_t caches */
  246. fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
  247. }
  248. C10_ALWAYS_INLINE
  249. enum xnn_status xnnp_setup_fully_connected_nc(
  250. xnn_operator_t fully_connected_op,
  251. size_t batch_size,
  252. const int8_t* input,
  253. int8_t* output,
  254. pthreadpool_t threadpool) {
  255. return xnn_setup_fully_connected_nc_qs8(
  256. fully_connected_op, /* xnn_operator_t fully_connected_op */
  257. batch_size, /* size_t batch_size */
  258. input, /* const int8_t* input */
  259. output, /* int8_t* output */
  260. threadpool); /* pthreadpool_t threadpool */
  261. }
  262. } // namespace xnnp_utils
  263. } // namespace native
  264. } // namespace at
  265. #endif // USE_XNNPACK