CUDAMathCompat.h 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #pragma once
  2. /* This file defines math functions compatible across different gpu
  3. * platforms (currently CUDA and HIP).
  4. */
  5. #if defined(__CUDACC__) || defined(__HIPCC__)
  6. #include <c10/macros/Macros.h>
  7. #include <c10/util/Exception.h>
  8. #ifdef __HIPCC__
  9. #define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE
  10. #else /* __HIPCC__ */
  11. #ifdef __CUDACC_RTC__
  12. #define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE
  13. #else /* __CUDACC_RTC__ */
  14. #define __MATH_FUNCTIONS_DECL__ static inline C10_HOST_DEVICE
  15. #endif /* __CUDACC_RTC__ */
  16. #endif /* __HIPCC__ */
  17. namespace c10 {
  18. namespace cuda {
  19. namespace compat {
  20. __MATH_FUNCTIONS_DECL__ float abs(float x) {
  21. return ::fabsf(x);
  22. }
  23. __MATH_FUNCTIONS_DECL__ double abs(double x) {
  24. return ::fabs(x);
  25. }
  26. __MATH_FUNCTIONS_DECL__ float exp(float x) {
  27. return ::expf(x);
  28. }
  29. __MATH_FUNCTIONS_DECL__ double exp(double x) {
  30. return ::exp(x);
  31. }
  32. __MATH_FUNCTIONS_DECL__ float ceil(float x) {
  33. return ::ceilf(x);
  34. }
  35. __MATH_FUNCTIONS_DECL__ double ceil(double x) {
  36. return ::ceil(x);
  37. }
  38. __MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
  39. #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
  40. return ::copysignf(x, y);
  41. #else
  42. // std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64
  43. // (e.g. Jetson), see PyTorch PR #51834
  44. // This host function needs to be here for the compiler but is never used
  45. TORCH_INTERNAL_ASSERT(
  46. false, "CUDAMathCompat copysign should not run on the CPU");
  47. #endif
  48. }
  49. __MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
  50. #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
  51. return ::copysign(x, y);
  52. #else
  53. // see above
  54. TORCH_INTERNAL_ASSERT(
  55. false, "CUDAMathCompat copysign should not run on the CPU");
  56. #endif
  57. }
  58. __MATH_FUNCTIONS_DECL__ float floor(float x) {
  59. return ::floorf(x);
  60. }
  61. __MATH_FUNCTIONS_DECL__ double floor(double x) {
  62. return ::floor(x);
  63. }
  64. __MATH_FUNCTIONS_DECL__ float log(float x) {
  65. return ::logf(x);
  66. }
  67. __MATH_FUNCTIONS_DECL__ double log(double x) {
  68. return ::log(x);
  69. }
  70. __MATH_FUNCTIONS_DECL__ float log1p(float x) {
  71. return ::log1pf(x);
  72. }
  73. __MATH_FUNCTIONS_DECL__ double log1p(double x) {
  74. return ::log1p(x);
  75. }
  76. __MATH_FUNCTIONS_DECL__ float max(float x, float y) {
  77. return ::fmaxf(x, y);
  78. }
  79. __MATH_FUNCTIONS_DECL__ double max(double x, double y) {
  80. return ::fmax(x, y);
  81. }
  82. __MATH_FUNCTIONS_DECL__ float min(float x, float y) {
  83. return ::fminf(x, y);
  84. }
  85. __MATH_FUNCTIONS_DECL__ double min(double x, double y) {
  86. return ::fmin(x, y);
  87. }
  88. __MATH_FUNCTIONS_DECL__ float pow(float x, float y) {
  89. return ::powf(x, y);
  90. }
  91. __MATH_FUNCTIONS_DECL__ double pow(double x, double y) {
  92. return ::pow(x, y);
  93. }
  94. __MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) {
  95. return ::sincosf(x, sptr, cptr);
  96. }
  97. __MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) {
  98. return ::sincos(x, sptr, cptr);
  99. }
  100. __MATH_FUNCTIONS_DECL__ float sqrt(float x) {
  101. return ::sqrtf(x);
  102. }
  103. __MATH_FUNCTIONS_DECL__ double sqrt(double x) {
  104. return ::sqrt(x);
  105. }
  106. __MATH_FUNCTIONS_DECL__ float rsqrt(float x) {
  107. return ::rsqrtf(x);
  108. }
  109. __MATH_FUNCTIONS_DECL__ double rsqrt(double x) {
  110. return ::rsqrt(x);
  111. }
  112. __MATH_FUNCTIONS_DECL__ float tan(float x) {
  113. return ::tanf(x);
  114. }
  115. __MATH_FUNCTIONS_DECL__ double tan(double x) {
  116. return ::tan(x);
  117. }
  118. __MATH_FUNCTIONS_DECL__ float tanh(float x) {
  119. return ::tanhf(x);
  120. }
  121. __MATH_FUNCTIONS_DECL__ double tanh(double x) {
  122. return ::tanh(x);
  123. }
  124. __MATH_FUNCTIONS_DECL__ float normcdf(float x) {
  125. return ::normcdff(x);
  126. }
  127. __MATH_FUNCTIONS_DECL__ double normcdf(double x) {
  128. return ::normcdf(x);
  129. }
  130. } // namespace compat
  131. } // namespace cuda
  132. } // namespace c10
  133. #endif