#pragma once #include #include #include #include #include #include // This file includes utilties for dynamic_casting done by TensorIterator, see CUDALoops.cuh and Loops.h. // dynamic_casting handles when the types expected by the iterator do not match the types of the arguments // to the function that is being called. // On CUDA, the cast is currently pushed down into the kernel (for performance reasons). // On CPU, there is currently an internal assert that a dynamic_cast is not needed. namespace at { namespace native { // `needs_dynamic_casting` compares the types expected by iterator // (i.e. dtypes of the operands) with the actual type of the arguments // (and returns) of func_t template::arity> struct needs_dynamic_casting { static bool check(TensorIteratorBase& iter) { using traits = function_traits; using cpp_type = typename traits::template arg::type; using cpp_map = c10::CppTypeToScalarType; if (iter.input_dtype(nargs-1) != cpp_map::value) { return true; } return needs_dynamic_casting::check(iter); } }; template struct needs_dynamic_casting { static bool check(TensorIteratorBase& iter) { using traits = function_traits; using cpp_type = typename traits::result_type; // we could assert output numbers are correct here, but checks // (including arity) are currently pushed outside of this struct. return c10::guts::if_constexpr::value>([]() { return false; }, /* else */ [&](auto _) { // decltype(_) is used to delay computation using delayed_type = typename decltype(_)::template type_identity; return iter.dtype(0) != c10::CppTypeToScalarType::value; }); } }; }} //namespace at::native