BatchingMetaprogramming.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the BSD-style license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <ATen/Tensor.h>
  8. #include <ATen/VmapGeneratedPlumbing.h>
  9. // This file contains template metaprogramming things that are used for our
  10. // batching rules.
  11. //
  12. // See NOTE: [vmap plumbing] for more details on why this is necessary.
  13. // The plumbing has a bunch of metaprogramming hacks for determining the signature
  14. // of a batching rule from the signature of the operator, many of which use the
  15. // helper functions in this file.
  16. namespace at {
  17. namespace functorch {
  18. // Metaprogramming things
  19. template <class... Items> using typelist = c10::guts::typelist::typelist<Items...>;
  20. template <class TypeList> using head_t = c10::guts::typelist::head_t<TypeList>;
  21. template <class TL1, class TL2> using concat_t = c10::guts::typelist::concat_t<TL1, TL2>;
  22. template <typename T> class debug_t;
  23. // tail operation
  24. template<class TypeList>
  25. struct tail final {
  26. static_assert(c10::guts::false_t<TypeList>::value,
  27. "In typelist::tail<T>, the T argument must be typelist<...>.");
  28. };
  29. template<class Head, class... Tail>
  30. struct tail<typelist<Head, Tail...>> final {
  31. using type = typelist<Tail...>;
  32. };
  33. template<class TypeList> using tail_t = typename tail<TypeList>::type;
  34. template <class First, class Second, class Next, class Tail>
  35. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext {
  36. using type = Next;
  37. };
  38. template <class Next, class Tail>
  39. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor, optional<int64_t>, Next, Tail> {
  40. using type = Tail;
  41. };
  42. template <class Next, class Tail>
  43. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const Tensor&, optional<int64_t>, Next, Tail> {
  44. using type = Tail;
  45. };
  46. template <class Next, class Tail>
  47. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<Tensor&, optional<int64_t>, Next, Tail> {
  48. using type = Tail;
  49. };
  50. template <class Next, class Tail>
  51. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>, optional<int64_t>, Next, Tail> {
  52. using type = Tail;
  53. };
  54. template <class Next, class Tail>
  55. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<const optional<Tensor>&, optional<int64_t>, Next, Tail> {
  56. using type = Tail;
  57. };
  58. template <class Next, class Tail>
  59. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<optional<Tensor>&, optional<int64_t>, Next, Tail> {
  60. using type = Tail;
  61. };
  62. template <class Next, class Tail>
  63. struct IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<std::vector<Tensor>, optional<int64_t>, Next, Tail> {
  64. using type = Tail;
  65. };
  66. template <class TypeList> struct RemoveBatchDimAfterTensor {
  67. using first = head_t<TypeList>;
  68. using next = tail_t<TypeList>;
  69. using second = head_t<next>;
  70. using tail = tail_t<next>;
  71. using type = concat_t<
  72. typelist<first>,
  73. typename RemoveBatchDimAfterTensor<
  74. typename IfFirstIsTensorAndSecondisBatchDimThenTailElseNext<first, second, next, tail>::type
  75. >::type
  76. >;
  77. };
  78. template <class Type> struct RemoveBatchDimAfterTensor<typelist<Type>> {
  79. using type = typelist<Type>;
  80. };
  81. template <> struct RemoveBatchDimAfterTensor<typelist<>> {
  82. using type = typelist<>;
  83. };
  84. template<class TypeList> using remove_batch_dim_after_tensor_t = typename RemoveBatchDimAfterTensor<TypeList>::type;
  85. template <typename T> struct UnpackSingleItemTuple {
  86. using type = T;
  87. };
  88. template <typename T> struct UnpackSingleItemTuple<std::tuple<T>> {
  89. using type = T;
  90. };
  91. template <typename T> using unpack_single_item_tuple_t = typename UnpackSingleItemTuple<T>::type;
  92. template <typename Return, typename TupleArgs> struct BuildFunctionHelper;
  93. template <typename Return, typename... Args> struct BuildFunctionHelper<Return, std::tuple<Args...>> {
  94. using type = Return(Args...);
  95. };
  96. template <typename Return, typename TL>
  97. struct BuildFunction {
  98. using type = typename BuildFunctionHelper<Return, c10::guts::typelist::to_tuple_t<TL>>::type;
  99. };
  100. template <typename Return, typename TL> using build_function_t = typename BuildFunction<Return, TL>::type;
  101. template <typename batch_rule_t> struct ToOperatorType {
  102. using batch_rule_return_type = typename c10::guts::function_traits<batch_rule_t>::return_type;
  103. using batch_rule_parameter_types = typename c10::guts::function_traits<batch_rule_t>::parameter_types;
  104. using operator_parameter_types = remove_batch_dim_after_tensor_t<batch_rule_parameter_types>;
  105. using operator_return_type =
  106. unpack_single_item_tuple_t<
  107. c10::guts::typelist::to_tuple_t<
  108. remove_batch_dim_after_tensor_t<
  109. c10::guts::typelist::from_tuple_t<batch_rule_return_type>>>>;
  110. using type = build_function_t<operator_return_type, operator_parameter_types>;
  111. };
  112. template <typename batch_rule_t> using to_operator_t = typename ToOperatorType<batch_rule_t>::type;
  113. }
  114. } // namespace at