reduction_template.cuh 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680
  1. namespace at {
  2. namespace cuda {
  3. //windows doesn't like large string literals, so split in two
  4. const std::string reduction_template_0 = R"ESCAPE(
  5. #define C10_HOST_DEVICE __host__ __device__
  6. #define C10_DEVICE __device__
  7. #if defined(__clang__) && defined(__HIP__)
  8. #ifndef __forceinline__
  9. #define __forceinline__ inline __attribute__((always_inline))
  10. #endif
  11. // until ROCm support for kernel asserts is restored
  12. #define assert(expr) (static_cast<void>(0))
  13. #endif
  14. template <typename T>
  15. __device__ __forceinline__ T WARP_SHFL_DOWN(T value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  16. {
  17. #if defined(__clang__) && defined(__HIP__)
  18. return __shfl_down(value, delta, width);
  19. #else
  20. return __shfl_down_sync(mask, value, delta, width);
  21. #endif
  22. }
  23. #if ${complex}
  24. template <typename T>
  25. __device__ __forceinline__ std::complex<T> WARP_SHFL_DOWN(std::complex<T> value, unsigned int delta, int width = warpSize, unsigned int mask = 0xffffffff)
  26. {
  27. return std::complex<T>(
  28. #if defined(__clang__) && defined(__HIP__)
  29. __shfl_down(value.real(), delta, width),
  30. __shfl_down(value.imag(), delta, width));
  31. #else
  32. __shfl_down_sync(mask, value.real(), delta, width),
  33. __shfl_down_sync(mask, value.imag(), delta, width));
  34. #endif
  35. }
  36. #endif
  37. // aligned vector generates vectorized load/store on CUDA
  38. template<typename scalar_t, int vec_size>
  39. struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
  40. scalar_t val[vec_size];
  41. };
  42. C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
  43. // get GCD of num and denom using Euclid's algorithm.
  44. // Can replace this with std::gcd if we ever support c++17.
  45. size_t a = denominator;
  46. size_t b = numerator;
  47. while (b != 0) {
  48. a %= b;
  49. // swap(a,b)
  50. size_t tmp = a;
  51. a = b;
  52. b = tmp;
  53. }
  54. // a is now the GCD
  55. numerator /= a;
  56. denominator /= a;
  57. }
  58. struct ReduceConfig {
  59. //has to match host-side ReduceConfig in the eager code
  60. static constexpr int BLOCK_X = 0;
  61. static constexpr int BLOCK_Y = 1;
  62. static constexpr int CTA = 2;
  63. static constexpr int input_vec_size = 4;
  64. int element_size_bytes;
  65. int num_inputs;
  66. int num_outputs;
  67. int step_input = 1;
  68. int step_output = 1;
  69. int ctas_per_output = 1;
  70. int input_mult[3] = {0, 0, 0};
  71. int output_mult[2] = {0, 0};
  72. int block_width;
  73. int block_height;
  74. int num_threads;
  75. bool vectorize_input = false;
  76. int output_vec_size = 1;
  77. C10_HOST_DEVICE bool should_block_x_reduce() const {
  78. return input_mult[BLOCK_X] != 0;
  79. }
  80. C10_HOST_DEVICE bool should_block_y_reduce() const {
  81. return input_mult[BLOCK_Y] != 0;
  82. }
  83. C10_HOST_DEVICE bool should_global_reduce() const {
  84. return input_mult[CTA] != 0;
  85. }
  86. C10_DEVICE bool should_store(int output_idx) const {
  87. return output_idx < num_outputs &&
  88. (!should_block_x_reduce() || threadIdx.x == 0) &&
  89. (!should_block_y_reduce() || threadIdx.y == 0);
  90. }
  91. C10_DEVICE bool should_reduce_tail() const {
  92. return (!should_block_y_reduce() || threadIdx.y == 0) &&
  93. (!should_global_reduce() || blockIdx.y == 0);
  94. }
  95. C10_HOST_DEVICE int input_idx() const {
  96. int lane = threadIdx.x;
  97. int warp = threadIdx.y;
  98. int cta2 = blockIdx.y;
  99. return (lane * input_mult[BLOCK_X] +
  100. warp * input_mult[BLOCK_Y] +
  101. cta2 * input_mult[CTA]);
  102. }
  103. template <int output_vec_size>
  104. C10_HOST_DEVICE int output_idx() const {
  105. int lane = threadIdx.x;
  106. int warp = threadIdx.y;
  107. int cta1 = blockIdx.x;
  108. return (lane * output_mult[BLOCK_X] +
  109. warp * output_mult[BLOCK_Y] +
  110. cta1 * step_output) * output_vec_size;
  111. }
  112. C10_DEVICE int shared_memory_offset(int offset) const {
  113. return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
  114. }
  115. C10_DEVICE int staging_memory_offset(int cta2) const {
  116. int offset = cta2 + blockIdx.x * gridDim.y;
  117. if (!should_block_x_reduce()) {
  118. offset = threadIdx.x + offset * blockDim.x;
  119. }
  120. return offset;
  121. }
  122. };
  123. //TODO this will need to be different for more generic reduction functions
  124. namespace reducer {
  125. using scalar_t = ${scalar_type};
  126. using arg_t = ${reduction_accum_type};
  127. using out_scalar_t = ${result_type};
  128. inline __device__ ${functor}
  129. inline __device__ out_scalar_t project(arg_t arg) {
  130. return (out_scalar_t) arg;
  131. }
  132. inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
  133. return WARP_SHFL_DOWN(arg, offset);
  134. }
  135. inline __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
  136. return acc;
  137. }
  138. // wrap a normal reduction that ignores the index
  139. inline __device__ arg_t reduce(arg_t acc, arg_t val, int64_t idx) {
  140. return combine(acc, val);
  141. }
  142. }
  143. struct ReduceJitOp {
  144. using scalar_t = ${scalar_type};
  145. using arg_t = ${reduction_accum_type};
  146. using out_scalar_t = ${result_type};
  147. using InputCalculator = OffsetCalculator<1>;
  148. using OutputCalculator = OffsetCalculator<2>;
  149. // static constexpr bool can_accumulate_in_output =
  150. // std::is_convertible<arg_t, out_scalar_t>::value
  151. // && std::is_convertible<out_scalar_t, arg_t>::value;
  152. static constexpr int input_vec_size = ReduceConfig::input_vec_size;
  153. arg_t ident;
  154. ReduceConfig config;
  155. InputCalculator input_calc;
  156. OutputCalculator output_calc;
  157. const void* src;
  158. const char* dst[2]; //it accepts at most two destinations
  159. // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
  160. // output is not permissible
  161. void* acc_buf;
  162. // cta_buf used for accumulation between blocks during global reduction
  163. void* cta_buf;
  164. int* semaphores;
  165. int64_t base_idx;
  166. bool accumulate;
  167. bool final_output;
  168. int noutputs;
  169. C10_DEVICE void run() const {
  170. extern __shared__ char shared_memory[];
  171. uint32_t output_idx = config.output_idx<${output_vec_size}>();
  172. uint32_t input_idx = config.input_idx();
  173. auto base_offsets1 = output_calc.get(output_idx)[1];
  174. using arg_vec_t = Array<arg_t, ${output_vec_size}>;
  175. arg_vec_t value;
  176. if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
  177. const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
  178. value = thread_reduce<${output_vec_size}>(input_slice);
  179. }
  180. if (config.should_block_y_reduce()) {
  181. value = block_y_reduce<${output_vec_size}>(value, shared_memory);
  182. }
  183. if (config.should_block_x_reduce()) {
  184. value = block_x_reduce<${output_vec_size}>(value, shared_memory);
  185. }
  186. using out_ptr_vec_t = Array<out_scalar_t*, ${output_vec_size}>;
  187. using offset_vec_t = Array<uint32_t, ${output_vec_size}>;
  188. offset_vec_t base_offsets;
  189. out_ptr_vec_t out;
  190. #pragma unroll
  191. for (int i = 0; i < ${output_vec_size}; i++) {
  192. base_offsets[i] = output_calc.get(output_idx + i)[0];
  193. out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
  194. }
  195. arg_vec_t* acc = nullptr;
  196. if (acc_buf != nullptr) {
  197. size_t numerator = sizeof(arg_t);
  198. size_t denominator = sizeof(out_scalar_t);
  199. reduce_fraction(numerator, denominator);
  200. acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
  201. }
  202. if (config.should_global_reduce()) {
  203. value = global_reduce<${output_vec_size}>(value, acc, shared_memory);
  204. } else if (config.should_store(output_idx)) {
  205. if (accumulate) {
  206. #pragma unroll
  207. for (int i = 0; i < ${output_vec_size}; i++) {
  208. value[i] = reducer::translate_idx(value[i], base_idx);
  209. }
  210. }
  211. if (acc == nullptr) {
  212. if (accumulate) {
  213. value = accumulate_in_output<${output_vec_size}>(out, value);
  214. }
  215. if (final_output) {
  216. set_results_to_output<${output_vec_size}>(value, base_offsets);
  217. } else {
  218. #pragma unroll
  219. for (int i = 0; i < ${output_vec_size}; i++) {
  220. *(out[i]) = get_accumulated_output(out[i], value[i]);
  221. }
  222. }
  223. } else {
  224. if (accumulate) {
  225. #pragma unroll
  226. for (int i = 0; i < ${output_vec_size}; i++) {
  227. value[i] = reducer::combine((*acc)[i], value[i]);
  228. }
  229. }
  230. if (final_output) {
  231. set_results_to_output<${output_vec_size}>(value, base_offsets);
  232. } else {
  233. *acc = value;
  234. }
  235. }
  236. }
  237. }
  238. template <int output_vec_size>
  239. C10_DEVICE Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
  240. if (config.vectorize_input) {
  241. assert(output_vec_size == 1);
  242. // reduce at the header of input_slice where memory is not aligned,
  243. // so that thread_reduce will have an aligned memory to work on.
  244. return {input_vectorized_thread_reduce_impl(data)};
  245. } else {
  246. uint32_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
  247. bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
  248. if (is_contiguous) {
  249. return thread_reduce_impl<output_vec_size>(data, [](uint32_t idx) { return idx; });
  250. } else if (input_calc.dims == 1) {
  251. return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return idx * element_stride; });
  252. } else {
  253. return thread_reduce_impl<output_vec_size>(data, [&](uint32_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
  254. }
  255. }
  256. }
  257. C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
  258. uint32_t end = config.num_inputs;
  259. // Handle the head of input slice where data is not aligned
  260. arg_t value = ident;
  261. constexpr int align_bytes = alignof(aligned_vector<scalar_t, input_vec_size>);
  262. constexpr int align_elements = align_bytes / sizeof(scalar_t);
  263. int shift = ((int64_t)data) % align_bytes / sizeof(scalar_t);
  264. if (shift > 0) {
  265. data -= shift;
  266. end += shift;
  267. if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
  268. value = reducer::reduce(value, data[threadIdx.x], threadIdx.x - shift);
  269. }
  270. end -= align_elements;
  271. data += align_elements;
  272. shift = align_elements - shift;
  273. }
  274. // Do the vectorized reduction
  275. using load_t = aligned_vector<scalar_t, input_vec_size>;
  276. uint32_t idx = config.input_idx();
  277. const uint32_t stride = config.step_input;
  278. // Multiple accumulators to remove dependency between unrolled loops.
  279. arg_t value_list[input_vec_size];
  280. value_list[0] = value;
  281. #pragma unroll
  282. for (int i = 1; i < input_vec_size; i++) {
  283. value_list[i] = ident;
  284. }
  285. scalar_t values[input_vec_size];
  286. load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
  287. while (idx * input_vec_size + input_vec_size - 1 < end) {
  288. *values_vector = reinterpret_cast<const load_t*>(data)[idx];
  289. #pragma unroll
  290. for (uint32_t i = 0; i < input_vec_size; i++) {
  291. value_list[i] = reducer::reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
  292. }
  293. idx += stride;
  294. }
  295. // tail
  296. uint32_t tail_start = end - end % input_vec_size;
  297. if (config.should_reduce_tail()) {
  298. int idx = tail_start + threadIdx.x;
  299. if (idx < end) {
  300. value_list[0] = reducer::reduce(value_list[0], data[idx], idx + shift);
  301. }
  302. }
  303. // combine accumulators
  304. #pragma unroll
  305. for (int i = 1; i < input_vec_size; i++) {
  306. value_list[0] = reducer::combine(value_list[0], value_list[i]);
  307. }
  308. return value_list[0];
  309. }
  310. template <int output_vec_size, typename offset_calc_t>
  311. C10_DEVICE Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
  312. uint32_t idx = config.input_idx();
  313. const uint32_t end = config.num_inputs;
  314. const uint32_t stride = config.step_input;
  315. const int vt0=${vt0};
  316. using arg_vec_t = Array<arg_t, output_vec_size>;
  317. using load_t = aligned_vector<scalar_t, output_vec_size>;
  318. const load_t* data = reinterpret_cast<const load_t*>(data_);
  319. // Multiple accumulators to remove dependency between unrolled loops.
  320. arg_vec_t value_list[vt0];
  321. #pragma unroll
  322. for (int i = 0; i < vt0; i++) {
  323. #pragma unroll
  324. for (int j = 0; j < output_vec_size; j++) {
  325. value_list[i][j] = ident;
  326. }
  327. }
  328. load_t values[vt0];
  329. while (idx + (vt0 - 1) * stride < end) {
  330. #pragma unroll
  331. for (uint32_t i = 0; i < vt0; i++) {
  332. values[i] = data[calc(idx + i * stride) / output_vec_size];
  333. }
  334. #pragma unroll
  335. for (uint32_t i = 0; i < vt0; i++) {
  336. #pragma unroll
  337. for (uint32_t j = 0; j < output_vec_size; j++) {
  338. value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx + i * stride);
  339. }
  340. }
  341. idx += stride * vt0;
  342. }
  343. // tail
  344. int idx_ = idx;
  345. #pragma unroll
  346. for (uint32_t i = 0; i < vt0; i++) {
  347. if (idx >= end) {
  348. break;
  349. }
  350. values[i] = data[calc(idx) / output_vec_size];
  351. idx += stride;
  352. }
  353. idx = idx_;
  354. #pragma unroll
  355. for (uint32_t i = 0; i < vt0; i++) {
  356. if (idx >= end) {
  357. break;
  358. }
  359. #pragma unroll
  360. for (uint32_t j = 0; j < output_vec_size; j++) {
  361. value_list[i][j] = reducer::reduce(value_list[i][j], values[i].val[j], idx);
  362. }
  363. idx += stride;
  364. }
  365. // combine accumulators
  366. #pragma unroll
  367. for (int i = 1; i < vt0; i++) {
  368. #pragma unroll
  369. for (uint32_t j = 0; j < output_vec_size; j++) {
  370. value_list[0][j] = reducer::combine(value_list[0][j], value_list[i][j]);
  371. }
  372. }
  373. return value_list[0];
  374. }
  375. template <int output_vec_size>
  376. C10_DEVICE Array<arg_t, output_vec_size> block_x_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
  377. using args_vec_t = Array<arg_t, output_vec_size>;
  378. int dim_x = blockDim.x;
  379. args_vec_t* shared = (args_vec_t*)shared_memory;
  380. if (dim_x > warpSize) {
  381. int address_base = threadIdx.x + threadIdx.y*blockDim.x;
  382. shared[address_base] = value;
  383. for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
  384. __syncthreads();
  385. if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
  386. args_vec_t other = shared[address_base + offset];
  387. #pragma unroll
  388. for (int i = 0; i < output_vec_size; i++) {
  389. value[i] = reducer::combine(value[i], other[i]);
  390. }
  391. shared[address_base] = value;
  392. }
  393. }
  394. dim_x = warpSize;
  395. }
  396. __syncthreads();
  397. for (int offset = 1; offset < dim_x; offset <<= 1) {
  398. #pragma unroll
  399. for (int i = 0; i < output_vec_size; i++) {
  400. arg_t other = reducer::warp_shfl_down(value[i], offset);
  401. value[i] = reducer::combine(value[i], other);
  402. }
  403. }
  404. return value;
  405. }
  406. template <int output_vec_size>
  407. C10_DEVICE Array<arg_t, output_vec_size> block_y_reduce(Array<arg_t, output_vec_size> value, char* shared_memory) const {
  408. using args_vec_t = Array<arg_t, output_vec_size>;
  409. args_vec_t* shared = (args_vec_t*)shared_memory;
  410. shared[config.shared_memory_offset(0)] = value;
  411. for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
  412. __syncthreads();
  413. if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
  414. args_vec_t other = shared[config.shared_memory_offset(offset)];
  415. #pragma unroll
  416. for (int i = 0; i < output_vec_size; i++) {
  417. value[i] = reducer::combine(value[i], other[i]);
  418. }
  419. shared[config.shared_memory_offset(0)] = value;
  420. }
  421. }
  422. return value;
  423. }
  424. )ESCAPE";
  425. const std::string reduction_template_1 = R"ESCAPE(
  426. C10_DEVICE bool mark_block_finished() const {
  427. __shared__ bool is_last_block_done_shared;
  428. __syncthreads();
  429. if (threadIdx.x == 0 && threadIdx.y == 0) {
  430. int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
  431. is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
  432. }
  433. __syncthreads();
  434. return is_last_block_done_shared;
  435. }
  436. template <int output_vec_size>
  437. C10_DEVICE Array<arg_t, output_vec_size> accumulate_in_output(
  438. Array<out_scalar_t*, output_vec_size> out,
  439. Array<arg_t, output_vec_size> value
  440. ) const {
  441. Array<arg_t, output_vec_size> ret;
  442. #pragma unroll
  443. for (int i = 0; i < output_vec_size; i++) {
  444. ret[i] = reducer::combine(*(out[i]), value[i]);
  445. }
  446. return ret;
  447. }
  448. C10_DEVICE out_scalar_t get_accumulated_output(
  449. out_scalar_t* out, arg_t value
  450. ) const {
  451. assert(!final_output);
  452. return (out_scalar_t)value;
  453. }
  454. template<class T>
  455. C10_DEVICE void set_results(const T x, const uint32_t base_offset) const {
  456. assert(noutputs == 1);
  457. auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
  458. *res = x;
  459. }
  460. //TODO - multi-output reduction - we won't be able to use thrust::pair
  461. //just explicitly specify typed output reads/writes
  462. //Currently implemented for max of two outputs
  463. // template<class T1, class T2>
  464. // C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
  465. // if (noutputs >= 1) {
  466. // auto res0 = (T1*)((char*)dst[0] + base_offset);
  467. // *res0 = x.first;
  468. // }
  469. // if (noutputs >= 2) {
  470. // // base offset is computed assuming element size being sizeof(T1), so we need to make a
  471. // // correction to obtain the correct base offset
  472. // auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
  473. // *res1 = x.second;
  474. // }
  475. // }
  476. template <int output_vec_size>
  477. C10_DEVICE void set_results_to_output(Array<arg_t, output_vec_size> value, Array<uint32_t, output_vec_size> base_offset) const {
  478. assert(final_output);
  479. #pragma unroll
  480. for (int i = 0; i < output_vec_size; i++) {
  481. set_results(reducer::project(value[i]), base_offset[i]);
  482. }
  483. }
  484. template <int output_vec_size>
  485. C10_DEVICE Array<arg_t, output_vec_size> global_reduce(Array<arg_t, output_vec_size> value, Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
  486. using arg_vec_t = Array<arg_t, output_vec_size>;
  487. using out_ptr_vec_t = Array<out_scalar_t*, output_vec_size>;
  488. using offset_vec_t = Array<uint32_t, output_vec_size>;
  489. arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
  490. uint32_t output_idx = config.output_idx<output_vec_size>();
  491. offset_vec_t base_offsets;
  492. out_ptr_vec_t out;
  493. #pragma unroll
  494. for (int i = 0; i < output_vec_size; i++) {
  495. base_offsets[i] = output_calc.get(output_idx + i)[0];
  496. out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
  497. }
  498. bool should_store = config.should_store(output_idx);
  499. if (should_store) {
  500. uint32_t offset = config.staging_memory_offset(blockIdx.y);
  501. reduce_buffer[offset] = value;
  502. }
  503. __threadfence(); // make sure writes are globally visible
  504. __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
  505. bool is_last_block_done = mark_block_finished();
  506. if (is_last_block_done) {
  507. value = ident;
  508. if (config.should_block_x_reduce()) {
  509. uint32_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
  510. uint32_t step = blockDim.x * blockDim.y;
  511. for (; input_offset < config.ctas_per_output; input_offset += step) {
  512. uint32_t idx = config.staging_memory_offset(input_offset);
  513. arg_vec_t next = reduce_buffer[idx];
  514. #pragma unroll
  515. for (int i = 0; i < output_vec_size; i++) {
  516. value[i] = reducer::combine(value[i], next[i]);
  517. }
  518. }
  519. } else {
  520. uint32_t input_offset = threadIdx.y;
  521. uint32_t step = blockDim.y;
  522. for (; input_offset < config.ctas_per_output; input_offset += step) {
  523. uint32_t idx = config.staging_memory_offset(input_offset);
  524. arg_vec_t next = reduce_buffer[idx];
  525. #pragma unroll
  526. for (int i = 0; i < output_vec_size; i++) {
  527. value[i] = reducer::combine(value[i], next[i]);
  528. }
  529. }
  530. }
  531. value = block_y_reduce(value, shared_memory);
  532. if (config.should_block_x_reduce()) {
  533. value = block_x_reduce<output_vec_size>(value, shared_memory);
  534. }
  535. if (should_store) {
  536. if (accumulate) {
  537. #pragma unroll
  538. for (int i = 0; i < output_vec_size; i++) {
  539. value[i] = reducer::translate_idx(value[i], base_idx);
  540. }
  541. }
  542. if (acc == nullptr) {
  543. if (accumulate) {
  544. value = accumulate_in_output<output_vec_size>(out, value);
  545. }
  546. if (final_output) {
  547. set_results_to_output<output_vec_size>(value, base_offsets);
  548. } else {
  549. #pragma unroll
  550. for (int i = 0; i < output_vec_size; i++) {
  551. *(out[i]) = get_accumulated_output(out[i], value[i]);
  552. }
  553. }
  554. } else {
  555. if (accumulate) {
  556. #pragma unroll
  557. for (int i = 0; i < output_vec_size; i++) {
  558. value[i] = reducer::combine((*acc)[i], value[i]);
  559. }
  560. }
  561. if (final_output) {
  562. set_results_to_output<output_vec_size>(value, base_offsets);
  563. } else {
  564. *acc = value;
  565. }
  566. }
  567. }
  568. }
  569. return value;
  570. }
  571. };
  572. extern "C"
  573. __launch_bounds__(${max_threads_lb}, 4)
  574. __global__ void reduction_${name}_kernel(ReduceJitOp r){
  575. r.run();
  576. }
  577. )ESCAPE";
  578. const std::string reduction_template = reduction_template_0 + reduction_template_1;
  579. const std::string &get_reduction_template() {
  580. return reduction_template;
  581. }
  582. }}