#pragma once #include #include #include #ifdef USE_FBGEMM #include #endif namespace at { namespace native { inline namespace CPU_CAPABILITY { template inline T data_index_init(T offset) { return offset; } template inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { offset = data_index_init(offset, std::forward(args)...); x = offset % X; return offset / X; } inline bool data_index_step() { return true; } template inline bool data_index_step(T& x, const T& X, Args&&... args) { if (data_index_step(std::forward(args)...)) { x = ((x + 1) == X) ? 0 : (x + 1); return x == 0; } return false; } // Helper struct for bfloat16 vectorization // Useful when you need float as immediate dtype or accumulate dtype using namespace vec; struct Vec2 { Vectorized val0, val1; Vec2(Vectorized v0, Vectorized v1) : val0(v0), val1(v1) {} Vec2(float v) : val0(v), val1(v) {} static Vec2 loadu(const BFloat16* ptr) { Vectorized v0, v1; std::tie(v0, v1) = convert_bfloat16_float(Vectorized::loadu(ptr)); return {v0, v1}; } void store(BFloat16* ptr) const { Vectorized val = convert_float_bfloat16(val0, val1); val.store(ptr); } }; inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; } inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; } template struct VectorizedType { using type = Vectorized; }; template <> struct VectorizedType { using type = Vec2; }; template using VecType = typename VectorizedType::type; // Helper for mixed data type parameter Vec::load inline std::tuple, Vectorized> load2f(const BFloat16* ptr) { return convert_bfloat16_float(Vectorized::loadu(ptr)); } inline std::tuple, Vectorized> load2f(const float* ptr) { using Vec = Vectorized; return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size())); } inline std::tuple, Vectorized> load2f(const BFloat16* ptr, int64_t count) { return convert_bfloat16_float(Vectorized::loadu(ptr, count)); } inline std::tuple, Vectorized> load2f(const float* ptr, int64_t count) { using Vec = Vectorized; if (count > Vec::size()) { return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size())); } else { return std::make_tuple(Vec::loadu(ptr, count), Vec(0)); } } } // namespace namespace utils { template T CeilLog2(const T& x) { if (x <= 2) { return 1; } // Last set bit is floor(log2(x)), floor + 1 is ceil // except when x is an exact powers of 2, so subtract 1 first return static_cast(llvm::findLastSet(static_cast(x) - 1)) + 1; } // matrix transpose: // src has shape of M by N, with leading dimension of ld_src // dst has shape of N by M, with leading dimension of ld_dst template inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) { for (int64_t j = 0; j < N; j++) { for (int64_t i = 0; i < M; i++) { dst[j * ld_dst + i] = src[i * ld_src + j]; } } } #ifdef USE_FBGEMM template <> inline void transpose(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) { TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM."); fbgemm::transpose_simd(M, N, src, ld_src, dst, ld_dst); } #endif template inline void parallel_sparse_csr( const TensorAccessor& crow_acc, const int64_t M, const int64_t nnz, const F& f) { TORCH_CHECK(crow_acc.size(0) == M + 1); // directly parallel on `M` may lead to load imbalance, // statically determine thread partition here to average payload // for each thread. int num_threads = at::get_num_threads(); std::vector thread_splits(num_threads + 1, M); int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads)); thread_splits[0] = 0; int64_t sum = 0; int64_t t = 1; for (const auto m : c10::irange(M)) { int64_t row_start = crow_acc[m]; int64_t row_end = crow_acc[m + 1]; sum += row_end - row_start; if (sum > t * thread_averge_payload) { thread_splits[t] = m; t++; } } // need to restore the last index, // due to rounding error when calculating `thread_averge_payload`. thread_splits[num_threads] = M; at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) { int tid = at::get_thread_num(); int64_t begin = thread_splits[tid]; int64_t end = thread_splits[tid + 1]; f(begin, end); }); } } // namespace utils } // namespace native } // namespace at