TensorModeKernel.cuh 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. #pragma once
  2. #include <ATen/cuda/detail/IndexUtils.cuh>
  3. #include <ATen/native/cuda/Loops.cuh>
  4. #include <ATen/native/cuda/SortingCommon.cuh>
  5. #include <ATen/native/cuda/block_reduce.cuh>
  6. namespace at {
  7. namespace native {
  8. // Used for a segmented reduction
  9. struct ModeUnsignedBoolPair {
  10. unsigned int val;
  11. bool flag;
  12. };
  13. // In the kernel below, we have a common pattern of reducing (unsigned int,
  14. // unsigned int) pairs of data
  15. struct ModeUnsignedPair {
  16. unsigned int val;
  17. unsigned int index;
  18. };
  19. // Inclusive Scan via an upsweep/downsweep mechanism. Assumes:
  20. //
  21. // 1. Power2ScanSize is a power of 2. This code still works for collections that
  22. // do not exactly contain a power of 2 number of elements, simply round up to
  23. // the nearest power of 2 and then call.
  24. //
  25. // 2. That there are two-elements per thread, i.e. the size of the smem storage
  26. // is 2 * blockDim.x * sizeof(T).
  27. //
  28. // Consider a (+)-Scan on the following elements:
  29. //
  30. // Upsweep:
  31. //
  32. // 0 1 2 3 4 5 6 7
  33. // 1 5 9 13
  34. // 6 22
  35. // 28
  36. //
  37. // Downsweep:
  38. // 15
  39. // 3 10 21
  40. template <int Power2ScanSize, typename T, class BinaryOp>
  41. __device__ void inclusivePrefixScan(T* smem, BinaryOp binop) {
  42. // Reduce step ("upsweep")
  43. #pragma unroll
  44. for (int stride = 1; stride < Power2ScanSize; stride <<= 1) {
  45. int index = (threadIdx.x + 1) * stride * 2 - 1;
  46. if (index < Power2ScanSize) {
  47. smem[index] = binop(smem[index], smem[index - stride]);
  48. }
  49. __syncthreads();
  50. }
  51. // Post-reduce step ("downsweep")
  52. #pragma unroll
  53. for (int stride = Power2ScanSize / 4; stride > 0; stride >>= 1) {
  54. int index = (threadIdx.x + 1) * stride * 2 - 1;
  55. if ((index + stride) < Power2ScanSize) {
  56. smem[index + stride] = binop(smem[index + stride], smem[index]);
  57. }
  58. __syncthreads();
  59. }
  60. }
  61. // Block-wide reduction where each thread locally reduces N
  62. // values before letting a single warp take over - assumes
  63. // threadVals is in registers, not shared memory
  64. //
  65. // If smem is not used again, there is no need to __syncthreads before this
  66. // call. However, if smem will be used, e.g., this function is called in a loop,
  67. // then __syncthreads is needed either before or afterwards to prevent non-0
  68. // threads overriding smem in the next loop before num-0 thread reads from it.
  69. template <int N, typename T, typename ReduceOp>
  70. __device__ T reduceBlockWithNThreadLocalReductions(
  71. T* smem,
  72. T threadVals[N],
  73. const unsigned int numVals,
  74. ReduceOp reduceOp,
  75. T init) {
  76. int offset = threadIdx.x * N;
  77. T local = offset < numVals ? threadVals[0] : init;
  78. #pragma unroll
  79. for (int i = 1; i < N; ++i) {
  80. ++offset;
  81. T next = offset < numVals ? threadVals[i] : init;
  82. local = reduceOp.combine(local, next);
  83. }
  84. return cuda_utils::BlockReduce(local, reduceOp, init, smem);
  85. }
  86. template <typename T>
  87. __device__ inline void swapVars(T& t1, T& t2) {
  88. T tmp = t1;
  89. t1 = t2;
  90. t2 = tmp;
  91. }
  92. template <typename Comparator, typename K, typename V>
  93. __device__ inline void bitonicSwap(
  94. K& kA,
  95. V& vA,
  96. bool& validA,
  97. K& kB,
  98. V& vB,
  99. bool& validB,
  100. bool dir,
  101. const Comparator& comp) {
  102. // Invalid entries always sort to the end
  103. bool swap = (comp(kA, kB) && validA) || !validB;
  104. if (swap == dir) {
  105. swapVars(kA, kB);
  106. swapVars(vA, vB);
  107. swapVars(validA, validB);
  108. }
  109. };
  110. template <typename Comparator, typename K>
  111. __device__ inline void bitonicSwapKeys(
  112. K& kA,
  113. bool& validA,
  114. K& kB,
  115. bool& validB,
  116. bool dir,
  117. const Comparator& comp) {
  118. bool swap = (comp(kA, kB) && validA) || !validB;
  119. if (swap == dir) {
  120. swapVars(kA, kB);
  121. swapVars(validA, validB);
  122. }
  123. }
  124. template <
  125. typename K,
  126. typename IndexType,
  127. int Power2SortSize,
  128. typename Comparator>
  129. __device__ inline void bitonicSortKeys(
  130. K keys[Power2SortSize],
  131. bool valid[Power2SortSize],
  132. const Comparator& comp) {
  133. #if !defined(USE_ROCM)
  134. #pragma unroll
  135. #endif
  136. for (unsigned int size = 2; size < Power2SortSize; size *= 2) {
  137. bool flag = ((threadIdx.x & (size / 2)) != 0);
  138. #if !defined(USE_ROCM)
  139. #pragma unroll
  140. #endif
  141. for (unsigned int stride = size / 2; stride > 0; stride /= 2) {
  142. __syncthreads();
  143. unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
  144. bitonicSwapKeys<Comparator, K>(
  145. keys[pos],
  146. valid[pos],
  147. keys[pos + stride],
  148. valid[pos + stride],
  149. flag,
  150. comp);
  151. }
  152. }
  153. #if !defined(USE_ROCM)
  154. #pragma unroll
  155. #endif
  156. for (unsigned int stride = Power2SortSize / 2; stride > 0; stride /= 2) {
  157. __syncthreads();
  158. unsigned int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
  159. bitonicSwapKeys<Comparator, K>(
  160. keys[pos],
  161. valid[pos],
  162. keys[pos + stride],
  163. valid[pos + stride],
  164. false,
  165. comp);
  166. }
  167. __syncthreads();
  168. }
  169. // The mode kernel has the following characteristics: It uses internal shared
  170. // memory buffers of Power2Size, which must be greater than the number of
  171. // elements. Additionally, there is one block for every slice to calculate the
  172. // mode for, and in each block there is one thread for every two elements.
  173. //
  174. // Both sorted and positions are assumed to be contiguous Tensors with the mode
  175. // dimension as the innermost dim, such that we can get the particular slice for
  176. // a Tensor via its linear block dimension * the slice size.
  177. template <typename T, unsigned int Power2Size>
  178. #if defined(CUDA_VERSION) && CUDA_VERSION >= 11070
  179. __launch_bounds__(1024, 1)
  180. #endif
  181. __global__ void compute_mode(
  182. T* input,
  183. at::cuda::detail::TensorInfo<T, unsigned int> values,
  184. at::cuda::detail::TensorInfo<int64_t, unsigned int> indices,
  185. int64_t sliceSize,
  186. int64_t slices) {
  187. int tidx = threadIdx.x;
  188. int stidx = blockDim.x + threadIdx.x; // Second index this thread responsible for
  189. // First, we need to calculate the offset into the sorted Tensor that
  190. // represents the start of the slice for this block to calculate the mode for.
  191. // This offset is a combination of the gridIndices, and the number of elements
  192. // in the slice.
  193. unsigned int blockId = getLinearBlockId<unsigned int>();
  194. unsigned int linearOffset = blockId * sliceSize;
  195. if (blockId >= slices) {
  196. return;
  197. }
  198. // shmem is a dynamically sized buffer we will use throughout the kernel to
  199. // handle computation efficiently. The size of this shmem must be
  200. // sizeof(T) * Power2Size + (2 * sizeof(unsigned int) * Power2Size)
  201. //
  202. // Initially, the buffer will be organized as follows:
  203. //
  204. // [smem (slice elements) | bmem (valid indices) | <scratch space>]
  205. extern __shared__ char shmem[];
  206. // smem represents a proportion of the shared memory buffer that is used to
  207. // store the elements from the slice:
  208. T* smem = reinterpret_cast<T*>(shmem);
  209. // Each thread loads up to two elements from the Tensor into shared memory
  210. if (tidx < sliceSize) {
  211. smem[tidx] = c10::load(&input[linearOffset + tidx]);
  212. }
  213. if (stidx < sliceSize) {
  214. smem[stidx] = c10::load(&input[linearOffset + stidx]);
  215. }
  216. // Next, we initialize a boolean region of the buffer, offset by the loaded
  217. // element smem region
  218. bool* bmem = reinterpret_cast<bool*>(&smem[Power2Size]);
  219. // The first use of this region stores bmem[i] = i < sliceSize to mark the
  220. // valid components in the smem buffer
  221. bmem[tidx] = tidx < sliceSize;
  222. bmem[stidx] = stidx < sliceSize;
  223. __syncthreads(); // barrier for smem, bmem initialization
  224. // First, sort the input slice in ascending order. smem contains the input
  225. // elements, and bmem marks the valid indices
  226. bitonicSortKeys<T, unsigned int, Power2Size>(
  227. smem, bmem, [&] GPU_LAMBDA(const auto& a, const auto& b) {
  228. return a < b;
  229. });
  230. __syncthreads(); // make no assumptions that the sort syncs at end
  231. // The next step of our algorithm is performing a block-wide comparison of
  232. // neighboring elements. In particular, given an sorted input slice A, we
  233. // produce an output slice B, such that B[i] = 1 if A[i-i] != A[i], otherwise
  234. // 0.
  235. //
  236. // Given the input A = [0, 0, 1, 1, 2, 2, 2, 4, 5, 6, 6, 7, 8]
  237. // B = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
  238. //
  239. // In particular, we can think of B[i] true indicating the start of a sequence
  240. // of equal values in the sorted list. Similarly, we will also store the
  241. // negation of B, which we'll call C. In particular, we can think of C[i] =
  242. // true iff A[i-1] == A[i] in our original sorted slice.
  243. //
  244. // C = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
  245. // We overwrite bmem, and treat the rest of shared memory as a buffer of
  246. // (index, flag) pairs where the index represents values from C, and the flag
  247. // represents values from B.
  248. //
  249. // [smem (sorted slice) | ubpmem (index, flag pairs)]
  250. struct ModeUnsignedBoolPair* ubpmem =
  251. reinterpret_cast<struct ModeUnsignedBoolPair*>(&smem[Power2Size]);
  252. if (tidx == 0) {
  253. ubpmem[0].flag = true;
  254. ubpmem[0].val = 0;
  255. }
  256. // Compares elements (0, 1), (2, 3), ... and sets 1, 3, ...
  257. ubpmem[tidx * 2 + 1].flag =
  258. smem[tidx * 2] != smem[tidx * 2 + 1]; // (0, 1), (1, 2), etc.
  259. ubpmem[tidx * 2 + 1].val = !ubpmem[tidx * 2 + 1].flag;
  260. // Compares elements (1, 2), (3, 4), ... and sets 2, 4, ...
  261. if (((tidx + 1) * 2) < Power2Size) {
  262. ubpmem[(tidx + 1) * 2].flag =
  263. smem[((tidx + 1) * 2) - 1] != smem[(tidx + 1) * 2];
  264. ubpmem[(tidx + 1) * 2].val = !ubpmem[(tidx + 1) * 2].flag;
  265. }
  266. __syncthreads(); // barrier for ubpmem initialization
  267. // Next, we perform a segmented prefix sum on the neighboring elements, where
  268. // the presence of a one indicates the start of a segment. In this case B acts
  269. // as the segment start flags, and C is the buffer to be summed:
  270. //
  271. // Input (C) = [0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]
  272. // Flag (B) = [1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1]
  273. // Output (C) = [0, 1, 0, 1, 0, 1, 2, 0, 0, 0, 1, 0, 0]
  274. //
  275. // Afterwards, the (index) components of the ubpmem buffer contain the lengths
  276. // of the segments (minus 1), i.e. the counts of each element in the original
  277. // input.
  278. inclusivePrefixScan<Power2Size>(
  279. ubpmem, [=] GPU_LAMBDA(const auto& a, const auto& b) {
  280. ModeUnsignedBoolPair c;
  281. c.val = a.flag ? a.val : a.val + b.val;
  282. c.flag = a.flag | b.flag;
  283. return c;
  284. });
  285. // assumes scan syncs at the end
  286. // Next, we reinterpret the ubpmem buffer as pairs of unsigned integers (i.e.
  287. // we treat the boolean flag regions as integers). We initialize these to
  288. // represent indices, and we'll call this buffer I
  289. struct ModeUnsignedPair* uupmem =
  290. reinterpret_cast<struct ModeUnsignedPair*>(ubpmem);
  291. // At this point, we need to find the maximum element in lengths buffer C.
  292. // This element will represent the count (-1) of the mode. Because of the
  293. // way we have set up the problem, the index where this mode occurs will
  294. // also be the location of the mode value in the sorted array, e.g.
  295. //
  296. // smem = [0, 0, 1, 1, 1, 2]
  297. // C = [0, 1, 0, 1, 2, 0]
  298. // I = [0, 1, 2, 3, 4, 5]
  299. // ^
  300. // maximum value, also aligned with mode = 1
  301. //
  302. // We perform a block wide max-reduction of the C buffer, but we also need the
  303. // indices to come along with it, so we utilize the uupmem construction.
  304. //
  305. // At the end we need to return the ModeUnsignedPair containing index = 4, val
  306. // = 2, which represents the max
  307. // In practice, we will make each thread locally reduce 2 values in its
  308. // registers prior to the global block-wide reduction. Note that instead of
  309. // tidx/stidx, we utilize tidx * 2, tidx * 2 + 1, so each thread deals with
  310. // adjacent elements. This is because the reduce code below relies on thread
  311. // elements to be adjacent.
  312. struct ModeUnsignedPair uup[2];
  313. uup[0].index = tidx * 2;
  314. uup[0].val = ubpmem[tidx * 2].val;
  315. uup[1].index = tidx * 2 + 1;
  316. uup[1].val = ubpmem[tidx * 2 + 1].val;
  317. __syncthreads();
  318. struct ModeUnsignedPair max = {0, 0};
  319. struct MaxOp {
  320. inline __device__ ModeUnsignedPair combine(ModeUnsignedPair a, ModeUnsignedPair b) const {
  321. return b.val > a.val ? b : a;
  322. }
  323. inline __device__ ModeUnsignedPair warp_shfl_down(ModeUnsignedPair acc, int offset) const {
  324. ModeUnsignedPair ret;
  325. ret.index = WARP_SHFL_DOWN(acc.index, offset);
  326. ret.val = WARP_SHFL_DOWN(acc.val, offset);
  327. return ret;
  328. }
  329. } max_op;
  330. max = reduceBlockWithNThreadLocalReductions<2>(
  331. uupmem,
  332. uup,
  333. sliceSize,
  334. max_op,
  335. max);
  336. // Store the mode in shared memory for use in finding the mode in the input
  337. // slice
  338. __shared__ T mode;
  339. // Given the above constraints, the mode is the value at the reduced index in
  340. // the original sorted element buffer
  341. if (tidx == 0) {
  342. mode = smem[max.index];
  343. }
  344. __syncthreads(); // broadcast mode
  345. // Finally, we need to find "an" index of the mode in the input
  346. // Tensor. The API does not constrain which index we pick, but here
  347. // we always pick the largest index. We store the index if the value
  348. // is the mode, or 0 otherwise. Then find the maximum value.
  349. //
  350. // Again we reduce 2 elements in the thread's registers prior to the
  351. // block-wide reduction
  352. unsigned mode_index[2] = {0u, 0u};
  353. if (tidx * 2 < sliceSize) {
  354. const unsigned idx = tidx * 2;
  355. mode_index[0] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
  356. }
  357. if (tidx * 2 + 1 < sliceSize) {
  358. const unsigned idx = tidx * 2 + 1;
  359. mode_index[1] = c10::load(&input[linearOffset + idx]) == mode ? idx : 0u;
  360. }
  361. struct MaxIndexOp {
  362. inline __device__ unsigned combine(unsigned a, unsigned b) const {
  363. return b > a ? b : a;
  364. }
  365. inline __device__ unsigned warp_shfl_down(unsigned acc, int offset) const {
  366. return WARP_SHFL_DOWN(acc, offset);
  367. }
  368. } max_index_op;
  369. int64_t index = reduceBlockWithNThreadLocalReductions<2>(
  370. reinterpret_cast<unsigned*>(&shmem[0]),
  371. mode_index,
  372. sliceSize,
  373. max_index_op,
  374. 0u);
  375. // Finally, we have the mode, and an index where it occurs. We use a single
  376. // thread to place this in the appropriate output position
  377. if (tidx == 0) {
  378. unsigned int outputOffset =
  379. at::cuda::detail::IndexToOffset<T, unsigned int, -1>::get(
  380. blockId, values);
  381. values.data[outputOffset] = mode;
  382. indices.data[outputOffset] = index;
  383. }
  384. }
  385. } // namespace native
  386. } // namespace at