vec512.h 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #pragma once
  2. // DO NOT DEFINE STATIC DATA IN THIS HEADER!
  3. // See Note [Do not compile initializers with AVX]
  4. #include <ATen/cpu/vec/intrinsics.h>
  5. #include <ATen/cpu/vec/vec_base.h>
  6. #include <ATen/cpu/vec/vec512/vec512_float.h>
  7. #include <ATen/cpu/vec/vec512/vec512_bfloat16.h>
  8. #include <ATen/cpu/vec/vec512/vec512_double.h>
  9. #include <ATen/cpu/vec/vec512/vec512_int.h>
  10. #include <ATen/cpu/vec/vec512/vec512_qint.h>
  11. #include <ATen/cpu/vec/vec512/vec512_complex_float.h>
  12. #include <ATen/cpu/vec/vec512/vec512_complex_double.h>
  13. #include <algorithm>
  14. #include <cstddef>
  15. #include <cstdint>
  16. #include <cstring>
  17. #include <iostream>
  18. namespace at {
  19. namespace vec {
  20. // See Note [CPU_CAPABILITY namespace]
  21. inline namespace CPU_CAPABILITY {
  22. inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
  23. stream << val.val_;
  24. return stream;
  25. }
  26. inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
  27. stream << static_cast<int>(val.val_);
  28. return stream;
  29. }
  30. inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
  31. stream << static_cast<unsigned int>(val.val_);
  32. return stream;
  33. }
  34. template <typename T>
  35. std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
  36. T buf[Vectorized<T>::size()];
  37. vec.store(buf);
  38. stream << "vec[";
  39. for (int i = 0; i != Vectorized<T>::size(); i++) {
  40. if (i != 0) {
  41. stream << ", ";
  42. }
  43. stream << buf[i];
  44. }
  45. stream << "]";
  46. return stream;
  47. }
  48. #if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
  49. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  50. template<>
  51. inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
  52. return _mm512_castpd_ps(src);
  53. }
  54. template<>
  55. inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
  56. return _mm512_castps_pd(src);
  57. }
  58. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  59. template<int64_t scale = 1>
  60. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
  61. inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
  62. return _mm512_i64gather_pd(vindex, base_addr, scale);
  63. }
  64. template<int64_t scale = 1>
  65. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
  66. inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
  67. return _mm512_i32gather_ps(vindex, base_addr, scale);
  68. }
  69. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  70. template<int64_t scale = 1>
  71. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
  72. inline mask_gather(const Vectorized<double>& src, const double* base_addr,
  73. const Vectorized<int64_t>& vindex, const Vectorized<double>& mask) {
  74. auto all_ones = _mm512_castsi512_pd(_mm512_set1_epi64(0xFFFFFFFFFFFFFFFF));
  75. auto mask_ = _mm512_cmp_pd_mask(all_ones, mask.values, _CMP_EQ_OQ);
  76. return _mm512_mask_i64gather_pd(src, mask_, vindex, base_addr, scale);
  77. }
  78. template<int64_t scale = 1>
  79. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
  80. inline mask_gather(const Vectorized<float>& src, const float* base_addr,
  81. const Vectorized<int32_t>& vindex, const Vectorized<float>& mask) {
  82. auto all_ones = _mm512_castsi512_ps(_mm512_set1_epi32(0xFFFFFFFF));
  83. auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
  84. return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
  85. }
  86. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  87. template<>
  88. Vectorized<int64_t>
  89. inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
  90. return _mm512_cvtpd_epi64(src);
  91. }
  92. template<>
  93. Vectorized<int32_t>
  94. inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
  95. return _mm512_cvttps_epi32(src);
  96. }
  97. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  98. template <>
  99. std::pair<Vectorized<double>, Vectorized<double>>
  100. inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
  101. // inputs:
  102. // a = {a0, a1, a3, a3, a4, a5, a6, a7}
  103. // b = {b0, b1, b2, b3, b4, b5, b6, b7}
  104. // group cols crossing lanes:
  105. // return {a0, b0, a1, b1, a2, b2, a3, b3}
  106. // {a4, b4, a5, b5, a6, b6, a7, b7}
  107. __m512i idx1 = _mm512_set_epi64(11, 3, 10, 2, 9, 1, 8, 0);
  108. __m512i idx2 = _mm512_set_epi64(15, 7, 14, 6, 13, 5, 12, 4);
  109. return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
  110. _mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
  111. }
  112. template <>
  113. std::pair<Vectorized<float>, Vectorized<float>>
  114. inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
  115. // inputs:
  116. // a = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
  117. // b = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
  118. //
  119. // return:
  120. // {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
  121. // {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
  122. __m512i idx1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4,
  123. 19, 3, 18, 2, 17, 1, 16, 0);
  124. __m512i idx2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12,
  125. 27, 11, 26, 10, 25, 9, 24, 8);
  126. return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
  127. _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
  128. }
  129. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  130. template <>
  131. std::pair<Vectorized<double>, Vectorized<double>>
  132. inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
  133. // inputs:
  134. // a = {a0, b0, a1, b1, a2, b2, a3, b3}
  135. // b = {a4, b4, a5, b5, a6, b6, a7, b7}
  136. // output:
  137. // return {a0, a1, a2, a3, a4, a5, a6, a7}
  138. // {b0, b1, b2, b3, b4, b5, b6, b7}
  139. // The members of indices have been written in binary format for better understandability
  140. __m512i idx1 = _mm512_set_epi64(14, 12, 10, 8, 6, 4, 2, 0);
  141. __m512i idx2 = _mm512_set_epi64(15, 13, 11, 9, 7, 5, 3, 1);
  142. return std::make_pair(_mm512_mask_permutex2var_pd(a, 0xff, idx1, b),
  143. _mm512_mask_permutex2var_pd(a, 0xff, idx2, b));
  144. }
  145. template <>
  146. std::pair<Vectorized<float>, Vectorized<float>>
  147. inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
  148. // inputs:
  149. // a = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
  150. // b = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
  151. // output:
  152. // return {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
  153. // {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
  154. __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16,
  155. 14, 12, 10, 8, 6, 4, 2, 0);
  156. __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17,
  157. 15, 13, 11, 9, 7, 5, 3, 1);
  158. return std::make_pair(_mm512_mask_permutex2var_ps(a, 0xffff, idx1, b),
  159. _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b));
  160. }
  161. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  162. template<>
  163. inline Vectorized<float> flip(const Vectorized<float> & v) {
  164. const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7,
  165. 8, 9, 10, 11, 12, 13, 14, 15);
  166. return _mm512_permutexvar_ps(mask, v);
  167. }
  168. template<>
  169. inline Vectorized<double> flip(const Vectorized<double> & v) {
  170. const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
  171. return _mm512_permutexvar_pd(mask, v);
  172. }
  173. template<>
  174. inline Vectorized<int64_t> flip(const Vectorized<int64_t> & v) {
  175. const __m512i mask = _mm512_set_epi64(0, 1, 2, 3, 4, 5, 6, 7);
  176. return _mm512_permutexvar_epi64(mask, v);
  177. }
  178. template<>
  179. inline Vectorized<int32_t> flip(const Vectorized<int32_t> & v) {
  180. const __m512i mask = _mm512_set_epi32(0, 1, 2, 3, 4, 5, 6, 7,
  181. 8, 9, 10, 11, 12, 13, 14, 15);
  182. return _mm512_permutexvar_epi32(mask, v);
  183. }
  184. template<>
  185. inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
  186. const __m512i mask = _mm512_set_epi16(
  187. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  188. 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
  189. );
  190. return _mm512_permutexvar_epi16(mask, v);
  191. }
  192. inline __m512i flip8(const __m512i & v) {
  193. const __m512i mask1 = _mm512_set_epi8(
  194. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  195. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  196. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  197. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
  198. );
  199. const __m512i mask2 = _mm512_set_epi64(1, 0, 3, 2, 5, 4, 7, 6);
  200. auto reversed_vec = _mm512_shuffle_epi8(v, mask1);
  201. return _mm512_permutexvar_epi64(mask2, reversed_vec);
  202. }
  203. template<>
  204. inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
  205. return flip8(v);
  206. }
  207. template<>
  208. inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
  209. return flip8(v);
  210. }
  211. #endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
  212. }}}