moments_utils.h 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. #pragma once
  2. #include <array>
  3. #include <cstring>
  4. #include <numeric>
  5. #include <utility>
  6. #include <vector>
  7. #include <ATen/Parallel.h>
  8. #include <ATen/OpMathType.h>
  9. #include <ATen/cpu/vec/vec.h>
  10. #include <ATen/native/cpu/utils.h>
  11. #include <c10/util/SmallVector.h>
  12. #include <c10/util/irange.h>
  13. namespace at {
  14. namespace native {
  15. inline namespace CPU_CAPABILITY {
  16. template<typename T> using acc_t = at::opmath_type<T>;
  17. constexpr int64_t kChunkSize = 16;
  18. template <typename T>
  19. void AddMoments(
  20. int64_t m0_add,
  21. const T& m1_add,
  22. const T& m2_add,
  23. int64_t& m0,
  24. T& m1,
  25. T& m2) {
  26. const int64_t n = m0 + m0_add;
  27. const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
  28. const T delta = m1_add - m1;
  29. m1 += c * delta;
  30. m2 += m2_add + delta * delta * c * static_cast<T>(m0);
  31. m0 = n;
  32. }
  33. template <typename T>
  34. C10_ALWAYS_INLINE void AddMomentsVec(
  35. int64_t m0_add,
  36. const vec::Vectorized<T>& m1_add,
  37. const vec::Vectorized<T>& m2_add,
  38. int64_t& m0,
  39. vec::Vectorized<T>& m1,
  40. vec::Vectorized<T>& m2) {
  41. using Vec = vec::Vectorized<T>;
  42. const int64_t n = m0 + m0_add;
  43. const T c = n == 0 ? static_cast<T>(0) : static_cast<T>(m0_add) / static_cast<T>(n);
  44. const Vec c_vec(c);
  45. const Vec delta = m1_add - m1;
  46. m1 += c_vec * delta;
  47. m2 += m2_add + delta * delta * c_vec * Vec(static_cast<T>(m0));
  48. m0 = n;
  49. }
  50. template <typename T>
  51. inline void UpdateMomentsVec(
  52. int64_t m0,
  53. const T* X_ptr,
  54. const std::array<vec::Vectorized<acc_t<T>>, kChunkSize>& c_vecs,
  55. int64_t& m0_stk0,
  56. vec::Vectorized<acc_t<T>>& m1_stk0,
  57. vec::Vectorized<acc_t<T>>& m2_stk0) {
  58. using Vec = vec::Vectorized<acc_t<T>>;
  59. Vec m1_vec(0);
  60. Vec m2_vec(0);
  61. for (const auto j : c10::irange(m0)) {
  62. const Vec x_vec = Vec::loadu(X_ptr + j * Vec::size());
  63. const Vec delta_vec = x_vec - m1_vec;
  64. m1_vec += delta_vec * c_vecs[j];
  65. m2_vec += delta_vec * (x_vec - m1_vec);
  66. }
  67. AddMomentsVec(m0, m1_vec, m2_vec, m0_stk0, m1_stk0, m2_stk0);
  68. }
  69. // each bfloat16 vector will be converted to two float vectors,
  70. // and accumulated successively on m1_stk0/m2_stk0.
  71. template <>
  72. inline void UpdateMomentsVec<BFloat16>(
  73. int64_t m0,
  74. const BFloat16* X_ptr,
  75. const std::array<vec::Vectorized<float>, kChunkSize>& c_vecs,
  76. int64_t& m0_stk0,
  77. vec::Vectorized<float>& m1_stk0,
  78. vec::Vectorized<float>& m2_stk0) {
  79. using bVec = vec::Vectorized<BFloat16>;
  80. using fVec = vec::Vectorized<float>;
  81. fVec m1_fvec0(0), m1_fvec1(0);
  82. fVec m2_fvec0(0), m2_fvec1(0);
  83. for (const auto j : c10::irange(m0)) {
  84. const bVec x_bvec = bVec::loadu(X_ptr + j * bVec::size());
  85. fVec x_fvec0, x_fvec1;
  86. std::tie(x_fvec0, x_fvec1) = convert_bfloat16_float(x_bvec);
  87. const fVec delta_fvec0 = x_fvec0 - m1_fvec0;
  88. const fVec delta_fvec1 = x_fvec1 - m1_fvec1;
  89. m1_fvec0 += delta_fvec0 * c_vecs[j];
  90. m1_fvec1 += delta_fvec1 * c_vecs[j];
  91. m2_fvec0 += delta_fvec0 * (x_fvec0 - m1_fvec0);
  92. m2_fvec1 += delta_fvec1 * (x_fvec1 - m1_fvec1);
  93. }
  94. AddMomentsVec(m0, m1_fvec0, m2_fvec0, m0_stk0, m1_stk0, m2_stk0);
  95. AddMomentsVec(m0, m1_fvec1, m2_fvec1, m0_stk0, m1_stk0, m2_stk0);
  96. }
  97. // Compute rowwise moments by Welford algorithm and cascade sum to improve
  98. // numerical stability.
  99. // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
  100. // https://en.wikipedia.org/wiki/Pairwise_summation
  101. template <typename T, int64_t kMaxDepth>
  102. std::pair<acc_t<T>, acc_t<T>> RowwiseMomentsImpl(const T* X, int64_t N, int64_t ddof = 0) {
  103. using T_ACC = acc_t<T>;
  104. constexpr int64_t kVecSize = vec::Vectorized<T>::size();
  105. constexpr int64_t kAccVecSize = vec::Vectorized<T_ACC>::size();
  106. const int64_t n = N / kVecSize;
  107. const int64_t m = divup(n, kChunkSize);
  108. const int64_t depth = utils::CeilLog2(m);
  109. using Vec = vec::Vectorized<T_ACC>;
  110. const Vec kZeroVec(T_ACC(0));
  111. c10::SmallVector<int64_t, kMaxDepth> m0_stk(depth, 0);
  112. c10::SmallVector<Vec, kMaxDepth> m1_stk(depth, kZeroVec);
  113. c10::SmallVector<Vec, kMaxDepth> m2_stk(depth, kZeroVec);
  114. for (const auto i : c10::irange(m)) {
  115. const T* X_ptr = X + i * kChunkSize * kVecSize;
  116. const int64_t m0 = std::min(kChunkSize, n - i * kChunkSize);
  117. static std::array<Vec, kChunkSize> c_vecs = ([]() {
  118. std::array<Vec, kChunkSize> result;
  119. for (const auto i : c10::irange(kChunkSize)) {
  120. result[i] = Vec(T_ACC(1) / static_cast<T_ACC>(i + 1));
  121. }
  122. return result;
  123. })();
  124. UpdateMomentsVec(m0, X_ptr, c_vecs, m0_stk[0], m1_stk[0], m2_stk[0]);
  125. int64_t mask = i + 1;
  126. for (int64_t j = 1; j < depth && (mask & 1) == 0; ++j) {
  127. AddMomentsVec(
  128. m0_stk[j - 1],
  129. m1_stk[j - 1],
  130. m2_stk[j - 1],
  131. m0_stk[j],
  132. m1_stk[j],
  133. m2_stk[j]);
  134. m0_stk[j - 1] = 0;
  135. m1_stk[j - 1] = kZeroVec;
  136. m2_stk[j - 1] = kZeroVec;
  137. mask >>= 1;
  138. }
  139. }
  140. for (const auto i : c10::irange(1, depth)) {
  141. AddMomentsVec(
  142. m0_stk[i], m1_stk[i], m2_stk[i], m0_stk[0], m1_stk[0], m2_stk[0]);
  143. }
  144. std::array<T_ACC, kAccVecSize> m1_arr{};
  145. std::array<T_ACC, kAccVecSize> m2_arr{};
  146. m1_stk[0].store(m1_arr.data());
  147. m2_stk[0].store(m2_arr.data());
  148. int64_t m0 = 0;
  149. T_ACC m1 = 0;
  150. T_ACC m2 = 0;
  151. for (int64_t i = n * kVecSize; i < N; ++i) {
  152. T_ACC x = static_cast<T_ACC>(X[i]);
  153. const T_ACC delta = x - m1;
  154. ++m0;
  155. m1 += delta / static_cast<T_ACC>(m0);
  156. m2 += delta * (x - m1);
  157. }
  158. // for BFloat16, each vector in m1_arr/m2_arr holds 2*n accumulated result
  159. int64_t m0_add = n * kVecSize / kAccVecSize;
  160. for (const auto i : c10::irange(kAccVecSize)) {
  161. AddMoments(m0_add, m1_arr[i], m2_arr[i], m0, m1, m2);
  162. }
  163. return std::make_pair(m1, m2 / static_cast<T_ACC>(N - ddof));
  164. }
  165. template <typename T>
  166. std::pair<acc_t<T>, acc_t<T>> RowwiseMoments(const T* X, int64_t N, int64_t ddof = 0) {
  167. using Vec = vec::Vectorized<T>;
  168. constexpr int64_t kVecSize = Vec::size();
  169. const int64_t n = N / kVecSize;
  170. const int64_t m = divup(n, kChunkSize);
  171. const int64_t depth = utils::CeilLog2(m);
  172. if (depth <= 4) {
  173. return RowwiseMomentsImpl<T, 4>(X, N, ddof);
  174. } else if (depth <= 8) {
  175. return RowwiseMomentsImpl<T, 8>(X, N, ddof);
  176. } else if (depth <= 16) {
  177. return RowwiseMomentsImpl<T, 16>(X, N, ddof);
  178. } else if (depth <= 32) {
  179. return RowwiseMomentsImpl<T, 32>(X, N, ddof);
  180. } else {
  181. return RowwiseMomentsImpl<T, 64>(X, N, ddof);
  182. }
  183. }
  184. } // namespace CPU_CAPABILITY
  185. } // namespace native
  186. } // namespace at