CUDALoops.cuh 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #pragma once
  2. // This file provides two functions to help write GPU elementwise kernels:
  3. //
  4. // gpu_kernel(TensorIterator iter, <lambda>)
  5. // gpu_kernel_with_scalars(TensorIterator iter, <lambda>)
  6. //
  7. // The gpu_kernel_with_scalars generates specializations that support a
  8. // single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar
  9. // is lifted to a kernel parameter instead of copying to device memory.
  10. // This should be used in conjunction with TensorIterator::allow_cpu_scalars_,
  11. // which is the default for TensorIterator::binary_op. Otherwise, all inputs
  12. // and the output must be on the GPU.
  13. //
  14. // For example, to write a reciprocal kernel for GPU float Tensors:
  15. //
  16. // gpu_kernel(iter, []GPU_LAMBDA(float a) {
  17. // return 1.0f / a;
  18. // });
  19. //
  20. // To write a multiplication kernel for GPU float Tensors where one argument
  21. // may be a CPU scalar:
  22. //
  23. // gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) {
  24. // return a * b;
  25. // });
  26. //
  27. // See BinaryOpsKernel.cu for the complete implementation
  28. //
  29. #include <type_traits>
  30. #include <tuple>
  31. #include <iostream>
  32. #include <ATen/cuda/CUDAContext.h>
  33. #include <ATen/core/Array.h>
  34. #include <ATen/detail/FunctionTraits.h>
  35. #include <ATen/native/TensorIterator.h>
  36. #include <c10/macros/Macros.h>
  37. #include <c10/core/DynamicCast.h>
  38. #include <c10/core/ScalarType.h>
  39. #include <c10/util/TypeCast.h>
  40. #include <c10/util/C++17.h>
  41. #ifdef __NVCC__
  42. #define ASSERT_HOST_DEVICE_LAMBDA(type) \
  43. static_assert(__nv_is_extended_host_device_lambda_closure_type(type), \
  44. #type " must be a __host__ __device__ lambda")
  45. #else
  46. #define ASSERT_HOST_DEVICE_LAMBDA(type)
  47. #endif
  48. namespace at { namespace native {
  49. template<int vec_size, typename func_t, typename array_t>
  50. C10_LAUNCH_BOUNDS_1(num_threads())
  51. __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
  52. using traits = function_traits<func_t>;
  53. int remaining = N - block_work_size() * blockIdx.x;
  54. if (remaining < block_work_size()) { // if this block handles the reminder, just do a naive unrolled loop
  55. auto input_calc = TrivialOffsetCalculator<traits::arity>();
  56. auto output_calc = TrivialOffsetCalculator<1>();
  57. auto loader = memory::LoadWithoutCast();
  58. auto storer = memory::StoreWithoutCast();
  59. auto policy = memory::policies::unroll<array_t, decltype(input_calc), decltype(output_calc),
  60. memory::LoadWithoutCast, memory::StoreWithoutCast>(
  61. data, remaining, input_calc, output_calc, loader, storer);
  62. elementwise_kernel_helper(f, policy);
  63. } else { // if this block has a full `block_work_size` data to handle, use vectorized memory access
  64. elementwise_kernel_helper(f, memory::policies::vectorized<vec_size, array_t>(data));
  65. }
  66. }
  67. template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
  68. C10_LAUNCH_BOUNDS_1(num_threads())
  69. __global__ void unrolled_elementwise_kernel(int N, func_t f, array_t data,
  70. inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
  71. {
  72. int remaining = N - block_work_size() * blockIdx.x;
  73. auto policy = memory::policies::unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(data, remaining, ic, oc, l, s);
  74. elementwise_kernel_helper(f, policy);
  75. }
  76. // this function assume trivial 1d and no dynamic casting
  77. template<typename func_t, typename array_t>
  78. static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t data) {
  79. TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
  80. using traits = function_traits<func_t>;
  81. int64_t grid = (N + block_work_size() - 1) / block_work_size();
  82. auto stream = at::cuda::getCurrentCUDAStream();
  83. int vec_size = memory::can_vectorize_up_to<func_t>(data);
  84. switch (vec_size) {
  85. case 4:
  86. vectorized_elementwise_kernel<4, func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data);
  87. C10_CUDA_KERNEL_LAUNCH_CHECK();
  88. break;
  89. case 2:
  90. vectorized_elementwise_kernel<2, func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data);
  91. C10_CUDA_KERNEL_LAUNCH_CHECK();
  92. break;
  93. case 1: {
  94. auto input_calc = TrivialOffsetCalculator<traits::arity>();
  95. auto output_calc = TrivialOffsetCalculator<1>();
  96. auto loader = memory::LoadWithoutCast();
  97. auto storer = memory::StoreWithoutCast();
  98. unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, input_calc, output_calc, loader, storer);
  99. C10_CUDA_KERNEL_LAUNCH_CHECK();
  100. break;
  101. }
  102. default:
  103. TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
  104. }
  105. }
  106. template<typename func_t, typename array_t, typename inp_calc_t, typename out_calc_t, typename loader_t, typename storer_t>
  107. static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data,
  108. inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s)
  109. {
  110. TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
  111. int64_t grid = (N + block_work_size() - 1) / block_work_size();
  112. auto stream = at::cuda::getCurrentCUDAStream();
  113. unrolled_elementwise_kernel<func_t, array_t><<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
  114. C10_CUDA_KERNEL_LAUNCH_CHECK();
  115. }
  116. template<int nt, int vt, typename func_t>
  117. C10_LAUNCH_BOUNDS_2(nt, 4)
  118. __global__ void elementwise_kernel(int N, func_t f) {
  119. int tid = threadIdx.x;
  120. int nv = nt * vt;
  121. int idx = nv * blockIdx.x + tid;
  122. #pragma unroll
  123. for (int i = 0; i < vt; i++) {
  124. if (idx < N) {
  125. f(idx);
  126. idx += nt;
  127. }
  128. }
  129. }
  130. template<int nt, int vt, typename func_t>
  131. static void launch_legacy_kernel(int64_t N, const func_t& f) {
  132. TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
  133. if (N == 0) {
  134. return;
  135. }
  136. dim3 block(nt);
  137. dim3 grid((N + block.x * vt - 1) / (block.x * vt));
  138. auto stream = at::cuda::getCurrentCUDAStream();
  139. elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
  140. C10_CUDA_KERNEL_LAUNCH_CHECK();
  141. }
  142. template <typename traits, typename func_t, typename index_t, size_t... INDEX>
  143. C10_HOST_DEVICE typename traits::result_type
  144. invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i,
  145. std::index_sequence<INDEX...>) {
  146. (void)strides;
  147. (void)i;
  148. return f(c10::load<typename traits::template arg<INDEX>::type>(data[INDEX] + i * strides[INDEX])...);
  149. }
  150. template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
  151. C10_HOST_DEVICE typename traits::result_type
  152. invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) {
  153. using Indices = std::make_index_sequence<traits::arity>;
  154. return invoke_impl<traits>(f, data, strides, i, Indices{});
  155. }
  156. template <typename traits, typename func_t, typename index_t, size_t... I>
  157. C10_HOST_DEVICE typename traits::result_type
  158. invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i,
  159. std::index_sequence<I...>) {
  160. (void)strides;
  161. (void)i;
  162. return f(c10::fetch_and_cast<typename traits::template arg<I>::type>(dtypes[I], data[I] + i * strides[I])...);
  163. }
  164. template <typename func_t, typename index_t, typename traits = function_traits<func_t>>
  165. C10_HOST_DEVICE typename traits::result_type
  166. invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) {
  167. using Indices = std::make_index_sequence<traits::arity>;
  168. return invoke_impl<traits>(f, data, strides, dtypes, i, Indices{});
  169. }
  170. template <typename func_t>
  171. void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
  172. using traits = function_traits<func_t>;
  173. using arg0_t = typename traits::result_type;
  174. constexpr int ntensors = traits::arity + 1;
  175. TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
  176. TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
  177. TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
  178. at::detail::Array<char*, ntensors> data;
  179. for (int i = 0; i < ntensors; i++) {
  180. data[i] = (char*)iter.data_ptr(i);
  181. }
  182. int64_t numel = iter.numel();
  183. bool contiguous = iter.is_contiguous();
  184. bool dynamic_casting = needs_dynamic_casting<func_t>::check(iter);
  185. if (!dynamic_casting) {
  186. if (contiguous) {
  187. launch_vectorized_kernel(numel, f, data);
  188. } else {
  189. auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
  190. constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
  191. launch_legacy_kernel<128,unroll_factor>(numel, [=]GPU_LAMBDA(int idx) {
  192. auto offsets = offset_calc.get(idx);
  193. arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
  194. *out = invoke(f, &data.data[1], &offsets.data[1], 1);
  195. });
  196. }
  197. } else {
  198. if (contiguous) {
  199. auto loader = memory::LoadWithCast<traits::arity>(iter);
  200. auto storer = memory::StoreWithCast<1>(iter);
  201. auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
  202. auto output_offset_calculator = TrivialOffsetCalculator<1>();
  203. launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer);
  204. } else {
  205. at::detail::Array<ScalarType, ntensors> dtypes;
  206. for (int i = 0; i < ntensors; i++) {
  207. dtypes[i] = iter.dtype(i);
  208. }
  209. auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
  210. launch_legacy_kernel<128, 4>(numel, [=]GPU_LAMBDA(int idx) {
  211. auto offsets = offset_calc.get(idx);
  212. void* out = data[0] + offsets[0];
  213. arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1);
  214. c10::cast_and_store<arg0_t>(dtypes[0], out, result);
  215. });
  216. }
  217. }
  218. }
  219. }} // namespace at::native