KernelUtils.h 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. #pragma once
  2. #include <limits>
  3. #include <c10/util/Exception.h>
  4. namespace at { namespace cuda { namespace detail {
  5. // CUDA: grid stride looping
  6. //
  7. // int64_t _i_n_d_e_x specifically prevents overflow in the loop increment.
  8. // If input.numel() < INT_MAX, _i_n_d_e_x < INT_MAX, except after the final
  9. // iteration of the loop where _i_n_d_e_x += blockDim.x * gridDim.x can be
  10. // greater than INT_MAX. But in that case _i_n_d_e_x >= n, so there are no
  11. // further iterations and the overflowed value in i=_i_n_d_e_x is not used.
  12. #define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
  13. int64_t _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
  14. for (index_type i=_i_n_d_e_x; _i_n_d_e_x < (n); _i_n_d_e_x+=blockDim.x * gridDim.x, i=_i_n_d_e_x)
  15. #define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
  16. // Use 1024 threads per block, which requires cuda sm_2x or above
  17. constexpr int CUDA_NUM_THREADS = 1024;
  18. // CUDA: number of blocks for threads.
  19. inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block=CUDA_NUM_THREADS) {
  20. TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
  21. constexpr int64_t max_int = std::numeric_limits<int>::max();
  22. // Round up division for positive number that cannot cause integer overflow
  23. auto block_num = (N - 1) / max_threads_per_block + 1;
  24. TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
  25. return static_cast<int>(block_num);
  26. }
  27. }}} // namespace at::cuda::detail