Metaprogramming.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #pragma once
  2. #include <c10/util/Array.h>
  3. #include <c10/util/TypeList.h>
  4. #include <array>
  5. #include <functional>
  6. #include <type_traits>
  7. namespace c10 {
  8. namespace guts {
  9. /**
  10. * Access information about result type or arguments from a function type.
  11. * Example:
  12. * using A = function_traits<int (float, double)>::return_type // A == int
  13. * using A = function_traits<int (float, double)>::parameter_types::tuple_type
  14. * // A == tuple<float, double>
  15. */
  16. template <class Func>
  17. struct function_traits {
  18. static_assert(
  19. !std::is_same<Func, Func>::value,
  20. "In function_traits<Func>, Func must be a plain function type.");
  21. };
  22. template <class Result, class... Args>
  23. struct function_traits<Result(Args...)> {
  24. using func_type = Result(Args...);
  25. using return_type = Result;
  26. using parameter_types = typelist::typelist<Args...>;
  27. static constexpr auto number_of_parameters = sizeof...(Args);
  28. };
  29. /**
  30. * infer_function_traits: creates a `function_traits` type for a simple
  31. * function (pointer) or functor (lambda/struct). Currently does not support
  32. * class methods.
  33. */
  34. template <typename Functor>
  35. struct infer_function_traits {
  36. using type = function_traits<
  37. c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
  38. };
  39. template <typename Result, typename... Args>
  40. struct infer_function_traits<Result (*)(Args...)> {
  41. using type = function_traits<Result(Args...)>;
  42. };
  43. template <typename Result, typename... Args>
  44. struct infer_function_traits<Result(Args...)> {
  45. using type = function_traits<Result(Args...)>;
  46. };
  47. template <typename T>
  48. using infer_function_traits_t = typename infer_function_traits<T>::type;
  49. /**
  50. * make_function_traits: creates a `function_traits` type given a Return type
  51. * and a typelist of Argument types
  52. *
  53. * Example:
  54. * bool f(int, int);
  55. *
  56. * infer_function_traits_t<f> == make_function_traits_t<bool,
  57. * typelist::typelist<int, int>>
  58. */
  59. template <typename Result, typename ArgList>
  60. struct make_function_traits {
  61. static_assert(
  62. false_t<ArgList>::value,
  63. "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
  64. };
  65. template <typename Result, typename... Args>
  66. struct make_function_traits<Result, typelist::typelist<Args...>> {
  67. using type = function_traits<Result(Args...)>;
  68. };
  69. template <typename Result, typename ArgList>
  70. using make_function_traits_t =
  71. typename make_function_traits<Result, ArgList>::type;
  72. /**
  73. * Use extract_arg_by_filtered_index to return the i-th argument whose
  74. * type fulfills a given type trait. The argument itself is perfectly forwarded.
  75. *
  76. * Example:
  77. * std::string arg1 = "Hello";
  78. * std::string arg2 = "World";
  79. * std::string&& result = extract_arg_by_filtered_index<is_string, 1>(0,
  80. * arg1, 2.0, std::move(arg2));
  81. *
  82. * Warning: Taking the result by rvalue reference can cause segfaults because
  83. * ownership will not be passed on from the original reference. The original
  84. * reference dies after the expression and the resulting
  85. */
  86. namespace detail {
  87. template <
  88. template <class>
  89. class Condition,
  90. size_t index,
  91. class Enable,
  92. class... Args>
  93. struct extract_arg_by_filtered_index_;
  94. template <
  95. template <class>
  96. class Condition,
  97. size_t index,
  98. class Head,
  99. class... Tail>
  100. struct extract_arg_by_filtered_index_<
  101. Condition,
  102. index,
  103. std::enable_if_t<!Condition<Head>::value>,
  104. Head,
  105. Tail...> {
  106. static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
  107. return extract_arg_by_filtered_index_<Condition, index, void, Tail...>::
  108. call(std::forward<Tail>(tail)...);
  109. }
  110. };
  111. template <
  112. template <class>
  113. class Condition,
  114. size_t index,
  115. class Head,
  116. class... Tail>
  117. struct extract_arg_by_filtered_index_<
  118. Condition,
  119. index,
  120. std::enable_if_t<Condition<Head>::value && index != 0>,
  121. Head,
  122. Tail...> {
  123. static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
  124. return extract_arg_by_filtered_index_<Condition, index - 1, void, Tail...>::
  125. call(std::forward<Tail>(tail)...);
  126. }
  127. };
  128. template <template <class> class Condition, size_t index>
  129. struct extract_arg_by_filtered_index_<Condition, index, void> {
  130. static void call() {
  131. static_assert(
  132. index != index, "extract_arg_by_filtered_index out of range.");
  133. }
  134. };
  135. template <
  136. template <class>
  137. class Condition,
  138. size_t index,
  139. class Head,
  140. class... Tail>
  141. struct extract_arg_by_filtered_index_<
  142. Condition,
  143. index,
  144. std::enable_if_t<Condition<Head>::value && index == 0>,
  145. Head,
  146. Tail...> {
  147. static decltype(auto) call(Head&& head, Tail&&... /*tail*/) {
  148. return std::forward<Head>(head);
  149. }
  150. };
  151. } // namespace detail
  152. template <template <class> class Condition, size_t index, class... Args>
  153. decltype(auto) extract_arg_by_filtered_index(Args&&... args) {
  154. static_assert(
  155. is_type_condition<Condition>::value,
  156. "In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
  157. return detail::
  158. extract_arg_by_filtered_index_<Condition, index, void, Args...>::call(
  159. std::forward<Args>(args)...);
  160. }
  161. /**
  162. * Use filter_map to map a subset of the arguments to values.
  163. * The subset is defined by type traits, and will be evaluated at compile time.
  164. * At runtime, it will just loop over the pre-filtered arguments to create an
  165. * std::array.
  166. *
  167. * Example:
  168. * std::array<double, 2> result = filter_map<double, std::is_integral>([] (auto
  169. * a) {return (double)a;}, 3, "bla", 4);
  170. * // result == {3.0, 4.0}
  171. */
  172. namespace detail {
  173. template <class ResultType, size_t num_results>
  174. struct filter_map_ {
  175. template <
  176. template <class>
  177. class Condition,
  178. class Mapper,
  179. class... Args,
  180. size_t... INDEX>
  181. static guts::array<ResultType, num_results> call(
  182. const Mapper& mapper,
  183. std::index_sequence<INDEX...>,
  184. Args&&... args) {
  185. return guts::array<ResultType, num_results>{
  186. mapper(extract_arg_by_filtered_index<Condition, INDEX>(
  187. std::forward<Args>(args)...))...};
  188. }
  189. };
  190. template <class ResultType>
  191. struct filter_map_<ResultType, 0> {
  192. template <
  193. template <class>
  194. class Condition,
  195. class Mapper,
  196. class... Args,
  197. size_t... INDEX>
  198. static guts::array<ResultType, 0> call(
  199. const Mapper& /*mapper*/,
  200. std::index_sequence<INDEX...>,
  201. Args&&... /*args*/) {
  202. return guts::array<ResultType, 0>{};
  203. }
  204. };
  205. } // namespace detail
  206. template <
  207. class ResultType,
  208. template <class>
  209. class Condition,
  210. class Mapper,
  211. class... Args>
  212. decltype(auto) filter_map(const Mapper& mapper, Args&&... args) {
  213. static_assert(
  214. is_type_condition<Condition>::value,
  215. "In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
  216. static constexpr size_t num_results =
  217. typelist::count_if<Condition, typelist::typelist<Args...>>::value;
  218. return detail::filter_map_<ResultType, num_results>::
  219. template call<Condition, Mapper, Args...>(
  220. mapper,
  221. std::make_index_sequence<num_results>(),
  222. std::forward<Args>(args)...);
  223. }
  224. /**
  225. * make_offset_index_sequence<Start, N>
  226. * Like make_index_sequence<N>, but starting from Start instead of 0.
  227. *
  228. * Example:
  229. * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
  230. */
  231. template <size_t Start, size_t N, size_t... Is>
  232. struct make_offset_index_sequence_impl
  233. : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
  234. static_assert(
  235. static_cast<int>(Start) >= 0,
  236. "make_offset_index_sequence: Start < 0");
  237. static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
  238. };
  239. template <size_t Start, size_t... Is>
  240. struct make_offset_index_sequence_impl<Start, 0, Is...> {
  241. typedef std::index_sequence<Is...> type;
  242. };
  243. template <size_t Start, size_t N>
  244. using make_offset_index_sequence =
  245. typename make_offset_index_sequence_impl<Start, N>::type;
  246. /**
  247. * Use tuple_elements to extract a position-indexed subset of elements
  248. * from the argument tuple into a result tuple.
  249. *
  250. * Example:
  251. * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
  252. * std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
  253. * 2>());
  254. */
  255. template <class Tuple, size_t... Is>
  256. constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
  257. return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
  258. }
  259. /**
  260. * Use tuple_take to extract the first or last n elements from the argument
  261. * tuple into a result tuple.
  262. *
  263. * Example:
  264. * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
  265. * std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
  266. * std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
  267. */
  268. template <class Tuple, int N, class Enable = void>
  269. struct TupleTake {};
  270. template <class Tuple, int N>
  271. struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
  272. static auto call(Tuple t) {
  273. constexpr size_t size = std::tuple_size<Tuple>();
  274. static_assert(N <= size, "tuple_take: N > size");
  275. return tuple_elements(t, std::make_index_sequence<N>{});
  276. }
  277. };
  278. template <class Tuple, int N>
  279. struct TupleTake < Tuple,
  280. N, std::enable_if_t<N<0, void>> {
  281. static auto call(Tuple t) {
  282. constexpr size_t size = std::tuple_size<Tuple>();
  283. static_assert(-N <= size, "tuple_take: -N > size");
  284. return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
  285. }
  286. };
  287. template <class Tuple, int N>
  288. auto tuple_take(Tuple t) {
  289. return TupleTake<Tuple, N>::call(t);
  290. }
  291. /**
  292. * Use tuple_slice to extract a contiguous subtuple from the argument.
  293. *
  294. * Example:
  295. * std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
  296. * "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
  297. * tuple_slice<decltype(t), 1, 2>(t);
  298. */
  299. template <class Tuple, size_t Start, size_t N>
  300. constexpr auto tuple_slice(Tuple t) {
  301. constexpr size_t size = std::tuple_size<Tuple>();
  302. static_assert(Start + N <= size, "tuple_slice: Start + N > size");
  303. return tuple_elements(t, make_offset_index_sequence<Start, N>{});
  304. }
  305. /**
  306. * Use tuple_map to run a mapping function over a tuple to get a new tuple.
  307. *
  308. * Example 1:
  309. * auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
  310. * (int32_t a) -> int16_t {return a+1;});
  311. * // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
  312. *
  313. * Example 2:
  314. * struct Mapper {
  315. * std::string operator()(int32_t a) const {
  316. * return std::to_string(a);
  317. * }
  318. * int64_t operator()(const std::string& a) const {
  319. * return atoi(a.c_str());
  320. * }
  321. * };
  322. * auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
  323. * Mapper());
  324. * // result == std::tuple<std::string, int64_t>("3", 4)
  325. *
  326. * Example 3:
  327. * struct A final {
  328. * int32_t func() {
  329. * return 5;
  330. * }
  331. * };
  332. * struct B final {
  333. * std::string func() {
  334. * return "5";
  335. * }
  336. * };
  337. * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
  338. * a.func(); });
  339. * // result == std::tuple<int32_t, std::string>(5, "5");
  340. */
  341. namespace detail {
  342. template <class Mapper, class... Args, size_t... Indices>
  343. auto tuple_map(
  344. std::tuple<Args...>&& tuple,
  345. const Mapper& mapper,
  346. std::index_sequence<Indices...>) {
  347. return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
  348. tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
  349. }
  350. } // namespace detail
  351. template <class Mapper, class... Args>
  352. auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
  353. return detail::tuple_map(
  354. std::move(tuple), mapper, std::index_sequence_for<Args...>());
  355. }
  356. /**
  357. * tuple_concat concatenates several tuples into one.
  358. */
  359. namespace detail {
  360. // extract_tuple_element_by_index is a helper that takes a list of tuples and
  361. // extracts the i-th element in a flattened view of the tuples. Example:
  362. // extract_tuple_element_by_index<3>(tuple(2,3), tuple(4,5), tuple(6,7)) == 5.
  363. template <
  364. size_t index,
  365. class HeadTuple,
  366. class... TailTuples,
  367. std::enable_if_t<
  368. index<std::tuple_size<HeadTuple>::value, int> = 0> decltype(auto)
  369. extract_tuple_element_by_index(
  370. HeadTuple&& head_tuple,
  371. TailTuples&&... /*tail_tuples*/) {
  372. // TODO if constexpr instead of enable_if
  373. return std::get<index>(std::forward<HeadTuple>(head_tuple));
  374. }
  375. template <
  376. size_t index,
  377. class HeadTuple,
  378. class... TailTuples,
  379. std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0>
  380. decltype(auto) extract_tuple_element_by_index(
  381. HeadTuple&& /*head_tuple*/,
  382. TailTuples&&... tail_tuples) {
  383. // TODO if constexpr instead of enable_if
  384. return extract_tuple_element_by_index<
  385. index - std::tuple_size<HeadTuple>::value,
  386. TailTuples...>(std::forward<TailTuples>(tail_tuples)...);
  387. }
  388. static_assert(
  389. std::is_same<
  390. int&&,
  391. decltype(extract_tuple_element_by_index<2>(
  392. std::tuple<int32_t>(2),
  393. std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))>::
  394. value,
  395. "extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value");
  396. template <class ConcatenatedTuple, class... Tuples, size_t... ElementIndices>
  397. auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) {
  398. return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>(
  399. std::forward<Tuples>(tuples)...)...);
  400. }
  401. } // namespace detail
  402. template <class... Tuples>
  403. auto tuple_concat(Tuples&&... tuples) {
  404. using flattened_types =
  405. guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>;
  406. using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>;
  407. constexpr size_t num_elements = guts::typelist::size<flattened_types>::value;
  408. return detail::tuple_concat<concatenated_tuple, Tuples...>(
  409. std::forward<Tuples>(tuples)...,
  410. std::make_index_sequence<num_elements>());
  411. }
  412. /**
  413. * Concatenate multiple integer sequences
  414. * Example:
  415. * concat_iseq_t<std::index_sequence<2, 5, 3>, std::index_sequence<4, 2>,
  416. * std::index_sequence<5>>
  417. * == std::index_sequence<2, 5, 3, 4, 2, 5>
  418. */
  419. template <class... ISeqs>
  420. struct concat_iseq {
  421. static_assert(
  422. false_t<ISeqs...>::value,
  423. "In concat_iseq<T1, ...>, the T arguments each must be std::integer_sequence<...> with the same IntType.");
  424. };
  425. template <>
  426. struct concat_iseq<> {
  427. using type = std::index_sequence<>;
  428. };
  429. template <class IntType, IntType... Indices>
  430. struct concat_iseq<std::integer_sequence<IntType, Indices...>> {
  431. using type = std::integer_sequence<IntType, Indices...>;
  432. };
  433. template <
  434. class IntType,
  435. IntType... Head1Indices,
  436. IntType... Head2Indices,
  437. class... TailISeqs>
  438. struct concat_iseq<
  439. std::integer_sequence<IntType, Head1Indices...>,
  440. std::integer_sequence<IntType, Head2Indices...>,
  441. TailISeqs...> {
  442. using type = typename concat_iseq<
  443. std::integer_sequence<IntType, Head1Indices..., Head2Indices...>,
  444. TailISeqs...>::type;
  445. };
  446. template <class... ISeqs>
  447. using concat_iseq_t = typename concat_iseq<ISeqs...>::type;
  448. } // namespace guts
  449. } // namespace c10