123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- #pragma once
- #include <c10/core/DeviceType.h>
- #include <c10/macros/Export.h>
- #include <atomic>
- #include <utility>
- // Implements instruction set specific function dispatch.
- //
- // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
- // compiled multiple times with different compiler flags (e.g. -mavx2). A
- // DispatchStub contains a table of function pointers for a kernel. At runtime,
- // the fastest available kernel is chosen based on the features reported by
- // cpuinfo.
- //
- // Example:
- //
- // In native/MyKernel.h:
- // using fn_type = void(*)(const Tensor& x);
- // DECLARE_DISPATCH(fn_type, stub);
- //
- // In native/MyKernel.cpp
- // DEFINE_DISPATCH(stub);
- //
- // In native/cpu/MyKernel.cpp:
- // namespace {
- // // use anonymous namespace so that different cpu versions won't conflict
- // void kernel(const Tensor& x) { ... }
- // }
- // REGISTER_DISPATCH(stub, &kernel);
- //
- // To call:
- // stub(kCPU, tensor);
- //
- // TODO: CPU instruction set selection should be folded into whatever
- // the main dispatch mechanism is.
- // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
- #if defined(__clang__)
- #pragma clang diagnostic push
- #pragma clang diagnostic ignored "-Wundefined-var-template"
- #endif
- namespace at { namespace native {
- enum class CPUCapability {
- DEFAULT = 0,
- #if defined(HAVE_VSX_CPU_DEFINITION)
- VSX = 1,
- #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
- ZVECTOR = 1,
- #else
- AVX2 = 1,
- AVX512 = 2,
- #endif
- NUM_OPTIONS
- };
- CPUCapability get_cpu_capability();
- template <typename FnPtr, typename T>
- struct DispatchStub;
- /**
- * The sole purpose of this class is to outline methods that don't need to be
- * specialized or otherwise inlined and duplicated (by the compiler due to
- * template expansion), since it causes size bloat if there are a significant
- * number of specialization of the DispatchStub<> class.
- */
- struct TORCH_API DispatchStubImpl {
- void* get_call_ptr(
- DeviceType device_type
- , void *DEFAULT
- #ifdef HAVE_AVX512_CPU_DEFINITION
- , void *AVX512
- #endif
- #ifdef HAVE_AVX2_CPU_DEFINITION
- , void *AVX2
- #endif
- #ifdef HAVE_VSX_CPU_DEFINITION
- , void *VSX
- #endif
- #ifdef HAVE_ZVECTOR_CPU_DEFINITION
- , void *ZVECTOR
- #endif
- );
- /**
- * The CPU Dispatch actual method is chosen in decreasing order of preference by
- * DispatchStubImpl::choose_cpu_impl() in case none is found by
- * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
- */
- void* choose_cpu_impl(
- void *DEFAULT
- #ifdef HAVE_AVX512_CPU_DEFINITION
- , void *AVX512
- #endif
- #ifdef HAVE_AVX2_CPU_DEFINITION
- , void *AVX2
- #endif
- #ifdef HAVE_VSX_CPU_DEFINITION
- , void *VSX
- #endif
- #ifdef HAVE_ZVECTOR_CPU_DEFINITION
- , void *ZVECTOR
- #endif
- );
- // Fixing dispatch error in Windows debug builds.
- // See https://github.com/pytorch/pytorch/issues/22681 for more details.
- #if defined(_MSC_VER) && defined(_DEBUG)
- std::atomic<void*> cpu_dispatch_ptr;
- void* cuda_dispatch_ptr;
- void* hip_dispatch_ptr;
- void* mps_dispatch_ptr;
- #else
- std::atomic<void*> cpu_dispatch_ptr{nullptr};
- void* cuda_dispatch_ptr = nullptr;
- void* hip_dispatch_ptr = nullptr;
- void* mps_dispatch_ptr = nullptr;
- #endif
- };
- template <typename rT, typename T, typename... Args>
- struct DispatchStub<rT (*)(Args...), T> {
- using FnPtr = rT (*) (Args...);
- DispatchStub() = default;
- DispatchStub(const DispatchStub&) = delete;
- DispatchStub& operator=(const DispatchStub&) = delete;
- private:
- FnPtr get_call_ptr(DeviceType device_type) {
- return reinterpret_cast<FnPtr>(
- impl.get_call_ptr(device_type
- , reinterpret_cast<void*>(DEFAULT)
- #ifdef HAVE_AVX512_CPU_DEFINITION
- , reinterpret_cast<void*>(AVX512)
- #endif
- #ifdef HAVE_AVX2_CPU_DEFINITION
- , reinterpret_cast<void*>(AVX2)
- #endif
- #ifdef HAVE_VSX_CPU_DEFINITION
- , reinterpret_cast<void*>(VSX)
- #endif
- #ifdef HAVE_ZVECTOR_CPU_DEFINITION
- , reinterpret_cast<void*>(ZVECTOR)
- #endif
- )
- );
- }
- public:
- template <typename... ArgTypes>
- rT operator()(DeviceType device_type, ArgTypes&&... args) {
- FnPtr call_ptr = get_call_ptr(device_type);
- return (*call_ptr)(std::forward<ArgTypes>(args)...);
- }
- void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
- impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
- }
- void set_hip_dispatch_ptr(FnPtr fn_ptr) {
- impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
- }
- void set_mps_dispatch_ptr(FnPtr fn_ptr) {
- impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
- }
- static TORCH_API FnPtr DEFAULT;
- #ifdef HAVE_AVX512_CPU_DEFINITION
- static TORCH_API FnPtr AVX512;
- #endif
- #ifdef HAVE_AVX2_CPU_DEFINITION
- static TORCH_API FnPtr AVX2;
- #endif
- #ifdef HAVE_VSX_CPU_DEFINITION
- static TORCH_API FnPtr VSX;
- #endif
- #ifdef HAVE_ZVECTOR_CPU_DEFINITION
- static TORCH_API FnPtr ZVECTOR;
- #endif
- private:
- DispatchStubImpl impl;
- };
- namespace {
- template <typename DispatchStub>
- struct RegisterCUDADispatch {
- RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
- stub.set_cuda_dispatch_ptr(value);
- }
- };
- template <typename DispatchStub>
- struct RegisterMPSDispatch {
- RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
- stub.set_mps_dispatch_ptr(value);
- }
- };
- template <typename DispatchStub>
- struct RegisterHIPDispatch {
- RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
- // TODO: make this point at hip_dispatch_ptr
- stub.set_cuda_dispatch_ptr(value);
- }
- };
- } // anonymous namespace
- // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
- // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
- // adding parentheses and using helper struct to get rid of the parentheses, do
- // not work with MSVC. So do a `using`-declaration if you need to pass in such
- // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
- #define DECLARE_DISPATCH(fn, name) \
- struct name : DispatchStub<fn, name> { \
- name() = default; \
- name(const name&) = delete; \
- name& operator=(const name&) = delete; \
- }; \
- extern TORCH_API struct name name
- #define DEFINE_DISPATCH(name) struct name name
- #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
- template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn;
- #ifdef HAVE_AVX512_CPU_DEFINITION
- #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
- #else
- #define REGISTER_AVX512_DISPATCH(name, fn)
- #endif
- #ifdef HAVE_AVX2_CPU_DEFINITION
- #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
- #else
- #define REGISTER_AVX2_DISPATCH(name, fn)
- #endif
- #ifdef HAVE_VSX_CPU_DEFINITION
- #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
- #else
- #define REGISTER_VSX_DISPATCH(name, fn)
- #endif
- #ifdef HAVE_ZVECTOR_CPU_DEFINITION
- #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
- #else
- #define REGISTER_ZVECTOR_DISPATCH(name, fn)
- #endif
- // Macro to register the same kernel for all CPU arch types. This is useful
- // if a kernel does not benefit from being recompiled across different arch types.
- #define REGISTER_ALL_CPU_DISPATCH(name, fn) \
- REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
- REGISTER_AVX512_DISPATCH(name, fn) \
- REGISTER_AVX2_DISPATCH(name, fn) \
- REGISTER_VSX_DISPATCH(name, fn) \
- REGISTER_ZVECTOR_DISPATCH(name, fn)
- #define REGISTER_NO_CPU_DISPATCH(name) \
- REGISTER_ALL_CPU_DISPATCH(name, nullptr)
- #define REGISTER_CUDA_DISPATCH(name, fn) \
- static RegisterCUDADispatch<struct name> name ## __register(name, fn);
- #define REGISTER_HIP_DISPATCH(name, fn) \
- static RegisterHIPDispatch<struct name> name ## __register(name, fn);
- #define REGISTER_MPS_DISPATCH(name, fn) \
- static RegisterMPSDispatch<struct name> name ## __register(name, fn);
- // NB: This macro must be used in an actual 'cu' file; if you try using
- // it from a 'cpp' file it will not work!
- #if defined(__CUDACC__)
- #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
- #elif defined(__HIPCC__)
- // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
- // is HIP in the PyTorch HIPify build.
- #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
- // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
- #elif defined(__OBJC__) && defined(USE_MPS)
- // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
- #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
- #elif defined(CPU_CAPABILITY)
- #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
- #define REGISTER_NO_AVX512_DISPATCH(name) \
- REGISTER_AVX512_DISPATCH(name, nullptr)
- #endif
- }} // namespace at::native
- #if defined(__clang__)
- #pragma clang diagnostic pop
- #endif
|