123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435 |
- #pragma once
- #include <ATen/cuda/detail/IndexUtils.cuh>
- #include <ATen/native/cuda/Loops.cuh>
- #include <ATen/native/cuda/SortingCommon.cuh>
- #include <ATen/native/cuda/block_reduce.cuh>
- namespace at {
- namespace native {
- // Used for a segmented reduction
- struct ModeUnsignedBoolPair {
- unsigned int val;
- bool flag;
- };
- // In the kernel below, we have a common pattern of reducing (unsigned int,
- // unsigned int) pairs of data
- struct ModeUnsignedPair {
- unsigned int val;
- unsigned int index;
- };
- // Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
- //
- // 1. Power2ScanSize is a power of 2. This code still works for collections that
- // do not exactly contain a power of 2 number of elements, simply round up to
- // the nearest power of 2 and then call.
- //
- // 2. That there are two-elements per thread, i.e. the size of the smem storage
- // is 2 * blockDim.x * sizeof(T).
- //
- // Consider a (+)-Scan on the following elements:
- //
- // Upsweep:
- //
- // 0 1 2 3 4 5 6 7
- // 1 5 9 13
- // 6 22
- // 28
- //
- // Downsweep:
- // 15
- // 3 10 21
- template <int Power2ScanSize, typename T, class BinaryOp>
- __device__ void inclusivePrefixScan(T* smem, BinaryOp binop) {
- // Reduce step ("upsweep")
- #pragma unroll
- for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
- int index = (threadIdx.x + 1) * stride * 2 - 1;
- if (index < Power2ScanSize) {
- smem[index] = binop(smem[index], smem[index - stride]);
- }
- __syncthreads();
- }
- // Post-reduce step ("downsweep")
- #pragma unroll
- for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
- int index = (threadIdx.x + 1) * stride * 2 - 1;
- if ((index + stride) < Power2ScanSize) {
- smem[index + stride] = binop(smem[index + stride], smem[index]);
- }
- __syncthreads();
- }
- }
- // Block-wide reduction where each thread locally reduces N
- // values before letting a single warp take over - assumes
- // threadVals is in registers, not shared memory
- //
- // If smem is not used again, there is no need to __syncthreads before this
- // call. However, if smem will be used, e.g., this function is called in a loop,
- // then __syncthreads is needed either before or afterwards to prevent non-0
- // threads overriding smem in the next loop before num-0 thread reads from it.
- template <int N, typename T, typename ReduceOp>
- __device__ T reduceBlockWithNThreadLocalReductions(
- T* smem,
- T threadVals[N],
- const unsigned int numVals,
- ReduceOp reduceOp,
- T init) {
- int offset = threadIdx.x * N;
- T local = offset < numVals ? threadVals[0] : init;
- #pragma unroll
- for (int i = 1; i < N; ++i) {
- ++offset;
- T next = offset < numVals ? threadVals[i] : init;
- local = reduceOp.combine(local, next);
- }
- return cuda_utils::BlockReduce(local, reduceOp, init, smem);
- }
- template <typename T>
- __device__ inline void swapVars(T& t1, T& t2) {
- T tmp = t1;
- t1 = t2;
- t2 = tmp;
- }
- template <typename Comparator, typename K, typename V>
- __device__ inline void bitonicSwap(
- K& kA,
- V& vA,
- bool& validA,
- K& kB,
- V& vB,
- bool& validB,
- bool dir,
- const Comparator& comp) {
- // Invalid entries always sort to the end
- bool swap = (comp(kA, kB) && validA) || !validB;
- if (swap == dir) {
- swapVars(kA, kB);
- swapVars(vA, vB);
- swapVars(validA, validB);
- }
- };
- template <typename Comparator, typename K>
- __device__ inline void bitonicSwapKeys(
- K& kA,
- bool& validA,
- K& kB,
- bool& validB,
- bool dir,
- const Comparator& comp) {
- bool swap = (comp(kA, kB) && validA) || !validB;
- if (swap == dir) {
- swapVars(kA, kB);
- swapVars(validA, validB);
- }
- }
- template <
- typename K,
- typename IndexType,
- int Power2SortSize,
- typename Comparator>
- __device__ inline void bitonicSortKeys(
- K keys[Power2SortSize],
- bool valid[Power2SortSize],
- const Comparator& comp) {
- #if !defined(USE_ROCM)
- #pragma unroll
- #endif
- for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
- bool flag = ((threadIdx.x & (size / 2)) != 0);
- #if !defined(USE_ROCM)
- #pragma unroll
- #endif
- for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
- __syncthreads();
- unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
- bitonicSwapKeys<Comparator, K>(
- keys[pos],
- valid[pos],
- keys[pos + stride],
- valid[pos + stride],
- flag,
- comp);
- }
- }
- #if !defined(USE_ROCM)
- #pragma unroll
- #endif
- for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
- __syncthreads();
- unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
- bitonicSwapKeys<Comparator, K>(
- keys[pos],
- valid[pos],
- keys[pos + stride],
- valid[pos + stride],
- false,
- comp);
- }
- __syncthreads();
- }
- // The mode kernel has the following characteristics: It uses internal shared
- // memory buffers of Power2Size, which must be greater than the number of
- // elements. Additionally, there is one block for every slice to calculate the
- // mode for, and in each block there is one thread for every two elements.
- //
- // Both sorted and positions are assumed to be contiguous Tensors with the mode
- // dimension as the innermost dim, such that we can get the particular slice for
- // a Tensor via its linear block dimension * the slice size.
- template <typename T, unsigned int Power2Size>
- #if defined(CUDA_VERSION) && CUDA_VERSION >= 11070
- __launch_bounds__(1024, 1)
- #endif
- __global__ void compute_mode(
- T* input,
- at::cuda::detail::TensorInfo<T, unsigned int> values,
- at::cuda::detail::TensorInfo<int64_t, unsigned int> indices,
- int64_t sliceSize,
- int64_t slices) {
- int tidx = threadIdx.x;
- int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for
- // First, we need to calculate the offset into the sorted Tensor that
- // represents the start of the slice for this block to calculate the mode for.
- // This offset is a combination of the gridIndices, and the number of elements
- // in the slice.
- unsigned int blockId = getLinearBlockId<unsigned int>();
- unsigned int linearOffset = blockId * sliceSize;
- if (blockId >= slices) {
- return;
- }
- // shmem is a dynamically sized buffer we will use throughout the kernel to
- // handle computation efficiently. The size of this shmem must be
- // sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size)
- //
- // Initially, the buffer will be organized as follows:
- //
- // [smem (slice elements) | bmem (valid indices) | <scratch space>]
- extern __shared__ char shmem[];
- // smem represents a proportion of the shared memory buffer that is used to
- // store the elements from the slice:
- T* smem = reinterpret_cast<T*>(shmem);
- // Each thread loads up to two elements from the Tensor into shared memory
- if (tidx < sliceSize) {
- smem[tidx] = c10::load(&input[linearOffset + tidx]);
- }
- if (stidx < sliceSize) {
- smem[stidx] = c10::load(&input[linearOffset + stidx]);
- }
- // Next, we initialize a boolean region of the buffer, offset by the loaded
- // element smem region
- bool* bmem = reinterpret_cast<bool*>(&smem[Power2Size]);
- // The first use of this region stores bmem[i] = i < sliceSize to mark the
- // valid components in the smem buffer
- bmem[tidx] = tidx < sliceSize;
- bmem[stidx] = stidx < sliceSize;
- __syncthreads(); // barrier for smem, bmem initialization
- // First, sort the input slice in ascending order. smem contains the input
- // elements, and bmem marks the valid indices
- bitonicSortKeys<T, unsigned int, Power2Size>(
- smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) {
- return a < b;
- });
- __syncthreads(); // make no assumptions that the sort syncs at end
- // The next step of our algorithm is performing a block-wide comparison of
- // neighboring elements. In particular, given an sorted input slice A, we
- // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise
- // 0.
- //
- // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
- // B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
- //
- // In particular, we can think of B[i] true indicating the start of a sequence
- // of equal values in the sorted list. Similarly, we will also store the
- // negation of B, which we'll call C. In particular, we can think of C[i] =
- // true iff A[i-1] == A[i] in our original sorted slice.
- //
- // C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
- // We overwrite bmem, and treat the rest of shared memory as a buffer of
- // (index, flag) pairs where the index represents values from C, and the flag
- // represents values from B.
- //
- // [smem (sorted slice) | ubpmem (index, flag pairs)]
- struct ModeUnsignedBoolPair* ubpmem =
- reinterpret_cast<struct ModeUnsignedBoolPair*>(&smem[Power2Size]);
- if (tidx == 0) {
- ubpmem[0].flag = true;
- ubpmem[0].val = 0;
- }
- // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
- ubpmem[tidx * 2 + 1].flag =
- smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc.
- ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag;
- // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
- if (((tidx + 1) * 2) < Power2Size) {
- ubpmem[(tidx + 1) * 2].flag =
- smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2];
- ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag;
- }
- __syncthreads(); // barrier for ubpmem initialization
- // Next, we perform a segmented prefix sum on the neighboring elements, where
- // the presence of a one indicates the start of a segment. In this case B acts
- // as the segment start flags, and C is the buffer to be summed:
- //
- // Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
- // Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
- // Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
- //
- // Afterwards, the (index) components of the ubpmem buffer contain the lengths
- // of the segments (minus 1), i.e. the counts of each element in the original
- // input.
- inclusivePrefixScan<Power2Size>(
- ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) {
- ModeUnsignedBoolPair c;
- c.val = a.flag ? a.val : a.val + b.val;
- c.flag = a.flag | b.flag;
- return c;
- });
- // assumes scan syncs at the end
- // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e.
- // we treat the boolean flag regions as integers). We initialize these to
- // represent indices, and we'll call this buffer I
- struct ModeUnsignedPair* uupmem =
- reinterpret_cast<struct ModeUnsignedPair*>(ubpmem);
- // At this point, we need to find the maximum element in lengths buffer C.
- // This element will represent the count (-1) of the mode. Because of the
- // way we have set up the problem, the index where this mode occurs will
- // also be the location of the mode value in the sorted array, e.g.
- //
- // smem = [0, 0, 1, 1, 1, 2]
- // C = [0, 1, 0, 1, 2, 0]
- // I = [0, 1, 2, 3, 4, 5]
- // ^
- // maximum value, also aligned with mode = 1
- //
- // We perform a block wide max-reduction of the C buffer, but we also need the
- // indices to come along with it, so we utilize the uupmem construction.
- //
- // At the end we need to return the ModeUnsignedPair containing index = 4, val
- // = 2, which represents the max
- // In practice, we will make each thread locally reduce 2 values in its
- // registers prior to the global block-wide reduction. Note that instead of
- // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with
- // adjacent elements. This is because the reduce code below relies on thread
- // elements to be adjacent.
- struct ModeUnsignedPair uup[2];
- uup[0].index = tidx * 2;
- uup[0].val = ubpmem[tidx * 2].val;
- uup[1].index = tidx * 2 + 1;
- uup[1].val = ubpmem[tidx * 2 + 1].val;
- __syncthreads();
- struct ModeUnsignedPair max = {0, 0};
- struct MaxOp {
- inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
- return b.val > a.val ? b : a;
- }
- inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
- ModeUnsignedPair ret;
- ret.index = WARP_SHFL_DOWN(acc.index, offset);
- ret.val = WARP_SHFL_DOWN(acc.val, offset);
- return ret;
- }
- } max_op;
- max = reduceBlockWithNThreadLocalReductions<2>(
- uupmem,
- uup,
- sliceSize,
- max_op,
- max);
- // Store the mode in shared memory for use in finding the mode in the input
- // slice
- __shared__ T mode;
- // Given the above constraints, the mode is the value at the reduced index in
- // the original sorted element buffer
- if (tidx == 0) {
- mode = smem[max.index];
- }
- __syncthreads(); // broadcast mode
- // Finally, we need to find "an" index of the mode in the input
- // Tensor. The API does not constrain which index we pick, but here
- // we always pick the largest index. We store the index if the value
- // is the mode, or 0 otherwise. Then find the maximum value.
- //
- // Again we reduce 2 elements in the thread's registers prior to the
- // block-wide reduction
- unsigned mode_index[2] = {0u, 0u};
- if (tidx * 2 < sliceSize) {
- const unsigned idx = tidx * 2;
- mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
- }
- if (tidx * 2 + 1 < sliceSize) {
- const unsigned idx = tidx * 2 + 1;
- mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
- }
- struct MaxIndexOp {
- inline __device__ unsigned combine(unsigned a, unsigned b) const {
- return b > a ? b : a;
- }
- inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
- return WARP_SHFL_DOWN(acc, offset);
- }
- } max_index_op;
- int64_t index = reduceBlockWithNThreadLocalReductions<2>(
- reinterpret_cast<unsigned*>(&shmem[0]),
- mode_index,
- sliceSize,
- max_index_op,
- 0u);
- // Finally, we have the mode, and an index where it occurs. We use a single
- // thread to place this in the appropriate output position
- if (tidx == 0) {
- unsigned int outputOffset =
- at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get(
- blockId, values);
- values.data[outputOffset] = mode;
- indices.data[outputOffset] = index;
- }
- }
- } // namespace native
- } // namespace at
|