12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- #include <ATen/cuda/CUDAGeneratorImpl.h>
- #include <ATen/cuda/CUDAGraphsUtils.cuh>
- #include <ATen/Utils.h>
- #include <curand.h>
- #include <curand_kernel.h>
- #include <curand_philox4x32_x.h>
- namespace {
- // See note [Algorithm of randperm]
- template<typename T, typename scalar_t>
- __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
- int tid = threadIdx.x + blockDim.x * blockIdx.x;
- // find the beginning of islands
- if (tid >= n - 1) return; // out of range
- if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island
- if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island
- // find the size of islands
- int island_size = 0;
- do { island_size++; }
- while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
- // do random permutation inside each island.
- data += tid;
- auto seeds = at::cuda::philox::unpack(philox_args);
- curandStatePhilox4_32_10_t state;
- curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
- for (int i = island_size - 1; i > 0; i--) {
- unsigned int r = curand(&state) % (i + 1);
- if (i != r) {
- scalar_t tmp = data[i];
- data[i] = data[r];
- data[r] = tmp;
- }
- }
- }
- // See note [Algorithm of randperm]
- template<typename T, typename scalar_t>
- void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, c10::optional<at::Generator> &gen_) {
- auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
- int64_t counter_offset = n;
- at::PhiloxCudaState rng_engine_inputs;
- {
- // See Note [Acquire lock when using random generators]
- std::lock_guard<std::mutex> lock(gen->mutex_);
- rng_engine_inputs = gen->philox_cuda_state(counter_offset);
- }
- T mask = static_cast<T>((1UL << bits) - 1);
- randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
- keys, data, mask, n, rng_engine_inputs);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- }
|