123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292 |
- #pragma once
- #include <ATen/jit_macros.h>
- #if AT_USE_JITERATOR()
- #include <ATen/OpMathType.h>
- #include <ATen/TensorIterator.h>
- #include <ATen/core/Array.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <ATen/cuda/detail/OffsetCalculator.cuh>
- #include <ATen/native/cuda/jit_utils.h>
- #include <ATen/native/cuda/MemoryAccess.cuh>
- #include <ATen/native/cuda/thread_constants.h>
- #include <ATen/native/cuda/Loops.cuh>
- #include <c10/macros/Macros.h>
- #include <c10/core/ScalarType.h>
- #include <c10/util/SmallBuffer.h>
- #include <c10/util/C++17.h>
- #include <initializer_list>
- #include <type_traits>
- #include <tuple>
- #include <mutex>
- namespace at {
- namespace native {
- template <typename Tuple, std::size_t... I>
- constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
- constexpr auto size = seq.size();
- (void)t;
- return std::array<void*, size>{static_cast<void*>(&std::get<I>(t))...};
- }
- template <typename ...Args>
- constexpr auto tuple_to_array(std::tuple<Args...>& extra_args) {
- constexpr auto tuple_size = sizeof...(Args);
- return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
- }
- struct JittedVecKernelCache {
-
- at::cuda::jit::NvrtcFunction vec1;
- at::cuda::jit::NvrtcFunction vec2;
- at::cuda::jit::NvrtcFunction vec4;
- };
- struct JittedKernelVariantCache {
- JittedVecKernelCache vec;
- at::cuda::jit::NvrtcFunction noncontiguous;
- at::cuda::jit::NvrtcFunction dynamic_contiguous;
- at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
- };
- inline c10::SmallBuffer<void*, 64> pack_kernel_args(
- std::initializer_list<void*> args,
- c10::ArrayRef<void*> extra_args) {
- c10::SmallBuffer<void*, 64> ret(args.size() + extra_args.size());
- std::copy(args.begin(), args.end(), ret.data());
- std::copy(extra_args.begin(), extra_args.end(), ret.data() + args.size());
- return ret;
- }
- template<typename array_t,
- typename inp_calc_t,
- typename out_calc_t,
- typename loader_t,
- typename storer_t>
- void launch_jitted_unrolled_kernel(
- std::mutex &jiterator_mutex,
- at::cuda::jit::NvrtcFunction &fn_cache,
- const at::cuda::jit::KernelDescriptor &desc,
- int64_t N,
- array_t data,
- inp_calc_t ic,
- out_calc_t oc,
- loader_t l,
- storer_t s,
- bool contiguous,
- at::cuda::jit::BinaryFuncVariant scalar_pos,
- void* scalar_val,
- c10::ArrayRef<void*> extra_args) {
- TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
-
- const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
- if (!fn_cache.function) {
- const std::lock_guard<std::mutex> lock{jiterator_mutex};
- if (!fn_cache.function) {
- constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
- !std::is_same<decltype(s), memory::StoreWithoutCast>();
- auto code = at::cuda::jit::generate_code(
- desc, contiguous, dynamic_casting, scalar_pos);
- fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
- }
- }
- auto args = pack_kernel_args({&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
- at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
- {num_threads(), 1u, 1u});
- }
- template<int arity, typename array_t>
- void launch_jitted_vectorized_kernel(
- std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
- const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
- at::cuda::jit::BinaryFuncVariant scalar_pos,
- void *scalar_val, c10::ArrayRef<void*> extra_args) {
- TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
-
- const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
- const int vec_size = at::cuda::jit::can_vectorize_up_to(
- desc, c10::ArrayRef<char*>(data.data, data.size()));
-
-
- at::cuda::jit::NvrtcFunction* fn_ptr;
- if (vec_size == 4) {
- fn_ptr = &fn_cache.vec4;
- } else if (vec_size == 2) {
- fn_ptr = &fn_cache.vec2;
- } else if (vec_size ==1) {
- fn_ptr = &fn_cache.vec1;
- } else {
- TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
- }
- bool vectorized = vec_size > 1;
- if (!fn_ptr->function) {
- const std::lock_guard<std::mutex> lock{jiterator_mutex};
- if (!fn_ptr->function) {
-
- auto code = at::cuda::jit::generate_code(
- desc, true, false,
- scalar_pos, vectorized, vec_size);
- std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
-
- *fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
- }
- }
- if (vectorized) {
- auto args = pack_kernel_args({&N, &data, scalar_val}, extra_args);
- at::cuda::jit::launch_jitted_pwise_function(
- *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
- } else {
- auto ic = TrivialOffsetCalculator<arity>();
- auto oc = TrivialOffsetCalculator<1>();
- auto l = memory::LoadWithoutCast();
- auto s = memory::StoreWithoutCast();
- auto args = pack_kernel_args(
- {&N, &data, &ic, &oc, &l, &s, scalar_val}, extra_args);
- at::cuda::jit::launch_jitted_pwise_function(
- *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
- }
- }
- template <int arity>
- void jitted_gpu_kernel_generic(
- std::mutex &jiterator_mutex,
- JittedKernelVariantCache &cache,
- const at::cuda::jit::KernelDescriptor &desc,
- at::cuda::jit::BinaryFuncVariant scalar_pos,
- c10::ArrayRef<void*> extra_args,
- TensorIteratorBase& iter,
- const bool dynamic_casting,
- void *scalar_val) {
- TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
- TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
- TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
- constexpr int ntensors = arity + 1;
- at::detail::Array<char*, ntensors> data;
- for (auto i : c10::irange(ntensors)) {
- data[i] = (char*)iter.data_ptr(i);
- }
- int64_t numel = iter.numel();
- bool contiguous = iter.is_contiguous();
-
-
-
-
-
-
-
- if (!dynamic_casting) {
- if (contiguous) {
-
- launch_jitted_vectorized_kernel<arity>(
- jiterator_mutex, cache.vec, desc,
- numel, data, scalar_pos, scalar_val, extra_args);
- return;
- }
-
- auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
- auto output_offset_calculator = make_output_offset_calculator(iter);
- auto loader = memory::LoadWithoutCast();
- auto storer = memory::StoreWithoutCast();
- launch_jitted_unrolled_kernel(
- jiterator_mutex, cache.noncontiguous, desc, numel, data,
- input_offset_calculator, output_offset_calculator, loader,
- storer, contiguous, scalar_pos, scalar_val, extra_args);
- return;
- }
-
-
-
- auto storer = memory::StoreWithCast<1>(iter);
-
- auto loader = memory::LoadWithCast<arity>(iter);
- if (contiguous) {
-
- auto input_offset_calculator = TrivialOffsetCalculator<arity>();
- auto output_offset_calculator = TrivialOffsetCalculator<1>();
- launch_jitted_unrolled_kernel(
- jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
- output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
- return;
- }
-
- auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
- auto output_offset_calculator = make_output_offset_calculator(iter);
- launch_jitted_unrolled_kernel(
- jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
- output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
- }
- template <
- char const* name,
- typename result_type,
- typename f_inputs_type,
- int arity,
- at::cuda::jit::BinaryFuncVariant scalar_pos =
- at::cuda::jit::BinaryFuncVariant::NoScalar,
- typename... ExtraArgs>
- static void jitted_gpu_kernel_impl(
- TensorIteratorBase& iter,
- const std::string &f,
- const bool dynamic_casting,
- at::opmath_type<f_inputs_type> scalar_val,
- std::tuple<ExtraArgs...> extra_args) {
-
-
- static std::mutex jiterator_mutex;
- static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
- constexpr int nInputs = arity;
- constexpr int nOutputs = 1;
- static const auto desc = at::cuda::jit::make_kernel_descriptor<
- result_type, f_inputs_type, ExtraArgs...>(name, f, nInputs, nOutputs);
- auto &cache = device_caches[iter.device().index()];
- auto extra_args_array = tuple_to_array(extra_args);
- return jitted_gpu_kernel_generic<arity>(
- jiterator_mutex,
- cache,
- desc,
- scalar_pos,
- extra_args_array,
- iter,
- dynamic_casting,
- &scalar_val
- );
- }
- }}
- #endif
|