utils.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <ATen/cpu/vec/vec.h>
  4. #include <c10/util/llvmMathExtras.h>
  5. #ifdef USE_FBGEMM
  6. #include <fbgemm/Fbgemm.h>
  7. #endif
  8. namespace at {
  9. namespace native {
  10. inline namespace CPU_CAPABILITY {
  11. template <typename T>
  12. inline T data_index_init(T offset) {
  13. return offset;
  14. }
  15. template <typename T, typename... Args>
  16. inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
  17. offset = data_index_init(offset, std::forward<Args>(args)...);
  18. x = offset % X;
  19. return offset / X;
  20. }
  21. inline bool data_index_step() {
  22. return true;
  23. }
  24. template <typename T, typename... Args>
  25. inline bool data_index_step(T& x, const T& X, Args&&... args) {
  26. if (data_index_step(std::forward<Args>(args)...)) {
  27. x = ((x + 1) == X) ? 0 : (x + 1);
  28. return x == 0;
  29. }
  30. return false;
  31. }
  32. // Helper struct for bfloat16 vectorization
  33. // Useful when you need float as immediate dtype or accumulate dtype
  34. using namespace vec;
  35. struct Vec2 {
  36. Vectorized<float> val0, val1;
  37. Vec2(Vectorized<float> v0, Vectorized<float> v1) : val0(v0), val1(v1) {}
  38. Vec2(float v) : val0(v), val1(v) {}
  39. static Vec2 loadu(const BFloat16* ptr) {
  40. Vectorized<float> v0, v1;
  41. std::tie(v0, v1) = convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
  42. return {v0, v1};
  43. }
  44. void store(BFloat16* ptr) const {
  45. Vectorized<BFloat16> val = convert_float_bfloat16(val0, val1);
  46. val.store(ptr);
  47. }
  48. };
  49. inline Vec2 operator+(const Vec2& a, const Vec2& b) { return {a.val0 + b.val0, a.val1 + b.val1}; }
  50. inline Vec2 operator*(const Vec2& a, const Vec2& b) { return {a.val0 * b.val0, a.val1 * b.val1}; }
  51. template <typename scalar_t> struct VectorizedType { using type = Vectorized<scalar_t>; };
  52. template <> struct VectorizedType<BFloat16> { using type = Vec2; };
  53. template <typename scalar_t> using VecType = typename VectorizedType<scalar_t>::type;
  54. // Helper for mixed data type parameter Vec::load
  55. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr) {
  56. return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr));
  57. }
  58. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr) {
  59. using Vec = Vectorized<float>;
  60. return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size()));
  61. }
  62. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const BFloat16* ptr, int64_t count) {
  63. return convert_bfloat16_float(Vectorized<BFloat16>::loadu(ptr, count));
  64. }
  65. inline std::tuple<Vectorized<float>, Vectorized<float>> load2f(const float* ptr, int64_t count) {
  66. using Vec = Vectorized<float>;
  67. if (count > Vec::size()) {
  68. return std::make_tuple(Vec::loadu(ptr), Vec::loadu(ptr + Vec::size(), count - Vec::size()));
  69. } else {
  70. return std::make_tuple(Vec::loadu(ptr, count), Vec(0));
  71. }
  72. }
  73. } // namespace
  74. namespace utils {
  75. template <typename T>
  76. T CeilLog2(const T& x) {
  77. if (x <= 2) {
  78. return 1;
  79. }
  80. // Last set bit is floor(log2(x)), floor + 1 is ceil
  81. // except when x is an exact powers of 2, so subtract 1 first
  82. return static_cast<T>(llvm::findLastSet(static_cast<uint64_t>(x) - 1)) + 1;
  83. }
  84. // matrix transpose:
  85. // src has shape of M by N, with leading dimension of ld_src
  86. // dst has shape of N by M, with leading dimension of ld_dst
  87. template <typename T>
  88. inline void transpose(int64_t M, int64_t N, const T* src, int64_t ld_src, T* dst, int64_t ld_dst) {
  89. for (int64_t j = 0; j < N; j++) {
  90. for (int64_t i = 0; i < M; i++) {
  91. dst[j * ld_dst + i] = src[i * ld_src + j];
  92. }
  93. }
  94. }
  95. #ifdef USE_FBGEMM
  96. template <>
  97. inline void transpose<float>(int64_t M, int64_t N, const float* src, int64_t ld_src, float* dst, int64_t ld_dst) {
  98. TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
  99. fbgemm::transpose_simd<float>(M, N, src, ld_src, dst, ld_dst);
  100. }
  101. #endif
  102. template <typename index_t, typename F>
  103. inline void parallel_sparse_csr(
  104. const TensorAccessor<index_t, 1>& crow_acc,
  105. const int64_t M,
  106. const int64_t nnz,
  107. const F& f) {
  108. TORCH_CHECK(crow_acc.size(0) == M + 1);
  109. // directly parallel on `M` may lead to load imbalance,
  110. // statically determine thread partition here to average payload
  111. // for each thread.
  112. int num_threads = at::get_num_threads();
  113. std::vector<int64_t> thread_splits(num_threads + 1, M);
  114. int64_t thread_averge_payload = std::max((int64_t)1, divup(nnz, num_threads));
  115. thread_splits[0] = 0;
  116. int64_t sum = 0;
  117. int64_t t = 1;
  118. for (const auto m : c10::irange(M)) {
  119. int64_t row_start = crow_acc[m];
  120. int64_t row_end = crow_acc[m + 1];
  121. sum += row_end - row_start;
  122. if (sum > t * thread_averge_payload) {
  123. thread_splits[t] = m;
  124. t++;
  125. }
  126. }
  127. // need to restore the last index,
  128. // due to rounding error when calculating `thread_averge_payload`.
  129. thread_splits[num_threads] = M;
  130. at::parallel_for(0, num_threads, 1, [&](int64_t cbegin, int64_t cend) {
  131. int tid = at::get_thread_num();
  132. int64_t begin = thread_splits[tid];
  133. int64_t end = thread_splits[tid + 1];
  134. f(begin, end);
  135. });
  136. }
  137. } // namespace utils
  138. } // namespace native
  139. } // namespace at