Interpreter.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. #pragma once
  2. #include <ATen/functorch/Macros.h>
  3. #include <ATen/core/dispatch/Dispatcher.h>
  4. #include <c10/core/impl/LocalDispatchKeySet.h>
  5. #include <c10/util/Optional.h>
  6. #include <c10/util/variant.h>
  7. #include <bitset>
  8. namespace at { namespace functorch {
  9. // NOTE: [functorch interpreter stack]
  10. //
  11. // functorch's dispatching system uses a stack of interpreters.
  12. // Historically we've referred to this as the "DynamicLayerStack".
  13. //
  14. // An interpreter is something that reads in the code it is passed
  15. // and then executes it. We have a different interpreter per-transform:
  16. // the "VmapInterpreter" is responsible for reading in operators (like aten::mv)
  17. // and executing the batched version of it (the batching rule for aten::mv).
  18. //
  19. // Concretely, each interpreter is responsible for two things:
  20. //
  21. // 1) process(ophandle, stack)
  22. // Given an operator handle and a stack of arguments, the interpreter is
  23. // responsible for figuring out how to execute the operation under the semantics
  24. // of the interpreter. For e.g. VmapInterpreter, this is figuring out how to call
  25. // the batching rule.
  26. //
  27. // The batching rules are stored as kernels on the FuncTorchBatched key, so the way
  28. // VmapInterpreter calls the batching rule is roughly: (A) exclude all
  29. // dispatch keys aside from the Batched key, (B) redispatch so we get to the
  30. // Batched key.
  31. //
  32. // 2) sendToNextInterpreter(ophandle, stack)
  33. // The VmapInterpreter, when it sees aten::mv, will process it into a call to
  34. // aten::mm. It then needs to send the call to aten::mm to the next interpreter
  35. // in the interpreter stack.
  36. //
  37. // The VmapInterpreter just does this via a call to ophandle.callBoxed(stack)
  38. // and most Interpreters will implement it this way.
  39. enum RandomnessType {
  40. Error, // always errors when calling a random function
  41. Same, // randomness appears the same across batches
  42. Different, // randomness appears different across batches
  43. END
  44. };
  45. enum class TransformType {
  46. Torch, // Unused
  47. Vmap,
  48. Grad, // reverse-mode AD, aka vjp
  49. Jvp, // forward-mode AD
  50. Functionalize,
  51. };
  52. std::ostream& operator<<(std::ostream& os, const TransformType& t);
  53. // NOTE: [Interpreter "subclassing" design]
  54. //
  55. // How are various Interpreters for different transforms (vmap, grad, ...)
  56. // implemented?
  57. //
  58. // Accessing interpreters is in the hot-path of functorch so we have a constraint
  59. // that this code must be as fast as possible.
  60. //
  61. // As a result, we stay away from virtual methods and this causes our code
  62. // to look a little funny.
  63. //
  64. // `Interpreter` is the struct for Interpreters. It holds ALL of the
  65. // relevant information (what type of interpreter it is and the metadata).
  66. // Metadata for each interpreter is represented as a Union (c10::variant)
  67. // of all possible metadata (VmapInterpreterMeta, GradInterpreterMeta, ...).
  68. //
  69. // Given an Interpreter, how do I get a "VmapInterpreter"? You may wish to do this
  70. // if you want to access the metadata fields (like batchSize and randomness).
  71. //
  72. // Each type of interpreter (e.g. Vmap) has a convenience struct
  73. // (e.g. VmapInterpreterPtr) associated with it.
  74. //
  75. // Construct the convenience struct with VmapInterpreterPtr(Interpreter*),
  76. // and then one can access methods on VmapInterpreterPtr like so:
  77. // >>> VmapInterpreterPtr(&interpreter).batchSize()
  78. //
  79. // Finally, Interpreter::process switches on the type of the interpreter
  80. // and calls one of {Transform}Intepreter::processImpl under the hood.
  81. // Same for Interpreter::sendToNextInterpreter :)
  82. struct VmapInterpreterMeta {
  83. explicit VmapInterpreterMeta(int64_t batchSize, RandomnessType randomness) :
  84. batchSize_(batchSize), randomness_(randomness) {}
  85. int64_t batchSize_;
  86. RandomnessType randomness_;
  87. };
  88. struct GradInterpreterMeta {
  89. explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {}
  90. bool prevGradMode_;
  91. };
  92. struct JvpInterpreterMeta {
  93. explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {}
  94. bool prevFwdGradMode_;
  95. };
  96. struct FunctionalizeInterpreterMeta {
  97. explicit FunctionalizeInterpreterMeta(bool functionalizeAddBackViews) :
  98. functionalizeAddBackViews_(functionalizeAddBackViews) {}
  99. bool functionalizeAddBackViews_;
  100. };
  101. typedef c10::variant<
  102. int64_t,
  103. GradInterpreterMeta,
  104. JvpInterpreterMeta,
  105. VmapInterpreterMeta,
  106. FunctionalizeInterpreterMeta
  107. > InterpreterMeta;
  108. struct Interpreter {
  109. // factory functions
  110. static Interpreter Vmap(int64_t level, int64_t batchSize, RandomnessType randomness) {
  111. return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(batchSize, randomness));
  112. }
  113. static Interpreter Grad(int64_t level, bool prevGradMode) {
  114. return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode));
  115. }
  116. static Interpreter Jvp(int64_t level, bool prevFwdGradMode) {
  117. return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode));
  118. }
  119. static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) {
  120. return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews));
  121. }
  122. // methods
  123. TransformType key() const { return type_; }
  124. int64_t level() const { return level_; }
  125. const InterpreterMeta& meta() const { return meta_; }
  126. void process(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  127. void sendToNextInterpreter(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  128. void saveLocalDispatchKeySet(c10::impl::LocalDispatchKeySet keyset) {
  129. TORCH_INTERNAL_ASSERT(!savedLocalDispatchKeySet_.has_value());
  130. savedLocalDispatchKeySet_ = std::move(keyset);
  131. }
  132. void clearSavedLocalDispatchKeySet() {
  133. TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
  134. savedLocalDispatchKeySet_ = c10::nullopt;
  135. }
  136. c10::impl::LocalDispatchKeySet getSavedLocalDispatchKeySet() const {
  137. TORCH_INTERNAL_ASSERT(savedLocalDispatchKeySet_.has_value());
  138. return *savedLocalDispatchKeySet_;
  139. }
  140. // An Interpreter is alive if we are currently inside the ongoing transform
  141. // for the interpreter. For example, vmap(f)(x); inside of f, the vmap's
  142. // corresponding Interpreter is alive, even when it is not on the DynamicLayerStack.
  143. bool is_alive() const {
  144. return *is_alive_;
  145. }
  146. const std::shared_ptr<bool>& is_alive_ptr() const {
  147. return is_alive_;
  148. }
  149. void set_is_alive(bool alive) {
  150. *is_alive_ = alive;
  151. }
  152. // Please don't use this
  153. explicit Interpreter() = default;
  154. private:
  155. explicit Interpreter(TransformType type, int64_t level, InterpreterMeta meta):
  156. type_(type), level_(level), is_alive_(std::make_shared<bool>(false)), meta_(meta) {}
  157. // fields
  158. TransformType type_;
  159. int64_t level_;
  160. optional<c10::impl::LocalDispatchKeySet> savedLocalDispatchKeySet_;
  161. std::shared_ptr<bool> is_alive_;
  162. InterpreterMeta meta_;
  163. };
  164. // Applies the following for-loop:
  165. // for i in range(begin, end):
  166. // args[i] = func(args[i])
  167. void foreachTensorInplace(std::vector<IValue>& args, int64_t begin, int64_t end,
  168. std::function<Tensor(const Tensor&)> func);
  169. // Applies the following for-loop:
  170. // for i in range(begin, end):
  171. // if use_flag_relative[i] == 1: <-- treats use_flag_relative as a bitset
  172. // args[i] = func(args[i], i - begin, true)
  173. // args[i] = func(args[i], i - begin)
  174. void foreachTensorInplaceWithFlag(std::vector<IValue>& args, int64_t begin, int64_t end,
  175. const std::bitset<64> use_flag_relative, std::function<Tensor(const Tensor&, bool)> func);
  176. std::vector<int64_t> findUnwrappedInputs(std::vector<IValue>& args, int64_t begin, int64_t end);
  177. DispatchKeySet keysToExcludeWhenEnteringDynamicLayer(TransformType key);
  178. void setup_dispatch_key_tls(DispatchKeySet exclude, DispatchKeySet include);
  179. void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  180. }} // namespace at::functorch