Loops.cuh 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #pragma once
  2. #include <ATen/detail/FunctionTraits.h>
  3. #include <ATen/native/TensorIterator.h>
  4. #include <ATen/native/TensorIteratorDynamicCasting.h>
  5. #include <ATen/cuda/detail/OffsetCalculator.cuh>
  6. #include <ATen/OpMathType.h>
  7. #include <ATen/native/cuda/thread_constants.h>
  8. #include <thrust/tuple.h>
  9. #include <ATen/native/cuda/MemoryAccess.cuh>
  10. namespace at { namespace native {
  11. template<int N>
  12. static OffsetCalculator<N> make_input_offset_calculator(const TensorIteratorBase& iter) {
  13. // array size can not be 0, this happens when N == 0
  14. constexpr int array_size = std::max<int>(N, 1);
  15. TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs());
  16. std::array<const int64_t*, array_size> strides;
  17. int64_t element_sizes[array_size];
  18. for (int i = 0; i < N; i++) {
  19. strides[i] = iter.strides(i + iter.noutputs()).data();
  20. element_sizes[i] = iter.element_size(i + iter.noutputs());
  21. }
  22. return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
  23. }
  24. template <int num_outputs = 1>
  25. static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorIteratorBase& iter) {
  26. TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs());
  27. std::array<const int64_t*, num_outputs> strides;
  28. int64_t element_sizes[num_outputs];
  29. for (int i = 0; i < num_outputs; i++) {
  30. strides[i] = iter.strides(i).data();
  31. element_sizes[i] = iter.element_size(i);
  32. }
  33. return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
  34. }
  35. template<typename func_t, typename policy_t>
  36. __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
  37. using traits = function_traits<func_t>;
  38. using return_t = typename traits::result_type;
  39. using args_t = typename traits::ArgsTuple;
  40. int idx = blockIdx.x;
  41. return_t results[thread_work_size()];
  42. args_t args[thread_work_size()];
  43. // load
  44. policy.load(args, idx);
  45. // compute
  46. #pragma unroll
  47. for (int i = 0; i < thread_work_size(); i++) {
  48. if (policy.check_inbounds(i)) {
  49. results[i] = c10::guts::apply(f, args[i]);
  50. }
  51. }
  52. // store
  53. policy.store(results, idx);
  54. }
  55. }} // namespace at::native
  56. // Note:
  57. // CUDA and ROCm get diverged in this PR:
  58. // https://github.com/pytorch/pytorch/pull/32383
  59. // Because for some reason trying to enable vectorized
  60. // memory access introduce regression on ROCm.
  61. #if !defined(USE_ROCM)
  62. #include <ATen/native/cuda/CUDALoops.cuh>
  63. #else
  64. #include <ATen/native/cuda/ROCmLoops.cuh>
  65. #endif
  66. namespace at { namespace native {
  67. template <typename func_t>
  68. void gpu_kernel(TensorIteratorBase& iter, const func_t& f) {
  69. for (int arg = 0; arg < iter.ntensors(); arg++) {
  70. TORCH_INTERNAL_ASSERT(
  71. iter.device(arg).is_cuda(),
  72. "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
  73. }
  74. if (iter.numel() == 0) {
  75. return;
  76. }
  77. if (!iter.can_use_32bit_indexing()) {
  78. for (auto& sub_iter : iter.with_32bit_indexing()) {
  79. gpu_kernel(sub_iter, f);
  80. }
  81. return;
  82. }
  83. gpu_kernel_impl(iter, f);
  84. }
  85. template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
  86. struct AUnaryFunctor {
  87. using traits = function_traits<func_t>;
  88. using opmath_arg1_t = typename traits::template arg<0>::type;
  89. __device__ return_t operator()(arg2_t b) const {
  90. return f(a, b);
  91. }
  92. // NB: scalar is stored in higher precision!
  93. AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {}
  94. private:
  95. func_t f;
  96. opmath_arg1_t a;
  97. };
  98. template<typename arg1_t, typename arg2_t, typename return_t, typename func_t>
  99. struct BUnaryFunctor {
  100. using traits = function_traits<func_t>;
  101. using opmath_arg2_t = typename traits::template arg<1>::type;
  102. __device__ return_t operator()(arg1_t a) const {
  103. return f(a, b);
  104. }
  105. // NB: scalar is stored in higher precision!
  106. BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {}
  107. private:
  108. func_t f;
  109. opmath_arg2_t b;
  110. };
  111. // Though seemingly noop, this inserts casts from arg1_t to func_t's type
  112. // (which may be higher precision), as well as casts to return_t
  113. template <typename arg1_t, typename arg2_t, typename return_t, typename func_t>
  114. struct BinaryFunctor {
  115. __device__ return_t operator()(arg1_t a, arg2_t b) const {
  116. return f(a, b);
  117. }
  118. BinaryFunctor(func_t f_): f(f_) {}
  119. private:
  120. func_t f;
  121. };
  122. // Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which
  123. // accepts inputs at higher precision (typically opmath_t), but then
  124. // ensure that we load from memory at the correct precision (scalar_t)
  125. // to avoid expensive loads. For the whole sordid story see
  126. // https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302
  127. template <typename arg1_t, typename arg2_t = arg1_t, typename return_t = arg1_t, typename func_t>
  128. void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
  129. TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
  130. using traits = function_traits<func_t>;
  131. using opmath_arg1_t = typename traits::template arg<0>::type;
  132. using opmath_arg2_t = typename traits::template arg<1>::type;
  133. static_assert(
  134. traits::arity == 2,
  135. "gpu_kernel_with_scalars only supports two input arguments");
  136. if (iter.is_cpu_scalar(1)) {
  137. AUnaryFunctor<arg1_t, arg2_t, return_t, func_t> af(f, iter.scalar_value<opmath_arg1_t>(1));
  138. iter.remove_operand(1);
  139. // TODO: When all kernels that use gpu_kernel_with_scalars are
  140. // ported to structured, this device guard can be deleted. This
  141. // works around incorrect device guard generation for pre-structured
  142. // kernels device guards, but structured kernels do it right and
  143. // we can assume the device is already set correctly
  144. const OptionalDeviceGuard device_guard(iter.device(1));
  145. gpu_kernel(iter, af);
  146. } else if (iter.is_cpu_scalar(2)) {
  147. BUnaryFunctor<arg1_t, arg2_t, return_t, func_t> bf(f, iter.scalar_value<opmath_arg2_t>(2));
  148. iter.remove_operand(2);
  149. gpu_kernel(iter, bf);
  150. } else {
  151. gpu_kernel(iter, BinaryFunctor<arg1_t, arg2_t, return_t, func_t>(f));
  152. }
  153. }
  154. template <typename scalar_t, typename return_t = scalar_t, typename func_t>
  155. void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
  156. // Use symmetric property of the functor to reduce number of kernels,
  157. // requires f(a, b) == f(b, a)
  158. TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
  159. using traits = function_traits<func_t>;
  160. using opmath_arg_t = typename traits::template arg<0>::type;
  161. static_assert(
  162. traits::arity == 2,
  163. "gpu_kernel_with_scalars only supports two input arguments");
  164. static_assert(std::is_same<opmath_arg_t, typename traits::template arg<1>::type>::value,
  165. "f is not symmetric");
  166. OptionalDeviceGuard device_guard;
  167. opmath_arg_t scalar_val{};
  168. if (iter.is_cpu_scalar(1)) {
  169. scalar_val = iter.scalar_value<opmath_arg_t>(1);
  170. iter.remove_operand(1);
  171. // TODO: When all kernels that use gpu_kernel_with_scalars are
  172. // ported to structured, this device guard can be deleted. This
  173. // works around incorrect device guard generation for pre-structured
  174. // kernels device guards, but structured kernels do it right and
  175. // we can assume the device is already set correctly
  176. device_guard.reset_device(iter.device(1));
  177. } else if (iter.is_cpu_scalar(2)) {
  178. scalar_val = iter.scalar_value<opmath_arg_t>(2);
  179. iter.remove_operand(2);
  180. }
  181. if (iter.ninputs() == 2) {
  182. gpu_kernel(iter, BinaryFunctor<scalar_t, scalar_t, return_t, func_t>(f));
  183. } else {
  184. AUnaryFunctor<scalar_t, scalar_t, return_t, func_t> unary_f(f, scalar_val);
  185. gpu_kernel(iter, unary_f);
  186. }
  187. }
  188. // Legacy variant that assumes that func_t has the correct types
  189. // that we expect to load from memory
  190. template <typename func_t>
  191. void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) {
  192. using traits = function_traits<func_t>;
  193. static_assert(
  194. traits::arity == 2,
  195. "gpu_kernel_with_scalars only supports two input arguments");
  196. using arg1_t = typename traits::template arg<0>::type;
  197. using arg2_t = typename traits::template arg<1>::type;
  198. using return_t = typename traits::result_type;
  199. opmath_gpu_kernel_with_scalars<arg1_t, arg2_t, return_t, func_t>(iter, f);
  200. }
  201. namespace { // functions for `gpu_kernel_multiple_outputs`.
  202. // check the return type is `thrust::tuple`, not `std::tuple`.
  203. template <typename T> struct is_tuple: std::false_type {};
  204. template <typename ...T> struct is_tuple<thrust::tuple<T...>>: std::true_type {};
  205. template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
  206. C10_LAUNCH_BOUNDS_1(num_threads())
  207. __global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) {
  208. int remaining = N - block_work_size() * blockIdx.x;
  209. elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll<array_t, inp_calc_t, out_calc_t, num_outputs>(data, remaining, ic, oc));
  210. }
  211. template <int num_outputs, typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t>
  212. static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) {
  213. TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
  214. int64_t grid = (N + block_work_size() - 1) / block_work_size();
  215. auto stream = at::cuda::getCurrentCUDAStream();
  216. unrolled_elementwise_kernel_for_multi_outputs<num_outputs, func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc);
  217. C10_CUDA_KERNEL_LAUNCH_CHECK();
  218. }
  219. template <typename func_t>
  220. void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) {
  221. using traits = function_traits<func_t>;
  222. using output_t = typename traits::result_type;
  223. static_assert(is_tuple<output_t>::value, "f's return type must be `thrust::tuple`");
  224. constexpr int num_outputs = thrust::tuple_size<output_t>::value;
  225. constexpr int num_inputs = traits::arity;
  226. constexpr int ntensors = num_outputs + num_inputs;
  227. TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
  228. TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors);
  229. at::detail::Array<char*, ntensors> data;
  230. for (int i = 0; i < ntensors; i++) {
  231. data[i] = (char*)iter.data_ptr(i);
  232. }
  233. int64_t numel = iter.numel();
  234. if (iter.is_contiguous()) {
  235. auto input_calc = TrivialOffsetCalculator<num_inputs>();
  236. auto output_calc = TrivialOffsetCalculator<num_outputs>();
  237. launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
  238. } else {
  239. auto input_calc = make_input_offset_calculator<num_inputs>(iter);
  240. auto output_calc = make_output_offset_calculator<num_outputs>(iter);
  241. launch_unrolled_kernel_for_multi_outputs<num_outputs>(numel, f, data, input_calc, output_calc);
  242. }
  243. }
  244. } // namespace
  245. template <typename func_t>
  246. void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) {
  247. ASSERT_HOST_DEVICE_LAMBDA(func_t);
  248. for (int arg = 0; arg < iter.ntensors(); arg++) {
  249. TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
  250. }
  251. if (iter.numel() == 0) {
  252. return;
  253. }
  254. if (!iter.can_use_32bit_indexing()) {
  255. for (auto& sub_iter : iter.with_32bit_indexing()) {
  256. gpu_kernel_multiple_outputs(sub_iter, f);
  257. }
  258. return;
  259. }
  260. gpu_kernel_multiple_outputs_impl(iter, f);
  261. }
  262. }} //namespace at::native