DispatchStub.h 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. #pragma once
  2. #include <c10/core/DeviceType.h>
  3. #include <c10/macros/Export.h>
  4. #include <atomic>
  5. #include <utility>
  6. // Implements instruction set specific function dispatch.
  7. //
  8. // Kernels that may make use of specialized instruction sets (e.g. AVX2) are
  9. // compiled multiple times with different compiler flags (e.g. -mavx2). A
  10. // DispatchStub contains a table of function pointers for a kernel. At runtime,
  11. // the fastest available kernel is chosen based on the features reported by
  12. // cpuinfo.
  13. //
  14. // Example:
  15. //
  16. // In native/MyKernel.h:
  17. // using fn_type = void(*)(const Tensor& x);
  18. // DECLARE_DISPATCH(fn_type, stub);
  19. //
  20. // In native/MyKernel.cpp
  21. // DEFINE_DISPATCH(stub);
  22. //
  23. // In native/cpu/MyKernel.cpp:
  24. // namespace {
  25. // // use anonymous namespace so that different cpu versions won't conflict
  26. // void kernel(const Tensor& x) { ... }
  27. // }
  28. // REGISTER_DISPATCH(stub, &kernel);
  29. //
  30. // To call:
  31. // stub(kCPU, tensor);
  32. //
  33. // TODO: CPU instruction set selection should be folded into whatever
  34. // the main dispatch mechanism is.
  35. // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
  36. #if defined(__clang__)
  37. #pragma clang diagnostic push
  38. #pragma clang diagnostic ignored "-Wundefined-var-template"
  39. #endif
  40. namespace at { namespace native {
  41. enum class CPUCapability {
  42. DEFAULT = 0,
  43. #if defined(HAVE_VSX_CPU_DEFINITION)
  44. VSX = 1,
  45. #elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
  46. ZVECTOR = 1,
  47. #else
  48. AVX2 = 1,
  49. AVX512 = 2,
  50. #endif
  51. NUM_OPTIONS
  52. };
  53. CPUCapability get_cpu_capability();
  54. template <typename FnPtr, typename T>
  55. struct DispatchStub;
  56. /**
  57. * The sole purpose of this class is to outline methods that don't need to be
  58. * specialized or otherwise inlined and duplicated (by the compiler due to
  59. * template expansion), since it causes size bloat if there are a significant
  60. * number of specialization of the DispatchStub<> class.
  61. */
  62. struct TORCH_API DispatchStubImpl {
  63. void* get_call_ptr(
  64. DeviceType device_type
  65. , void *DEFAULT
  66. #ifdef HAVE_AVX512_CPU_DEFINITION
  67. , void *AVX512
  68. #endif
  69. #ifdef HAVE_AVX2_CPU_DEFINITION
  70. , void *AVX2
  71. #endif
  72. #ifdef HAVE_VSX_CPU_DEFINITION
  73. , void *VSX
  74. #endif
  75. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  76. , void *ZVECTOR
  77. #endif
  78. );
  79. /**
  80. * The CPU Dispatch actual method is chosen in decreasing order of preference by
  81. * DispatchStubImpl::choose_cpu_impl() in case none is found by
  82. * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
  83. */
  84. void* choose_cpu_impl(
  85. void *DEFAULT
  86. #ifdef HAVE_AVX512_CPU_DEFINITION
  87. , void *AVX512
  88. #endif
  89. #ifdef HAVE_AVX2_CPU_DEFINITION
  90. , void *AVX2
  91. #endif
  92. #ifdef HAVE_VSX_CPU_DEFINITION
  93. , void *VSX
  94. #endif
  95. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  96. , void *ZVECTOR
  97. #endif
  98. );
  99. // Fixing dispatch error in Windows debug builds.
  100. // See https://github.com/pytorch/pytorch/issues/22681 for more details.
  101. #if defined(_MSC_VER) && defined(_DEBUG)
  102. std::atomic<void*> cpu_dispatch_ptr;
  103. void* cuda_dispatch_ptr;
  104. void* hip_dispatch_ptr;
  105. void* mps_dispatch_ptr;
  106. #else
  107. std::atomic<void*> cpu_dispatch_ptr{nullptr};
  108. void* cuda_dispatch_ptr = nullptr;
  109. void* hip_dispatch_ptr = nullptr;
  110. void* mps_dispatch_ptr = nullptr;
  111. #endif
  112. };
  113. template <typename rT, typename T, typename... Args>
  114. struct DispatchStub<rT (*)(Args...), T> {
  115. using FnPtr = rT (*) (Args...);
  116. DispatchStub() = default;
  117. DispatchStub(const DispatchStub&) = delete;
  118. DispatchStub& operator=(const DispatchStub&) = delete;
  119. private:
  120. FnPtr get_call_ptr(DeviceType device_type) {
  121. return reinterpret_cast<FnPtr>(
  122. impl.get_call_ptr(device_type
  123. , reinterpret_cast<void*>(DEFAULT)
  124. #ifdef HAVE_AVX512_CPU_DEFINITION
  125. , reinterpret_cast<void*>(AVX512)
  126. #endif
  127. #ifdef HAVE_AVX2_CPU_DEFINITION
  128. , reinterpret_cast<void*>(AVX2)
  129. #endif
  130. #ifdef HAVE_VSX_CPU_DEFINITION
  131. , reinterpret_cast<void*>(VSX)
  132. #endif
  133. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  134. , reinterpret_cast<void*>(ZVECTOR)
  135. #endif
  136. )
  137. );
  138. }
  139. public:
  140. template <typename... ArgTypes>
  141. rT operator()(DeviceType device_type, ArgTypes&&... args) {
  142. FnPtr call_ptr = get_call_ptr(device_type);
  143. return (*call_ptr)(std::forward<ArgTypes>(args)...);
  144. }
  145. void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
  146. impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  147. }
  148. void set_hip_dispatch_ptr(FnPtr fn_ptr) {
  149. impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  150. }
  151. void set_mps_dispatch_ptr(FnPtr fn_ptr) {
  152. impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
  153. }
  154. static TORCH_API FnPtr DEFAULT;
  155. #ifdef HAVE_AVX512_CPU_DEFINITION
  156. static TORCH_API FnPtr AVX512;
  157. #endif
  158. #ifdef HAVE_AVX2_CPU_DEFINITION
  159. static TORCH_API FnPtr AVX2;
  160. #endif
  161. #ifdef HAVE_VSX_CPU_DEFINITION
  162. static TORCH_API FnPtr VSX;
  163. #endif
  164. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  165. static TORCH_API FnPtr ZVECTOR;
  166. #endif
  167. private:
  168. DispatchStubImpl impl;
  169. };
  170. namespace {
  171. template <typename DispatchStub>
  172. struct RegisterCUDADispatch {
  173. RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  174. stub.set_cuda_dispatch_ptr(value);
  175. }
  176. };
  177. template <typename DispatchStub>
  178. struct RegisterMPSDispatch {
  179. RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  180. stub.set_mps_dispatch_ptr(value);
  181. }
  182. };
  183. template <typename DispatchStub>
  184. struct RegisterHIPDispatch {
  185. RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
  186. // TODO: make this point at hip_dispatch_ptr
  187. stub.set_cuda_dispatch_ptr(value);
  188. }
  189. };
  190. } // anonymous namespace
  191. // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
  192. // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
  193. // adding parentheses and using helper struct to get rid of the parentheses, do
  194. // not work with MSVC. So do a `using`-declaration if you need to pass in such
  195. // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
  196. #define DECLARE_DISPATCH(fn, name) \
  197. struct name : DispatchStub<fn, name> { \
  198. name() = default; \
  199. name(const name&) = delete; \
  200. name& operator=(const name&) = delete; \
  201. }; \
  202. extern TORCH_API struct name name
  203. #define DEFINE_DISPATCH(name) struct name name
  204. #define REGISTER_ARCH_DISPATCH(name, arch, fn) \
  205. template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn;
  206. #ifdef HAVE_AVX512_CPU_DEFINITION
  207. #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
  208. #else
  209. #define REGISTER_AVX512_DISPATCH(name, fn)
  210. #endif
  211. #ifdef HAVE_AVX2_CPU_DEFINITION
  212. #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
  213. #else
  214. #define REGISTER_AVX2_DISPATCH(name, fn)
  215. #endif
  216. #ifdef HAVE_VSX_CPU_DEFINITION
  217. #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
  218. #else
  219. #define REGISTER_VSX_DISPATCH(name, fn)
  220. #endif
  221. #ifdef HAVE_ZVECTOR_CPU_DEFINITION
  222. #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
  223. #else
  224. #define REGISTER_ZVECTOR_DISPATCH(name, fn)
  225. #endif
  226. // Macro to register the same kernel for all CPU arch types. This is useful
  227. // if a kernel does not benefit from being recompiled across different arch types.
  228. #define REGISTER_ALL_CPU_DISPATCH(name, fn) \
  229. REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
  230. REGISTER_AVX512_DISPATCH(name, fn) \
  231. REGISTER_AVX2_DISPATCH(name, fn) \
  232. REGISTER_VSX_DISPATCH(name, fn) \
  233. REGISTER_ZVECTOR_DISPATCH(name, fn)
  234. #define REGISTER_NO_CPU_DISPATCH(name) \
  235. REGISTER_ALL_CPU_DISPATCH(name, nullptr)
  236. #define REGISTER_CUDA_DISPATCH(name, fn) \
  237. static RegisterCUDADispatch<struct name> name ## __register(name, fn);
  238. #define REGISTER_HIP_DISPATCH(name, fn) \
  239. static RegisterHIPDispatch<struct name> name ## __register(name, fn);
  240. #define REGISTER_MPS_DISPATCH(name, fn) \
  241. static RegisterMPSDispatch<struct name> name ## __register(name, fn);
  242. // NB: This macro must be used in an actual 'cu' file; if you try using
  243. // it from a 'cpp' file it will not work!
  244. #if defined(__CUDACC__)
  245. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  246. #elif defined(__HIPCC__)
  247. // TODO: cut this over to HIP dispatch once we stop pretending that CUDA
  248. // is HIP in the PyTorch HIPify build.
  249. #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
  250. // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
  251. #elif defined(__OBJC__) && defined(USE_MPS)
  252. // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
  253. #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
  254. #elif defined(CPU_CAPABILITY)
  255. #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
  256. #define REGISTER_NO_AVX512_DISPATCH(name) \
  257. REGISTER_AVX512_DISPATCH(name, nullptr)
  258. #endif
  259. }} // namespace at::native
  260. #if defined(__clang__)
  261. #pragma clang diagnostic pop
  262. #endif