vec256.h 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #pragma once
  2. // DO NOT DEFINE STATIC DATA IN THIS HEADER!
  3. // See Note [Do not compile initializers with AVX]
  4. #include <ATen/cpu/vec/intrinsics.h>
  5. #include <ATen/cpu/vec/vec_base.h>
  6. #if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR))
  7. #include <ATen/cpu/vec/vec256/vec256_float.h>
  8. #include <ATen/cpu/vec/vec256/vec256_float_neon.h>
  9. #include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
  10. #include <ATen/cpu/vec/vec256/vec256_double.h>
  11. #include <ATen/cpu/vec/vec256/vec256_int.h>
  12. #include <ATen/cpu/vec/vec256/vec256_qint.h>
  13. #include <ATen/cpu/vec/vec256/vec256_complex_float.h>
  14. #include <ATen/cpu/vec/vec256/vec256_complex_double.h>
  15. #elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX)
  16. #include <ATen/cpu/vec/vec256/vsx/vec256_common_vsx.h>
  17. #else
  18. #include <ATen/cpu/vec/vec256/zarch/vec256_zarch.h>
  19. #include <ATen/cpu/vec/vec256/vec256_bfloat16.h>
  20. #endif
  21. #include <algorithm>
  22. #include <cstddef>
  23. #include <cstdint>
  24. #include <cstring>
  25. #include <ostream>
  26. namespace at {
  27. namespace vec {
  28. // Note [CPU_CAPABILITY namespace]
  29. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  30. // This header, and all of its subheaders, will be compiled with
  31. // different architecture flags for each supported set of vector
  32. // intrinsics. So we need to make sure they aren't inadvertently
  33. // linked together. We do this by declaring objects in an `inline
  34. // namespace` which changes the name mangling, but can still be
  35. // accessed as `at::vec`.
  36. inline namespace CPU_CAPABILITY {
  37. inline std::ostream& operator<<(std::ostream& stream, const c10::qint32& val) {
  38. stream << val.val_;
  39. return stream;
  40. }
  41. inline std::ostream& operator<<(std::ostream& stream, const c10::qint8& val) {
  42. stream << static_cast<int>(val.val_);
  43. return stream;
  44. }
  45. inline std::ostream& operator<<(std::ostream& stream, const c10::quint8& val) {
  46. stream << static_cast<unsigned int>(val.val_);
  47. return stream;
  48. }
  49. template <typename T>
  50. std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
  51. T buf[Vectorized<T>::size()];
  52. vec.store(buf);
  53. stream << "vec[";
  54. for (int i = 0; i != Vectorized<T>::size(); i++) {
  55. if (i != 0) {
  56. stream << ", ";
  57. }
  58. stream << buf[i];
  59. }
  60. stream << "]";
  61. return stream;
  62. }
  63. #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
  64. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  65. template<>
  66. inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
  67. return _mm256_castpd_ps(src);
  68. }
  69. template<>
  70. inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
  71. return _mm256_castps_pd(src);
  72. }
  73. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  74. template<int64_t scale = 1>
  75. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
  76. inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
  77. return _mm256_i64gather_pd(base_addr, vindex, scale);
  78. }
  79. template<int64_t scale = 1>
  80. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
  81. inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
  82. return _mm256_i32gather_ps(base_addr, vindex, scale);
  83. }
  84. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  85. template<int64_t scale = 1>
  86. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
  87. inline mask_gather(const Vectorized<double>& src, const double* base_addr,
  88. const Vectorized<int64_t>& vindex, const Vectorized<double>& mask) {
  89. return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
  90. }
  91. template<int64_t scale = 1>
  92. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
  93. inline mask_gather(const Vectorized<float>& src, const float* base_addr,
  94. const Vectorized<int32_t>& vindex, const Vectorized<float>& mask) {
  95. return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
  96. }
  97. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  98. // Only works for inputs in the range: [-2^51, 2^51]
  99. // From: https://stackoverflow.com/a/41148578
  100. template<>
  101. Vectorized<int64_t>
  102. inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
  103. auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
  104. return _mm256_sub_epi64(
  105. _mm256_castpd_si256(x),
  106. _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
  107. );
  108. }
  109. template<>
  110. Vectorized<int32_t>
  111. inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
  112. return _mm256_cvttps_epi32(src);
  113. }
  114. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  115. template <>
  116. std::pair<Vectorized<double>, Vectorized<double>>
  117. inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
  118. // inputs:
  119. // a = {a0, a1, a3, a3}
  120. // b = {b0, b1, b2, b3}
  121. // swap lanes:
  122. // a_swapped = {a0, a1, b0, b1}
  123. // b_swapped = {a2, a3, b2, b3}
  124. auto a_swapped = _mm256_permute2f128_pd(a, b, 0b0100000); // 0, 2. 4 bits apart
  125. auto b_swapped = _mm256_permute2f128_pd(a, b, 0b0110001); // 1, 3. 4 bits apart
  126. // group cols crossing lanes:
  127. // return {a0, b0, a1, b1}
  128. // {a2, b2, a3, b3}
  129. return std::make_pair(_mm256_permute4x64_pd(a_swapped, 0b11011000), // 0, 2, 1, 3
  130. _mm256_permute4x64_pd(b_swapped, 0b11011000)); // 0, 2, 1, 3
  131. }
  132. template <>
  133. std::pair<Vectorized<float>, Vectorized<float>>
  134. inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
  135. // inputs:
  136. // a = {a0, a1, a2, a3, a4, a5, a6, a7}
  137. // b = {b0, b1, b2, b3, b4, b5, b6, b7}
  138. // swap lanes:
  139. // a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
  140. // b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
  141. // TODO: can we support caching this?
  142. auto a_swapped = _mm256_permute2f128_ps(a, b, 0b0100000); // 0, 2. 4 bits apart
  143. auto b_swapped = _mm256_permute2f128_ps(a, b, 0b0110001); // 1, 3. 4 bits apart
  144. // group cols crossing lanes:
  145. // return {a0, b0, a1, b1, a2, b2, a3, b3}
  146. // {a4, b4, a5, b5, a6, b6, a7, b7}
  147. const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
  148. return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
  149. _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
  150. }
  151. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  152. template <>
  153. std::pair<Vectorized<double>, Vectorized<double>>
  154. inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
  155. // inputs:
  156. // a = {a0, b0, a1, b1}
  157. // b = {a2, b2, a3, b3}
  158. // group cols crossing lanes:
  159. // a_grouped = {a0, a1, b0, b1}
  160. // b_grouped = {a2, a3, b2, b3}
  161. auto a_grouped = _mm256_permute4x64_pd(a, 0b11011000); // 0, 2, 1, 3
  162. auto b_grouped = _mm256_permute4x64_pd(b, 0b11011000); // 0, 2, 1, 3
  163. // swap lanes:
  164. // return {a0, a1, a2, a3}
  165. // {b0, b1, b2, b3}
  166. return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart
  167. _mm256_permute2f128_pd(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
  168. }
  169. template <>
  170. std::pair<Vectorized<float>, Vectorized<float>>
  171. inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
  172. // inputs:
  173. // a = {a0, b0, a1, b1, a2, b2, a3, b3}
  174. // b = {a4, b4, a5, b5, a6, b6, a7, b7}
  175. // group cols crossing lanes:
  176. // a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3}
  177. // b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7}
  178. // TODO: can we support caching this?
  179. const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
  180. auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
  181. auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
  182. // swap lanes:
  183. // return {a0, a1, a2, a3, a4, a5, a6, a7}
  184. // {b0, b1, b2, b3, b4, b5, b6, b7}
  185. return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, 0b0100000), // 0, 2. 4 bits apart
  186. _mm256_permute2f128_ps(a_grouped, b_grouped, 0b0110001)); // 1, 3. 4 bits apart
  187. }
  188. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLIP ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  189. template<>
  190. inline Vectorized<float> flip(const Vectorized<float> & v) {
  191. const __m256i mask_float = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
  192. return _mm256_permutevar8x32_ps(v, mask_float);
  193. }
  194. template<>
  195. inline Vectorized<double> flip(const Vectorized<double> & v) {
  196. return _mm256_permute4x64_pd(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3)
  197. }
  198. template<>
  199. inline Vectorized<int64_t> flip(const Vectorized<int64_t> & v) {
  200. return _mm256_permute4x64_epi64(v, 27); // 27 == _MM_SHUFFLE(0, 1, 2, 3)
  201. }
  202. template<>
  203. inline Vectorized<int32_t> flip(const Vectorized<int32_t> & v) {
  204. const __m256i mask_int32 = _mm256_set_epi32(0, 1, 2, 3, 4, 5, 6, 7);
  205. return _mm256_permutevar8x32_epi32(v, mask_int32);
  206. }
  207. template<>
  208. inline Vectorized<int16_t> flip(const Vectorized<int16_t> & v) {
  209. const __m256i mask = _mm256_set_epi8(
  210. 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14,
  211. 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14
  212. );
  213. auto reversed = _mm256_shuffle_epi8(v, mask);
  214. return _mm256_permute2x128_si256(reversed, reversed, 1);
  215. }
  216. inline __m256i flip8(const __m256i & v) {
  217. const __m256i mask_int8 = _mm256_set_epi8(
  218. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
  219. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
  220. );
  221. auto reversed = _mm256_shuffle_epi8(v, mask_int8);
  222. return _mm256_permute2x128_si256(reversed, reversed, 1);
  223. }
  224. template<>
  225. inline Vectorized<int8_t> flip(const Vectorized<int8_t> & v) {
  226. return flip8(v);
  227. }
  228. template<>
  229. inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
  230. return flip8(v);
  231. }
  232. #endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
  233. }}}