FunctionalizeInterpreter.h 919 B

12345678910111213141516171819202122
  1. #pragma once
  2. #include <ATen/functorch/Interpreter.h>
  3. namespace at { namespace functorch {
  4. // This is the interpreter that handles the functionalize() transform.
  5. // See NOTE: [functorch interpreter stack] for more details.
  6. struct FunctionalizeInterpreterPtr {
  7. explicit FunctionalizeInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Functionalize); }
  8. TransformType key() const { return base_->key(); }
  9. int64_t level() const { return base_->level(); }
  10. void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  11. void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  12. bool functionalizeAddBackViews() const {
  13. return c10::get<FunctionalizeInterpreterMeta>(base_->meta()).functionalizeAddBackViews_;
  14. }
  15. private:
  16. const Interpreter* base_;
  17. };
  18. }} // namespace at::functorch