123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- #pragma once
- #ifdef USE_XNNPACK
- #include <cstdint>
- #include <ATen/core/Tensor.h>
- #include <ATen/native/xnnpack/Common.h>
- using xnnpack_operator = at::native::xnnpack::Operator;
- namespace at {
- namespace native {
- namespace xnnp_utils {
- /*
- * Return shape in the same order as the memory format
- * e.g. channels_last will return NHWC instead of NCHW
- */
- std::vector<size_t> get_mem_format_aware_shape(const at::Tensor& in);
- /*
- * Input is always int8_t, output can be [int8_t, uint8_t].
- * input + offset = output
- * int8_t + 128 = uint8_t
- * int8_t + 0 = int8_t
- */
- template <typename PT>
- void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out);
- template <int kSpatialDim>
- Tensor convert_conv_weights_to_channel_last_tensor(
- const at::Tensor& src,
- int groups,
- bool transpose);
- /*
- * Series of create wrapper functions to call xnn_create_[de]conv* functions.
- */
- C10_ALWAYS_INLINE
- enum xnn_status xnnp_create_convolution2d_nhwc(
- uint32_t pad_top,
- uint32_t pad_right,
- uint32_t pad_bottom,
- uint32_t pad_left,
- uint32_t kernel_h,
- uint32_t kernel_w,
- uint32_t stride_h,
- uint32_t stride_w,
- uint32_t dilation_h,
- uint32_t dilation_w,
- uint32_t groups,
- size_t group_input_channels,
- size_t group_output_channels,
- size_t ip_chan_stride,
- size_t op_chan_stride,
- int8_t izp,
- float ip_scale,
- int8_t kzp,
- const float* k_scales,
- const int8_t* kernel,
- const int32_t* bias,
- int8_t ozp,
- float op_scale,
- int8_t op_min,
- int8_t op_max,
- uint32_t flags,
- xnn_operator_t* op,
- bool per_channel,
- bool transpose) {
- /* Symmetric quantization forces kzp = 0 */
- TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero."
- "But got: ", kzp);
- if (transpose) {
- TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
- return xnn_create_deconvolution2d_nhwc_qs8(
- pad_top, /* uint32_t output_padding_top */
- pad_right, /* uint32_t output_padding_right */
- pad_bottom, /* uint32_t output_padding_bottom */
- pad_left, /* uint32_t output_padding_left */
- kernel_h, /* uint32_t kernel_height */
- kernel_w, /* uint32_t kernel_width */
- stride_h, /* uint32_t stride_height */
- stride_w, /* uint32_t stride_width */
- dilation_h, /* uint32_t dilation_height */
- dilation_w, /* uint32_t dilation_width */
- groups, /* uint32_t groups */
- group_input_channels, /* size_t group_input_channels */
- group_output_channels, /* size_t group_output_channels */
- ip_chan_stride, /* size_t input_pixel_stride */
- op_chan_stride, /* size_t output_pixel_stride */
- izp, /* int8_t input_zero_point */
- ip_scale, /* float input_scale */
- k_scales[0], /* float kernel_scale */
- kernel, /* const int8_t* kernel */
- bias, /* const int32_t* bias */
- ozp, /* int8_t output_zero_point */
- op_scale, /* float output_scale */
- op_min, /* int8_t output_min */
- op_max, /* int8_t output_max */
- flags, /* uint32_t flags */
- nullptr, /* xnn_caches_t caches */
- op); /* xnn_operator_t* deconvolution_op_out */
- }
- if (!per_channel) {
- return xnn_create_convolution2d_nhwc_qs8(
- pad_top, /* uint32_t input_padding_top */
- pad_right, /* uint32_t input_padding_right */
- pad_bottom, /* uint32_t input_padding_bottom */
- pad_left, /* uint32_t input_padding_left */
- kernel_h, /* uint32_t kernel_height */
- kernel_w, /* uint32_t kernel_width */
- stride_h, /* uint32_t subsampling_height */
- stride_w, /* uint32_t subsampling_width */
- dilation_h, /* uint32_t dilation_height */
- dilation_w, /* uint32_t dilation_width */
- groups, /* uint32_t groups */
- group_input_channels, /* size_t group_input_channels */
- group_output_channels, /* size_t group_output_channels*/
- ip_chan_stride, /* size_t input_channel_stride */
- op_chan_stride, /* size_t output_channel_stride */
- izp, /* int8_t input_zero_point */
- ip_scale, /* float input_scale */
- k_scales[0], /* float kernel_scale */
- kernel, /* const int8_t* kernel */
- bias, /* const int32_t* bias */
- ozp, /* int8_t output_zero_point */
- op_scale, /* float output_scale */
- op_min, /* int8_t output_min */
- op_max, /* int8_t output_max */
- flags, /* uint32_t flags */
- nullptr, /* xnn_caches_t caches */
- op); /* xnn_operator_t* convolution_op_out */
- } else { /* per_channel */
- return xnn_create_convolution2d_nhwc_qc8(
- pad_top, /* uint32_t input_padding_top */
- pad_right, /* uint32_t input_padding_right */
- pad_bottom, /* uint32_t input_padding_bottom */
- pad_left, /* uint32_t input_padding_left */
- kernel_h, /* uint32_t kernel_height */
- kernel_w, /* uint32_t kernel_width */
- stride_h, /* uint32_t subsampling_height */
- stride_w, /* uint32_t subsampling_width */
- dilation_h, /* uint32_t dilation_height */
- dilation_w, /* uint32_t dilation_width */
- groups, /* uint32_t groups */
- group_input_channels, /* size_t group_input_channels */
- group_output_channels, /* size_t group_output_channels*/
- ip_chan_stride, /* size_t input_channel_stride */
- op_chan_stride, /* size_t output_channel_stride */
- izp, /* int8_t input_zero_point */
- ip_scale, /* float input_scale */
- k_scales, /* const float* kernel_scale */
- kernel, /* const int8_t* kernel */
- bias, /* const int32_t* bias */
- ozp, /* int8_t output_zero_point */
- op_scale, /* float output_scale */
- op_min, /* int8_t output_min */
- op_max, /* int8_t output_max */
- flags, /* uint32_t flags */
- nullptr, /* xnn_caches_t caches */
- op); /* xnn_operator_t* convolution_op_out */
- }
- }
- /*
- * Series of setup wrapper functions to call xnn_setup_[de]conv* functions.
- */
- C10_ALWAYS_INLINE
- enum xnn_status xnnp_setup_convolution2d_nhwc(
- xnn_operator_t op,
- size_t batch,
- size_t in_h,
- size_t in_w,
- const int8_t* inp,
- int8_t* outp,
- pthreadpool_t pt_pool,
- bool per_channel = false,
- bool transpose = false,
- uint32_t adj_h = 0,
- uint32_t adj_w = 0) {
- if(transpose) {
- TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!");
- return xnn_setup_deconvolution2d_nhwc_qs8(
- op, /* xnn_operator_t deconvolution_op */
- batch, /* size_t batch_size */
- in_h, /* size_t input_height */
- in_w, /* size_t input_width */
- adj_h, /* uint32_t adjustment_height */
- adj_w, /* uint32_t adjustment_width */
- inp, /* const int8_t* input */
- outp, /* int8_t* output */
- pt_pool); /* pthreadpool_t threadpool */
- }
- if (!per_channel) {
- return xnn_setup_convolution2d_nhwc_qs8(
- op, /* xnn_operator_t convolution_op */
- batch, /* size_t batch_size */
- in_h, /* size_t input_height */
- in_w, /* size_t input_width */
- inp, /* const int8_t* input */
- outp, /* int8_t* output */
- pt_pool); /* pthreadpool_t threadpool */
- } else { /* per_channel */
- return xnn_setup_convolution2d_nhwc_qc8(
- op, /* xnn_operator_t convolution_op */
- batch, /* size_t batch_size */
- in_h, /* size_t input_height */
- in_w, /* size_t input_width */
- inp, /* const int8_t* input */
- outp, /* int8_t* output */
- pt_pool); /* pthreadpool_t threadpool */
- }
- }
- /*
- * Series of wrapper functions to call xnn_create* and xnn_setup*
- * functions for linear
- */
- C10_ALWAYS_INLINE
- enum xnn_status xnnp_create_fully_connected_nc(
- size_t input_channels,
- size_t output_channels,
- size_t input_stride,
- size_t output_stride,
- int8_t input_zero_point,
- float input_scale,
- int8_t kernel_zero_point,
- float kernel_scale,
- const int8_t* kernel,
- const int32_t* bias,
- int8_t output_zero_point,
- float output_scale,
- int8_t output_min,
- int8_t output_max,
- uint32_t flags,
- xnn_operator_t* fully_connected_op_out) {
- /* Symmetric quantization forces kzp = 0 */
- TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero."
- "But got: ", kernel_zero_point);
- return xnn_create_fully_connected_nc_qs8(
- input_channels, /* size_t input_channels */
- output_channels, /* size_t output_channels */
- input_stride, /* size_t input_stride */
- output_stride, /* size_t output_stride */
- input_zero_point, /* int8_t input_zero_point */
- input_scale, /* float input_scale */
- kernel_scale, /* float kernel_scale */
- kernel, /* const int8_t* kernel */
- bias, /* const int32_t* bias */
- output_zero_point, /* int8_t output_zero_point */
- output_scale, /* float output_scale */
- output_min, /* int8_t output_min */
- output_max, /* int8_t output_max */
- flags, /* uint32_t flags */
- nullptr, /* xnn_caches_t caches */
- fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
- }
- C10_ALWAYS_INLINE
- enum xnn_status xnnp_setup_fully_connected_nc(
- xnn_operator_t fully_connected_op,
- size_t batch_size,
- const int8_t* input,
- int8_t* output,
- pthreadpool_t threadpool) {
- return xnn_setup_fully_connected_nc_qs8(
- fully_connected_op, /* xnn_operator_t fully_connected_op */
- batch_size, /* size_t batch_size */
- input, /* const int8_t* input */
- output, /* int8_t* output */
- threadpool); /* pthreadpool_t threadpool */
- }
- } // namespace xnnp_utils
- } // namespace native
- } // namespace at
- #endif // USE_XNNPACK
|