Pow.cuh 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #pragma once
  2. #include <ATen/native/Pow.h>
  3. #include <c10/core/Scalar.h>
  4. namespace at { namespace native {
  5. namespace {
  6. // SFINAE doesn't work well with NVCC under Windows for math functions like pow and sqrt.
  7. // So we need to define the functions with the explicit function signatures.
  8. // As for pow, the following signatures are defined as the device function:
  9. // pow(float, int)
  10. // pow(double, int)
  11. // pow(float, float)
  12. // pow(double, double)
  13. #ifdef _MSC_VER
  14. // Functions for pow
  15. // pow for at::Half
  16. static inline __host__ __device__ at::Half pow_(at::Half base, at::Half exp) {
  17. return static_cast<at::Half>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
  18. }
  19. // pow for at::BFloat16
  20. static inline __host__ __device__ at::BFloat16 pow_(at::BFloat16 base, at::BFloat16 exp) {
  21. return static_cast<at::BFloat16>(std::pow(static_cast<float>(base), static_cast<float>(exp)));
  22. }
  23. // pow (floating, floating/int)
  24. template <typename Base_type, typename Exp_type>
  25. static inline __host__ __device__ typename std::enable_if<std::is_floating_point<Base_type>::value && (std::is_same<Base_type, Exp_type>::value || std::is_same<Exp_type, int>::value), Base_type>::type
  26. pow_(Base_type base, Exp_type exp) {
  27. return std::pow(base, exp);
  28. }
  29. // pow (Otherwise)
  30. template <typename Base_type, typename Exp_type>
  31. static inline __host__ __device__ typename std::enable_if<!std::is_same<Base_type, Exp_type>::value && !std::is_same<Exp_type, int>::value, Base_type>::type
  32. pow_(Base_type base, Exp_type exp) {
  33. return static_cast<Base_type>(std::pow(static_cast<double>(base), static_cast<double>(exp)));
  34. }
  35. #else
  36. template <typename Base_type, typename Exp_type>
  37. static inline __host__ __device__ Base_type pow_(Base_type base, Exp_type exp) {
  38. return ::pow(base, exp);
  39. }
  40. #endif
  41. template <typename T>
  42. static inline __host__ __device__ std::enable_if_t<std::is_integral<T>::value, T> pow_(
  43. T base, T exp) {
  44. return at::native::powi(base, exp);
  45. }
  46. template <typename T>
  47. static inline __host__ __device__ c10::complex<T> pow_(c10::complex<T> base, c10::complex<T> exp) {
  48. return c10_complex_math::pow(base, exp);
  49. }
  50. } // namespace
  51. }} // namespace at::native