123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- #pragma once
- #include <ATen/jit_macros.h>
- #if AT_USE_JITERATOR()
- #include <ATen/cuda/CUDAConfig.h>
- #include <ATen/OpMathType.h>
- #include <ATen/TensorIterator.h>
- #include <ATen/native/TensorIteratorDynamicCasting.h>
- #include <ATen/native/cuda/MemoryAccess.cuh>
- #include <ATen/native/cuda/CUDAJitLoops.cuh>
- namespace at {
- namespace native {
- template <
- char const* name,
- typename return_type,
- typename f_inputs_type,
- int arity,
- typename... Args>
- void jitted_gpu_kernel(
- TensorIteratorBase& iter,
- const std::string& f,
- at::cuda::jit::BinaryFuncVariant scalar_pos =
- at::cuda::jit::BinaryFuncVariant::NoScalar,
- at::opmath_type<f_inputs_type> scalar_val = 0,
- std::tuple<Args...> extra_args = std::make_tuple()) {
-
-
- for (int arg = 0; arg < iter.ntensors(); arg++) {
- TORCH_INTERNAL_ASSERT(
- iter.device(arg).is_cuda(),
- "argument ", arg, ": expected a CUDA device but found ", iter.device(arg));
- }
- if (iter.numel() == 0) {
- return;
- }
- if (!iter.can_use_32bit_indexing()) {
- for (auto& sub_iter : iter.with_32bit_indexing()) {
- jitted_gpu_kernel<name, return_type, f_inputs_type, arity>(
- sub_iter, f, scalar_pos, scalar_val, extra_args);
- }
- return;
- }
-
-
-
-
-
-
- bool needs_dynamic_casting = false;
-
- const ScalarType return_scalar_type = c10::CppTypeToScalarType<return_type>::value;
- const auto dtype0 = iter.dtype(0);
- if (dtype0 != return_scalar_type) {
- needs_dynamic_casting = true;
- }
-
- const ScalarType inputs_scalar_type = c10::CppTypeToScalarType<f_inputs_type>::value;
- for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
- const auto dtypei = iter.dtype(i);
- if (dtypei != inputs_scalar_type) {
- needs_dynamic_casting = true;
- break;
- }
- }
- if (scalar_pos == at::cuda::jit::BinaryFuncVariant::NoScalar) {
-
-
-
- jitted_gpu_kernel_impl<
- name,
- return_type,
- f_inputs_type,
- arity,
- at::cuda::jit::BinaryFuncVariant::NoScalar>(
- iter, f, needs_dynamic_casting, scalar_val, extra_args);
- } else if (scalar_pos == at::cuda::jit::BinaryFuncVariant::RhsScalar) {
- jitted_gpu_kernel_impl<
- name,
- return_type,
- f_inputs_type,
- arity,
- at::cuda::jit::BinaryFuncVariant::RhsScalar>(
- iter,
- f,
- needs_dynamic_casting,
- scalar_val,
- extra_args);
- } else {
- jitted_gpu_kernel_impl<
- name,
- return_type,
- f_inputs_type,
- arity,
- at::cuda::jit::BinaryFuncVariant::LhsScalar>(
- iter,
- f,
- needs_dynamic_casting,
- scalar_val,
- extra_args);
- }
- }
- template <char const *name, typename return_type, typename f_inputs_type>
- void opmath_jitted_gpu_kernel_with_scalars(TensorIteratorBase& iter, const std::string& f) {
- TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
-
- using opmath_t = at::opmath_type<f_inputs_type>;
- if (iter.is_cpu_scalar(1)) {
- auto scalar_val = iter.scalar_value<opmath_t>(1);
- iter.remove_operand(1);
-
-
-
-
-
- const OptionalDeviceGuard device_guard(iter.device(1));
- jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::LhsScalar, scalar_val);
- } else if (iter.is_cpu_scalar(2)) {
- auto scalar_val = iter.scalar_value<opmath_t>(2);
- iter.remove_operand(2);
- jitted_gpu_kernel<name, return_type, f_inputs_type, 1>(iter, f, at::cuda::jit::BinaryFuncVariant::RhsScalar, scalar_val);
- } else {
- jitted_gpu_kernel<name, return_type, f_inputs_type, 2>(iter, f);
- }
- }
- }}
- #endif
|