#pragma once #include #include #include #include // Implements instruction set specific function dispatch. // // Kernels that may make use of specialized instruction sets (e.g. AVX2) are // compiled multiple times with different compiler flags (e.g. -mavx2). A // DispatchStub contains a table of function pointers for a kernel. At runtime, // the fastest available kernel is chosen based on the features reported by // cpuinfo. // // Example: // // In native/MyKernel.h: // using fn_type = void(*)(const Tensor& x); // DECLARE_DISPATCH(fn_type, stub); // // In native/MyKernel.cpp // DEFINE_DISPATCH(stub); // // In native/cpu/MyKernel.cpp: // namespace { // // use anonymous namespace so that different cpu versions won't conflict // void kernel(const Tensor& x) { ... } // } // REGISTER_DISPATCH(stub, &kernel); // // To call: // stub(kCPU, tensor); // // TODO: CPU instruction set selection should be folded into whatever // the main dispatch mechanism is. // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-var-template" #endif namespace at { namespace native { enum class CPUCapability { DEFAULT = 0, #if defined(HAVE_VSX_CPU_DEFINITION) VSX = 1, #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) ZVECTOR = 1, #else AVX2 = 1, AVX512 = 2, #endif NUM_OPTIONS }; CPUCapability get_cpu_capability(); template struct DispatchStub; /** * The sole purpose of this class is to outline methods that don't need to be * specialized or otherwise inlined and duplicated (by the compiler due to * template expansion), since it causes size bloat if there are a significant * number of specialization of the DispatchStub<> class. */ struct TORCH_API DispatchStubImpl { void* get_call_ptr( DeviceType device_type , void *DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , void *VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif ); /** * The CPU Dispatch actual method is chosen in decreasing order of preference by * DispatchStubImpl::choose_cpu_impl() in case none is found by * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr. */ void* choose_cpu_impl( void *DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , void *VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif ); // Fixing dispatch error in Windows debug builds. // See https://github.com/pytorch/pytorch/issues/22681 for more details. #if defined(_MSC_VER) && defined(_DEBUG) std::atomic cpu_dispatch_ptr; void* cuda_dispatch_ptr; void* hip_dispatch_ptr; void* mps_dispatch_ptr; #else std::atomic cpu_dispatch_ptr{nullptr}; void* cuda_dispatch_ptr = nullptr; void* hip_dispatch_ptr = nullptr; void* mps_dispatch_ptr = nullptr; #endif }; template struct DispatchStub { using FnPtr = rT (*) (Args...); DispatchStub() = default; DispatchStub(const DispatchStub&) = delete; DispatchStub& operator=(const DispatchStub&) = delete; private: FnPtr get_call_ptr(DeviceType device_type) { return reinterpret_cast( impl.get_call_ptr(device_type , reinterpret_cast(DEFAULT) #ifdef HAVE_AVX512_CPU_DEFINITION , reinterpret_cast(AVX512) #endif #ifdef HAVE_AVX2_CPU_DEFINITION , reinterpret_cast(AVX2) #endif #ifdef HAVE_VSX_CPU_DEFINITION , reinterpret_cast(VSX) #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , reinterpret_cast(ZVECTOR) #endif ) ); } public: template rT operator()(DeviceType device_type, ArgTypes&&... args) { FnPtr call_ptr = get_call_ptr(device_type); return (*call_ptr)(std::forward(args)...); } void set_cuda_dispatch_ptr(FnPtr fn_ptr) { impl.cuda_dispatch_ptr = reinterpret_cast(fn_ptr); } void set_hip_dispatch_ptr(FnPtr fn_ptr) { impl.hip_dispatch_ptr = reinterpret_cast(fn_ptr); } void set_mps_dispatch_ptr(FnPtr fn_ptr) { impl.mps_dispatch_ptr = reinterpret_cast(fn_ptr); } static TORCH_API FnPtr DEFAULT; #ifdef HAVE_AVX512_CPU_DEFINITION static TORCH_API FnPtr AVX512; #endif #ifdef HAVE_AVX2_CPU_DEFINITION static TORCH_API FnPtr AVX2; #endif #ifdef HAVE_VSX_CPU_DEFINITION static TORCH_API FnPtr VSX; #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION static TORCH_API FnPtr ZVECTOR; #endif private: DispatchStubImpl impl; }; namespace { template struct RegisterCUDADispatch { RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { stub.set_cuda_dispatch_ptr(value); } }; template struct RegisterMPSDispatch { RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { stub.set_mps_dispatch_ptr(value); } }; template struct RegisterHIPDispatch { RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { // TODO: make this point at hip_dispatch_ptr stub.set_cuda_dispatch_ptr(value); } }; } // anonymous namespace // Compiler will complain if you put things like std::tuple in // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g., // adding parentheses and using helper struct to get rid of the parentheses, do // not work with MSVC. So do a `using`-declaration if you need to pass in such // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h. #define DECLARE_DISPATCH(fn, name) \ struct name : DispatchStub { \ name() = default; \ name(const name&) = delete; \ name& operator=(const name&) = delete; \ }; \ extern TORCH_API struct name name #define DEFINE_DISPATCH(name) struct name name #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ template <> name::FnPtr TORCH_API DispatchStub::arch = fn; #ifdef HAVE_AVX512_CPU_DEFINITION #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) #else #define REGISTER_AVX512_DISPATCH(name, fn) #endif #ifdef HAVE_AVX2_CPU_DEFINITION #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) #else #define REGISTER_AVX2_DISPATCH(name, fn) #endif #ifdef HAVE_VSX_CPU_DEFINITION #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) #else #define REGISTER_VSX_DISPATCH(name, fn) #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn) #else #define REGISTER_ZVECTOR_DISPATCH(name, fn) #endif // Macro to register the same kernel for all CPU arch types. This is useful // if a kernel does not benefit from being recompiled across different arch types. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ REGISTER_AVX512_DISPATCH(name, fn) \ REGISTER_AVX2_DISPATCH(name, fn) \ REGISTER_VSX_DISPATCH(name, fn) \ REGISTER_ZVECTOR_DISPATCH(name, fn) #define REGISTER_NO_CPU_DISPATCH(name) \ REGISTER_ALL_CPU_DISPATCH(name, nullptr) #define REGISTER_CUDA_DISPATCH(name, fn) \ static RegisterCUDADispatch name ## __register(name, fn); #define REGISTER_HIP_DISPATCH(name, fn) \ static RegisterHIPDispatch name ## __register(name, fn); #define REGISTER_MPS_DISPATCH(name, fn) \ static RegisterMPSDispatch name ## __register(name, fn); // NB: This macro must be used in an actual 'cu' file; if you try using // it from a 'cpp' file it will not work! #if defined(__CUDACC__) #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) #elif defined(__HIPCC__) // TODO: cut this over to HIP dispatch once we stop pretending that CUDA // is HIP in the PyTorch HIPify build. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) #elif defined(__OBJC__) && defined(USE_MPS) // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) #elif defined(CPU_CAPABILITY) #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) #define REGISTER_NO_AVX512_DISPATCH(name) \ REGISTER_AVX512_DISPATCH(name, nullptr) #endif }} // namespace at::native #if defined(__clang__) #pragma clang diagnostic pop #endif