#pragma once // This file provides two functions to help write GPU elementwise kernels: // // gpu_kernel(TensorIterator iter, ) // gpu_kernel_with_scalars(TensorIterator iter, ) // // The gpu_kernel_with_scalars generates specializations that support a // single scalar CPU argument, such as from `cuda_tensor + 5`. The CPU scalar // is lifted to a kernel parameter instead of copying to device memory. // This should be used in conjunction with TensorIterator::allow_cpu_scalars_, // which is the default for TensorIterator::binary_op. Otherwise, all inputs // and the output must be on the GPU. // // For example, to write a reciprocal kernel for GPU float Tensors: // // gpu_kernel(iter, []GPU_LAMBDA(float a) { // return 1.0f / a; // }); // // To write a multiplication kernel for GPU float Tensors where one argument // may be a CPU scalar: // // gpu_kernel_with_scalars(iter, []GPU_LAMBDA(float a, float b) { // return a * b; // }); // // See BinaryOpsKernel.cu for the complete implementation // #include #include #include #include #include #include #include #include #include #include #include #include #ifdef __NVCC__ #define ASSERT_HOST_DEVICE_LAMBDA(type) \ static_assert(__nv_is_extended_host_device_lambda_closure_type(type), \ #type " must be a __host__ __device__ lambda") #else #define ASSERT_HOST_DEVICE_LAMBDA(type) #endif namespace at { namespace native { template C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { using traits = function_traits; int remaining = N - block_work_size() * blockIdx.x; if (remaining < block_work_size()) { // if this block handles the reminder, just do a naive unrolled loop auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); auto policy = memory::policies::unroll( data, remaining, input_calc, output_calc, loader, storer); elementwise_kernel_helper(f, policy); } else { // if this block has a full `block_work_size` data to handle, use vectorized memory access elementwise_kernel_helper(f, memory::policies::vectorized(data)); } } template C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void unrolled_elementwise_kernel(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s) { int remaining = N - block_work_size() * blockIdx.x; auto policy = memory::policies::unroll(data, remaining, ic, oc, l, s); elementwise_kernel_helper(f, policy); } // this function assume trivial 1d and no dynamic casting template static inline void launch_vectorized_kernel(int64_t N, const func_t& f, array_t data) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); using traits = function_traits; int64_t grid = (N + block_work_size() - 1) / block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); int vec_size = memory::can_vectorize_up_to(data); switch (vec_size) { case 4: vectorized_elementwise_kernel<4, func_t, array_t><<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 2: vectorized_elementwise_kernel<2, func_t, array_t><<>>(N, f, data); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; case 1: { auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator<1>(); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); unrolled_elementwise_kernel<<>>(N, f, data, input_calc, output_calc, loader, storer); C10_CUDA_KERNEL_LAUNCH_CHECK(); break; } default: TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size"); } } template static inline void launch_unrolled_kernel(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); int64_t grid = (N + block_work_size() - 1) / block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); unrolled_elementwise_kernel<<>>(N, f, data, ic, oc, l, s); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template C10_LAUNCH_BOUNDS_2(nt, 4) __global__ void elementwise_kernel(int N, func_t f) { int tid = threadIdx.x; int nv = nt * vt; int idx = nv * blockIdx.x + tid; #pragma unroll for (int i = 0; i < vt; i++) { if (idx < N) { f(idx); idx += nt; } } } template static void launch_legacy_kernel(int64_t N, const func_t& f) { TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); if (N == 0) { return; } dim3 block(nt); dim3 grid((N + block.x * vt - 1) / (block.x * vt)); auto stream = at::cuda::getCurrentCUDAStream(); elementwise_kernel<<>>(N, f); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template C10_HOST_DEVICE typename traits::result_type invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i, std::index_sequence) { (void)strides; (void)i; return f(c10::load::type>(data[INDEX] + i * strides[INDEX])...); } template > C10_HOST_DEVICE typename traits::result_type invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], int i) { using Indices = std::make_index_sequence; return invoke_impl(f, data, strides, i, Indices{}); } template C10_HOST_DEVICE typename traits::result_type invoke_impl(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i, std::index_sequence) { (void)strides; (void)i; return f(c10::fetch_and_cast::type>(dtypes[I], data[I] + i * strides[I])...); } template > C10_HOST_DEVICE typename traits::result_type invoke(const func_t &f, char *const C10_RESTRICT data[], const index_t strides[], const ScalarType dtypes[], int i) { using Indices = std::make_index_sequence; return invoke_impl(f, data, strides, dtypes, i, Indices{}); } template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using arg0_t = typename traits::result_type; constexpr int ntensors = traits::arity + 1; TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); at::detail::Array data; for (int i = 0; i < ntensors; i++) { data[i] = (char*)iter.data_ptr(i); } int64_t numel = iter.numel(); bool contiguous = iter.is_contiguous(); bool dynamic_casting = needs_dynamic_casting::check(iter); if (!dynamic_casting) { if (contiguous) { launch_vectorized_kernel(numel, f, data); } else { auto offset_calc = ::make_offset_calculator(iter); constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; launch_legacy_kernel<128,unroll_factor>(numel, [=]GPU_LAMBDA(int idx) { auto offsets = offset_calc.get(idx); arg0_t* out = (arg0_t*)(data[0] + offsets[0]); *out = invoke(f, &data.data[1], &offsets.data[1], 1); }); } } else { if (contiguous) { auto loader = memory::LoadWithCast(iter); auto storer = memory::StoreWithCast<1>(iter); auto input_offset_calculator = TrivialOffsetCalculator(); auto output_offset_calculator = TrivialOffsetCalculator<1>(); launch_unrolled_kernel(numel, f, data, input_offset_calculator, output_offset_calculator, loader, storer); } else { at::detail::Array dtypes; for (int i = 0; i < ntensors; i++) { dtypes[i] = iter.dtype(i); } auto offset_calc = ::make_offset_calculator(iter); launch_legacy_kernel<128, 4>(numel, [=]GPU_LAMBDA(int idx) { auto offsets = offset_calc.get(idx); void* out = data[0] + offsets[0]; arg0_t result = invoke(f, &data.data[1], &offsets.data[1], &dtypes.data[1], 1); c10::cast_and_store(dtypes[0], out, result); }); } } } }} // namespace at::native