#pragma once #ifdef USE_XNNPACK #include #include #include using xnnpack_operator = at::native::xnnpack::Operator; namespace at { namespace native { namespace xnnp_utils { /* * Return shape in the same order as the memory format * e.g. channels_last will return NHWC instead of NCHW */ std::vector get_mem_format_aware_shape(const at::Tensor& in); /* * Input is always int8_t, output can be [int8_t, uint8_t]. * input + offset = output * int8_t + 128 = uint8_t * int8_t + 0 = int8_t */ template void q8_copy_int8_weight_and_add_offset(const at::Tensor& in, at::Tensor& out); template Tensor convert_conv_weights_to_channel_last_tensor( const at::Tensor& src, int groups, bool transpose); /* * Series of create wrapper functions to call xnn_create_[de]conv* functions. */ C10_ALWAYS_INLINE enum xnn_status xnnp_create_convolution2d_nhwc( uint32_t pad_top, uint32_t pad_right, uint32_t pad_bottom, uint32_t pad_left, uint32_t kernel_h, uint32_t kernel_w, uint32_t stride_h, uint32_t stride_w, uint32_t dilation_h, uint32_t dilation_w, uint32_t groups, size_t group_input_channels, size_t group_output_channels, size_t ip_chan_stride, size_t op_chan_stride, int8_t izp, float ip_scale, int8_t kzp, const float* k_scales, const int8_t* kernel, const int32_t* bias, int8_t ozp, float op_scale, int8_t op_min, int8_t op_max, uint32_t flags, xnn_operator_t* op, bool per_channel, bool transpose) { /* Symmetric quantization forces kzp = 0 */ TORCH_CHECK(!kzp, "XNNPACK Q[SC]8 conv kernels expects kernel zero point to be zero." "But got: ", kzp); if (transpose) { TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); return xnn_create_deconvolution2d_nhwc_qs8( pad_top, /* uint32_t output_padding_top */ pad_right, /* uint32_t output_padding_right */ pad_bottom, /* uint32_t output_padding_bottom */ pad_left, /* uint32_t output_padding_left */ kernel_h, /* uint32_t kernel_height */ kernel_w, /* uint32_t kernel_width */ stride_h, /* uint32_t stride_height */ stride_w, /* uint32_t stride_width */ dilation_h, /* uint32_t dilation_height */ dilation_w, /* uint32_t dilation_width */ groups, /* uint32_t groups */ group_input_channels, /* size_t group_input_channels */ group_output_channels, /* size_t group_output_channels */ ip_chan_stride, /* size_t input_pixel_stride */ op_chan_stride, /* size_t output_pixel_stride */ izp, /* int8_t input_zero_point */ ip_scale, /* float input_scale */ k_scales[0], /* float kernel_scale */ kernel, /* const int8_t* kernel */ bias, /* const int32_t* bias */ ozp, /* int8_t output_zero_point */ op_scale, /* float output_scale */ op_min, /* int8_t output_min */ op_max, /* int8_t output_max */ flags, /* uint32_t flags */ nullptr, /* xnn_caches_t caches */ op); /* xnn_operator_t* deconvolution_op_out */ } if (!per_channel) { return xnn_create_convolution2d_nhwc_qs8( pad_top, /* uint32_t input_padding_top */ pad_right, /* uint32_t input_padding_right */ pad_bottom, /* uint32_t input_padding_bottom */ pad_left, /* uint32_t input_padding_left */ kernel_h, /* uint32_t kernel_height */ kernel_w, /* uint32_t kernel_width */ stride_h, /* uint32_t subsampling_height */ stride_w, /* uint32_t subsampling_width */ dilation_h, /* uint32_t dilation_height */ dilation_w, /* uint32_t dilation_width */ groups, /* uint32_t groups */ group_input_channels, /* size_t group_input_channels */ group_output_channels, /* size_t group_output_channels*/ ip_chan_stride, /* size_t input_channel_stride */ op_chan_stride, /* size_t output_channel_stride */ izp, /* int8_t input_zero_point */ ip_scale, /* float input_scale */ k_scales[0], /* float kernel_scale */ kernel, /* const int8_t* kernel */ bias, /* const int32_t* bias */ ozp, /* int8_t output_zero_point */ op_scale, /* float output_scale */ op_min, /* int8_t output_min */ op_max, /* int8_t output_max */ flags, /* uint32_t flags */ nullptr, /* xnn_caches_t caches */ op); /* xnn_operator_t* convolution_op_out */ } else { /* per_channel */ return xnn_create_convolution2d_nhwc_qc8( pad_top, /* uint32_t input_padding_top */ pad_right, /* uint32_t input_padding_right */ pad_bottom, /* uint32_t input_padding_bottom */ pad_left, /* uint32_t input_padding_left */ kernel_h, /* uint32_t kernel_height */ kernel_w, /* uint32_t kernel_width */ stride_h, /* uint32_t subsampling_height */ stride_w, /* uint32_t subsampling_width */ dilation_h, /* uint32_t dilation_height */ dilation_w, /* uint32_t dilation_width */ groups, /* uint32_t groups */ group_input_channels, /* size_t group_input_channels */ group_output_channels, /* size_t group_output_channels*/ ip_chan_stride, /* size_t input_channel_stride */ op_chan_stride, /* size_t output_channel_stride */ izp, /* int8_t input_zero_point */ ip_scale, /* float input_scale */ k_scales, /* const float* kernel_scale */ kernel, /* const int8_t* kernel */ bias, /* const int32_t* bias */ ozp, /* int8_t output_zero_point */ op_scale, /* float output_scale */ op_min, /* int8_t output_min */ op_max, /* int8_t output_max */ flags, /* uint32_t flags */ nullptr, /* xnn_caches_t caches */ op); /* xnn_operator_t* convolution_op_out */ } } /* * Series of setup wrapper functions to call xnn_setup_[de]conv* functions. */ C10_ALWAYS_INLINE enum xnn_status xnnp_setup_convolution2d_nhwc( xnn_operator_t op, size_t batch, size_t in_h, size_t in_w, const int8_t* inp, int8_t* outp, pthreadpool_t pt_pool, bool per_channel = false, bool transpose = false, uint32_t adj_h = 0, uint32_t adj_w = 0) { if(transpose) { TORCH_CHECK(!per_channel, "XNNPACK Q[SC]8 does not have a per channel deconvolution!"); return xnn_setup_deconvolution2d_nhwc_qs8( op, /* xnn_operator_t deconvolution_op */ batch, /* size_t batch_size */ in_h, /* size_t input_height */ in_w, /* size_t input_width */ adj_h, /* uint32_t adjustment_height */ adj_w, /* uint32_t adjustment_width */ inp, /* const int8_t* input */ outp, /* int8_t* output */ pt_pool); /* pthreadpool_t threadpool */ } if (!per_channel) { return xnn_setup_convolution2d_nhwc_qs8( op, /* xnn_operator_t convolution_op */ batch, /* size_t batch_size */ in_h, /* size_t input_height */ in_w, /* size_t input_width */ inp, /* const int8_t* input */ outp, /* int8_t* output */ pt_pool); /* pthreadpool_t threadpool */ } else { /* per_channel */ return xnn_setup_convolution2d_nhwc_qc8( op, /* xnn_operator_t convolution_op */ batch, /* size_t batch_size */ in_h, /* size_t input_height */ in_w, /* size_t input_width */ inp, /* const int8_t* input */ outp, /* int8_t* output */ pt_pool); /* pthreadpool_t threadpool */ } } /* * Series of wrapper functions to call xnn_create* and xnn_setup* * functions for linear */ C10_ALWAYS_INLINE enum xnn_status xnnp_create_fully_connected_nc( size_t input_channels, size_t output_channels, size_t input_stride, size_t output_stride, int8_t input_zero_point, float input_scale, int8_t kernel_zero_point, float kernel_scale, const int8_t* kernel, const int32_t* bias, int8_t output_zero_point, float output_scale, int8_t output_min, int8_t output_max, uint32_t flags, xnn_operator_t* fully_connected_op_out) { /* Symmetric quantization forces kzp = 0 */ TORCH_CHECK(!kernel_zero_point, "XNNPACK QS8 linear kernel expects kernel zero point to be zero." "But got: ", kernel_zero_point); return xnn_create_fully_connected_nc_qs8( input_channels, /* size_t input_channels */ output_channels, /* size_t output_channels */ input_stride, /* size_t input_stride */ output_stride, /* size_t output_stride */ input_zero_point, /* int8_t input_zero_point */ input_scale, /* float input_scale */ kernel_scale, /* float kernel_scale */ kernel, /* const int8_t* kernel */ bias, /* const int32_t* bias */ output_zero_point, /* int8_t output_zero_point */ output_scale, /* float output_scale */ output_min, /* int8_t output_min */ output_max, /* int8_t output_max */ flags, /* uint32_t flags */ nullptr, /* xnn_caches_t caches */ fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */ } C10_ALWAYS_INLINE enum xnn_status xnnp_setup_fully_connected_nc( xnn_operator_t fully_connected_op, size_t batch_size, const int8_t* input, int8_t* output, pthreadpool_t threadpool) { return xnn_setup_fully_connected_nc_qs8( fully_connected_op, /* xnn_operator_t fully_connected_op */ batch_size, /* size_t batch_size */ input, /* const int8_t* input */ output, /* int8_t* output */ threadpool); /* pthreadpool_t threadpool */ } } // namespace xnnp_utils } // namespace native } // namespace at #endif // USE_XNNPACK