ReduceUtils.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #pragma once
  2. #include <ATen/Parallel.h>
  3. #include <ATen/NumericUtils.h>
  4. #include <ATen/cpu/vec/vec.h>
  5. #include <ATen/cpu/vec/functional.h>
  6. #include <ATen/native/ReductionType.h>
  7. #include <c10/util/irange.h>
  8. namespace at::native {
  9. inline namespace CPU_CAPABILITY {
  10. using namespace vec;
  11. #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \
  12. [&] { \
  13. switch (op) { \
  14. case SUM: { \
  15. static constexpr ReductionType reduce = SUM; \
  16. return __VA_ARGS__(); \
  17. } \
  18. case MEAN: { \
  19. static constexpr ReductionType reduce = MEAN; \
  20. return __VA_ARGS__(); \
  21. } \
  22. case MIN: { \
  23. static constexpr ReductionType reduce = MIN; \
  24. return __VA_ARGS__(); \
  25. } \
  26. case MAX: { \
  27. static constexpr ReductionType reduce = MAX; \
  28. return __VA_ARGS__(); \
  29. } \
  30. case PROD: { \
  31. static constexpr ReductionType reduce = PROD; \
  32. return __VA_ARGS__(); \
  33. } \
  34. } \
  35. }()
  36. template <typename scalar_t, ReductionType reduce>
  37. inline vec_scalar_t<scalar_t> init_value() {
  38. using acc_t = vec_scalar_t<scalar_t>;
  39. acc_t val;
  40. if (reduce == ReductionType::SUM ||
  41. reduce == ReductionType::MEAN) {
  42. val = static_cast<acc_t>(0);
  43. } else if (reduce == ReductionType::PROD) {
  44. val = static_cast<acc_t>(1);
  45. } else if (reduce == ReductionType::MAX) {
  46. val = -std::numeric_limits<acc_t>::infinity();
  47. } else {
  48. TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
  49. val = std::numeric_limits<acc_t>::infinity();
  50. }
  51. return val;
  52. }
  53. template <typename scalar_t, ReductionType reduce>
  54. inline vec_scalar_t<scalar_t> init_value(const c10::optional<Scalar>& initial) {
  55. using acc_t = vec_scalar_t<scalar_t>;
  56. if (initial.has_value()) {
  57. return initial.value().to<acc_t>();
  58. } else {
  59. return init_value<scalar_t, reduce>();
  60. }
  61. }
  62. template <typename scalar_t>
  63. inline void init(scalar_t* out, int64_t size, const vec_scalar_t<scalar_t>& val) {
  64. using Vec = Vectorized<vec_scalar_t<scalar_t>>;
  65. map<scalar_t>(
  66. [val](Vec x) { return Vec(val); },
  67. out,
  68. out,
  69. size);
  70. }
  71. template <typename scalar_t, ReductionType reduce>
  72. inline void init(scalar_t* out, int64_t size, const c10::optional<Scalar>& initial) {
  73. using acc_t = vec_scalar_t<scalar_t>;
  74. acc_t val = init_value<scalar_t, reduce>(initial);
  75. init(out, size, val);
  76. }
  77. // overload with `include_self`, used by scatter_reduce
  78. template <typename scalar_t, ReductionType reduce>
  79. inline void init(scalar_t* out, int64_t size, bool include_self = false) {
  80. using acc_t = vec_scalar_t<scalar_t>;
  81. if (!include_self) {
  82. acc_t val = init_value<scalar_t, reduce>();
  83. init(out, size, val);
  84. }
  85. }
  86. template <typename scalar_t>
  87. inline scalar_t _max(const scalar_t& x, const scalar_t& y) {
  88. return at::_isnan(y) ? y : std::max(x, y);
  89. }
  90. template <typename scalar_t>
  91. inline Vectorized<scalar_t> _max(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
  92. // vec::maximum propagates NaN
  93. return vec::maximum(x, y);
  94. }
  95. template <typename scalar_t>
  96. inline scalar_t _min(const scalar_t& x, const scalar_t& y) {
  97. return at::_isnan(y) ? y : std::min(x, y);
  98. }
  99. template <typename scalar_t>
  100. inline Vectorized<scalar_t> _min(const Vectorized<scalar_t>& x, const Vectorized<scalar_t>& y) {
  101. // vec::minimum propagates NaN
  102. return vec::minimum(x, y);
  103. }
  104. // for Max and Min, propagate NaN:
  105. template <typename T, ReductionType reduce>
  106. inline T update(const T& x, const T& y) {
  107. if (reduce == ReductionType::SUM ||
  108. reduce == ReductionType::MEAN) {
  109. return x + y;
  110. } else if (reduce == ReductionType::PROD) {
  111. return x * y;
  112. } else if (reduce == ReductionType::MAX) {
  113. return _max(x, y);
  114. } else {
  115. TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN);
  116. return _min(x, y);
  117. }
  118. }
  119. template <typename scalar_t, ReductionType reduce>
  120. inline void update(scalar_t* out, scalar_t* data, int64_t K) {
  121. using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
  122. map2<scalar_t>(
  123. [](Vec x, Vec y) { return update<Vec, reduce>(x, y); },
  124. out,
  125. out,
  126. data,
  127. K);
  128. }
  129. template <typename scalar_t, ReductionType reduce>
  130. inline void write(scalar_t* out, int64_t count, int64_t K) {
  131. using Vec = vec::Vectorized<vec_scalar_t<scalar_t>>;
  132. if (reduce == ReductionType::MEAN) {
  133. if (count > 0) {
  134. vec::map<scalar_t>(
  135. [count](Vec x) { return x / Vec(count); },
  136. out,
  137. out,
  138. K);
  139. }
  140. }
  141. }
  142. } // namespace CPU_CAPABILITY
  143. } // namespace at::native