Dispatch.h 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. #pragma once
  2. #include <ATen/core/DeprecatedTypeProperties.h>
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/Exception.h>
  5. #include <c10/util/Half.h>
  6. #include <c10/util/Metaprogramming.h>
  7. #include <c10/util/complex.h>
  8. #include <c10/util/string_view.h>
  9. #ifdef __CUDACC__
  10. #include <cuda.h> // For CUDA_VERSION
  11. #endif
  12. #ifdef TEMPLATE_SELECTIVE_BUILD
  13. #include <ATen/selected_mobile_ops.h>
  14. #else
  15. namespace at {
  16. /**
  17. * The method should_include_kernel_dtype() returns true/false
  18. * based on whether the switching code for a specific dtype should be
  19. * included based on build time constants generated from tracing model
  20. * execution. This method will be implmeneted via code-generation and
  21. * included in this file when code-gen is ready.
  22. */
  23. inline constexpr bool should_include_kernel_dtype(
  24. const char* /*kernel_tag_str*/,
  25. at::ScalarType /*scalar_type*/
  26. ) {
  27. return true;
  28. }
  29. } // namespace at
  30. #endif
  31. /**
  32. * In the Facebook internal build (using BUCK), this macro is enabled by
  33. * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer
  34. * binary.
  35. */
  36. #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE
  37. namespace at {
  38. namespace detail {
  39. TORCH_API void record_kernel_function_dtype(std::string name);
  40. }
  41. } // namespace at
  42. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \
  43. at::detail::record_kernel_function_dtype( \
  44. std::string(NAME) + "$" + toString(enum_type));
  45. #else
  46. #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type)
  47. #endif
  48. // Avoid if_constexpr if possble, as it's more expensive to compile
  49. #if defined __cpp_if_constexpr
  50. #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
  51. do { \
  52. if constexpr (!at::should_include_kernel_dtype( \
  53. at_dispatch_name, enum_type)) { \
  54. AT_ERROR( \
  55. "dtype '", \
  56. toString(enum_type), \
  57. "' not selected for kernel tag ", \
  58. at_dispatch_name); \
  59. } \
  60. } while (0)
  61. #else // defined __cpp_if_constexpr
  62. #define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \
  63. at::guts::if_constexpr<!at::should_include_kernel_dtype( \
  64. at_dispatch_name, enum_type)>([&] { \
  65. AT_ERROR( \
  66. "dtype '", \
  67. toString(enum_type), \
  68. "' not selected for kernel tag ", \
  69. at_dispatch_name); \
  70. })
  71. #endif
  72. #define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \
  73. case enum_type: { \
  74. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  75. using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT<enum_type>; \
  76. return __VA_ARGS__(); \
  77. }
  78. #define AT_DISPATCH_CASE(enum_type, ...) \
  79. AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__)
  80. #define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \
  81. case enum_type: { \
  82. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  83. using scalar_t = scalar_type; \
  84. using underlying_t C10_UNUSED = typename scalar_t::underlying; \
  85. const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
  86. const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
  87. return __VA_ARGS__(); \
  88. }
  89. #define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  90. enum_type, scalar_type, bitwidth, qmin, qmax, ...) \
  91. case enum_type: { \
  92. AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \
  93. using scalar_t = scalar_type; \
  94. using underlying_t C10_UNUSED = typename scalar_t::underlying; \
  95. const auto& SCALAR_TYPE C10_UNUSED = enum_type; \
  96. const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \
  97. C10_UNUSED int bit_width = bitwidth; \
  98. C10_UNUSED int64_t quant_min = qmin; \
  99. C10_UNUSED int64_t quant_max = qmax; \
  100. return __VA_ARGS__(); \
  101. }
  102. namespace detail {
  103. inline at::ScalarType scalar_type(at::ScalarType s) {
  104. return s;
  105. }
  106. C10_DEPRECATED_MESSAGE(
  107. "passing at::DeprecatedTypeProperties to an AT_DISPATCH macro is deprecated, "
  108. "pass an at::ScalarType instead")
  109. inline at::ScalarType scalar_type(const at::DeprecatedTypeProperties& t) {
  110. return t.scalarType();
  111. }
  112. C10_DEPRECATED_MESSAGE(
  113. "AT_DISPATCH_ALL_TYPES_AND_HALF is deprecated, "
  114. "use AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ...) instead")
  115. inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF() {}
  116. C10_DEPRECATED_MESSAGE(
  117. "AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX is deprecated, "
  118. "use AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Half, ...) "
  119. "instead")
  120. inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
  121. } // namespace detail
  122. // The AT_DISPATCH_* family of macros provides the ability to
  123. // conveniently generate specializations of a kernel over all of the
  124. // dtypes we care about in PyTorch. We call it "dispatch" because
  125. // we are "dispatching" to the correct, dtype-specific kernel.
  126. //
  127. // A standard usage looks like:
  128. //
  129. // AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] {
  130. // // Your code here, with 'scalar_t' now defined to
  131. // // be the dtype in question
  132. // });
  133. //
  134. // There are many variations of this macro, so it's important to
  135. // understand exactly /which/ dtypes you want to get instantiated, as
  136. // well as what the "default" set is.
  137. //
  138. // The default set of dtypes that are instantiated (e.g., by
  139. // AT_DISPATCH_ALL_TYPES) are floating point types (float, double),
  140. // and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t),
  141. // but NOT booleans (bool), half-precision floats (Half) or
  142. // complex number (c10::complex<float>, c10::complex<double>).
  143. // This "cut" is somewhat historical (the default types are the
  144. // ones that TH historically supported), but it also reflects the
  145. // fact that the non-default types are "poorly" behaved (booleans
  146. // are NOT integers mod 2, half precision operations ~essentially
  147. // don't exist on CPU, complex numbers are an experimental application).
  148. //
  149. // Here are the questions you should generally ask to decide which
  150. // dispatch you want:
  151. //
  152. // 1. Is this an integral or floating point specific operation?
  153. // (If so, you'll want one of the FLOATING or INTEGRAL macros.)
  154. //
  155. // 2. Should half be supported? (If you're on CPU, the answer is almost
  156. // definitely no. If you do want support, use one of the AND_HALF
  157. // macros)
  158. //
  159. // Much rarer situations:
  160. //
  161. // 3. Should bool be supported? (You often have to write your kernel
  162. // differently if arithmetic operations are involved.) If so,
  163. // Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool
  164. //
  165. // 4. Should complex be supported? The answer is almost always no,
  166. // unless you are working on "generic" code that should work on
  167. // all dtypes.
  168. //
  169. // Parameters:
  170. // -----------
  171. //
  172. // 1. The NAME argument is a "tag" that is used to trace and then
  173. // conditionally compile fragments of the case statements such
  174. // that the kernel functions are specialized only for the dtypes
  175. // that are needed. The NAME parameter *must* be a build time
  176. // const char* (can't be std::string, etc...)
  177. //
  178. // Please ensure that the NAME is unique for every implementation
  179. // or you run the risk of over-including code for the kernel
  180. // functions. There is no risk of missing out on any code, so
  181. // it's mostly a risk of a Type-2 error, and not a Type-1 error.
  182. //
  183. // Switch-like syntax:
  184. // -------------------
  185. // There is also a switch-case like syntax which is useful if a kernel
  186. // needs to be specialized for particular scalar types
  187. //
  188. // AT_DISPATCH_SWITCH(self.scalar_type(), "op_name",
  189. // AT_DISPATCH_CASE_INTEGRAL_TYPES([&] {
  190. // op_integral<scalar_t>(iter);
  191. // })
  192. // AT_DISPATCH_CASE_FLOATING_TYPES([&] {
  193. // op_floating<scalar_t>(iter);
  194. // })
  195. // AT_DISPATCH_CASE(kBool, [&] {
  196. // op_bool(iter);
  197. // })
  198. // );
  199. //
  200. // For each AT_DISPATCH_FOO macro, there is a corresponding
  201. // AT_DISPATCH_CASE_FOO macro which can be used inside of an
  202. // AT_DISPATCH_SWITCH block.
  203. // NB: the the_type variable is not used, but we have kept it for
  204. // backwards compatibility. It's probably not used by anyone though;
  205. // but we're just being safe (and it doesn't hurt.) Note we must
  206. // use it to shut up warnings about unused store.
  207. #define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \
  208. [&] { \
  209. const auto& the_type = TYPE; \
  210. constexpr const char* at_dispatch_name = NAME; \
  211. /* don't use TYPE again in case it is an expensive or side-effect op */ \
  212. at::ScalarType _st = ::detail::scalar_type(the_type); \
  213. RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \
  214. switch (_st) { \
  215. __VA_ARGS__ \
  216. default: \
  217. AT_ERROR( \
  218. '"', \
  219. at_dispatch_name, \
  220. "\" not implemented for '", \
  221. toString(_st), \
  222. "'"); \
  223. } \
  224. }()
  225. #define AT_DISPATCH_CASE_FLOATING_TYPES(...) \
  226. AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
  227. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
  228. #define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
  229. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
  230. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \
  231. AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \
  232. AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
  233. AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
  234. #define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
  235. AT_DISPATCH_SWITCH( \
  236. TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__))
  237. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \
  238. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  239. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  240. #define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  241. AT_DISPATCH_SWITCH( \
  242. TYPE, \
  243. NAME, \
  244. AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  245. #define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
  246. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  247. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  248. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  249. #define AT_DISPATCH_FLOATING_TYPES_AND2( \
  250. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  251. AT_DISPATCH_SWITCH( \
  252. TYPE, \
  253. NAME, \
  254. AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \
  255. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  256. #define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \
  257. AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
  258. AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
  259. #define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
  260. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__))
  261. #define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \
  262. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \
  263. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  264. #define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  265. AT_DISPATCH_SWITCH( \
  266. TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  267. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \
  268. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
  269. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
  270. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
  271. AT_DISPATCH_SWITCH( \
  272. TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__))
  273. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \
  274. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  275. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  276. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \
  277. SCALARTYPE, TYPE, NAME, ...) \
  278. AT_DISPATCH_SWITCH( \
  279. TYPE, \
  280. NAME, \
  281. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \
  282. SCALARTYPE, __VA_ARGS__))
  283. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
  284. SCALARTYPE1, SCALARTYPE2, ...) \
  285. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  286. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  287. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  288. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \
  289. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  290. AT_DISPATCH_SWITCH( \
  291. TYPE, \
  292. NAME, \
  293. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \
  294. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  295. #define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
  296. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  297. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \
  298. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  299. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  300. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  301. #define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \
  302. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  303. AT_DISPATCH_SWITCH( \
  304. TYPE, \
  305. NAME, \
  306. AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \
  307. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  308. #define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \
  309. AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
  310. AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
  311. AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
  312. AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
  313. AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
  314. #define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
  315. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
  316. #define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \
  317. AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
  318. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  319. #define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  320. AT_DISPATCH_SWITCH( \
  321. TYPE, \
  322. NAME, \
  323. AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  324. #define AT_DISPATCH_CASE_ALL_TYPES(...) \
  325. AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \
  326. AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)
  327. #define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \
  328. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__))
  329. #define AT_DISPATCH_CASE_QINT_TYPES(...) \
  330. AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
  331. AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \
  332. AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__)
  333. #define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
  334. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__))
  335. #define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \
  336. AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \
  337. AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__)
  338. #define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \
  339. AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__))
  340. #define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \
  341. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  342. at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \
  343. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  344. at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \
  345. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  346. at::kQInt32, \
  347. at::qint32, \
  348. CHAR_BIT * sizeof(int), \
  349. INT_MIN, \
  350. INT_MAX, \
  351. __VA_ARGS__) \
  352. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  353. at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \
  354. AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \
  355. at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__)
  356. #define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \
  357. AT_DISPATCH_SWITCH( \
  358. TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__))
  359. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \
  360. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  361. AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)
  362. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
  363. AT_DISPATCH_SWITCH( \
  364. TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__))
  365. #define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \
  366. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  367. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  368. #define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
  369. AT_DISPATCH_SWITCH( \
  370. TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__))
  371. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \
  372. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  373. AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__)
  374. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
  375. AT_DISPATCH_SWITCH( \
  376. TYPE, \
  377. NAME, \
  378. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__))
  379. #define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \
  380. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  381. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  382. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  383. #define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  384. AT_DISPATCH_SWITCH( \
  385. TYPE, \
  386. NAME, \
  387. AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  388. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
  389. SCALARTYPE1, SCALARTYPE2, ...) \
  390. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  391. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  392. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)
  393. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \
  394. SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
  395. AT_DISPATCH_SWITCH( \
  396. TYPE, \
  397. NAME, \
  398. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \
  399. SCALARTYPE1, SCALARTYPE2, __VA_ARGS__))
  400. #define AT_DISPATCH_CASE_ALL_TYPES_AND3( \
  401. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  402. AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
  403. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  404. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  405. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  406. #define AT_DISPATCH_ALL_TYPES_AND3( \
  407. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  408. AT_DISPATCH_SWITCH( \
  409. TYPE, \
  410. NAME, \
  411. AT_DISPATCH_CASE_ALL_TYPES_AND3( \
  412. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  413. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
  414. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \
  415. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  416. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  417. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  418. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)
  419. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \
  420. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \
  421. AT_DISPATCH_SWITCH( \
  422. TYPE, \
  423. NAME, \
  424. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \
  425. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
  426. #define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  427. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
  428. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
  429. AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \
  430. AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \
  431. AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \
  432. AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
  433. #define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
  434. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
  435. AT_DISPATCH_SWITCH( \
  436. TYPE, \
  437. NAME, \
  438. AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \
  439. SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
  440. #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \
  441. AT_DISPATCH_SWITCH( \
  442. TYPE, \
  443. NAME, \
  444. AT_PRIVATE_CASE_TYPE_USING_HINT( \
  445. at::ScalarType::Int, index_t, __VA_ARGS__) \
  446. AT_PRIVATE_CASE_TYPE_USING_HINT( \
  447. at::ScalarType::Long, index_t, __VA_ARGS__))
  448. // ----------------------------------------------------------------------------
  449. // DEPRECATED MACROS, DON'T USE THESE
  450. // ----------------------------------------------------------------------------
  451. #define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
  452. detail::deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF(); \
  453. AT_DISPATCH_SWITCH( \
  454. TYPE, \
  455. NAME, \
  456. AT_DISPATCH_CASE_ALL_TYPES_AND(at::ScalarType::Half, __VA_ARGS__))