DistributionTemplates.h 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. #pragma once
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/Dispatch.h>
  4. #include <ATen/ExpandBase.h>
  5. #include <ATen/native/TensorIterator.h>
  6. #include <ATen/native/cuda/Loops.cuh>
  7. #include <c10/util/Half.h>
  8. #include <ATen/cuda/CUDAApplyUtils.cuh>
  9. #include <ATen/cuda/CUDAContext.h>
  10. #include <ATen/cuda/detail/OffsetCalculator.cuh>
  11. #include <ATen/cuda/CUDAGraphsUtils.cuh>
  12. #include <ATen/detail/FunctionTraits.h>
  13. #include <ATen/core/DistributionsHelper.h>
  14. #include <curand.h>
  15. #include <curand_kernel.h>
  16. #include <curand_philox4x32_x.h>
  17. #include <cstdint>
  18. #include <limits>
  19. #include <utility>
  20. #include <mutex>
  21. #include <tuple>
  22. #include <type_traits>
  23. namespace at {
  24. namespace native {
  25. namespace {
  26. // launch bounds used for kernels utilizing TensorIterator
  27. const uint32_t block_size_bound = 256;
  28. const uint32_t grid_size_bound = 4;
  29. // number of randoms given by distributions like curand_uniform4, curand_uniform2_double
  30. // used in calculating philox offset.
  31. const uint32_t curand4_engine_calls = 4;
  32. // utility function that calculates proper philox_offset
  33. // for distributions utilizing TensorIterator. For distributions using
  34. // TensorIterator, we are using a grid-stride loop with each
  35. // thread yielding one element per thread. For the edge of the grid-stride
  36. // loop, if the tensor size is large, the unroll loop will kick in and the float4
  37. // from curand4 will start getting utilized (for common tensor sizes, we end up
  38. // using rand.x from each thread). Hence, the philox_offset is
  39. // (number of elements per thread * number of engine calls), which makes
  40. // sure that philox offset increment is not less than the number of randoms used
  41. // in each thread.
  42. std::tuple<uint64_t, dim3, dim3> calc_execution_policy(int64_t total_elements) {
  43. const uint64_t numel = static_cast<uint64_t>(total_elements);
  44. const uint32_t block_size = block_size_bound;
  45. const uint32_t unroll = curand4_engine_calls;
  46. dim3 dim_block(block_size);
  47. dim3 grid((numel + block_size - 1) / block_size);
  48. uint32_t blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / block_size;
  49. grid.x = std::min(
  50. static_cast<uint32_t>(at::cuda::getCurrentDeviceProperties()->multiProcessorCount) * blocks_per_sm,
  51. grid.x);
  52. //number of times random will be generated per thread, to offset philox counter in thc random state
  53. uint64_t counter_offset = ((numel - 1) / (block_size * grid.x * unroll) + 1)
  54. * curand4_engine_calls;
  55. return std::make_tuple(counter_offset, grid, dim_block);
  56. }
  57. // grid stride loop kernel for distributions
  58. template<typename accscalar_t, int unroll_factor, typename dist_t, typename transform_t>
  59. C10_LAUNCH_BOUNDS_2(block_size_bound, grid_size_bound)
  60. __global__ void distribution_elementwise_grid_stride_kernel(int numel,
  61. PhiloxCudaState philox_args,
  62. const dist_t dist_func,
  63. const transform_t transform_func) {
  64. auto seeds = at::cuda::philox::unpack(philox_args);
  65. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  66. curandStatePhilox4_32_10_t state;
  67. curand_init(std::get<0>(seeds),
  68. idx,
  69. std::get<1>(seeds),
  70. &state);
  71. int rounded_size = ((numel - 1)/(blockDim.x * gridDim.x * unroll_factor)+1) *
  72. blockDim.x * gridDim.x * unroll_factor;
  73. for(int linear_index = idx; linear_index < rounded_size; linear_index += blockDim.x * gridDim.x * unroll_factor) {
  74. auto rand = dist_func(&state);
  75. #pragma unroll
  76. for (int ii = 0; ii < unroll_factor; ii++) {
  77. int li = linear_index + blockDim.x * gridDim.x * ii;
  78. if (li < numel) {
  79. transform_func(li, static_cast<accscalar_t>((&rand.x)[ii]));
  80. }
  81. }
  82. __syncthreads();
  83. }
  84. }
  85. /**
  86. * distribution_nullary_kernel is analogous to gpu_kernel in
  87. * ATen/native/cuda/Loops.cuh. Like gpu_kernel, it uses
  88. * TensorIterator to launch a kernel. However, the differences are
  89. * - it launches a grid-stride loop based kernel. The kernel is not
  90. * generic like elementwise_kernel in Loops.cuh and is specialized
  91. * for the distribution kernels here.
  92. * - For big size tensors, we can launch multiple kernels recursively
  93. * (i.e. if (!iter.can_use_32bit_indexing())) and hence, the philox
  94. * offset calculation is done in this function.
  95. *
  96. * FIXME: Can we specialize elementwise_kernel and launch_kernel in Loops.cuh
  97. * to have grid-stride loop kernel and then use that to launch our distribution
  98. * kernels? Note that we need a grid-stride loop kernel because, we found by testing
  99. * that it achieves peak effective bandwidth.
  100. */
  101. template<typename scalar_t,
  102. typename accscalar_t,
  103. int unroll_factor,
  104. typename RNG,
  105. typename dist_t,
  106. typename transform_t>
  107. void distribution_nullary_kernel(at::TensorIteratorBase& iter,
  108. RNG gen,
  109. const dist_t& dist_func,
  110. const transform_t transform_func) {
  111. static_assert(unroll_factor >= 1, "unroll_factor must be >= 1.");
  112. int64_t numel = iter.numel();
  113. if (numel == 0) {
  114. return;
  115. }
  116. auto execution_policy = calc_execution_policy(numel);
  117. auto counter_offset = std::get<0>(execution_policy);
  118. auto grid = std::get<1>(execution_policy);
  119. auto block = std::get<2>(execution_policy);
  120. PhiloxCudaState rng_engine_inputs;
  121. {
  122. // See Note [Acquire lock when using random generators]
  123. std::lock_guard<std::mutex> lock(gen->mutex_);
  124. rng_engine_inputs = gen->philox_cuda_state(counter_offset);
  125. }
  126. if (!iter.can_use_32bit_indexing()) {
  127. for (auto& sub_iter : iter.with_32bit_indexing()) {
  128. distribution_nullary_kernel<scalar_t, accscalar_t, unroll_factor>(sub_iter,
  129. gen, dist_func, transform_func);
  130. }
  131. return;
  132. }
  133. char* out_data = (char*)iter.data_ptr(0);
  134. auto stream = at::cuda::getCurrentCUDAStream();
  135. if (iter.is_trivial_1d()) {
  136. auto strides = iter.get_inner_strides();
  137. int stride0 = strides[0];
  138. distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
  139. numel,
  140. rng_engine_inputs,
  141. dist_func,
  142. [=]__device__(int idx, accscalar_t rand) {
  143. scalar_t* out = (scalar_t*)&out_data[stride0 * idx];
  144. *out = transform_func(rand);
  145. }
  146. );
  147. C10_CUDA_KERNEL_LAUNCH_CHECK();
  148. } else {
  149. auto offset_calc = make_offset_calculator<1>(iter);
  150. distribution_elementwise_grid_stride_kernel<accscalar_t, unroll_factor><<<grid, block, 0, stream>>>(
  151. numel,
  152. rng_engine_inputs,
  153. dist_func,
  154. [=]__device__(int idx, accscalar_t rand) {
  155. auto offsets = offset_calc.get(idx);
  156. scalar_t* out = (scalar_t*)&out_data[offsets[0]];
  157. *out = transform_func(rand);
  158. }
  159. );
  160. C10_CUDA_KERNEL_LAUNCH_CHECK();
  161. }
  162. }
  163. // Binary kernel
  164. template <typename func_t, typename inp_offset_calc_t, typename out_offset_calc_t>
  165. __global__ void distribution_binary_elementwise_kernel(
  166. int numel,
  167. func_t f,
  168. PhiloxCudaState philox_args,
  169. typename function_traits<func_t>::result_type *output_data,
  170. const typename function_traits<func_t>::template arg<1>::type *input_data_1,
  171. const typename function_traits<func_t>::template arg<2>::type *input_data_2,
  172. inp_offset_calc_t inp_calc,
  173. out_offset_calc_t out_calc) {
  174. auto seeds = at::cuda::philox::unpack(philox_args);
  175. using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
  176. using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
  177. input_t_1 inputs_1[thread_work_size()];
  178. input_t_2 inputs_2[thread_work_size()];
  179. int base_index = block_work_size() * blockIdx.x;
  180. int remaining = std::min<int>(numel - base_index, block_work_size());
  181. curandStatePhilox4_32_10_t state;
  182. curand_init(std::get<0>(seeds),
  183. blockIdx.x * blockDim.x + threadIdx.x,
  184. std::get<1>(seeds),
  185. &state);
  186. // load data into registers
  187. int thread_idx = threadIdx.x;
  188. #pragma unroll
  189. for (int i = 0; i < thread_work_size(); i++) {
  190. if (thread_idx >= remaining) {
  191. break;
  192. }
  193. int input_idx = thread_idx + base_index;
  194. auto offsets = inp_calc.get(input_idx);
  195. inputs_1[i] = input_data_1[offsets[0]];
  196. inputs_2[i] = input_data_2[offsets[1]];
  197. thread_idx += num_threads();
  198. }
  199. // compute and store
  200. thread_idx = threadIdx.x;
  201. #pragma unroll
  202. for (int i = 0; i < thread_work_size(); i++) {
  203. if (thread_idx >= remaining) {
  204. break;
  205. }
  206. int input_idx = thread_idx + base_index;
  207. auto offsets = out_calc.get(input_idx);
  208. output_data[offsets[0]] = f(state, inputs_1[i], inputs_2[i]);
  209. thread_idx += num_threads();
  210. }
  211. }
  212. template <typename func_t>
  213. void distribution_binary_kernel(TensorIteratorBase &iter, PhiloxCudaState philox_args, const func_t &f) {
  214. static_assert(std::is_same<typename function_traits<func_t>::template arg<0>::type, curandStatePhilox4_32_10_t&>::value, "the first argument of functor must be curandStatePhilox4_32_10_t");
  215. using input_t_1 = typename function_traits<func_t>::template arg<1>::type;
  216. using input_t_2 = typename function_traits<func_t>::template arg<2>::type;
  217. using output_t = typename function_traits<func_t>::result_type;
  218. if (!iter.can_use_32bit_indexing()) {
  219. for (auto& sub_iter : iter.with_32bit_indexing()) {
  220. distribution_binary_kernel(sub_iter, philox_args, f);
  221. }
  222. return;
  223. }
  224. TORCH_INTERNAL_ASSERT_DEBUG_ONLY(iter.can_use_32bit_indexing());
  225. int64_t numel = iter.numel();
  226. if (numel == 0) {
  227. return;
  228. }
  229. output_t *output_data = static_cast<output_t *>(iter.data_ptr(0));
  230. const input_t_1 *input_data_1 = static_cast<const input_t_1 *>(iter.data_ptr(1));
  231. const input_t_2 *input_data_2 = static_cast<const input_t_2 *>(iter.data_ptr(2));
  232. int64_t grid = (numel + block_work_size() - 1) / block_work_size();
  233. auto stream = at::cuda::getCurrentCUDAStream();
  234. if (iter.is_contiguous()) {
  235. distribution_binary_elementwise_kernel<<<grid,num_threads(), 0, stream>>>(
  236. numel, f, philox_args, output_data, input_data_1, input_data_2,
  237. TrivialOffsetCalculator<2>(), TrivialOffsetCalculator<1>());
  238. C10_CUDA_KERNEL_LAUNCH_CHECK();
  239. } else {
  240. distribution_binary_elementwise_kernel<<<grid, num_threads(), 0, stream>>>(
  241. numel, f, philox_args, output_data, input_data_1, input_data_2,
  242. make_input_offset_calculator<2>(iter), make_output_offset_calculator(iter));
  243. C10_CUDA_KERNEL_LAUNCH_CHECK();
  244. }
  245. }
  246. } // namespace
  247. }} // namespace at::native
  248. namespace at {
  249. namespace native {
  250. namespace templates {
  251. namespace cuda {
  252. // ==================================================== Random ========================================================
  253. template<typename RNG>
  254. void random_from_to_kernel(TensorIteratorBase& iter, uint64_t range, int64_t base, RNG gen) {
  255. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_from_to_kernel_cuda", [&] {
  256. if ((
  257. std::is_same<scalar_t, int64_t>::value ||
  258. std::is_same<scalar_t, double>::value ||
  259. std::is_same<scalar_t, float>::value ||
  260. std::is_same<scalar_t, at::BFloat16>::value) && range >= 1ULL << 32)
  261. {
  262. // define lambda to mod with range and add base
  263. auto random_func = [range, base] __device__ (uint64_t rand) {
  264. return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
  265. };
  266. distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
  267. gen,
  268. [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
  269. ulonglong2 ret;
  270. uint4 rand_val = curand4(state);
  271. ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
  272. ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
  273. return ret;
  274. },
  275. random_func);
  276. } else {
  277. auto random_func = [range, base] __device__ (uint32_t rand) {
  278. return transformation::uniform_int_from_to<scalar_t>(rand, range, base);
  279. };
  280. distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
  281. gen,
  282. [] __device__ (curandStatePhilox4_32_10_t* state) {
  283. return curand4(state);
  284. },
  285. random_func);
  286. }
  287. });
  288. }
  289. // This is the special kernel to handle single specific case:
  290. // from(inclusive) = std::numeric_limits<int64_t>::lowest()
  291. // to(exclusive) = None (= std::numeric_limits<int64_t>::max() + 1)
  292. template<typename RNG>
  293. void random_full_64_bits_range_kernel(TensorIteratorBase& iter, RNG gen) {
  294. AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::BFloat16, iter.dtype(), "random_full_64_bits_range_kernel_cuda", [&] {
  295. if (std::is_same<scalar_t, int64_t>::value ||
  296. std::is_same<scalar_t, double>::value ||
  297. std::is_same<scalar_t, float>::value ||
  298. std::is_same<scalar_t, at::BFloat16>::value) {
  299. auto random_func = [] __device__ (uint64_t rand) {
  300. return transformation::uniform_int_full_range<scalar_t>(rand);
  301. };
  302. distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter,
  303. gen,
  304. [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
  305. ulonglong2 ret;
  306. uint4 rand_val = curand4(state);
  307. ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
  308. ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
  309. return ret;
  310. },
  311. random_func);
  312. } else {
  313. TORCH_CHECK(false, "random_full_64_bits_range_kernel_cuda handles only int64, double, float and bfloat16");
  314. }
  315. });
  316. }
  317. template<typename RNG>
  318. struct RandomFromToKernel {
  319. void operator()(TensorIteratorBase& iter, uint64_t range, int64_t base, c10::optional<Generator> gen) {
  320. random_from_to_kernel(iter, range, base, check_generator<RNG>(gen));
  321. }
  322. void operator()(TensorIteratorBase& iter, c10::optional<Generator> gen) {
  323. random_full_64_bits_range_kernel(iter, check_generator<RNG>(gen));
  324. }
  325. };
  326. template<typename RNG>
  327. void random_kernel(TensorIteratorBase& iter, RNG gen) {
  328. AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "random_kernel_cuda", [&] {
  329. if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
  330. auto random_func = [] __device__ (uint64_t rand) {
  331. return transformation::uniform_int<scalar_t>(rand);
  332. };
  333. distribution_nullary_kernel<scalar_t, uint64_t, curand4_engine_calls/2>(iter, gen,
  334. [] __device__ (curandStatePhilox4_32_10_t* state) -> ulonglong2 {
  335. ulonglong2 ret;
  336. uint4 rand_val = curand4(state);
  337. ret.x = (static_cast<uint64_t>(rand_val.x) << 32) | rand_val.y;
  338. ret.y = (static_cast<uint64_t>(rand_val.z) << 32) | rand_val.w;
  339. return ret;
  340. },
  341. random_func);
  342. } else {
  343. auto random_func = [] __device__ (uint32_t rand) {
  344. return transformation::uniform_int<scalar_t>(rand);
  345. };
  346. distribution_nullary_kernel<scalar_t, uint32_t, curand4_engine_calls>(iter,
  347. gen,
  348. [] __device__ (curandStatePhilox4_32_10_t* state) {
  349. return curand4(state);
  350. },
  351. random_func);
  352. }
  353. });
  354. }
  355. template<typename RNG>
  356. struct RandomKernel {
  357. void operator()(TensorIteratorBase& iter, RNG gen) {
  358. random_kernel(iter, gen);
  359. }
  360. };
  361. // ====================================================================================================================
  362. template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
  363. void uniform_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
  364. if (std::is_same<scalar_t, double>::value) {
  365. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
  366. gen,
  367. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
  368. transform);
  369. } else {
  370. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
  371. gen,
  372. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
  373. transform);
  374. }
  375. }
  376. template<typename scalar_t, typename accscalar_t, size_t curand4_engine_calls, typename RNG, typename transform_t>
  377. void normal_and_transform(TensorIteratorBase& iter, RNG gen, transform_t transform) {
  378. if (std::is_same<scalar_t, double>::value) {
  379. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
  380. gen,
  381. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal2_double(state); },
  382. transform);
  383. } else {
  384. distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
  385. gen,
  386. [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_normal4(state); },
  387. transform);
  388. }
  389. }
  390. // ==================================================== Normal ========================================================
  391. template<typename RNG>
  392. void normal_kernel(const TensorBase &self, double mean_, double std_, RNG gen) {
  393. auto iter = TensorIterator::borrowing_nullary_op(self);
  394. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_kernel_cuda", [&] {
  395. using accscalar_t = at::acc_type<scalar_t, true>;
  396. auto mean = static_cast<accscalar_t>(mean_);
  397. auto std = static_cast<accscalar_t>(std_);
  398. // define lambda to multiply std and add mean
  399. auto normal_func = [mean, std] __device__ (accscalar_t rand) {
  400. return static_cast<scalar_t>(transformation::normal<accscalar_t>(rand, mean, std));
  401. };
  402. normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, normal_func);
  403. });
  404. }
  405. template<typename RNG>
  406. struct NormalKernel {
  407. void operator()(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {
  408. normal_kernel(self, mean, std, check_generator<RNG>(gen));
  409. }
  410. };
  411. // ==================================================== Uniform ========================================================
  412. template<typename RNG>
  413. void uniform_kernel(TensorIteratorBase& iter, double from_, double to_, RNG gen) {
  414. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_kernel_cuda", [&] {
  415. auto from = static_cast<scalar_t>(from_);
  416. auto to = static_cast<scalar_t>(to_);
  417. using accscalar_t = at::acc_type<scalar_t, true>;
  418. auto range = static_cast<accscalar_t>(to-from);
  419. // define lambda to reverse bounds, multiply 'range' and add 'from_'
  420. auto uniform_func = [range, from] __device__ (accscalar_t rand) {
  421. // reverse the bounds of curand4 from (0, 1] to [0, 1)
  422. // Note that this method is from legacy THCTensorRandom and is likely to give
  423. // you more 0-s, since, the probability of gettings 1-s is higher than 0-s and
  424. // by reversing the bounds, we are flipping the probabilities of 1-s and 0-s.
  425. // BEFORE TOUCHING THIS CODE READ: https://github.com/pytorch/pytorch/issues/16706
  426. auto reverse_bound_rand = rand == static_cast<accscalar_t>(1.0) ? static_cast<accscalar_t>(0.0) : rand;
  427. return static_cast<scalar_t>(reverse_bound_rand * range + from);
  428. };
  429. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, uniform_func);
  430. });
  431. }
  432. template<typename RNG>
  433. struct UniformKernel {
  434. void operator()(TensorIteratorBase& iter, double from, double to, c10::optional<Generator> gen) {
  435. uniform_kernel(iter, from, to, check_generator<RNG>(gen));
  436. }
  437. };
  438. // ================================================== LogNormal =======================================================
  439. template<typename RNG>
  440. void log_normal_kernel(TensorIteratorBase& iter, double mean_, double std_, RNG gen) {
  441. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
  442. using accscalar_t = at::acc_type<scalar_t, true>;
  443. auto mean = static_cast<accscalar_t>(mean_);
  444. auto std = static_cast<accscalar_t>(std_);
  445. // define lambda for log_normal transformation
  446. auto log_normal_func = [mean, std] __device__ (accscalar_t rand) {
  447. return static_cast<scalar_t>(transformation::log_normal<accscalar_t>(transformation::normal<accscalar_t>(rand, mean, std)));
  448. };
  449. normal_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, log_normal_func);
  450. });
  451. }
  452. template<typename RNG>
  453. struct LogNormalKernel {
  454. void operator()(TensorIteratorBase& iter, double mean, double std, c10::optional<Generator> gen) {
  455. log_normal_kernel(iter, mean, std, check_generator<RNG>(gen));
  456. }
  457. };
  458. // =================================================== Geometric ======================================================
  459. template<typename RNG>
  460. void geometric_kernel(TensorIteratorBase& iter, double p, RNG gen) {
  461. AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
  462. using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
  463. // define lambda for geometric transformation
  464. auto geometric_func = [p] __device__ (accscalar_t rand) {
  465. return static_cast<scalar_t>(transformation::geometric<accscalar_t>(rand, p));
  466. };
  467. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, geometric_func);
  468. });
  469. }
  470. template<typename RNG>
  471. struct GeometricKernel {
  472. void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
  473. geometric_kernel(iter, p, check_generator<RNG>(gen));
  474. }
  475. };
  476. // ================================================== Exponential =====================================================
  477. template<typename RNG>
  478. void exponential_kernel(TensorIteratorBase& iter, double lambda_, RNG gen) {
  479. TORCH_CHECK(isFloatingType(iter.dtype()), "Exponential distribution is a continuous probability distribution. dtype must be a floating point but you specified ", iter.dtype());
  480. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
  481. using accscalar_t = at::acc_type<scalar_t, true>;
  482. auto lambda = static_cast<accscalar_t>(lambda_);
  483. // define lambda for exponential transformation
  484. auto exponential_func = [lambda] __device__ (accscalar_t rand) {
  485. return static_cast<scalar_t>(transformation::exponential<accscalar_t>(rand, lambda));
  486. };
  487. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, exponential_func);
  488. });
  489. }
  490. template<typename RNG>
  491. struct ExponentialKernel {
  492. void operator()(TensorIteratorBase& iter, double lambda, c10::optional<Generator> gen) {
  493. exponential_kernel(iter, lambda, check_generator<RNG>(gen));
  494. }
  495. };
  496. // ==================================================== Cauchy ========================================================
  497. template<typename RNG>
  498. void cauchy_kernel(TensorIteratorBase& iter, double median_, double sigma_, RNG gen) {
  499. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
  500. using accscalar_t = at::acc_type<scalar_t, true>;
  501. auto median = static_cast<accscalar_t>(median_);
  502. auto sigma = static_cast<accscalar_t>(sigma_);
  503. // define lambda for cauchy transformation
  504. auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
  505. return static_cast<scalar_t>(transformation::cauchy<accscalar_t>(rand, median, sigma));
  506. };
  507. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, cauchy_func);
  508. });
  509. }
  510. template<typename RNG>
  511. struct CauchyKernel {
  512. void operator()(TensorIteratorBase& iter, double median, double sigma, c10::optional<Generator> gen) {
  513. cauchy_kernel(iter, median, sigma, check_generator<RNG>(gen));
  514. }
  515. };
  516. // ==================================================== Bernoulli =====================================================
  517. template<typename scalar_t, typename prob_t>
  518. void bernoulli_tensor_cuda_kernel(
  519. const TensorBase &ret, const at::TensorBase &p,
  520. PhiloxCudaState philox_args) {
  521. auto functor = [philox_args] __device__(
  522. int n, scalar_t& v1, scalar_t& v2, scalar_t& v3, scalar_t& v4,
  523. const prob_t& p1, const prob_t& p2, const prob_t& p3, const prob_t& p4) {
  524. auto seeds = at::cuda::philox::unpack(philox_args);
  525. curandStatePhilox4_32_10_t state;
  526. curand_init(std::get<0>(seeds),
  527. blockIdx.x * blockDim.x + threadIdx.x,
  528. std::get<1>(seeds),
  529. &state);
  530. // See Note [Register spilling in curand call for CUDA < 10]
  531. float4 rand = curand_uniform4(&state);
  532. switch (n) {
  533. case 4: {
  534. CUDA_KERNEL_ASSERT(0 <= p4 && p4 <= 1);
  535. v4 = static_cast<scalar_t>(rand.w <= p4);
  536. // fallthrough
  537. }
  538. case 3: {
  539. CUDA_KERNEL_ASSERT(0 <= p3 && p3 <= 1);
  540. v3 = static_cast<scalar_t>(rand.z <= p3);
  541. // fallthrough
  542. }
  543. case 2: {
  544. CUDA_KERNEL_ASSERT(0 <= p2 && p2 <= 1);
  545. v2 = static_cast<scalar_t>(rand.y <= p2);
  546. // fallthrough
  547. }
  548. case 1: {
  549. CUDA_KERNEL_ASSERT(0 <= p1 && p1 <= 1);
  550. v1 = static_cast<scalar_t>(rand.x <= p1);
  551. }
  552. }
  553. };
  554. // The template argument `4` below indicates that we want to operate on four
  555. // element at each time. See NOTE [ CUDA_tensor_applyN helpers ] for details.
  556. at::cuda::CUDA_tensor_apply2<scalar_t, prob_t, 4, decltype(functor),
  557. /*max_threads_per_block=*/512,
  558. /*min_blocks_per_sm==*/2>(ret, p, functor);
  559. }
  560. template<typename RNG>
  561. void bernoulli_kernel(const TensorBase &self, const TensorBase &p_, RNG gen) {
  562. PhiloxCudaState rng_engine_inputs;
  563. {
  564. // See Note [Acquire lock when using random generators]
  565. std::lock_guard<std::mutex> lock(gen->mutex_);
  566. rng_engine_inputs = gen->philox_cuda_state(10);
  567. }
  568. TORCH_CHECK(at::isFloatingType(p_.scalar_type()), "expected probabilities tensor to have floating type, got ", p_.scalar_type());
  569. // cast probabilities tensor to double for double `self` tensor, and to `float` for everything else
  570. const auto p_type = self.dtype() == at::kDouble ? at::kDouble : at::kFloat;
  571. auto p_cuda = p_.to(TensorOptions().device(self.device()).dtype(p_type));
  572. auto p = expand_inplace(self, p_cuda);
  573. AT_DISPATCH_ALL_TYPES_AND3(
  574. at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
  575. if (std::is_same<scalar_t, double>::value) {
  576. return bernoulli_tensor_cuda_kernel<double, double>(self, *p, rng_engine_inputs);
  577. } else {
  578. return bernoulli_tensor_cuda_kernel<scalar_t, float>(self, *p, rng_engine_inputs);
  579. }
  580. });
  581. }
  582. template<typename RNG>
  583. void bernoulli_kernel(TensorIteratorBase& iter, double p, RNG gen) {
  584. AT_DISPATCH_ALL_TYPES_AND3(
  585. at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
  586. using accscalar_t = at::DiscreteDistributionType<scalar_t>::type;
  587. // define lambda for bernoulli transformation
  588. auto bernoulli_func = [p] __device__ (accscalar_t rand) {
  589. return static_cast<scalar_t>(transformation::bernoulli<accscalar_t>(rand, p));
  590. };
  591. uniform_and_transform<scalar_t, accscalar_t, curand4_engine_calls>(iter, gen, bernoulli_func);
  592. });
  593. }
  594. template<typename RNG>
  595. struct BernoulliKernel {
  596. void operator()(TensorIteratorBase& iter, double p, c10::optional<Generator> gen) {
  597. bernoulli_kernel(iter, p, check_generator<RNG>(gen));
  598. }
  599. void operator()(const TensorBase &self, const TensorBase &p_, c10::optional<Generator> gen) {
  600. bernoulli_kernel(self, p_, check_generator<RNG>(gen));
  601. }
  602. };
  603. }}}}