Randperm.cuh 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #include <ATen/cuda/CUDAGeneratorImpl.h>
  2. #include <ATen/cuda/CUDAGraphsUtils.cuh>
  3. #include <ATen/Utils.h>
  4. #include <curand.h>
  5. #include <curand_kernel.h>
  6. #include <curand_philox4x32_x.h>
  7. namespace {
  8. // See note [Algorithm of randperm]
  9. template<typename T, typename scalar_t>
  10. __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
  11. int tid = threadIdx.x + blockDim.x * blockIdx.x;
  12. // find the beginning of islands
  13. if (tid >= n - 1) return; // out of range
  14. if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island
  15. if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island
  16. // find the size of islands
  17. int island_size = 0;
  18. do { island_size++; }
  19. while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
  20. // do random permutation inside each island.
  21. data += tid;
  22. auto seeds = at::cuda::philox::unpack(philox_args);
  23. curandStatePhilox4_32_10_t state;
  24. curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
  25. for (int i = island_size - 1; i > 0; i--) {
  26. unsigned int r = curand(&state) % (i + 1);
  27. if (i != r) {
  28. scalar_t tmp = data[i];
  29. data[i] = data[r];
  30. data[r] = tmp;
  31. }
  32. }
  33. }
  34. // See note [Algorithm of randperm]
  35. template<typename T, typename scalar_t>
  36. void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, c10::optional<at::Generator> &gen_) {
  37. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
  38. int64_t counter_offset = n;
  39. at::PhiloxCudaState rng_engine_inputs;
  40. {
  41. // See Note [Acquire lock when using random generators]
  42. std::lock_guard<std::mutex> lock(gen->mutex_);
  43. rng_engine_inputs = gen->philox_cuda_state(counter_offset);
  44. }
  45. T mask = static_cast<T>((1UL << bits) - 1);
  46. randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
  47. keys, data, mask, n, rng_engine_inputs);
  48. C10_CUDA_KERNEL_LAUNCH_CHECK();
  49. }
  50. }