12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- #pragma once
- #include <ATen/ceil_div.h>
- #include <ATen/cuda/DeviceUtils.cuh>
- #include <ATen/cuda/AsmUtils.cuh>
- #include <c10/macros/Macros.h>
- // Collection of in-kernel scan / prefix sum utilities
- namespace at {
- namespace cuda {
- // Inclusive prefix sum for binary vars using intra-warp voting +
- // shared memory
- template <typename T, bool KillWARDependency, class BinaryFunction>
- __device__ void inclusiveBinaryPrefixScan(T* smem, bool in, T* out, BinaryFunction binop) {
- // Within-warp, we use warp voting.
- #if defined (USE_ROCM)
- unsigned long long int vote = WARP_BALLOT(in);
- T index = __popcll(getLaneMaskLe() & vote);
- T carry = __popcll(vote);
- #else
- T vote = WARP_BALLOT(in);
- T index = __popc(getLaneMaskLe() & vote);
- T carry = __popc(vote);
- #endif
- int warp = threadIdx.x / C10_WARP_SIZE;
- // Per each warp, write out a value
- if (getLaneId() == 0) {
- smem[warp] = carry;
- }
- __syncthreads();
- // Sum across warps in one thread. This appears to be faster than a
- // warp shuffle scan for CC 3.0+
- if (threadIdx.x == 0) {
- int current = 0;
- for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) {
- T v = smem[i];
- smem[i] = binop(smem[i], current);
- current = binop(current, v);
- }
- }
- __syncthreads();
- // load the carry from the preceding warp
- if (warp >= 1) {
- index = binop(index, smem[warp - 1]);
- }
- *out = index;
- if (KillWARDependency) {
- __syncthreads();
- }
- }
- // Exclusive prefix sum for binary vars using intra-warp voting +
- // shared memory
- template <typename T, bool KillWARDependency, class BinaryFunction>
- __device__ void exclusiveBinaryPrefixScan(T* smem, bool in, T* out, T* carry, BinaryFunction binop) {
- inclusiveBinaryPrefixScan<T, false, BinaryFunction>(smem, in, out, binop);
- // Inclusive to exclusive
- *out -= (T) in;
- // The outgoing carry for all threads is the last warp's sum
- *carry = smem[at::ceil_div<int>(blockDim.x, C10_WARP_SIZE) - 1];
- if (KillWARDependency) {
- __syncthreads();
- }
- }
- }} // namespace at::cuda
|