#pragma once // DO NOT DEFINE STATIC DATA IN THIS HEADER! // See Note [Do not compile initializers with AVX] // // Note [Do not compile initializers with AVX] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // If you define a static initializer in this file, the initialization will use // AVX instructions because these object files are compiled with AVX enabled. // We need to avoid non-trivial global data in these architecture specific files // because there's no way to guard the global initializers with CPU capability // detection. // // See https://github.com/pytorch/pytorch/issues/37577 for an instance // of this bug in the past. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // These macros helped us unify vec_base.h #ifdef CPU_CAPABILITY_AVX512 #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(64))) #elif defined(_WIN32) #define __at_align__ __declspec(align(64)) #else #define __at_align__ #endif #define VECTOR_WIDTH 64 #define int_vector __m512i #else // CPU_CAPABILITY_AVX512 #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(32))) #elif defined(_WIN32) #define __at_align__ __declspec(align(32)) #else #define __at_align__ #endif #define VECTOR_WIDTH 32 #define int_vector __m256i #endif // CPU_CAPABILITY_AVX512 namespace at { namespace vec { // See Note [CPU_CAPABILITY namespace] inline namespace CPU_CAPABILITY { // at::Half and at::BFloat16 should be treated as floating point template struct is_floating_point: std::integral_constant::value || std::is_same::value || std::is_same::value> { }; template struct int_of_size; #define DEFINE_INT_OF_SIZE(int_t) \ template<> struct int_of_size { using type = int_t; } DEFINE_INT_OF_SIZE(int64_t); DEFINE_INT_OF_SIZE(int32_t); DEFINE_INT_OF_SIZE(int16_t); DEFINE_INT_OF_SIZE(int8_t); #undef DEFINE_INT_OF_SIZE template using int_same_size_t = typename int_of_size::type; // NOTE: If you specialize on a type, you must define all operations! // emulates Vectorized types #if defined(__s390x__) template #else template #endif struct Vectorized { private: __at_align__ T values[VECTOR_WIDTH / sizeof(T)]; public: using value_type = T; using size_type = int; // Note [constexpr static function to avoid odr-usage compiler bug] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Why, you might ask, is size defined to be a static constexpr function, // rather than a more ordinary 'static constexpr int size;' variable? // The problem lies within ODR rules for static constexpr members versus // static constexpr functions. First, recall that this class (along with all // of its derivations) live in an anonymous namespace: they are intended to be // *completely* inlined at their use-sites, because we need to compile it // multiple times for different instruction sets. // // Because of this constraint, we CANNOT provide a single definition for // any static members in this class; since we want to compile the class // multiple times, there wouldn't actually be any good place to put the // definition. Now here is the problem: if we ODR-use a static constexpr // member, we are *obligated* to provide a definition. Without the // definition, you get a compile error like: // // relocation R_X86_64_PC32 against undefined symbol // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making // a shared object; recompile with -fPIC // // If this were C++17, we could replace a static constexpr variable with // an inline variable which doesn't require one definition. But we are not // C++17. So the next best thing is to replace the member with a static // constexpr (and therefore inline) function, which does not require ODR // either. // // Also, technically according to the C++ standard, we don't have to define // a constexpr variable if we never odr-use it. But it seems that some // versions GCC/Clang have buggy determinations on whether or not an // identifier is odr-used or not, and in any case it's hard to tell if // a variable is odr-used or not. So best to just cut the problem at the root. static constexpr size_type size_T = sizeof(T); // Workaround to compile with VS2022. static constexpr size_type size() { return VECTOR_WIDTH / size_T; } Vectorized() : values{static_cast(0)} {} Vectorized(T val) { for (int i = 0; i != size(); i++) { values[i] = val; } } template> Vectorized(Args... vals) : values{vals...}{ } // This also implies const T& operator[](int idx) const inline operator const T*() const { return values; } // This also implies T& operator[](int idx) inline operator T*() { return values; } // Return the values as char* for type punning auto as_bytes() const -> const char* { return reinterpret_cast(values); } template static Vectorized blend(const Vectorized& a, const Vectorized& b) { int64_t mask = mask_; Vectorized vector; for (const auto i : c10::irange(size())) { if (mask & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; } mask = mask >> 1; } return vector; } static Vectorized blendv(const Vectorized& a, const Vectorized& b, const Vectorized& mask) { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); for (const auto i : c10::irange(size())) { if (buffer[i] & 0x01) { vector[i] = b[i]; } else { vector[i] = a[i]; } } return vector; } template // step sometimes requires a higher precision type (e.g., T=int, step_t=double) static Vectorized arange(T base = static_cast(0), step_t step = static_cast(1)) { Vectorized vector; for (const auto i : c10::irange(size())) { vector.values[i] = base + i * step; } return vector; } static Vectorized set(const Vectorized& a, const Vectorized& b, int64_t count = size()) { Vectorized vector; for (const auto i : c10::irange(size())) { if (i < count) { vector[i] = b[i]; } else { vector[i] = a[i]; } } return vector; } static Vectorized loadu(const void* ptr) { Vectorized vector; std::memcpy(vector.values, ptr, VECTOR_WIDTH); return vector; } static Vectorized loadu(const void* ptr, int64_t count) { Vectorized vector; std::memcpy(vector.values, ptr, count * sizeof(T)); return vector; } void store(void* ptr, int count = size()) const { std::memcpy(ptr, values, count * sizeof(T)); } int zero_mask() const { // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit int mask = 0; for (int i = 0; i < size(); ++ i) { if (values[i] == static_cast(0)) { mask |= (1 << i); } } return mask; } Vectorized isnan() const { Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (_isnan(values[i])) { std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } return vector; } Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } Vectorized map(T (*const f)(const T &)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { ret[i] = f(values[i]); } return ret; } template ::value && !c10::is_complex::value, int>::type = 0> Vectorized abs() const { // other_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_abs must be T"); return map([](T x) -> T { return x < static_cast(0) ? -x : x; }); } template ::value, int>::type = 0> Vectorized abs() const { // float_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "float_t_abs must be T"); // Specifically deal with floating-point because the generic code above won't handle -0.0 (which should result in // 0.0) properly. return map([](T x) -> T { return std::abs(x); }); } template ::value, int>::type = 0> Vectorized abs() const { // complex_t_abs is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "complex_t_abs must be T"); // Specifically map() does not perform the type conversion needed by abs. return map([](T x) { return static_cast(std::abs(x)); }); } template ::value, int>::type = 0> Vectorized sgn() const { return map(at::native::sgn_impl); } template ::value, int>::type = 0> Vectorized angle() const { // other_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_angle must be T"); return map(at::native::angle_impl); // compiler is unable to resolve the overload without } template ::value, int>::type = 0> Vectorized angle() const { // complex_t_angle is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "complex_t_angle must be T"); return map([](T x) { return static_cast(std::arg(x)); }); } template ::value, int>::type = 0> Vectorized real() const { // other_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_real must be T"); return *this; } template ::value, int>::type = 0> Vectorized real() const { // complex_t_real is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "complex_t_real must be T"); return map([](T x) { return static_cast(x.real()); }); } template ::value, int>::type = 0> Vectorized imag() const { // other_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_imag must be T"); return Vectorized(0); } template ::value, int>::type = 0> Vectorized imag() const { // complex_t_imag is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "complex_t_imag must be T"); return map([](T x) { return static_cast(x.imag()); }); } template ::value, int>::type = 0> Vectorized conj() const { // other_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_conj must be T"); return *this; } template ::value, int>::type = 0> Vectorized conj() const { // complex_t_conj is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "complex_t_conj must be T"); return map([](T x) { return static_cast(std::conj(x)); }); } Vectorized acos() const { return map(std::acos); } Vectorized asin() const { return map(std::asin); } Vectorized atan() const { return map(std::atan); } Vectorized atan2(const Vectorized &exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::atan2(values[i], exp[i]); } return ret; } template < typename U = T, typename std::enable_if_t::value, int> = 0> Vectorized copysign(const Vectorized &sign) const { Vectorized ret; for (size_type i = 0; i < size(); i++) { ret[i] = c10::copysign(values[i], sign[i]); } return ret; } Vectorized erf() const { return map(std::erf); } Vectorized erfc() const { return map(std::erfc); } Vectorized erfinv() const { return map(calc_erfinv); } Vectorized exp() const { return map(std::exp); } Vectorized exp2() const { return map(exp2_impl); } Vectorized expm1() const { return map(std::expm1); } Vectorized frac() const { return *this - this->trunc(); } template < typename U = T, typename std::enable_if_t::value, int> = 0> Vectorized fmod(const Vectorized& q) const { // U is for SFINAE purposes only. Make sure it is not changed. static_assert(std::is_same::value, "U must be T"); Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::fmod(values[i], q[i]); } return ret; } Vectorized log() const { return map(std::log); } Vectorized log10() const { return map(std::log10); } Vectorized log1p() const { return map(std::log1p); } template ::value, int>::type = 0> Vectorized log2() const { // other_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "other_t_log2 must be T"); return map(std::log2); } template ::value, int>::type = 0> Vectorized log2() const { // complex_t_log2 is for SFINAE and clarity. Make sure it is not changed. static_assert(std::is_same::value, "complex_t_log2 must be T"); const T log_2 = T(std::log(2.0)); return Vectorized(map(std::log))/Vectorized(log_2); } Vectorized ceil() const { return map(at::native::ceil_impl); } Vectorized cos() const { return map(std::cos); } Vectorized cosh() const { return map(std::cosh); } Vectorized floor() const { return map(at::native::floor_impl); } Vectorized hypot(const Vectorized &b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::hypot(values[i], b[i]); } return ret; } Vectorized i0() const { return map(calc_i0); } Vectorized i0e() const { return map(calc_i0e); } Vectorized igamma(const Vectorized &x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igamma(values[i], x[i]); } return ret; } Vectorized igammac(const Vectorized &x) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = calc_igammac(values[i], x[i]); } return ret; } Vectorized neg() const { // NB: the trailing return type is needed because we need to coerce the // return value back to T in the case of unary operator- incuring a // promotion return map([](T x) -> T { return -x; }); } Vectorized nextafter(const Vectorized &b) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::nextafter(values[i], b[i]); } return ret; } Vectorized round() const { // We do not use std::round because we would like to round midway numbers to the nearest even integer. return map(at::native::round_impl); } Vectorized sin() const { return map(std::sin); } Vectorized sinh() const { return map(std::sinh); } Vectorized tan() const { return map(std::tan); } Vectorized tanh() const { return map(std::tanh); } Vectorized trunc() const { return map(at::native::trunc_impl); } Vectorized lgamma() const { return map(std::lgamma); } Vectorized sqrt() const { return map(std::sqrt); } Vectorized reciprocal() const { return map([](T x) { return (T)(1) / x; }); } Vectorized rsqrt() const { return map([](T x) { return (T)1 / std::sqrt(x); }); } Vectorized pow(const Vectorized &exp) const { Vectorized ret; for (const auto i : c10::irange(size())) { ret[i] = std::pow(values[i], exp[i]); } return ret; } private: template inline Vectorized binary_pred(const Vectorized& other, Op op) const { // All bits are set to 1 if the pred is true, otherwise 0. Vectorized vector; for (int64_t i = 0; i != size(); i++) { if (op(values[i], other.values[i])) { std::memset(static_cast(vector.values + i), 0xFF, sizeof(T)); } else { std::memset(static_cast(vector.values + i), 0, sizeof(T)); } } return vector; } public: Vectorized operator==(const Vectorized& other) const { return binary_pred(other, std::equal_to()); } Vectorized operator!=(const Vectorized& other) const { return binary_pred(other, std::not_equal_to()); } Vectorized operator>=(const Vectorized& other) const { return binary_pred(other, std::greater_equal()); } Vectorized operator<=(const Vectorized& other) const { return binary_pred(other, std::less_equal()); } Vectorized operator>(const Vectorized& other) const { return binary_pred(other, std::greater()); } Vectorized operator<(const Vectorized& other) const { return binary_pred(other, std::less()); } private: template inline Vectorized binary_pred_bool(const Vectorized& other, Op op) const { // 1 if the pred is true, otherwise 0. Vectorized vector; for (int i = 0; i != size(); ++ i) { vector[i] = static_cast(op(values[i], other.values[i])); } return vector; } public: Vectorized eq(const Vectorized& other) const { return binary_pred_bool(other, std::equal_to()); } Vectorized ne(const Vectorized& other) const { return binary_pred_bool(other, std::not_equal_to()); } Vectorized gt(const Vectorized& other) const { return binary_pred_bool(other, std::greater()); } Vectorized ge(const Vectorized& other) const { return binary_pred_bool(other, std::greater_equal()); } Vectorized lt(const Vectorized& other) const { return binary_pred_bool(other, std::less()); } Vectorized le(const Vectorized& other) const { return binary_pred_bool(other, std::less_equal()); } }; template Vectorized inline operator+(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] + b[i]; } return c; } template Vectorized inline operator-(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] - b[i]; } return c; } template Vectorized inline operator*(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] * b[i]; } return c; } template Vectorized inline operator/(const Vectorized &a, const Vectorized &b) __ubsan_ignore_float_divide_by_zero__ { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] / b[i]; } return c; } template Vectorized inline operator||( const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] || b[i]; } return c; } // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if // either input is a NaN. template ::value, int>::type = 0> Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] > b[i]) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } template ::value, int>::type = 0> Vectorized inline maximum(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) > std::abs(b[i])) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if // either input is a NaN. template ::value, int>::type = 0> Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (a[i] < b[i]) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } template ::value, int>::type = 0> Vectorized inline minimum(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = (std::abs(a[i]) < std::abs(b[i])) ? a[i] : b[i]; if (_isnan(a[i])) { // If either input is NaN, propagate a NaN. // NOTE: The case where b[i] was NaN is handled correctly by the naive // ternary operator above. c[i] = a[i]; } } return c; } template ::value, int>::type = 0> Vectorized inline clamp(const Vectorized &a, const Vectorized &min_vec, const Vectorized &max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = std::min(std::max(a[i], min_vec[i]), max_vec[i]); } return c; } template ::value, int>::type = 0> Vectorized inline clamp_max(const Vectorized &a, const Vectorized &max_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] > max_vec[i] ? max_vec[i] : a[i]; } return c; } template ::value, int>::type = 0> Vectorized inline clamp_min(const Vectorized &a, const Vectorized &min_vec) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] < min_vec[i] ? min_vec[i] : a[i]; } return c; } struct Vectorizedi; #if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { int_vector buffer; #if defined(CPU_CAPABILITY_AVX2) int_vector a_buffer = _mm256_load_si256(reinterpret_cast((const T*)a)); int_vector b_buffer = _mm256_load_si256(reinterpret_cast((const T*)b)); #elif defined(CPU_CAPABILITY_AVX512) int_vector a_buffer = _mm512_load_si512(reinterpret_cast((const T*)a)); int_vector b_buffer = _mm512_load_si512(reinterpret_cast((const T*)b)); #endif buffer = op(a_buffer, b_buffer); __at_align__ T results[Vectorized::size()]; #if defined(CPU_CAPABILITY_AVX2) _mm256_store_si256(reinterpret_cast(results), buffer); #elif defined(CPU_CAPABILITY_AVX512) _mm512_store_si512(reinterpret_cast(results), buffer); #endif return Vectorized::loadu(results); } template>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { // We enclose _mm512_and_si512 or _mm256_and_si256 with lambda because it is always_inline #if defined(CPU_CAPABILITY_AVX2) return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_and_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_and_si512(a, b); }); #endif } template>::value, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { // We enclose _mm512_or_si512 or _mm256_or_si256 with lambda because it is always_inline #if defined(CPU_CAPABILITY_AVX2) return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_or_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_or_si512(a, b); }); #endif } template>::value, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { // We enclose _mm512_xor_si512 or _mm256_xor_si256 with lambda because it is always_inline #if defined(CPU_CAPABILITY_AVX2) return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm256_xor_si256(a, b); }); #elif defined(CPU_CAPABILITY_AVX512) return bitwise_binary_op(a, b, [](int_vector a, int_vector b) { return _mm512_xor_si512(a, b); }); #endif } #else template auto load(char const* data) -> T { T ret; std::memcpy(&ret, data, sizeof(ret)); return ret; } template static inline Vectorized bitwise_binary_op(const Vectorized &a, const Vectorized &b, Op op) { static constexpr uint32_t element_no = VECTOR_WIDTH / sizeof(intmax_t); __at_align__ intmax_t buffer[element_no]; static_assert(VECTOR_WIDTH % sizeof(intmax_t) == 0, "VECTOR_WIDTH not a multiple of sizeof(intmax_t)"); static_assert(sizeof(buffer) == sizeof(Vectorized), "sizeof(buffer) must match sizeof(Vectorized)"); // We should be using memcpy in order to respect the strict aliasing rule // see: https://github.com/pytorch/pytorch/issues/66119 // Using char* is defined in the C11 standard 6.5 Expression paragraph 7 // (http://www.open-std.org/jtc1/sc22/wg14/www/docs/n1570.pdf) const auto* a_data = a.as_bytes(); const auto* b_data = b.as_bytes(); // load each intmax_t chunk and process; increase pointers by sizeof(intmax_t) for (auto& out : buffer) { out = op(load(a_data), load(b_data)); a_data += sizeof(intmax_t); b_data += sizeof(intmax_t); } assert(a_data == a.as_bytes() + sizeof(a)); assert(b_data == b.as_bytes() + sizeof(b)); return Vectorized::loadu(buffer); } template>::value, int> = 0> inline Vectorized operator&(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_and()); } template>::value, int> = 0> inline Vectorized operator|(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_or()); } template>::value, int> = 0> inline Vectorized operator^(const Vectorized& a, const Vectorized& b) { return bitwise_binary_op(a, b, std::bit_xor()); } #endif // defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512) template>::value, int> = 0> inline Vectorized operator~(const Vectorized& a) { Vectorized ones; // All bits are 1 memset((T*) ones, 0xFF, VECTOR_WIDTH); return a ^ ones; } template Vectorized inline operator<<(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] << b[i]; } return c; } template Vectorized inline operator>>(const Vectorized &a, const Vectorized &b) { Vectorized c; for (int i = 0; i != Vectorized::size(); i++) { c[i] = a[i] >> b[i]; } return c; } template inline Vectorized& operator += (Vectorized& a, const Vectorized& b) { a = a + b; return a; } template inline Vectorized& operator -= (Vectorized& a, const Vectorized& b) { a = a - b; return a; } template inline Vectorized& operator /= (Vectorized& a, const Vectorized& b) { a = a / b; return a; } template inline Vectorized& operator %= (Vectorized& a, const Vectorized& b) { a = a % b; return a; } template inline Vectorized& operator *= (Vectorized& a, const Vectorized& b) { a = a * b; return a; } template inline Vectorized& operator <<= (Vectorized& a, const Vectorized& b) { a = a << b; return a; } template inline Vectorized& operator >>= (Vectorized& a, const Vectorized& b) { a = a >> b; return a; } template inline Vectorized fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b + c; } template inline Vectorized fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) { return a * b - c; } template std::enable_if_t> inline gather(T const* base_addr, const Vectorized>& vindex) { static constexpr int size = Vectorized::size(); int_same_size_t index_arr[size]; vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } return Vectorized::loadu(static_cast(buffer)); } template std::enable_if_t> inline mask_gather(const Vectorized& src, T const* base_addr, const Vectorized>& vindex, Vectorized& mask) { static constexpr int size = Vectorized::size(); T src_arr[size]; int_same_size_t mask_arr[size]; // use int type so we can logical and int_same_size_t index_arr[size]; src.store(static_cast(src_arr)); mask.store(static_cast(mask_arr)); vindex.store(static_cast(index_arr)); T buffer[size]; for (const auto i : c10::irange(size)) { if (mask_arr[i] & 0x01) { // check highest bit buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)]; } else { buffer[i] = src_arr[i]; } } mask = Vectorized(); // "zero out" mask return Vectorized::loadu(static_cast(buffer)); } // Cast a given vector to another type without changing the bits representation. // So a Vectorized of 512 bits containing all ones can be cast to a // Vectorized of 512 bits containing all ones (i.e., eight negative 1s). // A Vec of 256 bits containing all ones can be cast to a // Vec of 256 bits containing all ones (i.e., four negative 1s). // There is a struct here because we don't have static_if and I can't // partially specialize a templated function. template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { src_t src_arr[Vectorized::size()]; src.store(static_cast(src_arr)); return Vectorized::loadu(static_cast(src_arr)); } }; template struct CastImpl { static inline Vectorized apply(const Vectorized& src) { return src; } }; template inline Vectorized cast(const Vectorized& src) { return CastImpl::apply(src); } template inline Vectorized> convert_to_int_of_same_size(const Vectorized& src) { static constexpr int size = Vectorized::size(); T src_arr[size]; src.store(static_cast(src_arr)); int_same_size_t buffer[size]; for (const auto i : c10::irange(size)) { buffer[i] = static_cast>(src_arr[i]); } return Vectorized>::loadu(static_cast(buffer)); } // Example inputs for AVX512: // a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // b Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} // returns: // Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} // Example inputs for AVX2: a Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // b Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} // returns: Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} template inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> deinterleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; T a_arr[size]; T b_arr[size]; T buffer1[size]; T buffer2[size]; a.store(static_cast(a_arr)); b.store(static_cast(b_arr)); for (const auto i : c10::irange(half_size)) { buffer1[i] = a_arr[i * 2]; buffer1[half_size + i] = b_arr[i * 2]; buffer2[i] = a_arr[i * 2 + 1]; buffer2[half_size + i] = b_arr[i * 2 + 1]; } return std::make_pair(Vectorized::loadu(static_cast(buffer1)), Vectorized::loadu(static_cast(buffer2))); } // inverse operation of deinterleave2 // Example inputs for AVX512: // a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15} // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, b10, b11, b12, b13, b14, b15} // returns, for AVX512: // Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3, a4, b4, a5, b5, a6, b6, a7, b7} // Vectorized = {a8, b8, a9, b9, a10, b10, a11, b11, a12, b12, a13, b13, a14, b14, a15, b15} // Example inputs for AVX2 : a Vectorized = {a0, a1, a2, a3, a4, a5, a6, a7} // b Vectorized = {b0, b1, b2, b3, b4, b5, b6, b7} // returns: Vectorized = {a0, b0, a1, b1, a2, b2, a3, b3} // Vectorized = {a4, b4, a5, b5, a6, b6, a7, b7} template inline std::enable_if_t::size() % 2 == 0, std::pair, Vectorized>> interleave2(const Vectorized& a, const Vectorized& b) { static constexpr int size = Vectorized::size(); static constexpr int half_size = size / 2; T a_arr[size]; T b_arr[size]; T buffer1[size]; T buffer2[size]; a.store(static_cast(a_arr)); b.store(static_cast(b_arr)); for (const auto i : c10::irange(half_size)) { buffer1[i * 2] = a_arr[i]; buffer1[i * 2 + 1] = b_arr[i]; buffer2[i * 2] = a_arr[half_size + i]; buffer2[i * 2 + 1] = b_arr[half_size + i]; } return std::make_pair(Vectorized::loadu(static_cast(buffer1)), Vectorized::loadu(static_cast(buffer2))); } template inline void convert(const src_T *src, dst_T *dst, int64_t n) { #ifndef _MSC_VER # pragma unroll #endif for (const auto i : c10::irange(n)) { (void)i; //Suppress unused variable warning *dst = c10::convert(c10::load(src)); src++; dst++; } } template inline Vectorized flip(const Vectorized & data) { static constexpr int size = Vectorized::size(); T output[size]; T buffer[size]; data.store(static_cast(buffer)); for (const auto i : c10::irange(size)) { output[i] = buffer[size - i - 1]; } return Vectorized::loadu(static_cast(output)); } // Transpose the `src` buffer of type `T` and size (M,N) into the `dst` buffer. `ld_src` is the leading // dimension of `src` and `ld_dst` is the leading dimension of `dst`. template inline void transpose_mxn(const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { dst[j*ld_dst + i] = src[i*ld_src + j]; } } } }}}