DeviceUtils.cuh 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #pragma once
  2. #include <cuda.h>
  3. #include <c10/util/complex.h>
  4. #include <c10/util/Half.h>
  5. __device__ __forceinline__ unsigned int ACTIVE_MASK()
  6. {
  7. #if !defined(USE_ROCM)
  8. return __activemask();
  9. #else
  10. // will be ignored anyway
  11. return 0xffffffff;
  12. #endif
  13. }
  14. #if defined(USE_ROCM)
  15. __device__ __forceinline__ unsigned long long int WARP_BALLOT(int predicate)
  16. {
  17. return __ballot(predicate);
  18. }
  19. #else
  20. __device__ __forceinline__ unsigned int WARP_BALLOT(int predicate, unsigned int mask = 0xffffffff)
  21. {
  22. #if !defined(USE_ROCM)
  23. return __ballot_sync(mask, predicate);
  24. #else
  25. return __ballot(predicate);
  26. #endif
  27. }
  28. #endif
  29. template <typename T>
  30. __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
  31. {
  32. #if !defined(USE_ROCM)
  33. return __shfl_xor_sync(mask, value, laneMask, width);
  34. #else
  35. return __shfl_xor(value, laneMask, width);
  36. #endif
  37. }
  38. template <typename T>
  39. __device__ __forceinline__ T WARP_SHFL(T value, int srcLane, int width = warpSize, unsigned int mask = 0xffffffff)
  40. {
  41. #if !defined(USE_ROCM)
  42. return __shfl_sync(mask, value, srcLane, width);
  43. #else
  44. return __shfl(value, srcLane, width);
  45. #endif
  46. }
  47. template <typename T>
  48. __device__ __forceinline__ T WARP_SHFL_UP(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  49. {
  50. #if !defined(USE_ROCM)
  51. return __shfl_up_sync(mask, value, delta, width);
  52. #else
  53. return __shfl_up(value, delta, width);
  54. #endif
  55. }
  56. template <typename T>
  57. __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  58. {
  59. #if !defined(USE_ROCM)
  60. return __shfl_down_sync(mask, value, delta, width);
  61. #else
  62. return __shfl_down(value, delta, width);
  63. #endif
  64. }
  65. #if defined(USE_ROCM)
  66. template<>
  67. __device__ __forceinline__ int64_t WARP_SHFL_DOWN<int64_t>(int64_t value, unsigned int delta, int width , unsigned int mask)
  68. {
  69. //(HIP doesn't support int64_t). Trick from https://devblogs.nvidia.com/faster-parallel-reductions-kepler/
  70. int2 a = *reinterpret_cast<int2*>(&value);
  71. a.x = __shfl_down(a.x, delta);
  72. a.y = __shfl_down(a.y, delta);
  73. return *reinterpret_cast<int64_t*>(&a);
  74. }
  75. #endif
  76. template<>
  77. __device__ __forceinline__ c10::Half WARP_SHFL_DOWN<c10::Half>(c10::Half value, unsigned int delta, int width, unsigned int mask)
  78. {
  79. return c10::Half(WARP_SHFL_DOWN<unsigned short>(value.x, delta, width, mask), c10::Half::from_bits_t{});
  80. }
  81. template <typename T>
  82. __device__ __forceinline__ c10::complex<T> WARP_SHFL_DOWN(c10::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  83. {
  84. #if !defined(USE_ROCM)
  85. return c10::complex<T>(
  86. __shfl_down_sync(mask, value.real_, delta, width),
  87. __shfl_down_sync(mask, value.imag_, delta, width));
  88. #else
  89. return c10::complex<T>(
  90. __shfl_down(value.real_, delta, width),
  91. __shfl_down(value.imag_, delta, width));
  92. #endif
  93. }
  94. /**
  95. * For CC 3.5+, perform a load using __ldg
  96. */
  97. template <typename T>
  98. __device__ __forceinline__ T doLdg(const T* p) {
  99. #if __CUDA_ARCH__ >= 350 && !defined(USE_ROCM)
  100. return __ldg(p);
  101. #else
  102. return *p;
  103. #endif
  104. }