// Copyright (c) Facebook, Inc. and its affiliates. // All rights reserved. // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #pragma once #include #include // This file contains template metaprogramming things that are used for our // batching rules. // // See NOTE: [vmap plumbing] for more details on why this is necessary. // The plumbing has a bunch of metaprogramming hacks for determining the signature // of a batching rule from the signature of the operator, many of which use the // helper functions in this file. namespace at { namespace functorch { // Metaprogramming things template using typelist = c10::guts::typelist::typelist; template using head_t = c10::guts::typelist::head_t; template using concat_t = c10::guts::typelist::concat_t; template class debug_t; // tail operation template struct tail final { static_assert(c10::guts::false_t::value, "In typelist::tail, the T argument must be typelist<...>."); }; template struct tail> final { using type = typelist; }; template using tail_t = typename tail::type; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext { using type = Next; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { using type = Tail; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { using type = Tail; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, Next, Tail> { using type = Tail; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, optional, Next, Tail> { using type = Tail; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, optional, Next, Tail> { using type = Tail; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext&, optional, Next, Tail> { using type = Tail; }; template struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext, optional, Next, Tail> { using type = Tail; }; template struct RemoveBatchDimAfterTensor { using first = head_t; using next = tail_t; using second = head_t; using tail = tail_t; using type = concat_t< typelist, typename RemoveBatchDimAfterTensor< typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext::type >::type >; }; template struct RemoveBatchDimAfterTensor> { using type = typelist; }; template <> struct RemoveBatchDimAfterTensor> { using type = typelist<>; }; template using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor::type; template struct UnpackSingleItemTuple { using type = T; }; template struct UnpackSingleItemTuple> { using type = T; }; template using unpack_single_item_tuple_t = typename UnpackSingleItemTuple::type; template struct BuildFunctionHelper; template struct BuildFunctionHelper> { using type = Return(Args...); }; template struct BuildFunction { using type = typename BuildFunctionHelper>::type; }; template using build_function_t = typename BuildFunction::type; template struct ToOperatorType { using batch_rule_return_type = typename c10::guts::function_traits::return_type; using batch_rule_parameter_types = typename c10::guts::function_traits::parameter_types; using operator_parameter_types = remove_batch_dim_after_tensor_t; using operator_return_type = unpack_single_item_tuple_t< c10::guts::typelist::to_tuple_t< remove_batch_dim_after_tensor_t< c10::guts::typelist::from_tuple_t>>>; using type = build_function_t; }; template using to_operator_t = typename ToOperatorType::type; } } // namespace at