123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485 |
- #pragma once
- #include <c10/util/Array.h>
- #include <c10/util/TypeList.h>
- #include <array>
- #include <functional>
- #include <type_traits>
- namespace c10 {
- namespace guts {
- /**
- * Access information about result type or arguments from a function type.
- * Example:
- * using A = function_traits<int (float, double)>::return_type // A == int
- * using A = function_traits<int (float, double)>::parameter_types::tuple_type
- * // A == tuple<float, double>
- */
- template <class Func>
- struct function_traits {
- static_assert(
- !std::is_same<Func, Func>::value,
- "In function_traits<Func>, Func must be a plain function type.");
- };
- template <class Result, class... Args>
- struct function_traits<Result(Args...)> {
- using func_type = Result(Args...);
- using return_type = Result;
- using parameter_types = typelist::typelist<Args...>;
- static constexpr auto number_of_parameters = sizeof...(Args);
- };
- /**
- * infer_function_traits: creates a `function_traits` type for a simple
- * function (pointer) or functor (lambda/struct). Currently does not support
- * class methods.
- */
- template <typename Functor>
- struct infer_function_traits {
- using type = function_traits<
- c10::guts::detail::strip_class_t<decltype(&Functor::operator())>>;
- };
- template <typename Result, typename... Args>
- struct infer_function_traits<Result (*)(Args...)> {
- using type = function_traits<Result(Args...)>;
- };
- template <typename Result, typename... Args>
- struct infer_function_traits<Result(Args...)> {
- using type = function_traits<Result(Args...)>;
- };
- template <typename T>
- using infer_function_traits_t = typename infer_function_traits<T>::type;
- /**
- * make_function_traits: creates a `function_traits` type given a Return type
- * and a typelist of Argument types
- *
- * Example:
- * bool f(int, int);
- *
- * infer_function_traits_t<f> == make_function_traits_t<bool,
- * typelist::typelist<int, int>>
- */
- template <typename Result, typename ArgList>
- struct make_function_traits {
- static_assert(
- false_t<ArgList>::value,
- "In guts::make_function_traits<Result, TypeList>, the ArgList argument must be typelist<...>.");
- };
- template <typename Result, typename... Args>
- struct make_function_traits<Result, typelist::typelist<Args...>> {
- using type = function_traits<Result(Args...)>;
- };
- template <typename Result, typename ArgList>
- using make_function_traits_t =
- typename make_function_traits<Result, ArgList>::type;
- /**
- * Use extract_arg_by_filtered_index to return the i-th argument whose
- * type fulfills a given type trait. The argument itself is perfectly forwarded.
- *
- * Example:
- * std::string arg1 = "Hello";
- * std::string arg2 = "World";
- * std::string&& result = extract_arg_by_filtered_index<is_string, 1>(0,
- * arg1, 2.0, std::move(arg2));
- *
- * Warning: Taking the result by rvalue reference can cause segfaults because
- * ownership will not be passed on from the original reference. The original
- * reference dies after the expression and the resulting
- */
- namespace detail {
- template <
- template <class>
- class Condition,
- size_t index,
- class Enable,
- class... Args>
- struct extract_arg_by_filtered_index_;
- template <
- template <class>
- class Condition,
- size_t index,
- class Head,
- class... Tail>
- struct extract_arg_by_filtered_index_<
- Condition,
- index,
- std::enable_if_t<!Condition<Head>::value>,
- Head,
- Tail...> {
- static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
- return extract_arg_by_filtered_index_<Condition, index, void, Tail...>::
- call(std::forward<Tail>(tail)...);
- }
- };
- template <
- template <class>
- class Condition,
- size_t index,
- class Head,
- class... Tail>
- struct extract_arg_by_filtered_index_<
- Condition,
- index,
- std::enable_if_t<Condition<Head>::value && index != 0>,
- Head,
- Tail...> {
- static decltype(auto) call(Head&& /*head*/, Tail&&... tail) {
- return extract_arg_by_filtered_index_<Condition, index - 1, void, Tail...>::
- call(std::forward<Tail>(tail)...);
- }
- };
- template <template <class> class Condition, size_t index>
- struct extract_arg_by_filtered_index_<Condition, index, void> {
- static void call() {
- static_assert(
- index != index, "extract_arg_by_filtered_index out of range.");
- }
- };
- template <
- template <class>
- class Condition,
- size_t index,
- class Head,
- class... Tail>
- struct extract_arg_by_filtered_index_<
- Condition,
- index,
- std::enable_if_t<Condition<Head>::value && index == 0>,
- Head,
- Tail...> {
- static decltype(auto) call(Head&& head, Tail&&... /*tail*/) {
- return std::forward<Head>(head);
- }
- };
- } // namespace detail
- template <template <class> class Condition, size_t index, class... Args>
- decltype(auto) extract_arg_by_filtered_index(Args&&... args) {
- static_assert(
- is_type_condition<Condition>::value,
- "In extract_arg_by_filtered_index, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
- return detail::
- extract_arg_by_filtered_index_<Condition, index, void, Args...>::call(
- std::forward<Args>(args)...);
- }
- /**
- * Use filter_map to map a subset of the arguments to values.
- * The subset is defined by type traits, and will be evaluated at compile time.
- * At runtime, it will just loop over the pre-filtered arguments to create an
- * std::array.
- *
- * Example:
- * std::array<double, 2> result = filter_map<double, std::is_integral>([] (auto
- * a) {return (double)a;}, 3, "bla", 4);
- * // result == {3.0, 4.0}
- */
- namespace detail {
- template <class ResultType, size_t num_results>
- struct filter_map_ {
- template <
- template <class>
- class Condition,
- class Mapper,
- class... Args,
- size_t... INDEX>
- static guts::array<ResultType, num_results> call(
- const Mapper& mapper,
- std::index_sequence<INDEX...>,
- Args&&... args) {
- return guts::array<ResultType, num_results>{
- mapper(extract_arg_by_filtered_index<Condition, INDEX>(
- std::forward<Args>(args)...))...};
- }
- };
- template <class ResultType>
- struct filter_map_<ResultType, 0> {
- template <
- template <class>
- class Condition,
- class Mapper,
- class... Args,
- size_t... INDEX>
- static guts::array<ResultType, 0> call(
- const Mapper& /*mapper*/,
- std::index_sequence<INDEX...>,
- Args&&... /*args*/) {
- return guts::array<ResultType, 0>{};
- }
- };
- } // namespace detail
- template <
- class ResultType,
- template <class>
- class Condition,
- class Mapper,
- class... Args>
- decltype(auto) filter_map(const Mapper& mapper, Args&&... args) {
- static_assert(
- is_type_condition<Condition>::value,
- "In filter_map<Result, Condition>, the Condition argument must be a condition type trait, i.e. have a static constexpr bool ::value member.");
- static constexpr size_t num_results =
- typelist::count_if<Condition, typelist::typelist<Args...>>::value;
- return detail::filter_map_<ResultType, num_results>::
- template call<Condition, Mapper, Args...>(
- mapper,
- std::make_index_sequence<num_results>(),
- std::forward<Args>(args)...);
- }
- /**
- * make_offset_index_sequence<Start, N>
- * Like make_index_sequence<N>, but starting from Start instead of 0.
- *
- * Example:
- * make_offset_index_sequence<10, 3> == std::index_sequence<10, 11, 12>
- */
- template <size_t Start, size_t N, size_t... Is>
- struct make_offset_index_sequence_impl
- : make_offset_index_sequence_impl<Start, N - 1, Start + N - 1, Is...> {
- static_assert(
- static_cast<int>(Start) >= 0,
- "make_offset_index_sequence: Start < 0");
- static_assert(static_cast<int>(N) >= 0, "make_offset_index_sequence: N < 0");
- };
- template <size_t Start, size_t... Is>
- struct make_offset_index_sequence_impl<Start, 0, Is...> {
- typedef std::index_sequence<Is...> type;
- };
- template <size_t Start, size_t N>
- using make_offset_index_sequence =
- typename make_offset_index_sequence_impl<Start, N>::type;
- /**
- * Use tuple_elements to extract a position-indexed subset of elements
- * from the argument tuple into a result tuple.
- *
- * Example:
- * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
- * std::tuple<int, double> result = tuple_elements(t, std::index_sequence<0,
- * 2>());
- */
- template <class Tuple, size_t... Is>
- constexpr auto tuple_elements(Tuple t, std::index_sequence<Is...>) {
- return std::tuple<std::tuple_element_t<Is, Tuple>...>(std::get<Is>(t)...);
- }
- /**
- * Use tuple_take to extract the first or last n elements from the argument
- * tuple into a result tuple.
- *
- * Example:
- * std::tuple<int, const char*, double> t = std::make_tuple(0, "HEY", 2.0);
- * std::tuple<int, const char*> first_two = tuple_take<decltype(t), 2>(t);
- * std::tuple<const char*, double> last_two = tuple_take<decltype(t), -2>(t);
- */
- template <class Tuple, int N, class Enable = void>
- struct TupleTake {};
- template <class Tuple, int N>
- struct TupleTake<Tuple, N, std::enable_if_t<N >= 0, void>> {
- static auto call(Tuple t) {
- constexpr size_t size = std::tuple_size<Tuple>();
- static_assert(N <= size, "tuple_take: N > size");
- return tuple_elements(t, std::make_index_sequence<N>{});
- }
- };
- template <class Tuple, int N>
- struct TupleTake < Tuple,
- N, std::enable_if_t<N<0, void>> {
- static auto call(Tuple t) {
- constexpr size_t size = std::tuple_size<Tuple>();
- static_assert(-N <= size, "tuple_take: -N > size");
- return tuple_elements(t, make_offset_index_sequence<size + N, -N>{});
- }
- };
- template <class Tuple, int N>
- auto tuple_take(Tuple t) {
- return TupleTake<Tuple, N>::call(t);
- }
- /**
- * Use tuple_slice to extract a contiguous subtuple from the argument.
- *
- * Example:
- * std::tuple<int, const char*, double, bool> t = std::make_tuple(0,
- * "HEY", 2.0, false); std::tuple<int, const char*> middle_two =
- * tuple_slice<decltype(t), 1, 2>(t);
- */
- template <class Tuple, size_t Start, size_t N>
- constexpr auto tuple_slice(Tuple t) {
- constexpr size_t size = std::tuple_size<Tuple>();
- static_assert(Start + N <= size, "tuple_slice: Start + N > size");
- return tuple_elements(t, make_offset_index_sequence<Start, N>{});
- }
- /**
- * Use tuple_map to run a mapping function over a tuple to get a new tuple.
- *
- * Example 1:
- * auto result = tuple_map(std::tuple<int32_t, int32_t, int32_t>(3, 4, 5), []
- * (int32_t a) -> int16_t {return a+1;});
- * // result == std::tuple<int16_t, int16_t, int16_t>(4, 5, 6)
- *
- * Example 2:
- * struct Mapper {
- * std::string operator()(int32_t a) const {
- * return std::to_string(a);
- * }
- * int64_t operator()(const std::string& a) const {
- * return atoi(a.c_str());
- * }
- * };
- * auto result = tuple_map(std::tuple<int32_t, std::string>(3, "4"),
- * Mapper());
- * // result == std::tuple<std::string, int64_t>("3", 4)
- *
- * Example 3:
- * struct A final {
- * int32_t func() {
- * return 5;
- * }
- * };
- * struct B final {
- * std::string func() {
- * return "5";
- * }
- * };
- * auto result = tuple_map(std::make_tuple(A(), B()), [] (auto a) { return
- * a.func(); });
- * // result == std::tuple<int32_t, std::string>(5, "5");
- */
- namespace detail {
- template <class Mapper, class... Args, size_t... Indices>
- auto tuple_map(
- std::tuple<Args...>&& tuple,
- const Mapper& mapper,
- std::index_sequence<Indices...>) {
- return std::tuple<decltype(mapper(std::forward<Args>(std::get<Indices>(
- tuple))))...>(mapper(std::forward<Args>(std::get<Indices>(tuple)))...);
- }
- } // namespace detail
- template <class Mapper, class... Args>
- auto tuple_map(std::tuple<Args...>&& tuple, const Mapper& mapper) {
- return detail::tuple_map(
- std::move(tuple), mapper, std::index_sequence_for<Args...>());
- }
- /**
- * tuple_concat concatenates several tuples into one.
- */
- namespace detail {
- // extract_tuple_element_by_index is a helper that takes a list of tuples and
- // extracts the i-th element in a flattened view of the tuples. Example:
- // extract_tuple_element_by_index<3>(tuple(2,3), tuple(4,5), tuple(6,7)) == 5.
- template <
- size_t index,
- class HeadTuple,
- class... TailTuples,
- std::enable_if_t<
- index<std::tuple_size<HeadTuple>::value, int> = 0> decltype(auto)
- extract_tuple_element_by_index(
- HeadTuple&& head_tuple,
- TailTuples&&... /*tail_tuples*/) {
- // TODO if constexpr instead of enable_if
- return std::get<index>(std::forward<HeadTuple>(head_tuple));
- }
- template <
- size_t index,
- class HeadTuple,
- class... TailTuples,
- std::enable_if_t<index >= std::tuple_size<HeadTuple>::value, int> = 0>
- decltype(auto) extract_tuple_element_by_index(
- HeadTuple&& /*head_tuple*/,
- TailTuples&&... tail_tuples) {
- // TODO if constexpr instead of enable_if
- return extract_tuple_element_by_index<
- index - std::tuple_size<HeadTuple>::value,
- TailTuples...>(std::forward<TailTuples>(tail_tuples)...);
- }
- static_assert(
- std::is_same<
- int&&,
- decltype(extract_tuple_element_by_index<2>(
- std::tuple<int32_t>(2),
- std::tuple<int32_t&&, int32_t>(std::declval<int32_t>(), 3)))>::
- value,
- "extract_tuple_element_by_index should return rvalue references if the tuple contains them. It should not move them into a value");
- template <class ConcatenatedTuple, class... Tuples, size_t... ElementIndices>
- auto tuple_concat(Tuples&&... tuples, std::index_sequence<ElementIndices...>) {
- return ConcatenatedTuple(extract_tuple_element_by_index<ElementIndices>(
- std::forward<Tuples>(tuples)...)...);
- }
- } // namespace detail
- template <class... Tuples>
- auto tuple_concat(Tuples&&... tuples) {
- using flattened_types =
- guts::typelist::concat_t<guts::typelist::from_tuple_t<Tuples>...>;
- using concatenated_tuple = guts::typelist::to_tuple_t<flattened_types>;
- constexpr size_t num_elements = guts::typelist::size<flattened_types>::value;
- return detail::tuple_concat<concatenated_tuple, Tuples...>(
- std::forward<Tuples>(tuples)...,
- std::make_index_sequence<num_elements>());
- }
- /**
- * Concatenate multiple integer sequences
- * Example:
- * concat_iseq_t<std::index_sequence<2, 5, 3>, std::index_sequence<4, 2>,
- * std::index_sequence<5>>
- * == std::index_sequence<2, 5, 3, 4, 2, 5>
- */
- template <class... ISeqs>
- struct concat_iseq {
- static_assert(
- false_t<ISeqs...>::value,
- "In concat_iseq<T1, ...>, the T arguments each must be std::integer_sequence<...> with the same IntType.");
- };
- template <>
- struct concat_iseq<> {
- using type = std::index_sequence<>;
- };
- template <class IntType, IntType... Indices>
- struct concat_iseq<std::integer_sequence<IntType, Indices...>> {
- using type = std::integer_sequence<IntType, Indices...>;
- };
- template <
- class IntType,
- IntType... Head1Indices,
- IntType... Head2Indices,
- class... TailISeqs>
- struct concat_iseq<
- std::integer_sequence<IntType, Head1Indices...>,
- std::integer_sequence<IntType, Head2Indices...>,
- TailISeqs...> {
- using type = typename concat_iseq<
- std::integer_sequence<IntType, Head1Indices..., Head2Indices...>,
- TailISeqs...>::type;
- };
- template <class... ISeqs>
- using concat_iseq_t = typename concat_iseq<ISeqs...>::type;
- } // namespace guts
- } // namespace c10
|