OnednnUtils.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #pragma once
  2. #include <ATen/Config.h>
  3. #if AT_MKLDNN_ENABLED()
  4. #include <ATen/Tensor.h>
  5. #include <ATen/native/quantized/PackedParams.h>
  6. #include <ideep.hpp>
  7. #include <cpuinfo.h>
  8. #include <c10/util/CallOnce.h>
  9. using PrimitiveCacheKey = std::tuple<
  10. double, // input_scale
  11. int64_t, // input_zero_point
  12. std::vector<int64_t>, // input_shape
  13. double, // output_scale
  14. int64_t, // output_zero_point
  15. int64_t, // OMP_number_of_threads
  16. double, // accum_scale
  17. int64_t>; // accum_zero_point
  18. enum CacheKeyIndex {
  19. InputScale,
  20. InputZeroPoint,
  21. InputShape,
  22. OutputScale,
  23. OutputZeroPoint,
  24. NumOfThreads,
  25. };
  26. // Base class of primitive cache
  27. struct PrimitiveCache {
  28. PrimitiveCacheKey key;
  29. bool hit(const PrimitiveCacheKey& key) {
  30. return this->key == key;
  31. }
  32. };
  33. using LinearParams = ideep::matmul_forward_params;
  34. using Conv = dnnl::convolution_forward;
  35. using ConvDesc = dnnl::convolution_forward::primitive_desc;
  36. using ConvParams = ideep::convolution_forward_params;
  37. using Deconv = dnnl::deconvolution_forward;
  38. using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
  39. using DeconvParams = ideep::deconv_forward_params;
  40. struct LinearPrimitiveCache : PrimitiveCache {
  41. LinearPrimitiveCache() {}
  42. LinearPrimitiveCache(
  43. const PrimitiveCacheKey& key,
  44. const LinearParams& param) {
  45. this->key = key;
  46. this->param = param;
  47. }
  48. LinearPrimitiveCache(
  49. const PrimitiveCacheKey& key,
  50. const LinearParams& param,
  51. const ideep::tensor& bias) {
  52. this->key = key;
  53. this->param = param;
  54. if (!bias.is_empty()) {
  55. expected_bias =
  56. bias.reorder_if_differ_in(param.pd.bias_desc(), param.bias_attr);
  57. }
  58. }
  59. LinearParams param;
  60. ideep::tensor expected_bias;
  61. // For dynamic qlinear, scale and zero point
  62. // are set at execution time. So we only need to compare
  63. // the rest part of key.
  64. bool hit_dynamic(const PrimitiveCacheKey& new_key) {
  65. auto cached_input_shape = std::get<InputShape>(this->key);
  66. auto new_input_shape = std::get<InputShape>(new_key);
  67. return (
  68. cached_input_shape == new_input_shape &&
  69. std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
  70. }
  71. LinearParams& get_param() {
  72. return param;
  73. }
  74. ideep::tensor& get_expected_bias() {
  75. return expected_bias;
  76. }
  77. };
  78. struct ConvPrimitiveCache : PrimitiveCache {
  79. ConvPrimitiveCache() {}
  80. ConvPrimitiveCache(
  81. const PrimitiveCacheKey& key,
  82. const ConvParams& params,
  83. const ideep::tensor& bias) {
  84. this->key = key;
  85. this->params = params;
  86. if (!bias.is_empty()) {
  87. this->expected_bias =
  88. bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
  89. }
  90. }
  91. ideep::tensor expected_bias;
  92. ConvParams params;
  93. ConvParams& get_params() {
  94. return params;
  95. }
  96. ideep::tensor& get_bias() {
  97. return expected_bias;
  98. }
  99. };
  100. struct DeconvPrimitiveCache : PrimitiveCache {
  101. DeconvPrimitiveCache() {}
  102. DeconvPrimitiveCache(
  103. const PrimitiveCacheKey& key,
  104. const DeconvParams& params,
  105. const ideep::tensor& bias) {
  106. this->key = key;
  107. this->params = params;
  108. if (!bias.is_empty()) {
  109. this->expected_bias =
  110. bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
  111. }
  112. }
  113. DeconvParams params;
  114. ideep::tensor expected_bias;
  115. DeconvParams& get_params() {
  116. return params;
  117. }
  118. ideep::tensor& get_bias() {
  119. return expected_bias;
  120. }
  121. };
  122. enum PostOps {
  123. NoPostOp,
  124. Relu,
  125. LeakyRelu,
  126. Tanh,
  127. };
  128. struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
  129. PackedLinearWeightsOnednn(
  130. std::unique_ptr<ideep::tensor> weight,
  131. c10::optional<ideep::tensor> bias,
  132. at::Tensor orig_weight,
  133. c10::optional<at::Tensor> orig_bias)
  134. : weight_(std::move(weight)),
  135. bias_(std::move(bias)),
  136. orig_weight_(std::move(orig_weight)),
  137. orig_bias_(std::move(orig_bias)) {
  138. cache_initialized_flag = std::make_unique<c10::once_flag>();
  139. }
  140. std::unique_ptr<ideep::tensor> weight_;
  141. c10::optional<ideep::tensor> bias_;
  142. at::Tensor orig_weight_;
  143. c10::optional<at::Tensor> orig_bias_;
  144. at::Tensor apply(
  145. at::Tensor input,
  146. double output_scale,
  147. int64_t output_zero_point) override;
  148. at::Tensor apply_relu(
  149. at::Tensor input,
  150. double output_scale,
  151. int64_t output_zero_point) override;
  152. at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
  153. at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
  154. at::Tensor apply_leaky_relu(
  155. at::Tensor input,
  156. double output_scale,
  157. int64_t output_zero_point,
  158. double negative_slope);
  159. at::Tensor apply_tanh(
  160. at::Tensor input,
  161. double output_scale,
  162. int64_t output_zero_point);
  163. std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
  164. c10::optional<at::Tensor> bias() override {
  165. return orig_bias_;
  166. }
  167. static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
  168. at::Tensor weight,
  169. c10::optional<at::Tensor> bias);
  170. private:
  171. LinearPrimitiveCache prim_cache;
  172. std::unique_ptr<c10::once_flag> cache_initialized_flag;
  173. template <PostOps post_op>
  174. at::Tensor apply_impl(
  175. at::Tensor input,
  176. double output_scale,
  177. int64_t output_zero_point,
  178. torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
  179. template <bool ReluFused>
  180. at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
  181. LinearPrimitiveCache& get_cache() {
  182. return prim_cache;
  183. }
  184. };
  185. template <int kSpatialDim = 2>
  186. struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
  187. PackedConvWeightsOnednn(
  188. std::unique_ptr<ideep::tensor> weight,
  189. c10::optional<ideep::tensor> bias,
  190. at::Tensor orig_weight,
  191. c10::optional<at::Tensor> orig_bias,
  192. torch::List<int64_t> stride,
  193. torch::List<int64_t> padding,
  194. torch::List<int64_t> output_padding,
  195. torch::List<int64_t> dilation,
  196. int64_t groups,
  197. uint8_t transpose)
  198. : weight_(std::move(weight)),
  199. bias_(std::move(bias)),
  200. orig_weight_(std::move(orig_weight)),
  201. orig_bias_(std::move(orig_bias)),
  202. stride_(std::move(stride)),
  203. padding_(std::move(padding)),
  204. output_padding_(std::move(output_padding)),
  205. dilation_(std::move(dilation)),
  206. groups_(groups),
  207. transpose_(transpose) {
  208. cache_initialized_flag = std::make_unique<c10::once_flag>();
  209. }
  210. std::unique_ptr<ideep::tensor> weight_;
  211. c10::optional<ideep::tensor> bias_;
  212. at::Tensor orig_weight_;
  213. c10::optional<at::Tensor> orig_bias_;
  214. torch::List<int64_t> stride_;
  215. torch::List<int64_t> padding_;
  216. torch::List<int64_t> output_padding_;
  217. torch::List<int64_t> dilation_;
  218. int64_t groups_;
  219. uint8_t transpose_;
  220. at::Tensor apply(
  221. const at::Tensor& input,
  222. double output_scale,
  223. int64_t output_zero_point) override;
  224. at::Tensor apply_relu(
  225. const at::Tensor& input,
  226. double output_scale,
  227. int64_t output_zero_point) override;
  228. at::Tensor apply_dynamic(
  229. const at::Tensor& input,
  230. bool reduce_range) override;
  231. at::Tensor apply_add(
  232. const at::Tensor& input,
  233. const at::Tensor& accum,
  234. double output_scale,
  235. int64_t output_zero_point);
  236. at::Tensor apply_add_relu(
  237. const at::Tensor& input,
  238. const at::Tensor& accum,
  239. double output_scale,
  240. int64_t output_zero_point);
  241. std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
  242. static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
  243. at::Tensor weight,
  244. c10::optional<at::Tensor> bias,
  245. torch::List<int64_t> stride,
  246. torch::List<int64_t> padding,
  247. torch::List<int64_t> output_padding,
  248. torch::List<int64_t> dilation,
  249. int64_t groups,
  250. bool transpose);
  251. torch::List<int64_t> stride() const override {
  252. return stride_;
  253. }
  254. torch::List<int64_t> padding() const override {
  255. return padding_;
  256. }
  257. torch::List<int64_t> output_padding() const override {
  258. return output_padding_;
  259. }
  260. torch::List<int64_t> dilation() const override {
  261. return dilation_;
  262. }
  263. int64_t groups() const override {
  264. return groups_;
  265. }
  266. bool transpose() const override {
  267. return (bool)transpose_;
  268. }
  269. private:
  270. ConvPrimitiveCache conv_prim_cache;
  271. DeconvPrimitiveCache deconv_prim_cache;
  272. std::unique_ptr<c10::once_flag> cache_initialized_flag;
  273. template <bool ReluFused>
  274. at::Tensor apply_impl(
  275. const at::Tensor& input,
  276. const c10::optional<at::Tensor>& accum,
  277. double output_scale,
  278. int64_t output_zero_point);
  279. ConvPrimitiveCache& get_conv_cache() {
  280. assert(!transpose());
  281. return conv_prim_cache;
  282. }
  283. DeconvPrimitiveCache& get_deconv_cache() {
  284. assert(transpose());
  285. return deconv_prim_cache;
  286. }
  287. };
  288. namespace onednn_utils {
  289. // Try to reorder tensor to expected desc at runtime
  290. // Do it in a `try...catch...` manner to avoid oneDNN's errors
  291. // TODO: Move it to third_party/ideep
  292. static void try_reorder(
  293. ideep::tensor& t,
  294. const ideep::tensor::desc&& desc,
  295. ideep::scale_t scales) {
  296. if (t.get_desc() != desc) {
  297. try {
  298. t = t.reorder_if_differ_in(desc);
  299. } catch (...) {
  300. ideep::tensor&& plain = t.to_public(nullptr, t.get_data_type());
  301. t = plain.reorder_if_differ_in(desc);
  302. }
  303. t.set_scale(scales);
  304. }
  305. }
  306. // ONEDNN requires symmetric quantization of weight
  307. // Use this util function to check.
  308. static bool is_weight_symmetric_quant(
  309. const at::Tensor& weight,
  310. bool is_transposed_conv) {
  311. bool is_symmetric = true;
  312. const auto qtype = weight.qscheme();
  313. if (qtype == c10::kPerTensorAffine) {
  314. is_symmetric &= (weight.q_zero_point() == 0);
  315. } else if (qtype == c10::kPerChannelAffine) {
  316. if (is_transposed_conv) {
  317. // This case is currently not supported in PyTorch
  318. // but we do not want to raise an error in this util function.
  319. is_symmetric = false;
  320. } else {
  321. auto output_channels = weight.size(0);
  322. for (int i = 0; i < output_channels; ++i) {
  323. auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
  324. is_symmetric &= (zp == 0);
  325. }
  326. }
  327. } else {
  328. // This case is currently not supported in PyTorch
  329. // but we do not want to raise an error in this util function.
  330. is_symmetric = false;
  331. }
  332. return is_symmetric;
  333. }
  334. // When qengine is x86, use this util func to check if onednn kernel
  335. // is preferred than fbgemm's to get better performance.
  336. static bool should_use_onednn_quant(
  337. const at::Tensor& weight,
  338. bool is_transposed_conv,
  339. int groups,
  340. torch::List<int64_t> output_padding) {
  341. // Performance of onednn is only validated on Linux right now.
  342. // Also, the heuristics for dispatching are based on perf data on Linux.
  343. // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
  344. // TODO Support more OSs.
  345. #if !defined(__linux__)
  346. return false;
  347. #else
  348. bool vnni_available = cpuinfo_has_x86_avx512vnni();
  349. bool w_sym_quant =
  350. is_weight_symmetric_quant(weight, is_transposed_conv);
  351. bool opad_all_zero =
  352. std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
  353. return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
  354. #endif
  355. }
  356. } // onednn_utils
  357. #endif // #if AT_MKLDNN_ENABLED()