Reduce.cuh 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355
  1. #pragma once
  2. #include <assert.h>
  3. #include <ATen/core/Array.h>
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <ATen/cuda/DeviceUtils.cuh>
  6. #include <ATen/cuda/detail/OffsetCalculator.cuh>
  7. #include <ATen/detail/FunctionTraits.h>
  8. #include <ATen/native/TensorIterator.h>
  9. #include <ATen/native/cuda/thread_constants.h>
  10. #include <ATen/native/cuda/MemoryAccess.cuh>
  11. #include <ATen/OpMathType.h>
  12. #include <c10/macros/Macros.h>
  13. #include <c10/cuda/CUDACachingAllocator.h>
  14. #include <functional>
  15. #include <iosfwd>
  16. #include <type_traits>
  17. #include <utility>
  18. #include <thrust/pair.h>
  19. #include <ATen/native/cuda/jit_utils.h>
  20. namespace at { namespace native {
  21. using at::detail::Array;
  22. static inline int64_t div_up(int64_t a, int64_t b) {
  23. return (a + b - 1) / b;
  24. }
  25. // returns floor(log2(n))
  26. static inline int last_pow2(int n) {
  27. n |= (n >> 1);
  28. n |= (n >> 2);
  29. n |= (n >> 4);
  30. n |= (n >> 8);
  31. n |= (n >> 16);
  32. return std::max(1, n - (n >> 1));
  33. }
  34. // returns reduced fraction numerator & denominator
  35. C10_HOST_DEVICE static void reduce_fraction(size_t &numerator, size_t &denominator) {
  36. // get GCD of num and denom using Euclid's algorithm.
  37. // Can replace this with std::gcd if we ever support c++17.
  38. size_t a = denominator;
  39. size_t b = numerator;
  40. while (b != 0) {
  41. a %= b;
  42. // swap(a,b)
  43. size_t tmp = a;
  44. a = b;
  45. b = tmp;
  46. }
  47. // a is now the GCD
  48. numerator /= a;
  49. denominator /= a;
  50. }
  51. //template for changing MAX_NUM_THREADS based on op dtype
  52. template <typename T>
  53. struct mnt_wrapper {
  54. static constexpr int MAX_NUM_THREADS = 512;
  55. };
  56. template <>
  57. struct mnt_wrapper <c10::complex<double>>{
  58. static constexpr int MAX_NUM_THREADS = 256;
  59. };
  60. constexpr int max_reduce_threads(c10::ScalarType type) {
  61. return type == kComplexDouble ? 256 : 512;
  62. }
  63. struct ReduceConfig {
  64. static constexpr int BLOCK_X = 0;
  65. static constexpr int BLOCK_Y = 1;
  66. static constexpr int CTA = 2;
  67. static constexpr int input_vec_size = 4;
  68. ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
  69. : element_size_bytes(element_size_bytes)
  70. , num_inputs(num_inputs)
  71. , num_outputs(num_outputs) {}
  72. int element_size_bytes;
  73. int num_inputs;
  74. int num_outputs;
  75. int step_input = 1;
  76. int step_output = 1;
  77. int ctas_per_output = 1;
  78. int input_mult[3] = {0, 0, 0};
  79. int output_mult[2] = {0, 0};
  80. int block_width;
  81. int block_height;
  82. int num_threads;
  83. bool vectorize_input = false;
  84. int output_vec_size = 1;
  85. template <typename T>
  86. void set_block_dimension(int64_t dim0, int64_t dim1) {
  87. const int max_num_threads = mnt_wrapper<T>::MAX_NUM_THREADS / output_vec_size;
  88. int dim0_pow2 = dim0 < max_num_threads ? static_cast<int>(last_pow2(dim0)) : max_num_threads;
  89. int dim1_pow2 = dim1 < max_num_threads ? static_cast<int>(last_pow2(dim1)) : max_num_threads;
  90. block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
  91. block_height = std::min(dim1_pow2, int(max_num_threads / block_width));
  92. block_width = std::min(dim0_pow2, int(max_num_threads / block_height));
  93. num_threads = block_width * block_height;
  94. }
  95. int split_input(int parallelism) {
  96. int step = step_input;
  97. step_input *= parallelism;
  98. return step;
  99. }
  100. int split_output(int parallelism) {
  101. int step = step_output;
  102. step_output *= parallelism;
  103. return step;
  104. }
  105. dim3 block() const {
  106. return dim3(block_width, block_height);
  107. }
  108. dim3 grid() const {
  109. return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output);
  110. }
  111. C10_HOST_DEVICE bool should_block_x_reduce() const {
  112. return input_mult[BLOCK_X] != 0;
  113. }
  114. C10_HOST_DEVICE bool should_block_y_reduce() const {
  115. return input_mult[BLOCK_Y] != 0;
  116. }
  117. C10_HOST_DEVICE bool should_global_reduce() const {
  118. return input_mult[CTA] != 0;
  119. }
  120. C10_DEVICE bool should_store(int output_idx) const {
  121. return output_idx < num_outputs &&
  122. (!should_block_x_reduce() || threadIdx.x == 0) &&
  123. (!should_block_y_reduce() || threadIdx.y == 0);
  124. }
  125. C10_DEVICE bool should_reduce_tail() const {
  126. return (!should_block_y_reduce() || threadIdx.y == 0) &&
  127. (!should_global_reduce() || blockIdx.y == 0);
  128. }
  129. C10_HOST_DEVICE int input_idx() const {
  130. int lane = threadIdx.x;
  131. int warp = threadIdx.y;
  132. int cta2 = blockIdx.y;
  133. return (lane * input_mult[BLOCK_X] +
  134. warp * input_mult[BLOCK_Y] +
  135. cta2 * input_mult[CTA]);
  136. }
  137. template <int output_vec_size>
  138. C10_HOST_DEVICE int output_idx() const {
  139. int lane = threadIdx.x;
  140. int warp = threadIdx.y;
  141. int cta1 = blockIdx.x;
  142. return (lane * output_mult[BLOCK_X] +
  143. warp * output_mult[BLOCK_Y] +
  144. cta1 * step_output) * output_vec_size;
  145. }
  146. C10_DEVICE int shared_memory_offset(int offset) const {
  147. return threadIdx.x + (threadIdx.y + offset) * blockDim.x;
  148. }
  149. C10_DEVICE int staging_memory_offset(int cta2) const {
  150. int offset = cta2 + blockIdx.x * gridDim.y;
  151. if (!should_block_x_reduce()) {
  152. offset = threadIdx.x + offset * blockDim.x;
  153. }
  154. return offset;
  155. }
  156. int shared_memory_size() const {
  157. if (!should_block_y_reduce() &&
  158. (!should_block_x_reduce() ||
  159. block_width <= at::cuda::warp_size())) {
  160. return 0;
  161. }
  162. return element_size_bytes * num_threads * output_vec_size;
  163. }
  164. int64_t global_memory_size() const {
  165. if (!should_global_reduce()) {
  166. return 0;
  167. }
  168. auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
  169. if (!should_block_x_reduce()) {
  170. size *= block().x * output_vec_size;
  171. }
  172. return size;
  173. }
  174. int semaphore_size() const {
  175. if (!should_global_reduce()) {
  176. return 0;
  177. }
  178. return sizeof(int) * grid().x;
  179. }
  180. int values_per_thread() const {
  181. return div_up(num_inputs, step_input);
  182. }
  183. };
  184. std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
  185. template<int nt, int output_vec_size, typename R>
  186. C10_LAUNCH_BOUNDS_2(nt, 4)
  187. __global__ void reduce_kernel(R reduction) {
  188. reduction.template run<output_vec_size>();
  189. }
  190. template <typename index_t>
  191. static OffsetCalculator<2, index_t> make_output_calculator(const TensorIterator& iter) {
  192. int num_reduce_dims = iter.num_reduce_dims();
  193. int num_output_dims = iter.ndim() - num_reduce_dims;
  194. int input_index = iter.ntensors() - 1;
  195. int output_index = 0;
  196. std::array<const int64_t*, 2> strides = {
  197. iter.strides(output_index).data() + num_reduce_dims,
  198. iter.strides(input_index).data() + num_reduce_dims,
  199. };
  200. auto shape = iter.shape().data() + num_reduce_dims;
  201. return OffsetCalculator<2, index_t>(num_output_dims, shape, strides.data());
  202. }
  203. template <typename index_t>
  204. static OffsetCalculator<1, index_t> make_input_calculator(const TensorIterator& iter) {
  205. int num_reduce_dims = iter.num_reduce_dims();
  206. int input_index = iter.ntensors() - 1;
  207. std::array<const int64_t*, 1> strides = {
  208. iter.strides(input_index).data(),
  209. };
  210. return OffsetCalculator<1, index_t>(num_reduce_dims, iter.shape().data(), strides.data());
  211. }
  212. template <typename out_scalar_t, typename func_t>
  213. struct func_wrapper_t {
  214. using arg_t = typename binary_function_traits<func_t>::arg1_t;
  215. using scalar_t = typename binary_function_traits<func_t>::arg2_t;
  216. func_t combine;
  217. static inline __device__ out_scalar_t project(arg_t arg) {
  218. return (out_scalar_t) arg;
  219. }
  220. static inline __device__ arg_t warp_shfl_down(arg_t arg, int offset) {
  221. return WARP_SHFL_DOWN(arg, offset);
  222. }
  223. static __device__ arg_t translate_idx(arg_t acc, int64_t /*idx*/) {
  224. return acc;
  225. }
  226. func_wrapper_t(const func_t& op) : combine(op) {
  227. }
  228. // wrap a normal reduction that ignores the index
  229. __device__ arg_t reduce(arg_t acc, scalar_t val, int64_t idx) const {
  230. return combine(acc, val);
  231. }
  232. };
  233. template <typename scalar_t, typename func_t>
  234. func_wrapper_t<scalar_t, func_t> func_wrapper(const func_t& op) {
  235. return func_wrapper_t<scalar_t, func_t> { op };
  236. }
  237. template <typename scalar_t, typename out_scalar_t=scalar_t>
  238. struct ReduceJitOp {
  239. //ReduceJitOp is almost like ReduceOp, but it doesn't have ops functor that specifies reduction operations
  240. //Maybe we can find a way to unify ReduceOp and ReduceJitOp
  241. using InputCalculator = OffsetCalculator<1, uint32_t>;
  242. using OutputCalculator = OffsetCalculator<2, uint32_t>;
  243. //TODO for now arg_t is always opmath_t of the input, later we'll need to change it
  244. using arg_t = at::opmath_type<scalar_t>;
  245. static constexpr int input_vec_size = ReduceConfig::input_vec_size;
  246. //TODO - ReduceJitOp will probably need to be changed for reductions that need full functor,
  247. //not just wrapper
  248. arg_t ident;
  249. ReduceConfig config;
  250. InputCalculator input_calc;
  251. OutputCalculator output_calc;
  252. const void* src;
  253. const char* dst[2]; //it accepts at most two destinations
  254. // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
  255. // output is not permissible
  256. void* acc_buf;
  257. // cta_buf used for accumulation between blocks during global reduction
  258. void* cta_buf;
  259. int* semaphores;
  260. int64_t base_idx;
  261. bool accumulate;
  262. bool final_output;
  263. int noutputs;
  264. ReduceJitOp(
  265. ReduceConfig config,
  266. InputCalculator input_calc,
  267. OutputCalculator output_calc,
  268. const void* src,
  269. char* dst0,
  270. optional<char*> dst1,
  271. void* acc_buf,
  272. void* cta_buf,
  273. int* semaphores,
  274. arg_t ident,
  275. int noutputs,
  276. int64_t base_idx)
  277. : ident(ident),
  278. config(config),
  279. input_calc(input_calc),
  280. output_calc(output_calc),
  281. src(src),
  282. acc_buf(acc_buf),
  283. cta_buf(cta_buf),
  284. semaphores(semaphores),
  285. base_idx(base_idx),
  286. noutputs(noutputs) {
  287. dst[0] = dst0;
  288. if (dst1.has_value()) {
  289. dst[1] = dst1.value();
  290. }
  291. }
  292. };
  293. template <typename scalar_t, typename ops_t, typename index_t, typename out_scalar_t=scalar_t, int vt0=4>
  294. struct ReduceOp {
  295. using traits = function_traits<decltype(&ops_t::reduce)>;
  296. using arg_t = typename std::decay<typename traits::template arg<0>::type>::type;
  297. using InputCalculator = OffsetCalculator<1, index_t>;
  298. using OutputCalculator = OffsetCalculator<2, index_t>;
  299. static constexpr bool can_accumulate_in_output =
  300. std::is_convertible<arg_t, out_scalar_t>::value
  301. && std::is_convertible<out_scalar_t, arg_t>::value;
  302. static constexpr int input_vec_size = ReduceConfig::input_vec_size;
  303. ops_t ops;
  304. arg_t ident;
  305. ReduceConfig config;
  306. InputCalculator input_calc;
  307. OutputCalculator output_calc;
  308. const void* src;
  309. const char* dst[2]; //it accepts at most two destinations
  310. // acc_buf used for accumulation among sub Tensor Iterator when accumulation on
  311. // output is not permissible
  312. void* acc_buf;
  313. // cta_buf used for accumulation between blocks during global reduction
  314. void* cta_buf;
  315. int* semaphores;
  316. int64_t base_idx;
  317. bool accumulate;
  318. bool final_output;
  319. int noutputs;
  320. ReduceOp(
  321. ops_t ops,
  322. ReduceConfig config,
  323. InputCalculator input_calc,
  324. OutputCalculator output_calc,
  325. const void* src,
  326. char* dst0,
  327. optional<char*> dst1,
  328. void* acc_buf,
  329. void* cta_buf,
  330. int* semaphores,
  331. arg_t ident,
  332. int noutputs,
  333. int64_t base_idx)
  334. : ops(ops),
  335. ident(ident),
  336. config(config),
  337. input_calc(input_calc),
  338. output_calc(output_calc),
  339. src(src),
  340. acc_buf(acc_buf),
  341. cta_buf(cta_buf),
  342. semaphores(semaphores),
  343. base_idx(base_idx),
  344. noutputs(noutputs) {
  345. dst[0] = dst0;
  346. if (dst1.has_value()) {
  347. dst[1] = dst1.value();
  348. }
  349. }
  350. template <int output_vec_size>
  351. C10_DEVICE void run() const {
  352. extern __shared__ char shared_memory[];
  353. index_t output_idx = config.output_idx<output_vec_size>();
  354. index_t input_idx = config.input_idx();
  355. auto base_offsets1 = output_calc.get(output_idx)[1];
  356. using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
  357. arg_vec_t value;
  358. if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
  359. const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
  360. value = thread_reduce<output_vec_size>(input_slice);
  361. }
  362. if (config.should_block_y_reduce()) {
  363. value = block_y_reduce<output_vec_size>(value, shared_memory);
  364. }
  365. if (config.should_block_x_reduce()) {
  366. value = block_x_reduce<output_vec_size>(value, shared_memory);
  367. }
  368. using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
  369. using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
  370. offset_vec_t base_offsets;
  371. out_ptr_vec_t out;
  372. #pragma unroll
  373. for (int i = 0; i < output_vec_size; i++) {
  374. base_offsets[i] = output_calc.get(output_idx + i)[0];
  375. out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
  376. }
  377. arg_vec_t* acc = nullptr;
  378. if (acc_buf != nullptr) {
  379. size_t numerator = sizeof(arg_t);
  380. size_t denominator = sizeof(out_scalar_t);
  381. reduce_fraction(numerator, denominator);
  382. acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
  383. }
  384. if (config.should_global_reduce()) {
  385. value = global_reduce<output_vec_size>(value, acc, shared_memory);
  386. } else if (config.should_store(output_idx)) {
  387. if (accumulate) {
  388. #pragma unroll
  389. for (int i = 0; i < output_vec_size; i++) {
  390. value[i] = ops.translate_idx(value[i], base_idx);
  391. }
  392. }
  393. if (acc == nullptr) {
  394. if (accumulate) {
  395. value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
  396. }
  397. if (final_output) {
  398. set_results_to_output<output_vec_size>(value, base_offsets);
  399. } else {
  400. #pragma unroll
  401. for (int i = 0; i < output_vec_size; i++) {
  402. *(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
  403. }
  404. }
  405. } else {
  406. if (accumulate) {
  407. #pragma unroll
  408. for (int i = 0; i < output_vec_size; i++) {
  409. value[i] = ops.combine((*acc)[i], value[i]);
  410. }
  411. }
  412. if (final_output) {
  413. set_results_to_output<output_vec_size>(value, base_offsets);
  414. } else {
  415. *acc = value;
  416. }
  417. }
  418. }
  419. }
  420. template <int output_vec_size>
  421. C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
  422. if (config.vectorize_input) {
  423. assert(output_vec_size == 1);
  424. // reduce at the header of input_slice where memory is not aligned,
  425. // so that thread_reduce will have an aligned memory to work on.
  426. return {input_vectorized_thread_reduce_impl(data)};
  427. } else {
  428. index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
  429. bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
  430. if (is_contiguous) {
  431. return thread_reduce_impl<output_vec_size>(data, [](index_t idx) { return idx; });
  432. } else if (input_calc.dims == 1) {
  433. return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return idx * element_stride; });
  434. } else {
  435. return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
  436. }
  437. }
  438. }
  439. C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
  440. index_t end = config.num_inputs;
  441. // Handle the head of input slice where data is not aligned
  442. arg_t value = ident;
  443. constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, input_vec_size>);
  444. constexpr int align_elements = align_bytes / sizeof(scalar_t);
  445. int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t);
  446. if (shift > 0) {
  447. data -= shift;
  448. end += shift;
  449. if(threadIdx.x >= shift && threadIdx.x < align_elements && config.should_reduce_tail()){
  450. value = ops.reduce(value, c10::load(data + threadIdx.x), threadIdx.x - shift);
  451. }
  452. end -= align_elements;
  453. data += align_elements;
  454. shift = align_elements - shift;
  455. }
  456. // Do the vectorized reduction
  457. using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;
  458. index_t idx = config.input_idx();
  459. const index_t stride = config.step_input;
  460. // Multiple accumulators to remove dependency between unrolled loops.
  461. arg_t value_list[input_vec_size];
  462. value_list[0] = value;
  463. #pragma unroll
  464. for (int i = 1; i < input_vec_size; i++) {
  465. value_list[i] = ident;
  466. }
  467. while (idx * input_vec_size + input_vec_size - 1 < end) {
  468. const auto values_vec = memory::load_vector<input_vec_size>(data, idx);
  469. #pragma unroll
  470. for (index_t i = 0; i < input_vec_size; i++) {
  471. value_list[i] = ops.reduce(value_list[i], values_vec.val[i], shift + idx * input_vec_size + i);
  472. }
  473. idx += stride;
  474. }
  475. // tail
  476. index_t tail_start = end - end % input_vec_size;
  477. if (config.should_reduce_tail()) {
  478. int idx = tail_start + threadIdx.x;
  479. if (idx < end) {
  480. const auto value = c10::load(data + idx);
  481. value_list[0] = ops.reduce(value_list[0], value, idx + shift);
  482. }
  483. }
  484. // combine accumulators
  485. #pragma unroll
  486. for (int i = 1; i < input_vec_size; i++) {
  487. value_list[0] = ops.combine(value_list[0], value_list[i]);
  488. }
  489. return value_list[0];
  490. }
  491. template <int output_vec_size, typename offset_calc_t>
  492. C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
  493. index_t idx = config.input_idx();
  494. const index_t end = config.num_inputs;
  495. const index_t stride = config.step_input;
  496. using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
  497. using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
  498. // Multiple accumulators to remove dependency between unrolled loops.
  499. arg_vec_t value_list[vt0];
  500. #pragma unroll
  501. for (int i = 0; i < vt0; i++) {
  502. #pragma unroll
  503. for (int j = 0; j < output_vec_size; j++) {
  504. value_list[i][j] = ident;
  505. }
  506. }
  507. load_t values[vt0];
  508. while (idx + (vt0 - 1) * stride < end) {
  509. #pragma unroll
  510. for (index_t i = 0; i < vt0; i++) {
  511. const auto offset = calc(idx + i * stride) / output_vec_size;
  512. values[i] = memory::load_vector<output_vec_size>(data_, offset);
  513. }
  514. #pragma unroll
  515. for (index_t i = 0; i < vt0; i++) {
  516. #pragma unroll
  517. for (index_t j = 0; j < output_vec_size; j++) {
  518. value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
  519. }
  520. }
  521. idx += stride * vt0;
  522. }
  523. // tail
  524. int idx_ = idx;
  525. #pragma unroll
  526. for (index_t i = 0; i < vt0; i++) {
  527. if (idx >= end) {
  528. break;
  529. }
  530. const auto offset = calc(idx) / output_vec_size;
  531. values[i] = memory::load_vector<output_vec_size>(data_, offset);
  532. idx += stride;
  533. }
  534. idx = idx_;
  535. #pragma unroll
  536. for (index_t i = 0; i < vt0; i++) {
  537. if (idx >= end) {
  538. break;
  539. }
  540. #pragma unroll
  541. for (index_t j = 0; j < output_vec_size; j++) {
  542. value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
  543. }
  544. idx += stride;
  545. }
  546. // combine accumulators
  547. #pragma unroll
  548. for (int i = 1; i < vt0; i++) {
  549. #pragma unroll
  550. for (index_t j = 0; j < output_vec_size; j++) {
  551. value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
  552. }
  553. }
  554. return value_list[0];
  555. }
  556. template <int output_vec_size>
  557. C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_x_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
  558. using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
  559. int dim_x = blockDim.x;
  560. args_vec_t* shared = (args_vec_t*)shared_memory;
  561. if (dim_x > warpSize) {
  562. int address_base = threadIdx.x + threadIdx.y*blockDim.x;
  563. shared[address_base] = value;
  564. for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
  565. __syncthreads();
  566. if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
  567. args_vec_t other = shared[address_base + offset];
  568. #pragma unroll
  569. for (int i = 0; i < output_vec_size; i++) {
  570. value[i] = ops.combine(value[i], other[i]);
  571. }
  572. shared[address_base] = value;
  573. }
  574. }
  575. dim_x = warpSize;
  576. }
  577. __syncthreads();
  578. for (int offset = 1; offset < dim_x; offset <<= 1) {
  579. #pragma unroll
  580. for (int i = 0; i < output_vec_size; i++) {
  581. arg_t other = ops.warp_shfl_down(value[i], offset);
  582. value[i] = ops.combine(value[i], other);
  583. }
  584. }
  585. return value;
  586. }
  587. template <int output_vec_size>
  588. C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_y_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
  589. using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
  590. args_vec_t* shared = (args_vec_t*)shared_memory;
  591. shared[config.shared_memory_offset(0)] = value;
  592. for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
  593. __syncthreads();
  594. if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
  595. args_vec_t other = shared[config.shared_memory_offset(offset)];
  596. #pragma unroll
  597. for (int i = 0; i < output_vec_size; i++) {
  598. value[i] = ops.combine(value[i], other[i]);
  599. }
  600. shared[config.shared_memory_offset(0)] = value;
  601. }
  602. }
  603. return value;
  604. }
  605. C10_DEVICE bool mark_block_finished() const {
  606. __shared__ bool is_last_block_done_shared;
  607. __syncthreads();
  608. if (threadIdx.x == 0 && threadIdx.y == 0) {
  609. int prev_blocks_finished = atomicAdd(&semaphores[blockIdx.x], 1);
  610. is_last_block_done_shared = (prev_blocks_finished == gridDim.y - 1);
  611. }
  612. __syncthreads();
  613. return is_last_block_done_shared;
  614. }
  615. template <int output_vec_size, bool can_acc>
  616. C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
  617. at::detail::Array<out_scalar_t*, output_vec_size> out,
  618. at::detail::Array<arg_t, output_vec_size> value,
  619. typename std::enable_if<can_acc>::type* = nullptr
  620. ) const {
  621. at::detail::Array<arg_t, output_vec_size> ret;
  622. #pragma unroll
  623. for (int i = 0; i < output_vec_size; i++) {
  624. ret[i] = ops.combine(*(out[i]), value[i]);
  625. }
  626. return ret;
  627. }
  628. template <bool can_acc>
  629. C10_DEVICE out_scalar_t get_accumulated_output(
  630. out_scalar_t* out, arg_t value,
  631. typename std::enable_if<can_acc>::type* = nullptr
  632. ) const {
  633. assert(!final_output);
  634. return (out_scalar_t)value;
  635. }
  636. // This function should never be called --
  637. // it's the version of `accumulate_in_output`
  638. // when accumulation in the output is not possible.
  639. template <int output_vec_size, bool can_acc>
  640. C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
  641. at::detail::Array<out_scalar_t*, output_vec_size>,
  642. at::detail::Array<arg_t, output_vec_size>,
  643. typename std::enable_if<!can_acc>::type* = nullptr
  644. ) const {
  645. assert(false); // can't use AT_ASSERT in Cuda.
  646. return arg_t {};
  647. }
  648. // This function should never be called --
  649. // it's the version of `get_accumulated_output`
  650. // when accumulation in the output is not possible.
  651. template <bool can_acc>
  652. C10_DEVICE out_scalar_t get_accumulated_output(
  653. out_scalar_t* out, arg_t value,
  654. typename std::enable_if<!can_acc>::type* = nullptr
  655. ) const {
  656. assert(false);
  657. return *out;
  658. }
  659. template<class T>
  660. C10_DEVICE void set_results(const T x, const index_t base_offset) const {
  661. assert(noutputs == 1);
  662. auto res = (out_scalar_t*)((char*)dst[0] + base_offset);
  663. *res = x;
  664. }
  665. //Currently implemented for max of two outputs
  666. template<class T1, class T2>
  667. C10_DEVICE void set_results(const thrust::pair<T1, T2> x, const index_t base_offset) const {
  668. if (noutputs >= 1) {
  669. auto res0 = (T1*)((char*)dst[0] + base_offset);
  670. *res0 = x.first;
  671. }
  672. if (noutputs >= 2) {
  673. // base offset is computed assuming element size being sizeof(T1), so we need to make a
  674. // correction to obtain the correct base offset
  675. auto res1 = (T2*) ((char *) dst[1] + base_offset / sizeof(T1) * sizeof(T2));
  676. *res1 = x.second;
  677. }
  678. }
  679. template <int output_vec_size>
  680. C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
  681. assert(final_output);
  682. #pragma unroll
  683. for (int i = 0; i < output_vec_size; i++) {
  684. set_results(ops.project(value[i]), base_offset[i]);
  685. }
  686. }
  687. template <int output_vec_size>
  688. C10_DEVICE at::detail::Array<arg_t, output_vec_size> global_reduce(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
  689. using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
  690. using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
  691. using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
  692. arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
  693. index_t output_idx = config.output_idx<output_vec_size>();
  694. offset_vec_t base_offsets;
  695. out_ptr_vec_t out;
  696. #pragma unroll
  697. for (int i = 0; i < output_vec_size; i++) {
  698. base_offsets[i] = output_calc.get(output_idx + i)[0];
  699. out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
  700. }
  701. bool should_store = config.should_store(output_idx);
  702. if (should_store) {
  703. index_t offset = config.staging_memory_offset(blockIdx.y);
  704. reduce_buffer[offset] = value;
  705. }
  706. __threadfence(); // make sure writes are globally visible
  707. __syncthreads(); // if multiple warps in this block wrote to staging, make sure they're all done
  708. bool is_last_block_done = mark_block_finished();
  709. if (is_last_block_done) {
  710. value = ident;
  711. if (config.should_block_x_reduce()) {
  712. index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
  713. index_t step = blockDim.x * blockDim.y;
  714. for (; input_offset < config.ctas_per_output; input_offset += step) {
  715. index_t idx = config.staging_memory_offset(input_offset);
  716. arg_vec_t next = reduce_buffer[idx];
  717. #pragma unroll
  718. for (int i = 0; i < output_vec_size; i++) {
  719. value[i] = ops.combine(value[i], next[i]);
  720. }
  721. }
  722. } else {
  723. index_t input_offset = threadIdx.y;
  724. index_t step = blockDim.y;
  725. for (; input_offset < config.ctas_per_output; input_offset += step) {
  726. index_t idx = config.staging_memory_offset(input_offset);
  727. arg_vec_t next = reduce_buffer[idx];
  728. #pragma unroll
  729. for (int i = 0; i < output_vec_size; i++) {
  730. value[i] = ops.combine(value[i], next[i]);
  731. }
  732. }
  733. }
  734. value = block_y_reduce(value, shared_memory);
  735. if (config.should_block_x_reduce()) {
  736. value = block_x_reduce<output_vec_size>(value, shared_memory);
  737. }
  738. if (should_store) {
  739. if (accumulate) {
  740. #pragma unroll
  741. for (int i = 0; i < output_vec_size; i++) {
  742. value[i] = ops.translate_idx(value[i], base_idx);
  743. }
  744. }
  745. if (acc == nullptr) {
  746. if (accumulate) {
  747. value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
  748. }
  749. if (final_output) {
  750. set_results_to_output<output_vec_size>(value, base_offsets);
  751. } else {
  752. #pragma unroll
  753. for (int i = 0; i < output_vec_size; i++) {
  754. *(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
  755. }
  756. }
  757. } else {
  758. if (accumulate) {
  759. #pragma unroll
  760. for (int i = 0; i < output_vec_size; i++) {
  761. value[i] = ops.combine((*acc)[i], value[i]);
  762. }
  763. }
  764. if (final_output) {
  765. set_results_to_output<output_vec_size>(value, base_offsets);
  766. } else {
  767. *acc = value;
  768. }
  769. }
  770. }
  771. }
  772. return value;
  773. }
  774. };
  775. template<int max_threads, typename R>
  776. static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
  777. dim3 block = config.block();
  778. dim3 grid = config.grid();
  779. auto stream = at::cuda::getCurrentCUDAStream();
  780. int shared_memory = config.shared_memory_size();
  781. switch(config.output_vec_size) {
  782. case 4:
  783. reduce_kernel<max_threads / 4, 4, R><<<grid, block, shared_memory, stream>>>(reduction);
  784. C10_CUDA_KERNEL_LAUNCH_CHECK();
  785. break;
  786. case 2:
  787. reduce_kernel<max_threads / 2, 2, R><<<grid, block, shared_memory, stream>>>(reduction);
  788. C10_CUDA_KERNEL_LAUNCH_CHECK();
  789. break;
  790. default:
  791. reduce_kernel<max_threads / 1, 1, R><<<grid, block, shared_memory, stream>>>(reduction);
  792. C10_CUDA_KERNEL_LAUNCH_CHECK();
  793. }
  794. }
  795. inline void launch_jitted_reduce_kernel(
  796. std::mutex &jiterator_mutex,
  797. std::array<at::cuda::jit::NvrtcFunction, 3> &fn_cache,
  798. const at::cuda::jit::KernelDescriptor &desc,
  799. int vt0, const ReduceConfig& config, void *reduction) {
  800. dim3 block = config.block();
  801. dim3 grid = config.grid();
  802. int shared_memory = config.shared_memory_size();
  803. at::cuda::jit::NvrtcFunction* fn_ptr;
  804. switch(config.output_vec_size) {
  805. case 4:
  806. fn_ptr = &fn_cache[0];
  807. break;
  808. case 2:
  809. fn_ptr = &fn_cache[1];
  810. break;
  811. default:
  812. fn_ptr = &fn_cache[2];
  813. }
  814. if (!fn_ptr->function) {
  815. int max_threads_codegen =
  816. max_reduce_threads(desc.f_inputs_type) / config.output_vec_size;
  817. auto code = at::cuda::jit::generate_reduction_code(
  818. desc, vt0, true, false, config.output_vec_size, max_threads_codegen);
  819. *fn_ptr = at::cuda::jit::jit_pwise_function(code, "reduction_" + desc.name);
  820. }
  821. constexpr int kernel_args = 1;
  822. void* args[kernel_args];
  823. args[0] = reduction;
  824. at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, grid, block, shared_memory);
  825. }
  826. class AccumulationBuffer {
  827. public:
  828. AccumulationBuffer() {}
  829. AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) {
  830. out_ptr_ = (char*)out_ptr;
  831. if (out_t_size >= acc_t_size) {
  832. // reusing output buffer for accumulation.
  833. acc_ptr_ = (char*)out_ptr;
  834. numerator_ = 1;
  835. denominator_ = 1;
  836. } else {
  837. auto& allocator = *c10::cuda::CUDACachingAllocator::get();
  838. buffer_ = allocator.allocate(size);
  839. acc_ptr_ = (char*)buffer_.get();
  840. numerator_ = acc_t_size;
  841. denominator_ = out_t_size;
  842. reduce_fraction(numerator_, denominator_);
  843. }
  844. }
  845. char* get_acc_slice(char* out_ptr) {
  846. if (acc_ptr_ == nullptr) {
  847. return nullptr;
  848. }
  849. return acc_ptr_ + ((out_ptr - out_ptr_) * numerator_ / denominator_);
  850. }
  851. private:
  852. char* acc_ptr_ = nullptr;
  853. char* out_ptr_ = nullptr;
  854. size_t numerator_;
  855. size_t denominator_;
  856. at::DataPtr buffer_;
  857. };
  858. template <typename scalar_t>
  859. int get_output_vec_size(const TensorIterator &iter) {
  860. int vec_size = 4;
  861. auto update_vec_size = [&vec_size](uint64_t n) {
  862. while(n % vec_size != 0) {
  863. vec_size /= 2;
  864. }
  865. };
  866. uint64_t base_address = reinterpret_cast<uint64_t>(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t);
  867. update_vec_size(base_address);
  868. const int output_index = iter.num_reduce_dims();
  869. update_vec_size(iter.shape()[output_index]);
  870. int j = 0;
  871. for(auto i : iter.strides(iter.noutputs())) {
  872. if (j != output_index) {
  873. update_vec_size(i / sizeof(scalar_t));
  874. }
  875. j++;
  876. }
  877. return vec_size;
  878. }
  879. template<typename arg_t, typename scalar_t, int vt0>
  880. ReduceConfig setReduceConfig(const TensorIterator& iter){
  881. // Start by assuming that each thread handles a single output and all
  882. // the inputs for that output.
  883. int64_t num_outputs = iter.num_output_elements();
  884. int64_t inputs_per_output = iter.numel() / num_outputs;
  885. int input_index = iter.ntensors() - 1;
  886. auto config = ReduceConfig(sizeof(arg_t), num_outputs, inputs_per_output);
  887. int64_t dim0;
  888. int64_t dim1;
  889. int64_t fastest_moving_stride;
  890. bool reduction_on_fastest_striding_dimension;
  891. if (iter.ndim() > 0) {
  892. // Adjust block size to map block width to fastest changing dimension of input
  893. // tensor. This grants the best possible memory accessing pattern, given that
  894. // for non-contiguous tensor with space in between, we cannot have perfect
  895. // memory coalescing.
  896. reduction_on_fastest_striding_dimension =
  897. (iter.num_reduce_dims() == iter.ndim()) ||
  898. (iter.strides(/*arg=*/input_index)[0] <
  899. iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
  900. // Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
  901. // dim0 & dim1 are more like the upper bound of the block dimension. The
  902. // actual launch config and reduction scheme is determined by setting values
  903. // to `config.input_mult` and `config.output_mult`.
  904. // We try to max out dim1 so that we have enough threads per CTA to deliver
  905. // performance for larger problem size.
  906. if (reduction_on_fastest_striding_dimension) {
  907. // Map block.x to the fastest reducing dimension. It implies:
  908. // 1. block_x_reduce is required.
  909. // 2. block.y now max out to num_outputs.
  910. dim0 = inputs_per_output;
  911. dim1 = num_outputs;
  912. fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
  913. } else {
  914. // Map block.x to the fastest non reducing dimension. It implies:
  915. // 1. block_x_reduce is turned off.
  916. // 2. block.y now max out to inputs_per_output.
  917. dim0 = num_outputs;
  918. dim1 = inputs_per_output;
  919. fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
  920. }
  921. } else {
  922. reduction_on_fastest_striding_dimension = true;
  923. fastest_moving_stride = sizeof(scalar_t);
  924. dim0 = 1;
  925. dim1 = 1;
  926. }
  927. // We do vectorization to gain better memory access, there are two cases which we call
  928. // "vectorize along input" and "vectorize along output". Note that the "input/output"
  929. // here does not mean we are vectorizing load/store instructions. We always only vectorize
  930. // load instructions.
  931. //
  932. // Case 1: "vectorize along input"
  933. // This case happens when we are reducing along fastest moving dimesion. In such case, threads
  934. // with the same threadIdx.y works on the same reduction cooperatively and will produce results
  935. // for the same ouput. In such case, values in each loaded vector always correspond to the same ouput.
  936. //
  937. // Case 2: "vectorize along output"
  938. // This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
  939. // threads with different threadIdx.x are independent and will produce results for different outputs.
  940. // In such case, values in each loaded vector always correspond to different outputs.
  941. if (fastest_moving_stride == sizeof(scalar_t)) {
  942. if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) {
  943. // Case 1: "vectorize along input"
  944. // Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
  945. // we should avoid vectorization.
  946. config.vectorize_input = true;
  947. dim0 /= config.input_vec_size;
  948. } else if (!reduction_on_fastest_striding_dimension) {
  949. // Case 2: "vectorize along output"
  950. config.output_vec_size = get_output_vec_size<scalar_t>(iter);
  951. dim0 /= config.output_vec_size;
  952. }
  953. }
  954. // Adjust block_width and block_height
  955. config.set_block_dimension<scalar_t>(dim0, dim1);
  956. int block_width = config.block_width;
  957. int block_height = config.block_height;
  958. if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) {
  959. // Split the input across lanes if the input is contiguous in the reduced
  960. // dimension. This will require reduction between threads using warp
  961. // shuffle instructions and shared memory (if block_width > warpSize).
  962. config.input_mult[0] = config.split_input(block_width);
  963. } else {
  964. // Otherwise split the output across lanes in a warp.
  965. config.output_mult[0] = config.split_output(block_width);
  966. }
  967. constexpr int min_values_per_thread = 16;
  968. constexpr int max_values_per_thread = 256;
  969. if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) {
  970. // Divide the input across warps in a thread-block, if that leaves at least
  971. // 16 elements to be summed by each thread. This will require inter-warp
  972. // reduction using shared memory.
  973. config.input_mult[1] = config.split_input(block_height);
  974. } else {
  975. // Otherwise, each warp handles a separate output.
  976. config.output_mult[1] = config.split_output(block_height);
  977. }
  978. const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / config.num_threads;
  979. const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  980. const int target_grid_size = num_mp * blocks_per_sm;
  981. int grid = config.grid().x;
  982. if (config.input_mult[1] != 0 && config.values_per_thread() >= max_values_per_thread && grid <= target_grid_size) {
  983. // Divide the input across thread-blocks if the amount of work per-thread
  984. // is large enough and the size of the output is small enough. This will
  985. // require a reduction using global memory.
  986. // If we decide to split input across blocks, as long as we can get enough
  987. // number of blocks (`target_grid_size`) to balance SM, we should still
  988. // make the number of values per thread large for best performance.
  989. int ctas_per_output1 = div_up(target_grid_size, grid);
  990. int ctas_per_output2 = div_up(config.values_per_thread(), min_values_per_thread);
  991. int ctas_per_output3 = div_up(config.values_per_thread(), max_values_per_thread);
  992. // We want the minimum of ctas_per_output1 and ctas_per_output2, so that each thread can have
  993. // a large number of values to deal with. But we don't want values_per_thread to be larger than
  994. // max_values_per_thread
  995. config.ctas_per_output = std::max(std::min<int>(ctas_per_output1, ctas_per_output2), ctas_per_output3);
  996. if (config.ctas_per_output > 1) {
  997. config.input_mult[2] = config.split_input(config.ctas_per_output);
  998. }
  999. }
  1000. return config;
  1001. };
  1002. template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
  1003. inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
  1004. AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
  1005. AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
  1006. using traits = function_traits<decltype(&ops_t::reduce)>;
  1007. using arg_t = typename traits::template arg<0>::type;
  1008. // at::Half/at::ComplexHalf overflows easily as it's range is very small.
  1009. // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
  1010. // set can_accumulate_in_output to False.
  1011. static constexpr bool is_inp_out_type_half_or_chalf =
  1012. (std::is_same<at::Half, scalar_t>::value &&
  1013. std::is_same<at::Half, out_scalar_t>::value) ||
  1014. (std::is_same<c10::complex<Half>, scalar_t>::value &&
  1015. std::is_same<c10::complex<Half>, out_scalar_t>::value);
  1016. // at::BFloat16 has lower precision and can lead to rounding errors.
  1017. // So when scalar_t and out_scalar_t are at::BFloat16, we
  1018. // set can_accumulate_in_output to False.
  1019. static constexpr bool is_inp_out_type_bfloat16 =
  1020. (std::is_same<at::BFloat16, scalar_t>::value &&
  1021. std::is_same<at::BFloat16, out_scalar_t>::value);
  1022. static constexpr bool can_accumulate_in_output =
  1023. std::is_convertible<arg_t, out_scalar_t>::value &&
  1024. !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
  1025. bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
  1026. std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
  1027. // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
  1028. // reused by all recursive function calls.
  1029. if (acc_buf_ptr == NULL) {
  1030. // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
  1031. // when accumulation in output is not possible.
  1032. if (!can_accumulate_in_output && !can_use_32bit_indexing) {
  1033. int64_t output_memory_size = iter.element_size(0);
  1034. for (int dim = 0; dim < iter.ndim(); dim++) {
  1035. output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
  1036. }
  1037. output_memory_size /= iter.element_size(0); //iter.strides is in bytes
  1038. owned_buf_ptr.reset(new AccumulationBuffer(sizeof(arg_t),
  1039. sizeof(out_scalar_t),
  1040. (char*) iter.data_ptr(0),
  1041. output_memory_size * sizeof(arg_t)));
  1042. } else {
  1043. owned_buf_ptr.reset(new AccumulationBuffer());
  1044. }
  1045. acc_buf_ptr = owned_buf_ptr.get();
  1046. }
  1047. if (!can_use_32bit_indexing) {
  1048. for (auto& sub_iter : iter.with_32bit_indexing()) {
  1049. int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
  1050. gpu_reduce_kernel<scalar_t, out_scalar_t, vt0>(sub_iter, ops, ident,
  1051. acc_buf_ptr, sub_iter_base_idx);
  1052. }
  1053. return;
  1054. }
  1055. const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
  1056. char* out_data = (char*)iter.data_ptr(0);
  1057. const auto noutputs = iter.noutputs();
  1058. optional<char*> out_data_extra;
  1059. if (noutputs > 1) {
  1060. out_data_extra = (char*)iter.data_ptr(1);
  1061. } else {
  1062. out_data_extra = nullopt;
  1063. }
  1064. char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
  1065. ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
  1066. at::DataPtr buffer;
  1067. at::DataPtr semaphores;
  1068. if (config.should_global_reduce()) {
  1069. auto& allocator = *c10::cuda::CUDACachingAllocator::get();
  1070. buffer = allocator.allocate(config.global_memory_size());
  1071. semaphores = allocator.allocate(config.semaphore_size());
  1072. auto stream = at::cuda::getCurrentCUDAStream();
  1073. AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
  1074. }
  1075. AT_ASSERT(can_use_32bit_indexing);
  1076. auto output_calc = make_output_calculator<uint32_t>(iter);
  1077. auto input_calc = make_input_calculator<uint32_t>(iter);
  1078. auto reduce = ReduceOp<scalar_t, ops_t, uint32_t, out_scalar_t, vt0>(
  1079. ops,
  1080. config,
  1081. input_calc,
  1082. output_calc,
  1083. in_data,
  1084. out_data,
  1085. out_data_extra,
  1086. acc_data,
  1087. buffer.get(),
  1088. (int*)semaphores.get(),
  1089. ident,
  1090. noutputs,
  1091. base_idx);
  1092. reduce.accumulate = iter.should_accumulate();
  1093. reduce.final_output = iter.is_final_output();
  1094. launch_reduce_kernel<mnt_wrapper<scalar_t>::MAX_NUM_THREADS>(config, reduce);
  1095. }
  1096. //TODO this is 100 lines of almost-copy-paste, because we have to have different template args for this function
  1097. //try unifying with gpu_reduce_kernel
  1098. template <char const* name, typename scalar_t, typename out_scalar_t, int vt0=4, typename ident_t=double>
  1099. inline void jitted_gpu_reduce_kernel(TensorIterator& iter, const std::string& func, ident_t ident=0,
  1100. AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
  1101. AT_ASSERT(iter.numel() > 0 && iter.ntensors() - iter.noutputs() == 1 && iter.noutputs() >= 1);
  1102. //TODO - this will be different for more complicated reductions, but for now reductions using
  1103. //func_wrapper all have arg_t = opmath
  1104. using arg_t = at::opmath_type<scalar_t>;
  1105. // at::Half/at::ComplexHalf overflows easily as it's range is very small.
  1106. // So when scalar_t and out_scalar_t are at::Half/at::ComplexHalf, we
  1107. // set can_accumulate_in_output to False.
  1108. static constexpr bool is_inp_out_type_half_or_chalf =
  1109. (std::is_same<at::Half, scalar_t>::value &&
  1110. std::is_same<at::Half, out_scalar_t>::value) ||
  1111. (std::is_same<c10::complex<Half>, scalar_t>::value &&
  1112. std::is_same<c10::complex<Half>, out_scalar_t>::value);
  1113. // at::BFloat16 has lower precision and can lead to rounding errors.
  1114. // So when scalar_t and out_scalar_t are at::BFloat16, we
  1115. // set can_accumulate_in_output to False.
  1116. static constexpr bool is_inp_out_type_bfloat16 =
  1117. (std::is_same<at::BFloat16, scalar_t>::value &&
  1118. std::is_same<at::BFloat16, out_scalar_t>::value);
  1119. static constexpr bool can_accumulate_in_output =
  1120. std::is_convertible<arg_t, out_scalar_t>::value &&
  1121. !(is_inp_out_type_half_or_chalf || is_inp_out_type_bfloat16);
  1122. bool can_use_32bit_indexing = iter.can_use_32bit_indexing();
  1123. std::unique_ptr<AccumulationBuffer> owned_buf_ptr;
  1124. // The acc_buf_ptr is a shared pointer. It is create at the first entrance and
  1125. // reused by all recursive function calls.
  1126. if (acc_buf_ptr == NULL) {
  1127. // acc_buf_ptr holds buffer used for accumulation among multiple sub_iter
  1128. // when accumulation in output is not possible.
  1129. if (!can_accumulate_in_output && !can_use_32bit_indexing) {
  1130. int64_t output_memory_size = iter.element_size(0);
  1131. for (int dim = 0; dim < iter.ndim(); dim++) {
  1132. output_memory_size = std::max(output_memory_size, iter.shape()[dim] * iter.strides(0)[dim]);
  1133. }
  1134. output_memory_size /= iter.element_size(0); //iter.strides is in bytes
  1135. owned_buf_ptr.reset(new AccumulationBuffer(sizeof(out_scalar_t), //TODO
  1136. sizeof(out_scalar_t),
  1137. (char*) iter.data_ptr(0),
  1138. output_memory_size * sizeof(out_scalar_t))); //TODO
  1139. } else {
  1140. owned_buf_ptr.reset(new AccumulationBuffer());
  1141. }
  1142. acc_buf_ptr = owned_buf_ptr.get();
  1143. }
  1144. if (!can_use_32bit_indexing) {
  1145. for (auto& sub_iter : iter.with_32bit_indexing()) {
  1146. int64_t sub_iter_base_idx = sub_iter.view_offsets()[0];
  1147. jitted_gpu_reduce_kernel<name, scalar_t, out_scalar_t, vt0>(sub_iter, func, ident,
  1148. acc_buf_ptr, sub_iter_base_idx);
  1149. }
  1150. return;
  1151. }
  1152. //TODO - for now we support a single input, we may be able to relax this constraint
  1153. const char* in_data = (char*)iter.data_ptr(iter.ntensors() - 1);
  1154. char* out_data = (char*)iter.data_ptr(0);
  1155. const auto noutputs = iter.noutputs();
  1156. optional<char*> out_data_extra;
  1157. if (noutputs > 1) {
  1158. out_data_extra = (char*)iter.data_ptr(1);
  1159. } else {
  1160. out_data_extra = nullopt;
  1161. }
  1162. char* acc_data = acc_buf_ptr->get_acc_slice(out_data);
  1163. ReduceConfig config = setReduceConfig<arg_t, scalar_t, vt0>(iter);
  1164. at::DataPtr buffer;
  1165. at::DataPtr semaphores;
  1166. if (config.should_global_reduce()) {
  1167. auto& allocator = *c10::cuda::CUDACachingAllocator::get();
  1168. buffer = allocator.allocate(config.global_memory_size());
  1169. semaphores = allocator.allocate(config.semaphore_size());
  1170. auto stream = at::cuda::getCurrentCUDAStream();
  1171. AT_CUDA_CHECK(cudaMemsetAsync(semaphores.get(), 0, config.semaphore_size(), stream));
  1172. }
  1173. AT_ASSERT(can_use_32bit_indexing);
  1174. auto output_calc = make_output_calculator<uint32_t>(iter);
  1175. auto input_calc = make_input_calculator<uint32_t>(iter);
  1176. auto reduce = ReduceJitOp<scalar_t, out_scalar_t>(
  1177. config,
  1178. input_calc,
  1179. output_calc,
  1180. in_data,
  1181. out_data,
  1182. out_data_extra,
  1183. acc_data,
  1184. buffer.get(),
  1185. (int*)semaphores.get(),
  1186. ident,
  1187. noutputs,
  1188. base_idx);
  1189. reduce.accumulate = iter.should_accumulate();
  1190. reduce.final_output = iter.is_final_output();
  1191. constexpr int nInputs = 1;
  1192. constexpr int nOutputs = 1;
  1193. static auto desc = at::cuda::jit::make_kernel_descriptor<
  1194. out_scalar_t, scalar_t>(name, func, nInputs, nOutputs);
  1195. static std::mutex jiterator_mutex;
  1196. static std::vector<std::array<at::cuda::jit::NvrtcFunction, 3>> fn_cache(c10::cuda::device_count());
  1197. auto &cache = fn_cache[iter.device().index()];
  1198. launch_jitted_reduce_kernel(
  1199. jiterator_mutex, cache, desc, vt0, config, &reduce);
  1200. }
  1201. }} // namespace at::native