JitLoops.cuh 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #pragma once
  2. #include <ATen/jit_macros.h>
  3. #if AT_USE_JITERATOR()
  4. #include <ATen/cuda/CUDAConfig.h>
  5. #include <ATen/OpMathType.h>
  6. #include <ATen/TensorIterator.h>
  7. #include <ATen/native/TensorIteratorDynamicCasting.h>
  8. #include <ATen/native/cuda/MemoryAccess.cuh>
  9. #include <ATen/native/cuda/CUDAJitLoops.cuh>
  10. namespace at {
  11. namespace native {
  12. /* Note [Jiterator]
  13. The "jiterator" simply just-in-time compiles the same kernels that
  14. Loops.cuh (and CUDALoops.cuh) usually build. This reduces build time,
  15. build size, and initial CUDA context size.
  16. By default on non-Windows systems, it also caches compiled kernels in ~/.cache/torch/kernels.
  17. This behavior is controlled with two environment variables:
  18. - USE_PYTORCH_KERNEL_CACHE, if set to zero then this will disable all cache use
  19. - PYTORCH_KERNEL_CACHE_PATH, if set specifies the folder to use for cached kernels
  20. The jiterator currently has some limitations, however. It cannot:
  21. - handle math on complex datatypes
  22. - handle kernels with scalar parameters
  23. These improvements will likely come soon.
  24. For examples of how to use the jiterator see the i1 and gcd kernel
  25. implementations, which pass jittable strings implementing their
  26. operations instead of the typical CUDA functors.
  27. To pass a runtime argument (similar to lambda captures in non-JIT kernels),
  28. we need to pass to additional arguments to `jitted_gpu_kernel` by value.
  29. Currently only primitive C++ types used for computation are valid.
  30. The order of these extra arguments should be same as the order they appear
  31. in kernel's function signature. (look at polygamma for example)
  32. NOTE: One big restriction being that these arguments should be after the
  33. arguments provided by TensorIterator. Eg. While capturing `n`, where
  34. `scalar_t x` and `scalar_t y` are provided by TensorIterator,
  35. * foo(scalar_t x, scalar_t y, int n) works!
  36. * foo(int n, scalar_t x, scalar_y) doesn't work
  37. * foo(scalar_t x, int n, scalar_y) doesn't work
  38. */
  39. // Entrypoint for jitted GPU kernels.
  40. // Only handles elementwise unary and binary kernels with a
  41. // common dtype and a single output.
  42. // NOTE: this assumes the op's iterator has a common_dtype.
  43. // NOTE: We use std::tuple instead of parameter pack
  44. // for `extra_args` due to following
  45. // bug on older versions of clang
  46. // https://bugs.llvm.org/show_bug.cgi?id=23029
  47. template <
  48. char const* name,
  49. typename return_type,
  50. typename f_inputs_type,
  51. int arity,
  52. typename... Args>
  53. void jitted_gpu_kernel(
  54. TensorIteratorBase& iter,
  55. const std::string& f,
  56. at::cuda::jit::BinaryFuncVariant scalar_pos =
  57. at::cuda::jit::BinaryFuncVariant::NoScalar,
  58. at::opmath_type<f_inputs_type> scalar_val = 0,
  59. std::tuple<Args...> extra_args = std::make_tuple()) {
  60. // TODO: much of preamble is common to both jitted_gpu_kernel and gpu_kernel
  61. // Maybe it could be refactored?
  62. for (int arg = 0; arg < iter.ntensors(); arg++) {
  63. TORCH_INTERNAL_ASSERT(
  64. iter.device(arg).is_cuda(),
  65. "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
  66. }
  67. if (iter.numel() == 0) {
  68. return;
  69. }
  70. if (!iter.can_use_32bit_indexing()) {
  71. for (auto& sub_iter : iter.with_32bit_indexing()) {
  72. jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
  73. sub_iter, f, scalar_pos, scalar_val, extra_args);
  74. }
  75. return;
  76. }
  77. // Computes if dynamic casting is needed
  78. // Dynamic casting is needed if an input's dtype differs from the common dtype
  79. // or if the result dtype differs from the output's dtype
  80. // Note: this is intentionally divergent from calling needs_dynamic_casting,
  81. // which is more general and inspects a lambda to determine if dynamic
  82. // casting is needed.
  83. bool needs_dynamic_casting = false;
  84. // Checks output
  85. const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
  86. const auto dtype0 = iter.dtype(0);
  87. if (dtype0 != return_scalar_type) {
  88. needs_dynamic_casting = true;
  89. }
  90. // Checks input(s)
  91. const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
  92. for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
  93. const auto dtypei = iter.dtype(i);
  94. if (dtypei != inputs_scalar_type) {
  95. needs_dynamic_casting = true;
  96. break;
  97. }
  98. }
  99. if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
  100. // NOTE: With `scalar_pos=NoScalar`,`scalar_val` is not used
  101. // for computation in the generated code and hence we pass a dummy
  102. // value of `0`.
  103. jitted_gpu_kernel_impl<
  104. /*name*/ name,
  105. /*return_type=*/return_type,
  106. /*f_inputs_type=*/f_inputs_type,
  107. arity,
  108. at::cuda::jit::BinaryFuncVariant::NoScalar>(
  109. iter, f, needs_dynamic_casting, /*scalar_val=*/scalar_val, extra_args);
  110. } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
  111. jitted_gpu_kernel_impl<
  112. /*name*/ name,
  113. /*return_type=*/return_type,
  114. /*f_inputs_type=*/f_inputs_type,
  115. arity,
  116. at::cuda::jit::BinaryFuncVariant::RhsScalar>(
  117. iter,
  118. f,
  119. needs_dynamic_casting,
  120. scalar_val,
  121. extra_args);
  122. } else {
  123. jitted_gpu_kernel_impl<
  124. /*name*/ name,
  125. /*return_type=*/return_type,
  126. /*f_inputs_type=*/f_inputs_type,
  127. arity,
  128. at::cuda::jit::BinaryFuncVariant::LhsScalar>(
  129. iter,
  130. f,
  131. needs_dynamic_casting,
  132. scalar_val,
  133. extra_args);
  134. }
  135. }
  136. // TODO: support runtime state capture similar to `jitted_gpu_kernel`.
  137. template <char const *name, typename return_type, typename f_inputs_type>
  138. void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
  139. TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
  140. //currently jiterator only handles binary functions where both inputs are of the same type (f_inputs_type)
  141. using opmath_t = at::opmath_type<f_inputs_type>;
  142. if (iter.is_cpu_scalar(1)) {
  143. auto scalar_val = iter.scalar_value<opmath_t>(1);
  144. iter.remove_operand(1);
  145. // TODO: When all kernels that use gpu_kernel_with_scalars are
  146. // ported to structured, this device guard can be deleted. This
  147. // works around incorrect device guard generation for pre-structured
  148. // kernels device guards, but structured kernels do it right and
  149. // we can assume the device is already set correctly
  150. const OptionalDeviceGuard device_guard(iter.device(1));
  151. jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
  152. } else if (iter.is_cpu_scalar(2)) {
  153. auto scalar_val = iter.scalar_value<opmath_t>(2);
  154. iter.remove_operand(2);
  155. jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
  156. } else {
  157. jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
  158. }
  159. }
  160. }} // at::native
  161. #endif // AT_USE_JITERATOR()