VmapInterpreter.h 965 B

12345678910111213141516171819202122232425
  1. #pragma once
  2. #include <ATen/functorch/Interpreter.h>
  3. namespace at { namespace functorch {
  4. // This is the interpreter that handles the functionalize() transform.
  5. // See NOTE: [functorch interpreter stack] for more details.
  6. struct VmapInterpreterPtr {
  7. explicit VmapInterpreterPtr(const Interpreter* base): base_(base) { TORCH_INTERNAL_ASSERT(base->key() == TransformType::Vmap); }
  8. TransformType key() const { return base_->key(); }
  9. int64_t level() const { return base_->level(); }
  10. void processImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  11. void sendToNextInterpreterImpl(const c10::OperatorHandle& op, torch::jit::Stack* stack, bool grad_special_case);
  12. int64_t batchSize() const {
  13. return c10::get<VmapInterpreterMeta>(base_->meta()).batchSize_;
  14. }
  15. RandomnessType randomness() const {
  16. return c10::get<VmapInterpreterMeta>(base_->meta()).randomness_;
  17. }
  18. private:
  19. const Interpreter* base_;
  20. };
  21. }} // namespace at::functorch