#pragma once #include #include #include #include #include #include #include namespace at { namespace cuda { namespace jit { enum class BinaryFuncVariant {NoScalar, RhsScalar, LhsScalar}; struct NvrtcFunction { CUmodule module = CUmodule(); CUfunction function = nullptr; }; struct KernelDescriptor { std::string name; std::string f; c10::ScalarType f_inputs_type; c10::ScalarType result_type; c10::SmallVector extra_args_types; int nInputs, nOutputs; }; // Helper function to return a vector // corresponding to the type of the arguments in parameter pack. template c10::SmallVector get_extra_args_types() { return {c10::CppTypeToScalarType::value ...}; } template < typename result_type, typename f_inputs_type, typename... ExtraArgs> KernelDescriptor make_kernel_descriptor( std::string name, std::string f, int nInputs, int nOutputs) { KernelDescriptor ret; ret.name = std::move(name); ret.f = std::move(f); ret.f_inputs_type = c10::CppTypeToScalarType::value; ret.result_type = c10::CppTypeToScalarType::value; ret.extra_args_types = get_extra_args_types(); ret.nInputs = nInputs; ret.nOutputs = nOutputs; return ret; } inline int can_vectorize_up_to(size_t default_alignment, void *pointer) { auto ip = reinterpret_cast(pointer); if (ip % (4 * default_alignment) == 0) { return 4; } if (ip % (2 * default_alignment) == 0) { return 2; } return 1; } inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef pointers) { TORCH_INTERNAL_ASSERT(desc.nOutputs == 1); TORCH_INTERNAL_ASSERT(static_cast(pointers.size()) == 1 + desc.nInputs); // Deals with output auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize(); int result = can_vectorize_up_to(result_size, pointers[0]); // Incorporates input(s) auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize(); for (auto i : c10::irange(1, pointers.size())) { result = std::min(result, can_vectorize_up_to(input_size, pointers[i])); } return result; } std::string generate_code( int nInputs, int nOutputs, const std::string& func, const std::string& name, const std::string& f_input_type, const std::string& compute_type, const std::string& result_type, bool contiguous, bool dynamic_casting, BinaryFuncVariant scalar_pos, c10::SmallVector& extra_args_typenames, bool vectorized=false, int vec_size=0, bool return_by_ref=false); std::string generate_code( const KernelDescriptor &desc, bool contiguous, bool dynamic_casting, BinaryFuncVariant scalar_pos, bool vectorized=false, int vec_size=0, bool return_by_ref=false); std::string generate_reduction_code( int nOutputs, const std::string& func, const std::string& name, const int vt0, const std::string& f_inputs_type, const std::string& reduction_accum_type, const std::string& result_type, bool contiguous, bool vectorized, int vec_size, int max_threads_codegen); std::string generate_reduction_code( const KernelDescriptor &desc, const int vt0, bool contiguous, bool vectorized, int vec_size, int max_threads_codegen); NvrtcFunction jit_pwise_function( const std::string& code, const std::string& kernel_name); void launch_jitted_pwise_function( NvrtcFunction function, void* args[], const dim3 nBlocks, const dim3 kBlockSize, const int smem=0); template struct delayed_false : std::false_type { }; // Defines type names // NOTE: General case is instantiated only for invalid types. // All the valid types have specialization using the TYPE_NAME_FN // macro below. template inline std::string typeName() { // we can't use static_assert(false) directly as the // program will be not compiled even if the template is not // instantiated, so we use `delayed_false` // to make sure compiler doesn't eagerly raise // fail this assertion. static_assert(delayed_false::value, "invalid type for jiterator"); return "void"; } #define TYPE_NAME_FN(ctype, name) \ template <> inline std::string typeName(){ \ return std::string(#ctype); \ } AT_FORALL_SCALAR_TYPES(TYPE_NAME_FN) #undef TYPE_NAME_FN // JIT uses std::complex directly, because nvRTC compile programs // with -default-device, so there is no such issue like: // "std::sin(complex) is __host__ only" template <> inline std::string typeName(){ return "bool"; } template <> inline std::string typeName>(){ return "std::complex"; } template <> inline std::string typeName>(){ return "std::complex"; } template <> inline std::string typeName>(){ return "std::complex"; } template <> inline std::string typeName(){ return "at::Half"; } template <> inline std::string typeName(){ return "at::BFloat16"; } #define TYPE_NAME_CASE(ctype, scalartype) \ case ScalarType::scalartype: return typeName(); inline std::string typeName(ScalarType t) { switch (t) { AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE) default: TORCH_CHECK(false, "invalid type for jiterator"); } } #undef TYPE_NAME_CASE TORCH_CUDA_CPP_API void initializeCudaContext(); }}} // namespace at::cuda::jit