ADInterpreters.h 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. #pragma once
  2. #include <ATen/functorch/Interpreter.h>
  3. namespace at { namespace functorch {
  4. // These are the interpreters for our AD transforms
  5. // (grad, vjp and jvp).
  6. // See NOTE: [functorch interpreter stack] for more details.
  7. struct TORCH_API GradInterpreterPtr {
  8. explicit GradInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Grad); }
  9. TransformType key() const { return base_->key(); }
  10. int64_t level() const { return base_->level(); }
  11. void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  12. void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  13. bool prevGradMode() const {
  14. return c10::get<GradInterpreterMeta>(base_->meta()).prevGradMode_;
  15. }
  16. Tensor lift(const Tensor& tensor) const;
  17. private:
  18. const Interpreter* base_;
  19. };
  20. struct TORCH_API JvpInterpreterPtr {
  21. explicit JvpInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Jvp); }
  22. TransformType key() const { return base_->key(); }
  23. int64_t level() const { return base_->level(); }
  24. void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  25. void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  26. bool prevFwdGradMode() const {
  27. return c10::get<JvpInterpreterMeta>(base_->meta()).prevFwdGradMode_;
  28. }
  29. Tensor lift(const Tensor& tensor) const;
  30. private:
  31. const Interpreter* base_;
  32. };
  33. }} // namespace at::functorch