block_reduce.cuh 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #pragma once
  2. #include <thrust/tuple.h>
  3. #include <ATen/native/SharedReduceOps.h>
  4. #include <ATen/cuda/DeviceUtils.cuh>
  5. namespace at {
  6. namespace native {
  7. namespace cuda_utils {
  8. constexpr int kCUDABlockReduceNumThreads = 512;
  9. // Algorithmic limitation: BlockReduce does two WarpReduce calls, each
  10. // of which reduces C10_WARP_SIZE elements. So, at most
  11. // C10_WARP_SIZE**2 elements can be reduced at a time.
  12. // NOTE: This is >= the max block size on current hardware anyway (1024).
  13. constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
  14. // Sums `val` accross all threads in a warp.
  15. //
  16. // Assumptions:
  17. // - The size of each block should be a multiple of `C10_WARP_SIZE`
  18. template <typename T>
  19. __inline__ __device__ T WarpReduceSum(T val) {
  20. #pragma unroll
  21. for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
  22. val += WARP_SHFL_DOWN(val, offset);
  23. }
  24. return val;
  25. }
  26. struct Block1D {
  27. static __forceinline__ __device__ int Tid() { return threadIdx.x; }
  28. static __forceinline__ __device__ int Warps() {
  29. return blockDim.x / C10_WARP_SIZE;
  30. }
  31. };
  32. struct Block2D {
  33. static __forceinline__ __device__ int Tid() {
  34. return threadIdx.x + threadIdx.y * blockDim.x;
  35. }
  36. static __forceinline__ __device__ int Warps() {
  37. return blockDim.x * blockDim.y / C10_WARP_SIZE;
  38. }
  39. };
  40. // Sums `val` across all threads in a block.
  41. //
  42. // Warning: the return value is only valid for thread 0.
  43. // Assumptions:
  44. // - The size of each block should be a multiple of `C10_WARP_SIZE`
  45. // - `shared` should be a pointer to shared memory with size of, at least,
  46. // `sizeof(T) * number_of_warps`
  47. template <typename T, typename B = Block1D>
  48. __inline__ __device__ T BlockReduceSum(T val, T* shared) {
  49. const int tid = B::Tid();
  50. const int lid = tid % C10_WARP_SIZE;
  51. const int wid = tid / C10_WARP_SIZE;
  52. val = WarpReduceSum(val);
  53. __syncthreads(); // prevent races when BlockReduces are called in a row.
  54. if (lid == 0) {
  55. shared[wid] = val;
  56. }
  57. __syncthreads();
  58. val = (tid < B::Warps()) ? shared[lid] : T(0);
  59. if (wid == 0) {
  60. val = WarpReduceSum(val);
  61. }
  62. return val;
  63. }
  64. template <typename T, class ReduceOp>
  65. __inline__ __device__ T WarpReduce(T val, const ReduceOp& op) {
  66. #pragma unroll
  67. for (int offset = (C10_WARP_SIZE >> 1); offset > 0; offset >>= 1) {
  68. val = op.combine(val, op.warp_shfl_down(val, offset));
  69. }
  70. return val;
  71. }
  72. template <typename T, class ReduceOp, typename B = Block1D>
  73. __inline__ __device__ T
  74. BlockReduce(T val, const ReduceOp& op, const T& identity_element, T* shared) {
  75. const int tid = B::Tid();
  76. const int lid = tid % C10_WARP_SIZE;
  77. const int wid = tid / C10_WARP_SIZE;
  78. val = WarpReduce(val, op);
  79. __syncthreads(); // prevent races when BlockReduces are called in a row.
  80. if (lid == 0) {
  81. shared[wid] = val;
  82. }
  83. __syncthreads();
  84. val = (tid < B::Warps()) ? shared[lid] : identity_element;
  85. if (wid == 0) {
  86. val = WarpReduce(val, op);
  87. }
  88. return val;
  89. }
  90. } // namespace cuda_utils
  91. } // namespace native
  92. } // namespace at