functional_base.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #pragma once
  2. // DO NOT DEFINE STATIC DATA IN THIS HEADER!
  3. // See Note [Do not compile initializers with AVX]
  4. #include <ATen/cpu/vec/vec.h>
  5. #include <c10/util/irange.h>
  6. namespace at { namespace vec {
  7. // slow path
  8. template <typename scalar_t, typename Op>
  9. inline scalar_t vec_reduce_all(
  10. const Op& vec_fun,
  11. vec::Vectorized<scalar_t> acc_vec,
  12. int64_t size) {
  13. using Vec = vec::Vectorized<scalar_t>;
  14. scalar_t acc_arr[Vec::size()];
  15. acc_vec.store(acc_arr);
  16. for (const auto i : c10::irange(1, size)) {
  17. std::array<scalar_t, Vec::size()> acc_arr_next = {0};
  18. acc_arr_next[0] = acc_arr[i];
  19. Vec acc_vec_next = Vec::loadu(acc_arr_next.data());
  20. acc_vec = vec_fun(acc_vec, acc_vec_next);
  21. }
  22. acc_vec.store(acc_arr);
  23. return acc_arr[0];
  24. }
  25. template <typename scalar_t, typename Op>
  26. struct VecReduceAllSIMD {
  27. static inline scalar_t apply(const Op& vec_fun, Vectorized<scalar_t> acc_vec) {
  28. return vec_reduce_all(vec_fun, acc_vec, Vectorized<scalar_t>::size());
  29. }
  30. };
  31. #if defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
  32. #if defined(CPU_CAPABILITY_AVX2)
  33. template <typename Op>
  34. struct VecReduceAllSIMD<float, Op> {
  35. static inline float apply(const Op& vec_fun, Vectorized<float> acc_vec) {
  36. using Vec = Vectorized<float>;
  37. Vec v = acc_vec;
  38. // 128-bit shuffle
  39. Vec v1 = _mm256_permute2f128_ps(v, v, 0x1);
  40. v = vec_fun(v, v1);
  41. // 64-bit shuffle
  42. v1 = _mm256_shuffle_ps(v, v, 0x4E);
  43. v = vec_fun(v, v1);
  44. // 32-bit shuffle
  45. v1 = _mm256_shuffle_ps(v, v, 0xB1);
  46. v = vec_fun(v, v1);
  47. return _mm256_cvtss_f32(v);
  48. }
  49. };
  50. #endif // defined(CPU_CAPABILITY_AVX2)
  51. #if defined(CPU_CAPABILITY_AVX512)
  52. template <typename Op>
  53. struct VecReduceAllSIMD<float, Op> {
  54. static inline float apply(const Op& vec_fun, Vectorized<float> acc_vec) {
  55. using Vec = Vectorized<float>;
  56. Vec v = acc_vec;
  57. // 256-bit shuffle
  58. Vec v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
  59. v = vec_fun(v, v1);
  60. // 128-bit shuffle
  61. v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
  62. v = vec_fun(v, v1);
  63. // 64-bit shuffle
  64. v1 = _mm512_shuffle_ps(v, v, 0x4E);
  65. v = vec_fun(v, v1);
  66. // 32-bit shuffle
  67. v1 = _mm512_shuffle_ps(v, v, 0xB1);
  68. v = vec_fun(v, v1);
  69. return _mm512_cvtss_f32(v);
  70. }
  71. };
  72. #endif // defined(CPU_CAPABILITY_AVX512)
  73. #endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
  74. template <typename scalar_t, typename Op>
  75. inline scalar_t vec_reduce_all(const Op& vec_fun, Vectorized<scalar_t> acc_vec) {
  76. return VecReduceAllSIMD<scalar_t, Op>::apply(vec_fun, acc_vec);
  77. }
  78. template <typename scalar_t, typename Op>
  79. inline scalar_t reduce_all(const Op& vec_fun, const scalar_t* data, int64_t size) {
  80. using Vec = vec::Vectorized<scalar_t>;
  81. if (size < Vec::size())
  82. return vec_reduce_all(vec_fun, Vec::loadu(data, size), size);
  83. int64_t d = Vec::size();
  84. Vec acc_vec = Vec::loadu(data);
  85. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  86. Vec data_vec = Vec::loadu(data + d);
  87. acc_vec = vec_fun(acc_vec, data_vec);
  88. }
  89. if (size - d > 0) {
  90. Vec data_vec = Vec::loadu(data + d, size - d);
  91. acc_vec = Vec::set(acc_vec, vec_fun(acc_vec, data_vec), size - d);
  92. }
  93. return vec_reduce_all(vec_fun, acc_vec);
  94. }
  95. // similar to reduce_all, but reduces into two outputs
  96. template <typename scalar_t, typename Op1, typename Op2>
  97. inline std::pair<scalar_t, scalar_t> reduce2_all(const Op1& vec_fun1, const Op2& vec_fun2,
  98. const scalar_t* data, int64_t size) {
  99. using Vec = vec::Vectorized<scalar_t>;
  100. if (size < Vec::size()) {
  101. auto loaded_data = Vec::loadu(data, size);
  102. return std::pair<scalar_t, scalar_t>(
  103. vec_reduce_all(vec_fun1, loaded_data, size),
  104. vec_reduce_all(vec_fun2, loaded_data, size));
  105. }
  106. int64_t d = Vec::size();
  107. Vec acc_vec1 = Vec::loadu(data);
  108. Vec acc_vec2 = Vec::loadu(data);
  109. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  110. Vec data_vec = Vec::loadu(data + d);
  111. acc_vec1 = vec_fun1(acc_vec1, data_vec);
  112. acc_vec2 = vec_fun2(acc_vec2, data_vec);
  113. }
  114. if (size - d > 0) {
  115. Vec data_vec = Vec::loadu(data + d, size - d);
  116. acc_vec1 = Vec::set(acc_vec1, vec_fun1(acc_vec1, data_vec), size - d);
  117. acc_vec2 = Vec::set(acc_vec2, vec_fun2(acc_vec2, data_vec), size - d);
  118. }
  119. return std::pair<scalar_t, scalar_t>(
  120. vec_reduce_all(vec_fun1, acc_vec1),
  121. vec_reduce_all(vec_fun2, acc_vec2));
  122. }
  123. template <typename scalar_t, typename MapOp, typename ReduceOp>
  124. inline scalar_t map_reduce_all(
  125. const MapOp& map_fun,
  126. const ReduceOp& red_fun,
  127. const scalar_t* data,
  128. int64_t size) {
  129. using Vec = vec::Vectorized<scalar_t>;
  130. if (size < Vec::size())
  131. return vec_reduce_all(red_fun, map_fun(Vec::loadu(data, size)), size);
  132. int64_t d = Vec::size();
  133. Vec acc_vec = map_fun(Vec::loadu(data));
  134. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  135. Vec data_vec = Vec::loadu(data + d);
  136. data_vec = map_fun(data_vec);
  137. acc_vec = red_fun(acc_vec, data_vec);
  138. }
  139. if (size - d > 0) {
  140. Vec data_vec = Vec::loadu(data + d, size - d);
  141. data_vec = map_fun(data_vec);
  142. acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
  143. }
  144. return vec_reduce_all(red_fun, acc_vec);
  145. }
  146. template <typename scalar_t, typename MapOp, typename ReduceOp>
  147. inline scalar_t map2_reduce_all(
  148. const MapOp& map_fun,
  149. const ReduceOp& red_fun,
  150. const scalar_t* data,
  151. const scalar_t* data2,
  152. int64_t size) {
  153. using Vec = vec::Vectorized<scalar_t>;
  154. if (size < Vec::size()) {
  155. Vec data_vec = Vec::loadu(data, size);
  156. Vec data2_vec = Vec::loadu(data2, size);
  157. data_vec = map_fun(data_vec, data2_vec);
  158. return vec_reduce_all(red_fun, data_vec, size);
  159. }
  160. int64_t d = Vec::size();
  161. Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2));
  162. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  163. Vec data_vec = Vec::loadu(data + d);
  164. Vec data2_vec = Vec::loadu(data2 + d);
  165. data_vec = map_fun(data_vec, data2_vec);
  166. acc_vec = red_fun(acc_vec, data_vec);
  167. }
  168. if (size - d > 0) {
  169. Vec data_vec = Vec::loadu(data + d, size - d);
  170. Vec data2_vec = Vec::loadu(data2 + d, size - d);
  171. data_vec = map_fun(data_vec, data2_vec);
  172. acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
  173. }
  174. return vec_reduce_all(red_fun, acc_vec);
  175. }
  176. template <typename scalar_t, typename MapOp, typename ReduceOp>
  177. inline scalar_t map3_reduce_all(
  178. const MapOp& map_fun,
  179. const ReduceOp& red_fun,
  180. const scalar_t* data,
  181. const scalar_t* data2,
  182. const scalar_t* data3,
  183. int64_t size) {
  184. using Vec = vec::Vectorized<scalar_t>;
  185. if (size < Vec::size()) {
  186. Vec data_vec = Vec::loadu(data, size);
  187. Vec data2_vec = Vec::loadu(data2, size);
  188. Vec data3_vec = Vec::loadu(data3, size);
  189. data_vec = map_fun(data_vec, data2_vec, data3_vec);
  190. return vec_reduce_all(red_fun, data_vec, size);
  191. }
  192. int64_t d = Vec::size();
  193. Vec acc_vec = map_fun(Vec::loadu(data), Vec::loadu(data2), Vec::loadu(data3));
  194. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  195. Vec data_vec = Vec::loadu(data + d);
  196. Vec data2_vec = Vec::loadu(data2 + d);
  197. Vec data3_vec = Vec::loadu(data3 + d);
  198. data_vec = map_fun(data_vec, data2_vec, data3_vec);
  199. acc_vec = red_fun(acc_vec, data_vec);
  200. }
  201. if (size - d > 0) {
  202. Vec data_vec = Vec::loadu(data + d, size - d);
  203. Vec data2_vec = Vec::loadu(data2 + d, size - d);
  204. Vec data3_vec = Vec::loadu(data3 + d, size - d);
  205. data_vec = map_fun(data_vec, data2_vec, data3_vec);
  206. acc_vec = Vec::set(acc_vec, red_fun(acc_vec, data_vec), size - d);
  207. }
  208. return vec_reduce_all(red_fun, acc_vec);
  209. }
  210. template <typename scalar_t, typename Op>
  211. inline void map(
  212. const Op& vec_fun,
  213. scalar_t* output_data,
  214. const scalar_t* input_data,
  215. int64_t size) {
  216. using Vec = vec::Vectorized<scalar_t>;
  217. int64_t d = 0;
  218. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  219. Vec output_vec = vec_fun(Vec::loadu(input_data + d));
  220. output_vec.store(output_data + d);
  221. }
  222. if (size - d > 0) {
  223. Vec output_vec = vec_fun(Vec::loadu(input_data + d, size - d));
  224. output_vec.store(output_data + d, size - d);
  225. }
  226. }
  227. template <typename scalar_t, typename Op>
  228. inline void map2(
  229. const Op& vec_fun,
  230. scalar_t* output_data,
  231. const scalar_t* input_data,
  232. const scalar_t* input_data2,
  233. int64_t size) {
  234. using Vec = vec::Vectorized<scalar_t>;
  235. int64_t d = 0;
  236. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  237. Vec data_vec = Vec::loadu(input_data + d);
  238. Vec data_vec2 = Vec::loadu(input_data2 + d);
  239. Vec output_vec = vec_fun(data_vec, data_vec2);
  240. output_vec.store(output_data + d);
  241. }
  242. if (size - d > 0) {
  243. Vec data_vec = Vec::loadu(input_data + d, size - d);
  244. Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
  245. Vec output_vec = vec_fun(data_vec, data_vec2);
  246. output_vec.store(output_data + d, size - d);
  247. }
  248. }
  249. template <typename scalar_t, typename Op>
  250. inline void map3(
  251. const Op& vec_fun,
  252. scalar_t* output_data,
  253. const scalar_t* input_data1,
  254. const scalar_t* input_data2,
  255. const scalar_t* input_data3,
  256. int64_t size) {
  257. using Vec = vec::Vectorized<scalar_t>;
  258. int64_t d = 0;
  259. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  260. Vec data_vec1 = Vec::loadu(input_data1 + d);
  261. Vec data_vec2 = Vec::loadu(input_data2 + d);
  262. Vec data_vec3 = Vec::loadu(input_data3 + d);
  263. Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
  264. output_vec.store(output_data + d);
  265. }
  266. if (size - d > 0) {
  267. Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
  268. Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
  269. Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
  270. Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3);
  271. output_vec.store(output_data + d, size - d);
  272. }
  273. }
  274. template <typename scalar_t, typename Op>
  275. inline void map4(
  276. const Op& vec_fun,
  277. scalar_t* output_data,
  278. const scalar_t* input_data1,
  279. const scalar_t* input_data2,
  280. const scalar_t* input_data3,
  281. const scalar_t* input_data4,
  282. int64_t size) {
  283. using Vec = vec::Vectorized<scalar_t>;
  284. int64_t d = 0;
  285. for (; d < size - (size % Vec::size()); d += Vec::size()) {
  286. Vec data_vec1 = Vec::loadu(input_data1 + d);
  287. Vec data_vec2 = Vec::loadu(input_data2 + d);
  288. Vec data_vec3 = Vec::loadu(input_data3 + d);
  289. Vec data_vec4 = Vec::loadu(input_data4 + d);
  290. Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
  291. output_vec.store(output_data + d);
  292. }
  293. if (size - d > 0) {
  294. Vec data_vec1 = Vec::loadu(input_data1 + d, size - d);
  295. Vec data_vec2 = Vec::loadu(input_data2 + d, size - d);
  296. Vec data_vec3 = Vec::loadu(input_data3 + d, size - d);
  297. Vec data_vec4 = Vec::loadu(input_data4 + d, size - d);
  298. Vec output_vec = vec_fun(data_vec1, data_vec2, data_vec3, data_vec4);
  299. output_vec.store(output_data + d, size - d);
  300. }
  301. }
  302. }} // namespace at::vec