ScanUtils.cuh 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. #pragma once
  2. #include <ATen/NumericUtils.h>
  3. #include <ATen/core/TensorBase.h>
  4. #include <ATen/cuda/cub.cuh>
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include <c10/util/Load.h>
  7. #include <limits>
  8. #include <cmath>
  9. namespace at {
  10. namespace native {
  11. template <typename integer>
  12. constexpr inline integer ceil_div(integer n, integer m) {
  13. return (n + m - 1) / m;
  14. }
  15. template<typename scalar_t, typename idx_t, typename BinaryOperation>
  16. __device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
  17. if(!at::_isnan(rhs) && (at::_isnan(lhs) || !binary_op(rhs, lhs))) {
  18. rhs = lhs;
  19. rhs_idx = lhs_idx;
  20. }
  21. }
  22. /* Perform an inclusive scan along the innermost dimension of a tensor.
  23. *
  24. * - num_rows is the size of the flattened outer dimensions;
  25. * - row_size is the size of the innermost dimension;
  26. *
  27. * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
  28. * considered as having 'num_rows' rows of size 'row_size'.
  29. * Each thread block processes one or more sets of contiguous rows (processing multiple rows
  30. * per thread block is quicker than processing a single row, especially for short rows).
  31. */
  32. template<typename scalar_t, int num_threads_x, int num_threads_y, class BinaryFunction>
  33. __global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
  34. int num_rows, int row_size,
  35. scalar_t init, BinaryFunction binary_op) {
  36. __shared__ scalar_t vbuf[num_threads_y][2 * num_threads_x];
  37. __shared__ int64_t ibuf[num_threads_y][2 * num_threads_x];
  38. scalar_t* row_buf = vbuf[threadIdx.y];
  39. int64_t* row_idx_buf = ibuf[threadIdx.y];
  40. for (int block_row = blockIdx.x * blockDim.y;
  41. block_row < num_rows;
  42. block_row += blockDim.y * gridDim.x) {
  43. int row = block_row + threadIdx.y;
  44. const scalar_t *row_self = self_ + row * row_size;
  45. scalar_t *row_values = values_ + row * row_size;
  46. int64_t *row_indices = indices_ + row * row_size;
  47. scalar_t block_total = init;
  48. int64_t block_idx_final = 0;
  49. // Perform scan on one block at a time, keeping track of the total value of
  50. // all blocks processed so far.
  51. for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
  52. // Load data into shared memory (two values per thread).
  53. int col1 = block_col + threadIdx.x;
  54. int col2 = block_col + num_threads_x + threadIdx.x;
  55. if (row < num_rows) {
  56. if (col1 < row_size) {
  57. row_buf[threadIdx.x] = c10::load(&row_self[col1]);
  58. row_idx_buf[threadIdx.x] = col1;
  59. } else {
  60. row_buf[threadIdx.x] = init;
  61. // No need to set the index here as the value in init will never be selected
  62. }
  63. if (col2 < row_size) {
  64. row_buf[num_threads_x + threadIdx.x] = c10::load(&row_self[col2]);
  65. row_idx_buf[num_threads_x + threadIdx.x] = col2;
  66. } else {
  67. row_buf[num_threads_x + threadIdx.x] = init;
  68. // No need to set the index here as the value in init will never be selected
  69. }
  70. // Add the total value of all previous blocks to the first value of this block.
  71. if (threadIdx.x == 0) {
  72. binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
  73. }
  74. }
  75. __syncthreads();
  76. // Parallel reduction (up-sweep).
  77. for (int s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
  78. if (row < num_rows && threadIdx.x < s) {
  79. int offset = (2 * threadIdx.x + 1) * d - 1;
  80. binary_op_update(row_buf[offset], row_buf[offset + d], row_idx_buf[offset], row_idx_buf[offset + d], binary_op);
  81. }
  82. __syncthreads();
  83. }
  84. // Down-sweep.
  85. for (int s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
  86. if (row < num_rows && threadIdx.x < s - 1) {
  87. int offset = 2 * (threadIdx.x + 1) * d - 1;
  88. binary_op_update(row_buf[offset], row_buf[offset + d], row_idx_buf[offset], row_idx_buf[offset + d], binary_op);
  89. }
  90. __syncthreads();
  91. }
  92. // Write back to output.
  93. if (row < num_rows) {
  94. if (col1 < row_size){
  95. row_values[col1] = row_buf[threadIdx.x];
  96. row_indices[col1] = row_idx_buf[threadIdx.x];
  97. }
  98. if (col2 < row_size) {
  99. row_values[col2] = row_buf[num_threads_x + threadIdx.x];
  100. row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
  101. }
  102. }
  103. block_total = row_buf[2 * num_threads_x - 1];
  104. block_idx_final = row_idx_buf[2 * num_threads_x - 1];
  105. __syncthreads();
  106. }
  107. }
  108. }
  109. /* Perform an inclusive scan along an outer dimension of a tensor.
  110. *
  111. * - num_orows is the size of the flattened outer dimensions;
  112. * - num_irows is the size of the flattened inner dimensions;
  113. * - row_size is the size of the dimension along which to compute the variance;
  114. *
  115. * The dimensions to the outside and inside of the specified dimension are considered as flattened.
  116. * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
  117. * outer dimensions, which contains several "inner rows").
  118. * Each thread processes a single inner row at a time.
  119. */
  120. template<typename scalar_t, class BinaryFunction>
  121. __global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_,
  122. const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size, scalar_t init, BinaryFunction binary_op) {
  123. for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
  124. for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
  125. scalar_t *self = self_ + orow * row_size * num_irows + irow;
  126. scalar_t *values = values_ + orow * row_size * num_irows + irow;
  127. int64_t *indices = indices_ + orow * row_size * num_irows + irow;
  128. scalar_t out = init;
  129. int64_t out_idx = 0;
  130. for (auto col = decltype(row_size){0}; col < row_size; ++col) {
  131. const auto val = c10::load(self);
  132. if(at::_isnan(val) || (!at::_isnan(out) && binary_op(val, out))) {
  133. out = val;
  134. out_idx = col;
  135. }
  136. *values = out;
  137. *indices = out_idx;
  138. self += num_irows;
  139. values += num_irows;
  140. indices += num_irows;
  141. }
  142. }
  143. }
  144. }
  145. inline void check_fits_in_unsigned(int64_t val, const char* name) {
  146. constexpr auto umax = std::numeric_limits<uint32_t>::max();
  147. TORCH_CHECK(
  148. val >= 0 && val <= umax, name, " must fit in a 32-bit uint32_t value");
  149. }
  150. template<typename scalar_t, class BinaryFunction>
  151. __host__ void scan_outer_dim_with_indices(
  152. const TensorBase& self, const TensorBase& values, const TensorBase& indices,
  153. int dim, scalar_t init, BinaryFunction binary_op) {
  154. int64_t row_size = self.size(dim);
  155. auto sizes = self.sizes();
  156. // Treat all outer dimensions (i.e. dim_ < dim) as one.
  157. const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
  158. // Treat all inner dimensions (i.e. dim > dimension) as one.
  159. const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
  160. //for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
  161. //make sure that input is not bigger than supported by uint32_t
  162. check_fits_in_unsigned(num_irows, "num_irows");
  163. check_fits_in_unsigned(num_orows, "num_orows");
  164. check_fits_in_unsigned(row_size, "row_size");
  165. dim3 threads(std::min(512, int(num_irows)));
  166. int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  167. dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
  168. tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
  169. self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(),
  170. num_orows, num_irows, row_size, init, binary_op);
  171. C10_CUDA_KERNEL_LAUNCH_CHECK();
  172. }
  173. template <typename scalar_t, class BinaryFunction>
  174. __host__ void scan_innermost_dim_with_indices(
  175. const TensorBase& self, const TensorBase& values, const TensorBase& indices,
  176. scalar_t init, BinaryFunction binary_op) {
  177. int ndim = self.dim();
  178. // Treat all outer dimensions as a single dimension.
  179. int row_size = self.size(ndim - 1);
  180. int num_rows = self.numel() / row_size;
  181. dim3 threads(16, 32);
  182. dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
  183. tensor_kernel_scan_innermost_dim_with_indices<scalar_t, 16, 32><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
  184. self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(),
  185. num_rows, row_size, init, binary_op);
  186. C10_CUDA_KERNEL_LAUNCH_CHECK();
  187. }
  188. template<typename scalar_t, typename BinaryFunction>
  189. void scan_dim_with_indices(const TensorBase& self, const TensorBase& values, const TensorBase& indices, //int64_t dim) {
  190. int64_t dim, scalar_t init, BinaryFunction binary_op) {
  191. int ndim = self.dim();
  192. auto self_ = self.expect_contiguous();
  193. TORCH_INTERNAL_ASSERT(values.is_contiguous() && indices.is_contiguous());
  194. if (dim == ndim - 1) {
  195. scan_innermost_dim_with_indices<scalar_t>(*self_, values, indices, init, binary_op);
  196. } else {
  197. scan_outer_dim_with_indices<scalar_t>(*self_, values, indices, dim, init, binary_op);
  198. }
  199. }
  200. // TODO: The implementation of `tensor_kernel_scan_outer_dim` and
  201. // `tensor_kernel_scan_innermost_dim` is similar to
  202. // `tensor_kernel_scan_outer_dim_with_indices`
  203. // `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
  204. // remove the duplication.
  205. /* Perform an inclusive scan along an outer dimension of a tensor.
  206. *
  207. * - num_orows is the size of the flattened outer dimensions;
  208. * - num_irows is the size of the flattened inner dimensions;
  209. * - row_size is the size of the dimension along which to scan;
  210. *
  211. * The dimensions to the outside and inside of the specified dimension are considered as flattened.
  212. * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
  213. * outer dimensions, which contains several "inner rows").
  214. * Each thread processes a single inner row at a time.
  215. */
  216. template<typename scalar_t, class BinaryOp>
  217. __global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_,
  218. const uint32_t num_orows, const uint32_t num_irows, const uint32_t row_size,
  219. const scalar_t init, BinaryOp binary_op)
  220. {
  221. for (uint32_t orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
  222. for (uint32_t irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
  223. scalar_t *src = src_ + orow * row_size * num_irows + irow;
  224. scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
  225. scalar_t acc = init;
  226. for (uint32_t col = 0; col < row_size; ++col) {
  227. acc = binary_op(acc, c10::load(src));
  228. *tgt = acc;
  229. src += num_irows;
  230. tgt += num_irows;
  231. }
  232. }
  233. }
  234. }
  235. /* Perform an inclusive scan along the innermost dimension of a tensor.
  236. *
  237. * - num_rows is the size of the flattened outer dimensions;
  238. * - row_size is the size of the innermost dimension;
  239. *
  240. * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
  241. * considered as having 'num_rows' rows of size 'row_size'.
  242. * Each thread block processes one or more sets of contiguous rows (processing multiple rows
  243. * per thread block is quicker than processing a single row, especially for short rows).
  244. */
  245. template<typename T, int num_threads_x, int num_threads_y, class BinaryFunction>
  246. __device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_,
  247. const uint32_t num_rows, const uint32_t row_size,
  248. T init, BinaryFunction binary_op){
  249. for (uint32_t block_row = blockIdx.x * blockDim.y;
  250. block_row < num_rows;
  251. block_row += blockDim.y * gridDim.x) {
  252. uint32_t row = block_row + threadIdx.y;
  253. T block_total = init;
  254. T *row_src = src_ + row * row_size;
  255. T *row_tgt = tgt_ + row * row_size;
  256. // Perform scan on one block at a time, keeping track of the total value of
  257. // all blocks processed so far.
  258. for (uint32_t block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
  259. // Load data into shared memory (two values per thread).
  260. uint32_t col1 = block_col + threadIdx.x;
  261. uint32_t col2 = block_col + num_threads_x + threadIdx.x;
  262. if (row < num_rows) {
  263. if (col1 < row_size) {
  264. row_buf[threadIdx.x] = row_src[col1];
  265. } else {
  266. row_buf[threadIdx.x] = init;
  267. }
  268. if (col2 < row_size) {
  269. row_buf[num_threads_x + threadIdx.x] = row_src[col2];
  270. } else {
  271. row_buf[num_threads_x + threadIdx.x] = init;
  272. }
  273. // Add the total value of all previous blocks to the first value of this block.
  274. if (threadIdx.x == 0) {
  275. row_buf[0] = binary_op(row_buf[0], block_total);
  276. }
  277. }
  278. __syncthreads();
  279. // Parallel reduction (up-sweep).
  280. for (uint32_t s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
  281. if (row < num_rows && threadIdx.x < s) {
  282. uint32_t offset = (2 * threadIdx.x + 1) * d - 1;
  283. row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
  284. }
  285. __syncthreads();
  286. }
  287. // Down-sweep.
  288. for (uint32_t s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
  289. if (row < num_rows && threadIdx.x < s - 1) {
  290. uint32_t offset = 2 * (threadIdx.x + 1) * d - 1;
  291. row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
  292. }
  293. __syncthreads();
  294. }
  295. // Write back to output.
  296. if (row < num_rows) {
  297. if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
  298. if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
  299. }
  300. block_total = row_buf[2 * num_threads_x - 1];
  301. __syncthreads();
  302. }
  303. }
  304. }
  305. template <
  306. typename T,
  307. int num_threads_x,
  308. int num_threads_y,
  309. class BinaryFunction>
  310. __global__ typename std::enable_if<!c10::is_complex<T>::value, void>::type
  311. tensor_kernel_scan_innermost_dim(
  312. T* tgt_,
  313. T* src_,
  314. const uint32_t num_rows,
  315. const uint32_t row_size,
  316. T init,
  317. BinaryFunction binary_op) {
  318. __shared__ T sbuf[num_threads_y][2 * num_threads_x];
  319. T* row_buf = sbuf[threadIdx.y];
  320. tensor_kernel_scan_innermost_dim_impl<T, num_threads_x, num_threads_y>(
  321. row_buf, tgt_, src_, num_rows, row_size, init, binary_op);
  322. }
  323. template <
  324. typename T,
  325. int num_threads_x,
  326. int num_threads_y,
  327. class BinaryFunction>
  328. __global__ typename std::enable_if<c10::is_complex<T>::value, void>::type
  329. tensor_kernel_scan_innermost_dim(
  330. T* tgt_,
  331. T* src_,
  332. const uint32_t num_rows,
  333. const uint32_t row_size,
  334. T init,
  335. BinaryFunction binary_op) {
  336. // As we cannot directly initialize shared array for complex types
  337. // Reference:
  338. // `error: initializer not allowed for __shared__ variable`
  339. // We instead get the base scalar type and allocate twice number of
  340. // elements required of base type and reinterpret them as complex.
  341. using base_t = typename scalar_value_type<T>::type;
  342. __shared__ base_t sbuf[num_threads_y][4 * num_threads_x];
  343. T* row_buf = reinterpret_cast<T*>(sbuf[threadIdx.y]);
  344. tensor_kernel_scan_innermost_dim_impl<T, num_threads_x, num_threads_y>(
  345. row_buf, tgt_, src_, num_rows, row_size, init, binary_op);
  346. }
  347. template<typename scalar_t, class BinaryFunction>
  348. __host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,
  349. int dim, scalar_t init, BinaryFunction binary_op) {
  350. const int64_t row_size = self.size(dim);
  351. auto sizes = self.sizes();
  352. // Treat all outer dimensions (i.e. dim_ < dim) as one.
  353. const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
  354. // Treat all inner dimensions (i.e. dim > dimension) as one.
  355. const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
  356. dim3 threads(std::min(512, int(num_irows)));
  357. int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  358. dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
  359. check_fits_in_unsigned(num_irows, "num_irows");
  360. check_fits_in_unsigned(num_orows, "num_orows");
  361. check_fits_in_unsigned(row_size, "row_size");
  362. tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
  363. result.data_ptr<scalar_t>(), self.data_ptr<scalar_t>(),
  364. num_orows, num_irows, row_size, init, binary_op);
  365. C10_CUDA_KERNEL_LAUNCH_CHECK();
  366. }
  367. template <typename scalar_t, class BinaryFunction>
  368. void scan_innermost_dim(const TensorBase& self, const TensorBase& result,
  369. scalar_t init, BinaryFunction binary_op) {
  370. int64_t ndim = self.dim();
  371. // Treat all outer dimensions as a single dimension.
  372. int64_t row_size = self.size(ndim - 1);
  373. int64_t num_rows = self.numel() / row_size;
  374. dim3 threads(16, 32);
  375. int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
  376. dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
  377. check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
  378. check_fits_in_unsigned(row_size, "row_size");
  379. tensor_kernel_scan_innermost_dim<scalar_t, 16, 32><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
  380. result.data_ptr<scalar_t>(), self.data_ptr<scalar_t>(),
  381. num_rows, row_size, init, binary_op);
  382. C10_CUDA_KERNEL_LAUNCH_CHECK();
  383. }
  384. template<typename scalar_t, typename BinaryFunction>
  385. void scan_dim(const TensorBase& self, const TensorBase& result,
  386. int64_t dim, scalar_t init, BinaryFunction binary_op) {
  387. int ndim = self.dim();
  388. auto self_ = self.expect_contiguous();
  389. TORCH_INTERNAL_ASSERT(result.is_contiguous());
  390. if (self.numel() == self.size(dim)) {
  391. cuda::cub::inclusive_scan(self_->data_ptr<scalar_t>(), result.data_ptr<scalar_t>(), binary_op, self.numel());
  392. } else if (dim == ndim - 1) {
  393. scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);
  394. } else {
  395. scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);
  396. }
  397. }
  398. }} // namespace at::native