#pragma once #include // Jiterator functions are guarded behind this macro #if AT_USE_JITERATOR() #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace at { namespace native { template constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence seq) { constexpr auto size = seq.size(); (void)t; // warning : unused parameter when tuple is empty. return std::array{static_cast(&std::get(t))...}; } // Helper function convert tuple to std::array // for passing the arguments to CUDA Kernel // NOTE: We capture tuple by reference, // so the pointers in returned array are only valid // till tuple is alive. template constexpr auto tuple_to_array(std::tuple& extra_args) { constexpr auto tuple_size = sizeof...(Args); return tuple_to_array_helper(extra_args, std::make_index_sequence{}); } struct JittedVecKernelCache { // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) at::cuda::jit::NvrtcFunction vec1; at::cuda::jit::NvrtcFunction vec2; at::cuda::jit::NvrtcFunction vec4; }; struct JittedKernelVariantCache { JittedVecKernelCache vec; at::cuda::jit::NvrtcFunction noncontiguous; at::cuda::jit::NvrtcFunction dynamic_contiguous; at::cuda::jit::NvrtcFunction dynamic_noncontiguous; }; inline c10::SmallBuffer pack_kernel_args( std::initializer_list args, c10::ArrayRef extra_args) { c10::SmallBuffer ret(args.size() + extra_args.size()); std::copy(args.begin(), args.end(), ret.data()); std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size()); return ret; } template void launch_jitted_unrolled_kernel( std::mutex &jiterator_mutex, at::cuda::jit::NvrtcFunction &fn_cache, const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data, inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous, at::cuda::jit::BinaryFuncVariant scalar_pos, void* scalar_val, c10::ArrayRef extra_args) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); //casting result to int is always safe, intermediate is int64 and won't overflow const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); if (!fn_cache.function) { const std::lock_guard lock{jiterator_mutex}; if (!fn_cache.function) { constexpr bool dynamic_casting = !std::is_same() || !std::is_same(); auto code = at::cuda::jit::generate_code( desc, contiguous, dynamic_casting, scalar_pos); fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name); } } auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); } template void launch_jitted_vectorized_kernel( std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache, const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data, at::cuda::jit::BinaryFuncVariant scalar_pos, void *scalar_val, c10::ArrayRef extra_args) { TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits::max()); // N is still int64_t for the computation, but it's always safe to cast result to int const uint32_t grid = (N + block_work_size() - 1) / block_work_size(); const int vec_size = at::cuda::jit::can_vectorize_up_to( desc, c10::ArrayRef(data.data, data.size())); // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements) // fn_ptr is set to the appropriate function based on the vec size and GPU used at::cuda::jit::NvrtcFunction* fn_ptr; if (vec_size == 4) { fn_ptr = &fn_cache.vec4; } else if (vec_size == 2) { fn_ptr = &fn_cache.vec2; } else if (vec_size ==1) { fn_ptr = &fn_cache.vec1; } else { TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel"); } bool vectorized = vec_size > 1; if (!fn_ptr->function) { const std::lock_guard lock{jiterator_mutex}; if (!fn_ptr->function) { // cache miss! // Generates program auto code = at::cuda::jit::generate_code( desc, /*contiguous=*/true, /*dynamic_casting=*/false, scalar_pos, vectorized, vec_size); std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name; // Acquires the program *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name); } } if (vectorized) { auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args); at::cuda::jit::launch_jitted_pwise_function( *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); } else { auto ic = TrivialOffsetCalculator(); auto oc = TrivialOffsetCalculator<1>(); auto l = memory::LoadWithoutCast(); auto s = memory::StoreWithoutCast(); auto args = pack_kernel_args( {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args); at::cuda::jit::launch_jitted_pwise_function( *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u}); } } template void jitted_gpu_kernel_generic( std::mutex &jiterator_mutex, JittedKernelVariantCache &cache, const at::cuda::jit::KernelDescriptor &desc, at::cuda::jit::BinaryFuncVariant scalar_pos, c10::ArrayRef extra_args, TensorIteratorBase& iter, const bool dynamic_casting, void *scalar_val) { TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ninputs() == arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); constexpr int ntensors = arity + 1; at::detail::Array data; for (auto i : c10::irange(ntensors)) { data[i] = (char*)iter.data_ptr(i); } int64_t numel = iter.numel(); bool contiguous = iter.is_contiguous(); // Decides which of 4 kernel types to launch // Variations are: // - Case 1: no dynamic casting and contiguous // - Case 2: no dynamic casting and noncontiguous // - Case 3: dynamic casting and contiguous // - Case 4: dynamic casting and noncontiguous // These cases align with the non-jitted CUDALoops.cuh cases in gpu_kernel_impl if (!dynamic_casting) { if (contiguous) { // Case 1: no dynamic casting and contiguous launch_jitted_vectorized_kernel( jiterator_mutex, cache.vec, desc, numel, data, scalar_pos, scalar_val, extra_args); return; } // Case 2: no dynamic casting and noncontiguous auto input_offset_calculator = make_input_offset_calculator(iter); auto output_offset_calculator = make_output_offset_calculator(iter); auto loader = memory::LoadWithoutCast(); auto storer = memory::StoreWithoutCast(); launch_jitted_unrolled_kernel( jiterator_mutex, cache.noncontiguous, desc, numel, data, input_offset_calculator, output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); return; } // Cases 3 and 4 are handled below // Both require construction of a storer (this asserts 1 output) and one or more loaders // Creates store cast to output (the zeroth tensor in TensorIterator) auto storer = memory::StoreWithCast<1>(iter); // Creates load casts from inputs (note offset indexing into the iterators 1...n tensors) auto loader = memory::LoadWithCast(iter); if (contiguous) { // Case 3: dynamic casting and contiguous auto input_offset_calculator = TrivialOffsetCalculator(); auto output_offset_calculator = TrivialOffsetCalculator<1>(); launch_jitted_unrolled_kernel( jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator, output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); return; } // Case 4: dynamic casting and noncontiguous auto input_offset_calculator = make_input_offset_calculator(iter); auto output_offset_calculator = make_output_offset_calculator(iter); launch_jitted_unrolled_kernel( jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator, output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args); } // NOTE: static to reduce chances of name collision. template < char const* name, typename result_type, typename f_inputs_type, int arity, at::cuda::jit::BinaryFuncVariant scalar_pos = at::cuda::jit::BinaryFuncVariant::NoScalar, typename... ExtraArgs> static void jitted_gpu_kernel_impl( TensorIteratorBase& iter, const std::string &f, const bool dynamic_casting, at::opmath_type scalar_val, std::tuple extra_args) { // TODO: Memory use can probably be optimized by re-using kernels across GPUs with // the same compute capability static std::mutex jiterator_mutex; static std::vector device_caches(c10::cuda::device_count()); constexpr int nInputs = arity; constexpr int nOutputs = 1; // TODO: Support more than 1 output static const auto desc = at::cuda::jit::make_kernel_descriptor< result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs); auto &cache = device_caches[iter.device().index()]; auto extra_args_array = tuple_to_array(extra_args); return jitted_gpu_kernel_generic( jiterator_mutex, cache, desc, scalar_pos, extra_args_array, iter, dynamic_casting, &scalar_val ); } }} // at::native #endif // AT_USE_JITERATOR()