vec_base.h 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044
  1. #pragma once
  2. // DO NOT DEFINE STATIC DATA IN THIS HEADER!
  3. // See Note [Do not compile initializers with AVX]
  4. //
  5. // Note [Do not compile initializers with AVX]
  6. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  7. // If you define a static initializer in this file, the initialization will use
  8. // AVX instructions because these object files are compiled with AVX enabled.
  9. // We need to avoid non-trivial global data in these architecture specific files
  10. // because there's no way to guard the global initializers with CPU capability
  11. // detection.
  12. //
  13. // See https://github.com/pytorch/pytorch/issues/37577 for an instance
  14. // of this bug in the past.
  15. #include <cassert>
  16. #include <cstring>
  17. #include <functional>
  18. #include <cmath>
  19. #include <type_traits>
  20. #include <bitset>
  21. #include <ATen/cpu/vec/intrinsics.h>
  22. #include <ATen/native/Math.h>
  23. #include <ATen/NumericUtils.h>
  24. #include <c10/util/C++17.h>
  25. #include <c10/util/BFloat16.h>
  26. #include <c10/util/BFloat16-math.h>
  27. #include <c10/util/copysign.h>
  28. #include <c10/util/math_compat.h>
  29. #include <ATen/native/cpu/zmath.h>
  30. #include <c10/util/TypeCast.h>
  31. #include <c10/macros/Macros.h>
  32. #include <c10/util/irange.h>
  33. #include <c10/util/Load.h>
  34. // These macros helped us unify vec_base.h
  35. #ifdef CPU_CAPABILITY_AVX512
  36. #if defined(__GNUC__)
  37. #define __at_align__ __attribute__((aligned(64)))
  38. #elif defined(_WIN32)
  39. #define __at_align__ __declspec(align(64))
  40. #else
  41. #define __at_align__
  42. #endif
  43. #define VECTOR_WIDTH 64
  44. #define int_vector __m512i
  45. #else // CPU_CAPABILITY_AVX512
  46. #if defined(__GNUC__)
  47. #define __at_align__ __attribute__((aligned(32)))
  48. #elif defined(_WIN32)
  49. #define __at_align__ __declspec(align(32))
  50. #else
  51. #define __at_align__
  52. #endif
  53. #define VECTOR_WIDTH 32
  54. #define int_vector __m256i
  55. #endif // CPU_CAPABILITY_AVX512
  56. namespace at {
  57. namespace vec {
  58. // See Note [CPU_CAPABILITY namespace]
  59. inline namespace CPU_CAPABILITY {
  60. // at::Half and at::BFloat16 should be treated as floating point
  61. template <typename T>
  62. struct is_floating_point:
  63. std::integral_constant<bool,
  64. std::is_floating_point<T>::value ||
  65. std::is_same<T, at::Half>::value ||
  66. std::is_same<T, at::BFloat16>::value> {
  67. };
  68. template<size_t n> struct int_of_size;
  69. #define DEFINE_INT_OF_SIZE(int_t) \
  70. template<> struct int_of_size<sizeof(int_t)> { using type = int_t; }
  71. DEFINE_INT_OF_SIZE(int64_t);
  72. DEFINE_INT_OF_SIZE(int32_t);
  73. DEFINE_INT_OF_SIZE(int16_t);
  74. DEFINE_INT_OF_SIZE(int8_t);
  75. #undef DEFINE_INT_OF_SIZE
  76. template <typename T>
  77. using int_same_size_t = typename int_of_size<sizeof(T)>::type;
  78. // NOTE: If you specialize on a type, you must define all operations!
  79. // emulates Vectorized types
  80. #if defined(__s390x__)
  81. template <class T, class TEMP=void>
  82. #else
  83. template <class T>
  84. #endif
  85. struct Vectorized {
  86. private:
  87. __at_align__ T values[VECTOR_WIDTH / sizeof(T)];
  88. public:
  89. using value_type = T;
  90. using size_type = int;
  91. // Note [constexpr static function to avoid odr-usage compiler bug]
  92. // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  93. // Why, you might ask, is size defined to be a static constexpr function,
  94. // rather than a more ordinary 'static constexpr int size;' variable?
  95. // The problem lies within ODR rules for static constexpr members versus
  96. // static constexpr functions. First, recall that this class (along with all
  97. // of its derivations) live in an anonymous namespace: they are intended to be
  98. // *completely* inlined at their use-sites, because we need to compile it
  99. // multiple times for different instruction sets.
  100. //
  101. // Because of this constraint, we CANNOT provide a single definition for
  102. // any static members in this class; since we want to compile the class
  103. // multiple times, there wouldn't actually be any good place to put the
  104. // definition. Now here is the problem: if we ODR-use a static constexpr
  105. // member, we are *obligated* to provide a definition. Without the
  106. // definition, you get a compile error like:
  107. //
  108. // relocation R_X86_64_PC32 against undefined symbol
  109. // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making
  110. // a shared object; recompile with -fPIC
  111. //
  112. // If this were C++17, we could replace a static constexpr variable with
  113. // an inline variable which doesn't require one definition. But we are not
  114. // C++17. So the next best thing is to replace the member with a static
  115. // constexpr (and therefore inline) function, which does not require ODR
  116. // either.
  117. //
  118. // Also, technically according to the C++ standard, we don't have to define
  119. // a constexpr variable if we never odr-use it. But it seems that some
  120. // versions GCC/Clang have buggy determinations on whether or not an
  121. // identifier is odr-used or not, and in any case it's hard to tell if
  122. // a variable is odr-used or not. So best to just cut the problem at the root.
  123. static constexpr size_type size_T = sizeof(T); // Workaround to compile with VS2022.
  124. static constexpr size_type size() {
  125. return VECTOR_WIDTH / size_T;
  126. }
  127. Vectorized() : values{static_cast<T>(0)} {}
  128. Vectorized(T val) {
  129. for (int i = 0; i != size(); i++) {
  130. values[i] = val;
  131. }
  132. }
  133. template<typename... Args,
  134. typename = std::enable_if_t<(sizeof...(Args) == size())>>
  135. Vectorized(Args... vals) : values{vals...}{
  136. }
  137. // This also implies const T& operator[](int idx) const
  138. inline operator const T*() const {
  139. return values;
  140. }
  141. // This also implies T& operator[](int idx)
  142. inline operator T*() {
  143. return values;
  144. }
  145. // Return the values as char* for type punning
  146. auto as_bytes() const -> const char* {
  147. return reinterpret_cast<const char*>(values);
  148. }
  149. template <int64_t mask_>
  150. static Vectorized<T> blend(const Vectorized<T>& a, const Vectorized<T>& b) {
  151. int64_t mask = mask_;
  152. Vectorized vector;
  153. for (const auto i : c10::irange(size())) {
  154. if (mask & 0x01) {
  155. vector[i] = b[i];
  156. } else {
  157. vector[i] = a[i];
  158. }
  159. mask = mask >> 1;
  160. }
  161. return vector;
  162. }
  163. static Vectorized<T> blendv(const Vectorized<T>& a, const Vectorized<T>& b,
  164. const Vectorized<T>& mask) {
  165. Vectorized vector;
  166. int_same_size_t<T> buffer[size()];
  167. mask.store(buffer);
  168. for (const auto i : c10::irange(size())) {
  169. if (buffer[i] & 0x01)
  170. {
  171. vector[i] = b[i];
  172. } else {
  173. vector[i] = a[i];
  174. }
  175. }
  176. return vector;
  177. }
  178. template<typename step_t> // step sometimes requires a higher precision type (e.g., T=int, step_t=double)
  179. static Vectorized<T> arange(T base = static_cast<T>(0), step_t step = static_cast<step_t>(1)) {
  180. Vectorized vector;
  181. for (const auto i : c10::irange(size())) {
  182. vector.values[i] = base + i * step;
  183. }
  184. return vector;
  185. }
  186. static Vectorized<T> set(const Vectorized<T>& a, const Vectorized<T>& b, int64_t count = size()) {
  187. Vectorized vector;
  188. for (const auto i : c10::irange(size())) {
  189. if (i < count) {
  190. vector[i] = b[i];
  191. } else {
  192. vector[i] = a[i];
  193. }
  194. }
  195. return vector;
  196. }
  197. static Vectorized<T> loadu(const void* ptr) {
  198. Vectorized vector;
  199. std::memcpy(vector.values, ptr, VECTOR_WIDTH);
  200. return vector;
  201. }
  202. static Vectorized<T> loadu(const void* ptr, int64_t count) {
  203. Vectorized vector;
  204. std::memcpy(vector.values, ptr, count * sizeof(T));
  205. return vector;
  206. }
  207. void store(void* ptr, int count = size()) const {
  208. std::memcpy(ptr, values, count * sizeof(T));
  209. }
  210. int zero_mask() const {
  211. // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
  212. int mask = 0;
  213. for (int i = 0; i < size(); ++ i) {
  214. if (values[i] == static_cast<T>(0)) {
  215. mask |= (1 << i);
  216. }
  217. }
  218. return mask;
  219. }
  220. Vectorized<T> isnan() const {
  221. Vectorized<T> vector;
  222. for (int64_t i = 0; i != size(); i++) {
  223. if (_isnan(values[i])) {
  224. std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
  225. } else {
  226. std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
  227. }
  228. }
  229. return vector;
  230. }
  231. Vectorized<T> map(T (*const f)(T)) const {
  232. Vectorized<T> ret;
  233. for (int64_t i = 0; i != size(); i++) {
  234. ret[i] = f(values[i]);
  235. }
  236. return ret;
  237. }
  238. Vectorized<T> map(T (*const f)(const T &)) const {
  239. Vectorized<T> ret;
  240. for (int64_t i = 0; i != size(); i++) {
  241. ret[i] = f(values[i]);
  242. }
  243. return ret;
  244. }
  245. template <typename other_t_abs = T,
  246. typename std::enable_if<!is_floating_point<other_t_abs>::value && !c10::is_complex<other_t_abs>::value, int>::type = 0>
  247. Vectorized<T> abs() const {
  248. // other_t_abs is for SFINAE and clarity. Make sure it is not changed.
  249. static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
  250. return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
  251. }
  252. template <typename float_t_abs = T,
  253. typename std::enable_if<is_floating_point<float_t_abs>::value, int>::type = 0>
  254. Vectorized<T> abs() const {
  255. // float_t_abs is for SFINAE and clarity. Make sure it is not changed.
  256. static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
  257. // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in
  258. // 0.0) properly.
  259. return map([](T x) -> T { return std::abs(x); });
  260. }
  261. template <typename complex_t_abs = T,
  262. typename std::enable_if<c10::is_complex<complex_t_abs>::value, int>::type = 0>
  263. Vectorized<T> abs() const {
  264. // complex_t_abs is for SFINAE and clarity. Make sure it is not changed.
  265. static_assert(std::is_same<complex_t_abs, T>::value, "complex_t_abs must be T");
  266. // Specifically map() does not perform the type conversion needed by abs.
  267. return map([](T x) { return static_cast<T>(std::abs(x)); });
  268. }
  269. template <typename other_t_sgn = T,
  270. typename std::enable_if<c10::is_complex<other_t_sgn>::value, int>::type = 0>
  271. Vectorized<T> sgn() const {
  272. return map(at::native::sgn_impl);
  273. }
  274. template <typename other_t_angle = T,
  275. typename std::enable_if<!c10::is_complex<other_t_angle>::value, int>::type = 0>
  276. Vectorized<T> angle() const {
  277. // other_t_angle is for SFINAE and clarity. Make sure it is not changed.
  278. static_assert(std::is_same<other_t_angle, T>::value, "other_t_angle must be T");
  279. return map(at::native::angle_impl<T>); // compiler is unable to resolve the overload without <T>
  280. }
  281. template <typename complex_t_angle = T,
  282. typename std::enable_if<c10::is_complex<complex_t_angle>::value, int>::type = 0>
  283. Vectorized<T> angle() const {
  284. // complex_t_angle is for SFINAE and clarity. Make sure it is not changed.
  285. static_assert(std::is_same<complex_t_angle, T>::value, "complex_t_angle must be T");
  286. return map([](T x) { return static_cast<T>(std::arg(x)); });
  287. }
  288. template <typename other_t_real = T,
  289. typename std::enable_if<!c10::is_complex<other_t_real>::value, int>::type = 0>
  290. Vectorized<T> real() const {
  291. // other_t_real is for SFINAE and clarity. Make sure it is not changed.
  292. static_assert(std::is_same<other_t_real, T>::value, "other_t_real must be T");
  293. return *this;
  294. }
  295. template <typename complex_t_real = T,
  296. typename std::enable_if<c10::is_complex<complex_t_real>::value, int>::type = 0>
  297. Vectorized<T> real() const {
  298. // complex_t_real is for SFINAE and clarity. Make sure it is not changed.
  299. static_assert(std::is_same<complex_t_real, T>::value, "complex_t_real must be T");
  300. return map([](T x) { return static_cast<T>(x.real()); });
  301. }
  302. template <typename other_t_imag = T,
  303. typename std::enable_if<!c10::is_complex<other_t_imag>::value, int>::type = 0>
  304. Vectorized<T> imag() const {
  305. // other_t_imag is for SFINAE and clarity. Make sure it is not changed.
  306. static_assert(std::is_same<other_t_imag, T>::value, "other_t_imag must be T");
  307. return Vectorized(0);
  308. }
  309. template <typename complex_t_imag = T,
  310. typename std::enable_if<c10::is_complex<complex_t_imag>::value, int>::type = 0>
  311. Vectorized<T> imag() const {
  312. // complex_t_imag is for SFINAE and clarity. Make sure it is not changed.
  313. static_assert(std::is_same<complex_t_imag, T>::value, "complex_t_imag must be T");
  314. return map([](T x) { return static_cast<T>(x.imag()); });
  315. }
  316. template <typename other_t_conj = T,
  317. typename std::enable_if<!c10::is_complex<other_t_conj>::value, int>::type = 0>
  318. Vectorized<T> conj() const {
  319. // other_t_conj is for SFINAE and clarity. Make sure it is not changed.
  320. static_assert(std::is_same<other_t_conj, T>::value, "other_t_conj must be T");
  321. return *this;
  322. }
  323. template <typename complex_t_conj = T,
  324. typename std::enable_if<c10::is_complex<complex_t_conj>::value, int>::type = 0>
  325. Vectorized<T> conj() const {
  326. // complex_t_conj is for SFINAE and clarity. Make sure it is not changed.
  327. static_assert(std::is_same<complex_t_conj, T>::value, "complex_t_conj must be T");
  328. return map([](T x) { return static_cast<T>(std::conj(x)); });
  329. }
  330. Vectorized<T> acos() const {
  331. return map(std::acos);
  332. }
  333. Vectorized<T> asin() const {
  334. return map(std::asin);
  335. }
  336. Vectorized<T> atan() const {
  337. return map(std::atan);
  338. }
  339. Vectorized<T> atan2(const Vectorized<T> &exp) const {
  340. Vectorized<T> ret;
  341. for (const auto i : c10::irange(size())) {
  342. ret[i] = std::atan2(values[i], exp[i]);
  343. }
  344. return ret;
  345. }
  346. template <
  347. typename U = T,
  348. typename std::enable_if_t<is_floating_point<U>::value, int> = 0>
  349. Vectorized<T> copysign(const Vectorized<T> &sign) const {
  350. Vectorized<T> ret;
  351. for (size_type i = 0; i < size(); i++) {
  352. ret[i] = c10::copysign(values[i], sign[i]);
  353. }
  354. return ret;
  355. }
  356. Vectorized<T> erf() const {
  357. return map(std::erf);
  358. }
  359. Vectorized<T> erfc() const {
  360. return map(std::erfc);
  361. }
  362. Vectorized<T> erfinv() const {
  363. return map(calc_erfinv);
  364. }
  365. Vectorized<T> exp() const {
  366. return map(std::exp);
  367. }
  368. Vectorized<T> exp2() const {
  369. return map(exp2_impl);
  370. }
  371. Vectorized<T> expm1() const {
  372. return map(std::expm1);
  373. }
  374. Vectorized<T> frac() const {
  375. return *this - this->trunc();
  376. }
  377. template <
  378. typename U = T,
  379. typename std::enable_if_t<is_floating_point<U>::value, int> = 0>
  380. Vectorized<T> fmod(const Vectorized<T>& q) const {
  381. // U is for SFINAE purposes only. Make sure it is not changed.
  382. static_assert(std::is_same<U, T>::value, "U must be T");
  383. Vectorized<T> ret;
  384. for (const auto i : c10::irange(size())) {
  385. ret[i] = std::fmod(values[i], q[i]);
  386. }
  387. return ret;
  388. }
  389. Vectorized<T> log() const {
  390. return map(std::log);
  391. }
  392. Vectorized<T> log10() const {
  393. return map(std::log10);
  394. }
  395. Vectorized<T> log1p() const {
  396. return map(std::log1p);
  397. }
  398. template <typename other_t_log2 = T,
  399. typename std::enable_if<!c10::is_complex<other_t_log2>::value, int>::type = 0>
  400. Vectorized<T> log2() const {
  401. // other_t_log2 is for SFINAE and clarity. Make sure it is not changed.
  402. static_assert(std::is_same<other_t_log2, T>::value, "other_t_log2 must be T");
  403. return map(std::log2);
  404. }
  405. template <typename complex_t_log2 = T,
  406. typename std::enable_if<c10::is_complex<complex_t_log2>::value, int>::type = 0>
  407. Vectorized<T> log2() const {
  408. // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed.
  409. static_assert(std::is_same<complex_t_log2, T>::value, "complex_t_log2 must be T");
  410. const T log_2 = T(std::log(2.0));
  411. return Vectorized(map(std::log))/Vectorized(log_2);
  412. }
  413. Vectorized<T> ceil() const {
  414. return map(at::native::ceil_impl);
  415. }
  416. Vectorized<T> cos() const {
  417. return map(std::cos);
  418. }
  419. Vectorized<T> cosh() const {
  420. return map(std::cosh);
  421. }
  422. Vectorized<T> floor() const {
  423. return map(at::native::floor_impl);
  424. }
  425. Vectorized<T> hypot(const Vectorized<T> &b) const {
  426. Vectorized<T> ret;
  427. for (const auto i : c10::irange(size())) {
  428. ret[i] = std::hypot(values[i], b[i]);
  429. }
  430. return ret;
  431. }
  432. Vectorized<T> i0() const {
  433. return map(calc_i0);
  434. }
  435. Vectorized<T> i0e() const {
  436. return map(calc_i0e);
  437. }
  438. Vectorized<T> igamma(const Vectorized<T> &x) const {
  439. Vectorized<T> ret;
  440. for (const auto i : c10::irange(size())) {
  441. ret[i] = calc_igamma(values[i], x[i]);
  442. }
  443. return ret;
  444. }
  445. Vectorized<T> igammac(const Vectorized<T> &x) const {
  446. Vectorized<T> ret;
  447. for (const auto i : c10::irange(size())) {
  448. ret[i] = calc_igammac(values[i], x[i]);
  449. }
  450. return ret;
  451. }
  452. Vectorized<T> neg() const {
  453. // NB: the trailing return type is needed because we need to coerce the
  454. // return value back to T in the case of unary operator- incuring a
  455. // promotion
  456. return map([](T x) -> T { return -x; });
  457. }
  458. Vectorized<T> nextafter(const Vectorized<T> &b) const {
  459. Vectorized<T> ret;
  460. for (const auto i : c10::irange(size())) {
  461. ret[i] = std::nextafter(values[i], b[i]);
  462. }
  463. return ret;
  464. }
  465. Vectorized<T> round() const {
  466. // We do not use std::round because we would like to round midway numbers to the nearest even integer.
  467. return map(at::native::round_impl);
  468. }
  469. Vectorized<T> sin() const {
  470. return map(std::sin);
  471. }
  472. Vectorized<T> sinh() const {
  473. return map(std::sinh);
  474. }
  475. Vectorized<T> tan() const {
  476. return map(std::tan);
  477. }
  478. Vectorized<T> tanh() const {
  479. return map(std::tanh);
  480. }
  481. Vectorized<T> trunc() const {
  482. return map(at::native::trunc_impl);
  483. }
  484. Vectorized<T> lgamma() const {
  485. return map(std::lgamma);
  486. }
  487. Vectorized<T> sqrt() const {
  488. return map(std::sqrt);
  489. }
  490. Vectorized<T> reciprocal() const {
  491. return map([](T x) { return (T)(1) / x; });
  492. }
  493. Vectorized<T> rsqrt() const {
  494. return map([](T x) { return (T)1 / std::sqrt(x); });
  495. }
  496. Vectorized<T> pow(const Vectorized<T> &exp) const {
  497. Vectorized<T> ret;
  498. for (const auto i : c10::irange(size())) {
  499. ret[i] = std::pow(values[i], exp[i]);
  500. }
  501. return ret;
  502. }
  503. private:
  504. template <typename Op>
  505. inline Vectorized<T> binary_pred(const Vectorized<T>& other, Op op) const {
  506. // All bits are set to 1 if the pred is true, otherwise 0.
  507. Vectorized<T> vector;
  508. for (int64_t i = 0; i != size(); i++) {
  509. if (op(values[i], other.values[i])) {
  510. std::memset(static_cast<void*>(vector.values + i), 0xFF, sizeof(T));
  511. } else {
  512. std::memset(static_cast<void*>(vector.values + i), 0, sizeof(T));
  513. }
  514. }
  515. return vector;
  516. }
  517. public:
  518. Vectorized<T> operator==(const Vectorized<T>& other) const { return binary_pred(other, std::equal_to<T>()); }
  519. Vectorized<T> operator!=(const Vectorized<T>& other) const { return binary_pred(other, std::not_equal_to<T>()); }
  520. Vectorized<T> operator>=(const Vectorized<T>& other) const { return binary_pred(other, std::greater_equal<T>()); }
  521. Vectorized<T> operator<=(const Vectorized<T>& other) const { return binary_pred(other, std::less_equal<T>()); }
  522. Vectorized<T> operator>(const Vectorized<T>& other) const { return binary_pred(other, std::greater<T>()); }
  523. Vectorized<T> operator<(const Vectorized<T>& other) const { return binary_pred(other, std::less<T>()); }
  524. private:
  525. template <typename Op>
  526. inline Vectorized<T> binary_pred_bool(const Vectorized<T>& other, Op op) const {
  527. // 1 if the pred is true, otherwise 0.
  528. Vectorized<T> vector;
  529. for (int i = 0; i != size(); ++ i) {
  530. vector[i] = static_cast<T>(op(values[i], other.values[i]));
  531. }
  532. return vector;
  533. }
  534. public:
  535. Vectorized<T> eq(const Vectorized<T>& other) const { return binary_pred_bool(other, std::equal_to<T>()); }
  536. Vectorized<T> ne(const Vectorized<T>& other) const { return binary_pred_bool(other, std::not_equal_to<T>()); }
  537. Vectorized<T> gt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater<T>()); }
  538. Vectorized<T> ge(const Vectorized<T>& other) const { return binary_pred_bool(other, std::greater_equal<T>()); }
  539. Vectorized<T> lt(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less<T>()); }
  540. Vectorized<T> le(const Vectorized<T>& other) const { return binary_pred_bool(other, std::less_equal<T>()); }
  541. };
  542. template <class T> Vectorized<T> inline operator+(const Vectorized<T> &a, const Vectorized<T> &b) {
  543. Vectorized<T> c;
  544. for (int i = 0; i != Vectorized<T>::size(); i++) {
  545. c[i] = a[i] + b[i];
  546. }
  547. return c;
  548. }
  549. template <class T> Vectorized<T> inline operator-(const Vectorized<T> &a, const Vectorized<T> &b) {
  550. Vectorized<T> c;
  551. for (int i = 0; i != Vectorized<T>::size(); i++) {
  552. c[i] = a[i] - b[i];
  553. }
  554. return c;
  555. }
  556. template <class T> Vectorized<T> inline operator*(const Vectorized<T> &a, const Vectorized<T> &b) {
  557. Vectorized<T> c;
  558. for (int i = 0; i != Vectorized<T>::size(); i++) {
  559. c[i] = a[i] * b[i];
  560. }
  561. return c;
  562. }
  563. template <class T> Vectorized<T> inline operator/(const Vectorized<T> &a, const Vectorized<T> &b) __ubsan_ignore_float_divide_by_zero__ {
  564. Vectorized<T> c;
  565. for (int i = 0; i != Vectorized<T>::size(); i++) {
  566. c[i] = a[i] / b[i];
  567. }
  568. return c;
  569. }
  570. template <class T> Vectorized<T> inline operator||(
  571. const Vectorized<T> &a, const Vectorized<T> &b) {
  572. Vectorized<T> c;
  573. for (int i = 0; i != Vectorized<T>::size(); i++) {
  574. c[i] = a[i] || b[i];
  575. }
  576. return c;
  577. }
  578. // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
  579. // either input is a NaN.
  580. template <class T,
  581. typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
  582. Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
  583. Vectorized<T> c;
  584. for (int i = 0; i != Vectorized<T>::size(); i++) {
  585. c[i] = (a[i] > b[i]) ? a[i] : b[i];
  586. if (_isnan(a[i])) {
  587. // If either input is NaN, propagate a NaN.
  588. // NOTE: The case where b[i] was NaN is handled correctly by the naive
  589. // ternary operator above.
  590. c[i] = a[i];
  591. }
  592. }
  593. return c;
  594. }
  595. template <class T,
  596. typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
  597. Vectorized<T> inline maximum(const Vectorized<T> &a, const Vectorized<T> &b) {
  598. Vectorized<T> c;
  599. for (int i = 0; i != Vectorized<T>::size(); i++) {
  600. c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i];
  601. if (_isnan(a[i])) {
  602. // If either input is NaN, propagate a NaN.
  603. // NOTE: The case where b[i] was NaN is handled correctly by the naive
  604. // ternary operator above.
  605. c[i] = a[i];
  606. }
  607. }
  608. return c;
  609. }
  610. // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
  611. // either input is a NaN.
  612. template <class T,
  613. typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
  614. Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
  615. Vectorized<T> c;
  616. for (int i = 0; i != Vectorized<T>::size(); i++) {
  617. c[i] = (a[i] < b[i]) ? a[i] : b[i];
  618. if (_isnan(a[i])) {
  619. // If either input is NaN, propagate a NaN.
  620. // NOTE: The case where b[i] was NaN is handled correctly by the naive
  621. // ternary operator above.
  622. c[i] = a[i];
  623. }
  624. }
  625. return c;
  626. }
  627. template <class T,
  628. typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
  629. Vectorized<T> inline minimum(const Vectorized<T> &a, const Vectorized<T> &b) {
  630. Vectorized<T> c;
  631. for (int i = 0; i != Vectorized<T>::size(); i++) {
  632. c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i];
  633. if (_isnan(a[i])) {
  634. // If either input is NaN, propagate a NaN.
  635. // NOTE: The case where b[i] was NaN is handled correctly by the naive
  636. // ternary operator above.
  637. c[i] = a[i];
  638. }
  639. }
  640. return c;
  641. }
  642. template <class T,
  643. typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
  644. Vectorized<T> inline clamp(const Vectorized<T> &a, const Vectorized<T> &min_vec, const Vectorized<T> &max_vec) {
  645. Vectorized<T> c;
  646. for (int i = 0; i != Vectorized<T>::size(); i++) {
  647. c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]);
  648. }
  649. return c;
  650. }
  651. template <class T,
  652. typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
  653. Vectorized<T> inline clamp_max(const Vectorized<T> &a, const Vectorized<T> &max_vec) {
  654. Vectorized<T> c;
  655. for (int i = 0; i != Vectorized<T>::size(); i++) {
  656. c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i];
  657. }
  658. return c;
  659. }
  660. template <class T,
  661. typename std::enable_if<!c10::is_complex<T>::value, int>::type = 0>
  662. Vectorized<T> inline clamp_min(const Vectorized<T> &a, const Vectorized<T> &min_vec) {
  663. Vectorized<T> c;
  664. for (int i = 0; i != Vectorized<T>::size(); i++) {
  665. c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i];
  666. }
  667. return c;
  668. }
  669. struct Vectorizedi;
  670. #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
  671. template <class T, typename Op>
  672. static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
  673. int_vector buffer;
  674. #if defined(CPU_CAPABILITY_AVX2)
  675. int_vector a_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)a));
  676. int_vector b_buffer = _mm256_load_si256(reinterpret_cast<const int_vector*>((const T*)b));
  677. #elif defined(CPU_CAPABILITY_AVX512)
  678. int_vector a_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)a));
  679. int_vector b_buffer = _mm512_load_si512(reinterpret_cast<const int_vector*>((const T*)b));
  680. #endif
  681. buffer = op(a_buffer, b_buffer);
  682. __at_align__ T results[Vectorized<T>::size()];
  683. #if defined(CPU_CAPABILITY_AVX2)
  684. _mm256_store_si256(reinterpret_cast<int_vector*>(results), buffer);
  685. #elif defined(CPU_CAPABILITY_AVX512)
  686. _mm512_store_si512(reinterpret_cast<int_vector*>(results), buffer);
  687. #endif
  688. return Vectorized<T>::loadu(results);
  689. }
  690. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  691. inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
  692. // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline
  693. #if defined(CPU_CAPABILITY_AVX2)
  694. return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); });
  695. #elif defined(CPU_CAPABILITY_AVX512)
  696. return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); });
  697. #endif
  698. }
  699. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  700. inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
  701. // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline
  702. #if defined(CPU_CAPABILITY_AVX2)
  703. return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); });
  704. #elif defined(CPU_CAPABILITY_AVX512)
  705. return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); });
  706. #endif
  707. }
  708. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  709. inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
  710. // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline
  711. #if defined(CPU_CAPABILITY_AVX2)
  712. return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); });
  713. #elif defined(CPU_CAPABILITY_AVX512)
  714. return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); });
  715. #endif
  716. }
  717. #else
  718. template <typename T>
  719. auto load(char const* data) -> T {
  720. T ret;
  721. std::memcpy(&ret, data, sizeof(ret));
  722. return ret;
  723. }
  724. template<class T, typename Op>
  725. static inline Vectorized<T> bitwise_binary_op(const Vectorized<T> &a, const Vectorized<T> &b, Op op) {
  726. static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t);
  727. __at_align__ intmax_t buffer[element_no];
  728. static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)");
  729. static_assert(sizeof(buffer) == sizeof(Vectorized<T>), "sizeof(buffer) must match sizeof(Vectorized<T>)");
  730. // We should be using memcpy in order to respect the strict aliasing rule
  731. // see: https://github.com/pytorch/pytorch/issues/66119
  732. // Using char* is defined in the C11 standard 6.5 Expression paragraph 7
  733. // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf)
  734. const auto* a_data = a.as_bytes();
  735. const auto* b_data = b.as_bytes();
  736. // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t)
  737. for (auto& out : buffer) {
  738. out = op(load<intmax_t>(a_data), load<intmax_t>(b_data));
  739. a_data += sizeof(intmax_t);
  740. b_data += sizeof(intmax_t);
  741. }
  742. assert(a_data == a.as_bytes() + sizeof(a));
  743. assert(b_data == b.as_bytes() + sizeof(b));
  744. return Vectorized<T>::loadu(buffer);
  745. }
  746. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  747. inline Vectorized<T> operator&(const Vectorized<T>& a, const Vectorized<T>& b) {
  748. return bitwise_binary_op(a, b, std::bit_and<intmax_t>());
  749. }
  750. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  751. inline Vectorized<T> operator|(const Vectorized<T>& a, const Vectorized<T>& b) {
  752. return bitwise_binary_op(a, b, std::bit_or<intmax_t>());
  753. }
  754. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  755. inline Vectorized<T> operator^(const Vectorized<T>& a, const Vectorized<T>& b) {
  756. return bitwise_binary_op(a, b, std::bit_xor<intmax_t>());
  757. }
  758. #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
  759. template<class T, typename std::enable_if_t<!std::is_base_of<Vectorizedi, Vectorized<T>>::value, int> = 0>
  760. inline Vectorized<T> operator~(const Vectorized<T>& a) {
  761. Vectorized<T> ones; // All bits are 1
  762. memset((T*) ones, 0xFF, VECTOR_WIDTH);
  763. return a ^ ones;
  764. }
  765. template <class T> Vectorized<T> inline operator<<(const Vectorized<T> &a, const Vectorized<T> &b) {
  766. Vectorized<T> c;
  767. for (int i = 0; i != Vectorized<T>::size(); i++) {
  768. c[i] = a[i] << b[i];
  769. }
  770. return c;
  771. }
  772. template <class T> Vectorized<T> inline operator>>(const Vectorized<T> &a, const Vectorized<T> &b) {
  773. Vectorized<T> c;
  774. for (int i = 0; i != Vectorized<T>::size(); i++) {
  775. c[i] = a[i] >> b[i];
  776. }
  777. return c;
  778. }
  779. template <typename T>
  780. inline Vectorized<T>& operator += (Vectorized<T>& a, const Vectorized<T>& b) {
  781. a = a + b;
  782. return a;
  783. }
  784. template <typename T>
  785. inline Vectorized<T>& operator -= (Vectorized<T>& a, const Vectorized<T>& b) {
  786. a = a - b;
  787. return a;
  788. }
  789. template <typename T>
  790. inline Vectorized<T>& operator /= (Vectorized<T>& a, const Vectorized<T>& b) {
  791. a = a / b;
  792. return a;
  793. }
  794. template <typename T>
  795. inline Vectorized<T>& operator %= (Vectorized<T>& a, const Vectorized<T>& b) {
  796. a = a % b;
  797. return a;
  798. }
  799. template <typename T>
  800. inline Vectorized<T>& operator *= (Vectorized<T>& a, const Vectorized<T>& b) {
  801. a = a * b;
  802. return a;
  803. }
  804. template <typename T>
  805. inline Vectorized<T>& operator <<= (Vectorized<T>& a, const Vectorized<T>& b) {
  806. a = a << b;
  807. return a;
  808. }
  809. template <typename T>
  810. inline Vectorized<T>& operator >>= (Vectorized<T>& a, const Vectorized<T>& b) {
  811. a = a >> b;
  812. return a;
  813. }
  814. template <typename T>
  815. inline Vectorized<T> fmadd(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
  816. return a * b + c;
  817. }
  818. template <typename T>
  819. inline Vectorized<T> fmsub(const Vectorized<T>& a, const Vectorized<T>& b, const Vectorized<T>& c) {
  820. return a * b - c;
  821. }
  822. template <int64_t scale = 1, typename T = void>
  823. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
  824. inline gather(T const* base_addr, const Vectorized<int_same_size_t<T>>& vindex) {
  825. static constexpr int size = Vectorized<T>::size();
  826. int_same_size_t<T> index_arr[size];
  827. vindex.store(static_cast<void*>(index_arr));
  828. T buffer[size];
  829. for (const auto i : c10::irange(size)) {
  830. buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
  831. }
  832. return Vectorized<T>::loadu(static_cast<void*>(buffer));
  833. }
  834. template <int64_t scale = 1, typename T = void>
  835. std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<T>>
  836. inline mask_gather(const Vectorized<T>& src, T const* base_addr,
  837. const Vectorized<int_same_size_t<T>>& vindex, Vectorized<T>& mask) {
  838. static constexpr int size = Vectorized<T>::size();
  839. T src_arr[size];
  840. int_same_size_t<T> mask_arr[size]; // use int type so we can logical and
  841. int_same_size_t<T> index_arr[size];
  842. src.store(static_cast<void*>(src_arr));
  843. mask.store(static_cast<void*>(mask_arr));
  844. vindex.store(static_cast<void*>(index_arr));
  845. T buffer[size];
  846. for (const auto i : c10::irange(size)) {
  847. if (mask_arr[i] & 0x01) { // check highest bit
  848. buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
  849. } else {
  850. buffer[i] = src_arr[i];
  851. }
  852. }
  853. mask = Vectorized<T>(); // "zero out" mask
  854. return Vectorized<T>::loadu(static_cast<void*>(buffer));
  855. }
  856. // Cast a given vector to another type without changing the bits representation.
  857. // So a Vectorized<double> of 512 bits containing all ones can be cast to a
  858. // Vectorized<int64_t> of 512 bits containing all ones (i.e., eight negative 1s).
  859. // A Vec<double> of 256 bits containing all ones can be cast to a
  860. // Vec<int64_t> of 256 bits containing all ones (i.e., four negative 1s).
  861. // There is a struct here because we don't have static_if and I can't
  862. // partially specialize a templated function.
  863. template<typename dst_t, typename src_t>
  864. struct CastImpl {
  865. static inline Vectorized<dst_t> apply(const Vectorized<src_t>& src) {
  866. src_t src_arr[Vectorized<src_t>::size()];
  867. src.store(static_cast<void*>(src_arr));
  868. return Vectorized<dst_t>::loadu(static_cast<const void*>(src_arr));
  869. }
  870. };
  871. template<typename scalar_t>
  872. struct CastImpl<scalar_t, scalar_t> {
  873. static inline Vectorized<scalar_t> apply(const Vectorized<scalar_t>& src) {
  874. return src;
  875. }
  876. };
  877. template<typename dst_t, typename src_t>
  878. inline Vectorized<dst_t> cast(const Vectorized<src_t>& src) {
  879. return CastImpl<dst_t, src_t>::apply(src);
  880. }
  881. template <typename T>
  882. inline Vectorized<int_same_size_t<T>> convert_to_int_of_same_size(const Vectorized<T>& src) {
  883. static constexpr int size = Vectorized<T>::size();
  884. T src_arr[size];
  885. src.store(static_cast<void*>(src_arr));
  886. int_same_size_t<T> buffer[size];
  887. for (const auto i : c10::irange(size)) {
  888. buffer[i] = static_cast<int_same_size_t<T>>(src_arr[i]);
  889. }
  890. return Vectorized<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
  891. }
  892. // Example inputs for AVX512:
  893. // a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
  894. // b Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
  895. // returns:
  896. // Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
  897. // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
  898. // Example inputs for AVX2: a Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
  899. // b Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
  900. // returns: Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
  901. // Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
  902. template <typename T>
  903. inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
  904. deinterleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
  905. static constexpr int size = Vectorized<T>::size();
  906. static constexpr int half_size = size / 2;
  907. T a_arr[size];
  908. T b_arr[size];
  909. T buffer1[size];
  910. T buffer2[size];
  911. a.store(static_cast<void*>(a_arr));
  912. b.store(static_cast<void*>(b_arr));
  913. for (const auto i : c10::irange(half_size)) {
  914. buffer1[i] = a_arr[i * 2];
  915. buffer1[half_size + i] = b_arr[i * 2];
  916. buffer2[i] = a_arr[i * 2 + 1];
  917. buffer2[half_size + i] = b_arr[i * 2 + 1];
  918. }
  919. return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
  920. Vectorized<T>::loadu(static_cast<void*>(buffer2)));
  921. }
  922. // inverse operation of deinterleave2
  923. // Example inputs for AVX512:
  924. // a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15}
  925. // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15}
  926. // returns, for AVX512:
  927. // Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7}
  928. // Vectorized<float> = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15}
  929. // Example inputs for AVX2 : a Vectorized<float> = {a0, a1, a2, a3, a4, a5, a6, a7}
  930. // b Vectorized<float> = {b0, b1, b2, b3, b4, b5, b6, b7}
  931. // returns: Vectorized<float> = {a0, b0, a1, b1, a2, b2, a3, b3}
  932. // Vectorized<float> = {a4, b4, a5, b5, a6, b6, a7, b7}
  933. template <typename T>
  934. inline std::enable_if_t<Vectorized<T>::size() % 2 == 0, std::pair<Vectorized<T>, Vectorized<T>>>
  935. interleave2(const Vectorized<T>& a, const Vectorized<T>& b) {
  936. static constexpr int size = Vectorized<T>::size();
  937. static constexpr int half_size = size / 2;
  938. T a_arr[size];
  939. T b_arr[size];
  940. T buffer1[size];
  941. T buffer2[size];
  942. a.store(static_cast<void*>(a_arr));
  943. b.store(static_cast<void*>(b_arr));
  944. for (const auto i : c10::irange(half_size)) {
  945. buffer1[i * 2] = a_arr[i];
  946. buffer1[i * 2 + 1] = b_arr[i];
  947. buffer2[i * 2] = a_arr[half_size + i];
  948. buffer2[i * 2 + 1] = b_arr[half_size + i];
  949. }
  950. return std::make_pair(Vectorized<T>::loadu(static_cast<void*>(buffer1)),
  951. Vectorized<T>::loadu(static_cast<void*>(buffer2)));
  952. }
  953. template <typename src_T, typename dst_T>
  954. inline void convert(const src_T *src, dst_T *dst, int64_t n) {
  955. #ifndef _MSC_VER
  956. # pragma unroll
  957. #endif
  958. for (const auto i : c10::irange(n)) {
  959. (void)i; //Suppress unused variable warning
  960. *dst = c10::convert<dst_T>(c10::load(src));
  961. src++;
  962. dst++;
  963. }
  964. }
  965. template <typename T>
  966. inline Vectorized<T> flip(const Vectorized<T> & data) {
  967. static constexpr int size = Vectorized<T>::size();
  968. T output[size];
  969. T buffer[size];
  970. data.store(static_cast<void*>(buffer));
  971. for (const auto i : c10::irange(size)) {
  972. output[i] = buffer[size - i - 1];
  973. }
  974. return Vectorized<T>::loadu(static_cast<void*>(output));
  975. }
  976. // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading
  977. // dimension of `src` and `ld_dst` is the leading dimension of `dst`.
  978. template <typename T, int M, int N>
  979. inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
  980. for (int i = 0; i < M; i++) {
  981. for (int j = 0; j < N; j++) {
  982. dst[j*ld_dst + i] = src[i*ld_src + j];
  983. }
  984. }
  985. }
  986. }}}