CUDAJitLoops.cuh 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. #pragma once
  2. #include <ATen/jit_macros.h>
  3. // Jiterator functions are guarded behind this macro
  4. #if AT_USE_JITERATOR()
  5. #include <ATen/OpMathType.h>
  6. #include <ATen/TensorIterator.h>
  7. #include <ATen/core/Array.h>
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <ATen/cuda/detail/OffsetCalculator.cuh>
  10. #include <ATen/native/cuda/jit_utils.h>
  11. #include <ATen/native/cuda/MemoryAccess.cuh>
  12. #include <ATen/native/cuda/thread_constants.h>
  13. #include <ATen/native/cuda/Loops.cuh>
  14. #include <c10/macros/Macros.h>
  15. #include <c10/core/ScalarType.h>
  16. #include <c10/util/SmallBuffer.h>
  17. #include <c10/util/C++17.h>
  18. #include <initializer_list>
  19. #include <type_traits>
  20. #include <tuple>
  21. #include <mutex>
  22. namespace at {
  23. namespace native {
  24. template <typename Tuple, std::size_t... I>
  25. constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
  26. constexpr auto size = seq.size();
  27. (void)t; // warning : unused parameter when tuple is empty.
  28. return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
  29. }
  30. // Helper function convert tuple to std::array<void*, N>
  31. // for passing the arguments to CUDA Kernel
  32. // NOTE: We capture tuple by reference,
  33. // so the pointers in returned array are only valid
  34. // till tuple is alive.
  35. template <typename ...Args>
  36. constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
  37. constexpr auto tuple_size = sizeof...(Args);
  38. return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
  39. }
  40. struct JittedVecKernelCache {
  41. // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
  42. at::cuda::jit::NvrtcFunction vec1;
  43. at::cuda::jit::NvrtcFunction vec2;
  44. at::cuda::jit::NvrtcFunction vec4;
  45. };
  46. struct JittedKernelVariantCache {
  47. JittedVecKernelCache vec;
  48. at::cuda::jit::NvrtcFunction noncontiguous;
  49. at::cuda::jit::NvrtcFunction dynamic_contiguous;
  50. at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
  51. };
  52. inline c10::SmallBuffer<void*, 64> pack_kernel_args(
  53. std::initializer_list<void*> args,
  54. c10::ArrayRef<void*> extra_args) {
  55. c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
  56. std::copy(args.begin(), args.end(), ret.data());
  57. std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
  58. return ret;
  59. }
  60. template<typename array_t,
  61. typename inp_calc_t,
  62. typename out_calc_t,
  63. typename loader_t,
  64. typename storer_t>
  65. void launch_jitted_unrolled_kernel(
  66. std::mutex &jiterator_mutex,
  67. at::cuda::jit::NvrtcFunction &fn_cache,
  68. const at::cuda::jit::KernelDescriptor &desc,
  69. int64_t N,
  70. array_t data,
  71. inp_calc_t ic,
  72. out_calc_t oc,
  73. loader_t l,
  74. storer_t s,
  75. bool contiguous,
  76. at::cuda::jit::BinaryFuncVariant scalar_pos,
  77. void* scalar_val,
  78. c10::ArrayRef<void*> extra_args) {
  79. TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
  80. //casting result to int is always safe, intermediate is int64 and won't overflow
  81. const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
  82. if (!fn_cache.function) {
  83. const std::lock_guard<std::mutex> lock{jiterator_mutex};
  84. if (!fn_cache.function) {
  85. constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
  86. !std::is_same<decltype(s), memory::StoreWithoutCast>();
  87. auto code = at::cuda::jit::generate_code(
  88. desc, contiguous, dynamic_casting, scalar_pos);
  89. fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
  90. }
  91. }
  92. auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
  93. at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
  94. {num_threads(), 1u, 1u});
  95. }
  96. template<int arity, typename array_t>
  97. void launch_jitted_vectorized_kernel(
  98. std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
  99. const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
  100. at::cuda::jit::BinaryFuncVariant scalar_pos,
  101. void *scalar_val, c10::ArrayRef<void*> extra_args) {
  102. TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
  103. // N is still int64_t for the computation, but it's always safe to cast result to int
  104. const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
  105. const int vec_size = at::cuda::jit::can_vectorize_up_to(
  106. desc, c10::ArrayRef<char*>(data.data, data.size()));
  107. // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
  108. // fn_ptr is set to the appropriate function based on the vec size and GPU used
  109. at::cuda::jit::NvrtcFunction* fn_ptr;
  110. if (vec_size == 4) {
  111. fn_ptr = &fn_cache.vec4;
  112. } else if (vec_size == 2) {
  113. fn_ptr = &fn_cache.vec2;
  114. } else if (vec_size ==1) {
  115. fn_ptr = &fn_cache.vec1;
  116. } else {
  117. TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
  118. }
  119. bool vectorized = vec_size > 1;
  120. if (!fn_ptr->function) {
  121. const std::lock_guard<std::mutex> lock{jiterator_mutex};
  122. if (!fn_ptr->function) { // cache miss!
  123. // Generates program
  124. auto code = at::cuda::jit::generate_code(
  125. desc, /*contiguous=*/true, /*dynamic_casting=*/false,
  126. scalar_pos, vectorized, vec_size);
  127. std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
  128. // Acquires the program
  129. *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
  130. }
  131. }
  132. if (vectorized) {
  133. auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
  134. at::cuda::jit::launch_jitted_pwise_function(
  135. *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
  136. } else {
  137. auto ic = TrivialOffsetCalculator<arity>();
  138. auto oc = TrivialOffsetCalculator<1>();
  139. auto l = memory::LoadWithoutCast();
  140. auto s = memory::StoreWithoutCast();
  141. auto args = pack_kernel_args(
  142. {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
  143. at::cuda::jit::launch_jitted_pwise_function(
  144. *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
  145. }
  146. }
  147. template <int arity>
  148. void jitted_gpu_kernel_generic(
  149. std::mutex &jiterator_mutex,
  150. JittedKernelVariantCache &cache,
  151. const at::cuda::jit::KernelDescriptor &desc,
  152. at::cuda::jit::BinaryFuncVariant scalar_pos,
  153. c10::ArrayRef<void*> extra_args,
  154. TensorIteratorBase& iter,
  155. const bool dynamic_casting,
  156. void *scalar_val) {
  157. TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
  158. TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
  159. TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
  160. constexpr int ntensors = arity + 1;
  161. at::detail::Array<char*, ntensors> data;
  162. for (auto i : c10::irange(ntensors)) {
  163. data[i] = (char*)iter.data_ptr(i);
  164. }
  165. int64_t numel = iter.numel();
  166. bool contiguous = iter.is_contiguous();
  167. // Decides which of 4 kernel types to launch
  168. // Variations are:
  169. // - Case 1: no dynamic casting and contiguous
  170. // - Case 2: no dynamic casting and noncontiguous
  171. // - Case 3: dynamic casting and contiguous
  172. // - Case 4: dynamic casting and noncontiguous
  173. // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl
  174. if (!dynamic_casting) {
  175. if (contiguous) {
  176. // Case 1: no dynamic casting and contiguous
  177. launch_jitted_vectorized_kernel<arity>(
  178. jiterator_mutex, cache.vec, desc,
  179. numel, data, scalar_pos, scalar_val, extra_args);
  180. return;
  181. }
  182. // Case 2: no dynamic casting and noncontiguous
  183. auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
  184. auto output_offset_calculator = make_output_offset_calculator(iter);
  185. auto loader = memory::LoadWithoutCast();
  186. auto storer = memory::StoreWithoutCast();
  187. launch_jitted_unrolled_kernel(
  188. jiterator_mutex, cache.noncontiguous, desc, numel, data,
  189. input_offset_calculator, output_offset_calculator, loader,
  190. storer, contiguous, scalar_pos, scalar_val, extra_args);
  191. return;
  192. }
  193. // Cases 3 and 4 are handled below
  194. // Both require construction of a storer (this asserts 1 output) and one or more loaders
  195. // Creates store cast to output (the zeroth tensor in TensorIterator)
  196. auto storer = memory::StoreWithCast<1>(iter);
  197. // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors)
  198. auto loader = memory::LoadWithCast<arity>(iter);
  199. if (contiguous) {
  200. // Case 3: dynamic casting and contiguous
  201. auto input_offset_calculator = TrivialOffsetCalculator<arity>();
  202. auto output_offset_calculator = TrivialOffsetCalculator<1>();
  203. launch_jitted_unrolled_kernel(
  204. jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
  205. output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
  206. return;
  207. }
  208. // Case 4: dynamic casting and noncontiguous
  209. auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
  210. auto output_offset_calculator = make_output_offset_calculator(iter);
  211. launch_jitted_unrolled_kernel(
  212. jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
  213. output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
  214. }
  215. // NOTE: static to reduce chances of name collision.
  216. template <
  217. char const* name,
  218. typename result_type,
  219. typename f_inputs_type,
  220. int arity,
  221. at::cuda::jit::BinaryFuncVariant scalar_pos =
  222. at::cuda::jit::BinaryFuncVariant::NoScalar,
  223. typename... ExtraArgs>
  224. static void jitted_gpu_kernel_impl(
  225. TensorIteratorBase& iter,
  226. const std::string &f,
  227. const bool dynamic_casting,
  228. at::opmath_type<f_inputs_type> scalar_val,
  229. std::tuple<ExtraArgs...> extra_args) {
  230. // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
  231. // the same compute capability
  232. static std::mutex jiterator_mutex;
  233. static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
  234. constexpr int nInputs = arity;
  235. constexpr int nOutputs = 1; // TODO: Support more than 1 output
  236. static const auto desc = at::cuda::jit::make_kernel_descriptor<
  237. result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
  238. auto &cache = device_caches[iter.device().index()];
  239. auto extra_args_array = tuple_to_array(extra_args);
  240. return jitted_gpu_kernel_generic<arity>(
  241. jiterator_mutex,
  242. cache,
  243. desc,
  244. scalar_pos,
  245. extra_args_array,
  246. iter,
  247. dynamic_casting,
  248. &scalar_val
  249. );
  250. }
  251. }} // at::native
  252. #endif // AT_USE_JITERATOR()