#pragma once #include #include namespace at { namespace native { namespace { // SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt. // So we need to define the functions with the explicit function signatures. // As for pow, the following signatures are defined as the device function: // pow(float, int) // pow(double, int) // pow(float, float) // pow(double, double) #ifdef _MSC_VER // Functions for pow // pow for at::Half static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) { return static_cast(std::pow(static_cast(base), static_cast(exp))); } // pow for at::BFloat16 static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) { return static_cast(std::pow(static_cast(base), static_cast(exp))); } // pow (floating, floating/int) template static inline __host__ __device__ typename std::enable_if::value && (std::is_same::value || std::is_same::value), Base_type>::type pow_(Base_type base, Exp_type exp) { return std::pow(base, exp); } // pow (Otherwise) template static inline __host__ __device__ typename std::enable_if::value && !std::is_same::value, Base_type>::type pow_(Base_type base, Exp_type exp) { return static_cast(std::pow(static_cast(base), static_cast(exp))); } #else template static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) { return ::pow(base, exp); } #endif template static inline __host__ __device__ std::enable_if_t::value, T> pow_( T base, T exp) { return at::native::powi(base, exp); } template static inline __host__ __device__ c10::complex pow_(c10::complex base, c10::complex exp) { return c10_complex_math::pow(base, exp); } } // namespace }} // namespace at::native