NumericUtils.h 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #pragma once
  2. #ifdef __HIPCC__
  3. #include <hip/hip_runtime.h>
  4. #endif
  5. #include <c10/macros/Macros.h>
  6. #include <c10/util/BFloat16.h>
  7. #include <c10/util/Half.h>
  8. #include <c10/util/complex.h>
  9. #include <cmath>
  10. #include <type_traits>
  11. namespace at {
  12. // std::isnan isn't performant to use on integral types; it will
  13. // (uselessly) convert to floating point and then do the test.
  14. // This function is.
  15. template <
  16. typename T,
  17. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  18. inline C10_HOST_DEVICE bool _isnan(T /*val*/) {
  19. return false;
  20. }
  21. template <
  22. typename T,
  23. typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
  24. inline C10_HOST_DEVICE bool _isnan(T val) {
  25. #if defined(__CUDACC__) || defined(__HIPCC__)
  26. return ::isnan(val);
  27. #else
  28. return std::isnan(val);
  29. #endif
  30. }
  31. template <
  32. typename T,
  33. typename std::enable_if<c10::is_complex<T>::value, int>::type = 0>
  34. inline bool _isnan(T val) {
  35. return std::isnan(val.real()) || std::isnan(val.imag());
  36. }
  37. template <
  38. typename T,
  39. typename std::enable_if<std::is_same<T, at::Half>::value, int>::type = 0>
  40. inline C10_HOST_DEVICE bool _isnan(T val) {
  41. return at::_isnan(static_cast<float>(val));
  42. }
  43. template <
  44. typename T,
  45. typename std::enable_if<std::is_same<T, at::BFloat16>::value, int>::type =
  46. 0>
  47. inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
  48. return at::_isnan(static_cast<float>(val));
  49. }
  50. inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) {
  51. return at::_isnan(static_cast<float>(val));
  52. }
  53. // std::isinf isn't performant to use on integral types; it will
  54. // (uselessly) convert to floating point and then do the test.
  55. // This function is.
  56. template <
  57. typename T,
  58. typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
  59. inline C10_HOST_DEVICE bool _isinf(T /*val*/) {
  60. return false;
  61. }
  62. template <
  63. typename T,
  64. typename std::enable_if<std::is_floating_point<T>::value, int>::type = 0>
  65. inline C10_HOST_DEVICE bool _isinf(T val) {
  66. #if defined(__CUDACC__) || defined(__HIPCC__)
  67. return ::isinf(val);
  68. #else
  69. return std::isinf(val);
  70. #endif
  71. }
  72. inline C10_HOST_DEVICE bool _isinf(at::Half val) {
  73. return at::_isinf(static_cast<float>(val));
  74. }
  75. inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) {
  76. return at::_isinf(static_cast<float>(val));
  77. }
  78. template <typename T>
  79. C10_HOST_DEVICE inline T exp(T x) {
  80. static_assert(
  81. !std::is_same<T, double>::value,
  82. "this template must be used with float or less precise type");
  83. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  84. // use __expf fast approximation for peak bandwidth
  85. return __expf(x);
  86. #else
  87. return ::exp(x);
  88. #endif
  89. }
  90. template <>
  91. C10_HOST_DEVICE inline double exp<double>(double x) {
  92. return ::exp(x);
  93. }
  94. template <typename T>
  95. C10_HOST_DEVICE inline T log(T x) {
  96. static_assert(
  97. !std::is_same<T, double>::value,
  98. "this template must be used with float or less precise type");
  99. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  100. // use __logf fast approximation for peak bandwidth
  101. return __logf(x);
  102. #else
  103. return ::log(x);
  104. #endif
  105. }
  106. template <>
  107. C10_HOST_DEVICE inline double log<double>(double x) {
  108. return ::log(x);
  109. }
  110. template <typename T>
  111. C10_HOST_DEVICE inline T log1p(T x) {
  112. static_assert(
  113. !std::is_same<T, double>::value,
  114. "this template must be used with float or less precise type");
  115. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  116. // use __logf fast approximation for peak bandwidth
  117. // NOTE: There is no __log1pf so unfortunately we lose precision.
  118. return __logf(1.0f + x);
  119. #else
  120. return ::log1p(x);
  121. #endif
  122. }
  123. template <>
  124. C10_HOST_DEVICE inline double log1p<double>(double x) {
  125. return ::log1p(x);
  126. }
  127. template <typename T>
  128. C10_HOST_DEVICE inline T tan(T x) {
  129. static_assert(
  130. !std::is_same<T, double>::value,
  131. "this template must be used with float or less precise type");
  132. #if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__)
  133. // use __tanf fast approximation for peak bandwidth
  134. return __tanf(x);
  135. #else
  136. return ::tan(x);
  137. #endif
  138. }
  139. template <>
  140. C10_HOST_DEVICE inline double tan<double>(double x) {
  141. return ::tan(x);
  142. }
  143. } // namespace at