#pragma once #include #include #include #include #include #include #include #include namespace at { namespace native { template static OffsetCalculator make_input_offset_calculator(const TensorIteratorBase& iter) { // array size can not be 0, this happens when N == 0 constexpr int array_size = std::max(N, 1); TORCH_INTERNAL_ASSERT(N == iter.ntensors() - iter.noutputs()); std::array strides; int64_t element_sizes[array_size]; for (int i = 0; i < N; i++) { strides[i] = iter.strides(i + iter.noutputs()).data(); element_sizes[i] = iter.element_size(i + iter.noutputs()); } return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); } template static OffsetCalculator make_output_offset_calculator(const TensorIteratorBase& iter) { TORCH_INTERNAL_ASSERT(num_outputs == iter.noutputs()); std::array strides; int64_t element_sizes[num_outputs]; for (int i = 0; i < num_outputs; i++) { strides[i] = iter.strides(i).data(); element_sizes[i] = iter.element_size(i); } return OffsetCalculator(iter.ndim(), iter.shape().data(), strides.data(), element_sizes); } template __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { using traits = function_traits; using return_t = typename traits::result_type; using args_t = typename traits::ArgsTuple; int idx = blockIdx.x; return_t results[thread_work_size()]; args_t args[thread_work_size()]; // load policy.load(args, idx); // compute #pragma unroll for (int i = 0; i < thread_work_size(); i++) { if (policy.check_inbounds(i)) { results[i] = c10::guts::apply(f, args[i]); } } // store policy.store(results, idx); } }} // namespace at::native // Note: // CUDA and ROCm get diverged in this PR: // https://github.com/pytorch/pytorch/pull/32383 // Because for some reason trying to enable vectorized // memory access introduce regression on ROCm. #if !defined(USE_ROCM) #include #else #include #endif namespace at { namespace native { template void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { for (int arg = 0; arg < iter.ntensors(); arg++) { TORCH_INTERNAL_ASSERT( iter.device(arg).is_cuda(), "argument ", arg, ": expected a CUDA device but found ", iter.device(arg)); } if (iter.numel() == 0) { return; } if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { gpu_kernel(sub_iter, f); } return; } gpu_kernel_impl(iter, f); } template struct AUnaryFunctor { using traits = function_traits; using opmath_arg1_t = typename traits::template arg<0>::type; __device__ return_t operator()(arg2_t b) const { return f(a, b); } // NB: scalar is stored in higher precision! AUnaryFunctor(func_t f_, opmath_arg1_t a_): f(f_), a(a_) {} private: func_t f; opmath_arg1_t a; }; template struct BUnaryFunctor { using traits = function_traits; using opmath_arg2_t = typename traits::template arg<1>::type; __device__ return_t operator()(arg1_t a) const { return f(a, b); } // NB: scalar is stored in higher precision! BUnaryFunctor(func_t f_, opmath_arg2_t b_): f(f_), b(b_) {} private: func_t f; opmath_arg2_t b; }; // Though seemingly noop, this inserts casts from arg1_t to func_t's type // (which may be higher precision), as well as casts to return_t template struct BinaryFunctor { __device__ return_t operator()(arg1_t a, arg2_t b) const { return f(a, b); } BinaryFunctor(func_t f_): f(f_) {} private: func_t f; }; // Unlike gpu_kernel_with_scalars, this allows you to pass a func_t which // accepts inputs at higher precision (typically opmath_t), but then // ensure that we load from memory at the correct precision (scalar_t) // to avoid expensive loads. For the whole sordid story see // https://dev-discuss.pytorch.org/t/cuda-loops-case-study-code-generation-vs-templates/302 template void opmath_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); using traits = function_traits; using opmath_arg1_t = typename traits::template arg<0>::type; using opmath_arg2_t = typename traits::template arg<1>::type; static_assert( traits::arity == 2, "gpu_kernel_with_scalars only supports two input arguments"); if (iter.is_cpu_scalar(1)) { AUnaryFunctor af(f, iter.scalar_value(1)); iter.remove_operand(1); // TODO: When all kernels that use gpu_kernel_with_scalars are // ported to structured, this device guard can be deleted. This // works around incorrect device guard generation for pre-structured // kernels device guards, but structured kernels do it right and // we can assume the device is already set correctly const OptionalDeviceGuard device_guard(iter.device(1)); gpu_kernel(iter, af); } else if (iter.is_cpu_scalar(2)) { BUnaryFunctor bf(f, iter.scalar_value(2)); iter.remove_operand(2); gpu_kernel(iter, bf); } else { gpu_kernel(iter, BinaryFunctor(f)); } } template void opmath_symmetric_gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { // Use symmetric property of the functor to reduce number of kernels, // requires f(a, b) == f(b, a) TORCH_INTERNAL_ASSERT(iter.ntensors() == 3); using traits = function_traits; using opmath_arg_t = typename traits::template arg<0>::type; static_assert( traits::arity == 2, "gpu_kernel_with_scalars only supports two input arguments"); static_assert(std::is_same::type>::value, "f is not symmetric"); OptionalDeviceGuard device_guard; opmath_arg_t scalar_val{}; if (iter.is_cpu_scalar(1)) { scalar_val = iter.scalar_value(1); iter.remove_operand(1); // TODO: When all kernels that use gpu_kernel_with_scalars are // ported to structured, this device guard can be deleted. This // works around incorrect device guard generation for pre-structured // kernels device guards, but structured kernels do it right and // we can assume the device is already set correctly device_guard.reset_device(iter.device(1)); } else if (iter.is_cpu_scalar(2)) { scalar_val = iter.scalar_value(2); iter.remove_operand(2); } if (iter.ninputs() == 2) { gpu_kernel(iter, BinaryFunctor(f)); } else { AUnaryFunctor unary_f(f, scalar_val); gpu_kernel(iter, unary_f); } } // Legacy variant that assumes that func_t has the correct types // that we expect to load from memory template void gpu_kernel_with_scalars(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; static_assert( traits::arity == 2, "gpu_kernel_with_scalars only supports two input arguments"); using arg1_t = typename traits::template arg<0>::type; using arg2_t = typename traits::template arg<1>::type; using return_t = typename traits::result_type; opmath_gpu_kernel_with_scalars(iter, f); } namespace { // functions for `gpu_kernel_multiple_outputs`. // check the return type is `thrust::tuple`, not `std::tuple`. template struct is_tuple: std::false_type {}; template struct is_tuple>: std::true_type {}; template C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void unrolled_elementwise_kernel_for_multi_outputs(int N, func_t f, array_t data, inp_calc_t ic, out_calc_t oc) { int remaining = N - block_work_size() * blockIdx.x; elementwise_kernel_helper(f, memory::policies::multi_outputs_unroll(data, remaining, ic, oc)); } template static inline void launch_unrolled_kernel_for_multi_outputs(int64_t N, const func_t& f, array_t data, inp_calc_t ic, out_calc_t oc) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); int64_t grid = (N + block_work_size() - 1) / block_work_size(); auto stream = at::cuda::getCurrentCUDAStream(); unrolled_elementwise_kernel_for_multi_outputs<<>>(N, f, data, ic, oc); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template void gpu_kernel_multiple_outputs_impl(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; using output_t = typename traits::result_type; static_assert(is_tuple::value, "f's return type must be `thrust::tuple`"); constexpr int num_outputs = thrust::tuple_size::value; constexpr int num_inputs = traits::arity; constexpr int ntensors = num_outputs + num_inputs; TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ntensors() == ntensors); at::detail::Array data; for (int i = 0; i < ntensors; i++) { data[i] = (char*)iter.data_ptr(i); } int64_t numel = iter.numel(); if (iter.is_contiguous()) { auto input_calc = TrivialOffsetCalculator(); auto output_calc = TrivialOffsetCalculator(); launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); } else { auto input_calc = make_input_offset_calculator(iter); auto output_calc = make_output_offset_calculator(iter); launch_unrolled_kernel_for_multi_outputs(numel, f, data, input_calc, output_calc); } } } // namespace template void gpu_kernel_multiple_outputs(TensorIteratorBase& iter, const func_t& f) { ASSERT_HOST_DEVICE_LAMBDA(func_t); for (int arg = 0; arg < iter.ntensors(); arg++) { TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda()); } if (iter.numel() == 0) { return; } if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { gpu_kernel_multiple_outputs(sub_iter, f); } return; } gpu_kernel_multiple_outputs_impl(iter, f); } }} //namespace at::native