#pragma once #include #include #include #include #include #include namespace at::native { inline namespace CPU_CAPABILITY { using namespace vec; #define AT_DISPATCH_REDUCTION_TYPES(op, ...) \ [&] { \ switch (op) { \ case SUM: { \ static constexpr ReductionType reduce = SUM; \ return __VA_ARGS__(); \ } \ case MEAN: { \ static constexpr ReductionType reduce = MEAN; \ return __VA_ARGS__(); \ } \ case MIN: { \ static constexpr ReductionType reduce = MIN; \ return __VA_ARGS__(); \ } \ case MAX: { \ static constexpr ReductionType reduce = MAX; \ return __VA_ARGS__(); \ } \ case PROD: { \ static constexpr ReductionType reduce = PROD; \ return __VA_ARGS__(); \ } \ } \ }() template inline vec_scalar_t init_value() { using acc_t = vec_scalar_t; acc_t val; if (reduce == ReductionType::SUM || reduce == ReductionType::MEAN) { val = static_cast(0); } else if (reduce == ReductionType::PROD) { val = static_cast(1); } else if (reduce == ReductionType::MAX) { val = -std::numeric_limits::infinity(); } else { TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); val = std::numeric_limits::infinity(); } return val; } template inline vec_scalar_t init_value(const c10::optional& initial) { using acc_t = vec_scalar_t; if (initial.has_value()) { return initial.value().to(); } else { return init_value(); } } template inline void init(scalar_t* out, int64_t size, const vec_scalar_t& val) { using Vec = Vectorized>; map( [val](Vec x) { return Vec(val); }, out, out, size); } template inline void init(scalar_t* out, int64_t size, const c10::optional& initial) { using acc_t = vec_scalar_t; acc_t val = init_value(initial); init(out, size, val); } // overload with `include_self`, used by scatter_reduce template inline void init(scalar_t* out, int64_t size, bool include_self = false) { using acc_t = vec_scalar_t; if (!include_self) { acc_t val = init_value(); init(out, size, val); } } template inline scalar_t _max(const scalar_t& x, const scalar_t& y) { return at::_isnan(y) ? y : std::max(x, y); } template inline Vectorized _max(const Vectorized& x, const Vectorized& y) { // vec::maximum propagates NaN return vec::maximum(x, y); } template inline scalar_t _min(const scalar_t& x, const scalar_t& y) { return at::_isnan(y) ? y : std::min(x, y); } template inline Vectorized _min(const Vectorized& x, const Vectorized& y) { // vec::minimum propagates NaN return vec::minimum(x, y); } // for Max and Min, propagate NaN: template inline T update(const T& x, const T& y) { if (reduce == ReductionType::SUM || reduce == ReductionType::MEAN) { return x + y; } else if (reduce == ReductionType::PROD) { return x * y; } else if (reduce == ReductionType::MAX) { return _max(x, y); } else { TORCH_INTERNAL_ASSERT(reduce == ReductionType::MIN); return _min(x, y); } } template inline void update(scalar_t* out, scalar_t* data, int64_t K) { using Vec = vec::Vectorized>; map2( [](Vec x, Vec y) { return update(x, y); }, out, out, data, K); } template inline void write(scalar_t* out, int64_t count, int64_t K) { using Vec = vec::Vectorized>; if (reduce == ReductionType::MEAN) { if (count > 0) { vec::map( [count](Vec x) { return x / Vec(count); }, out, out, K); } } } } // namespace CPU_CAPABILITY } // namespace at::native