LegacyBatchedFallback.h 974 B

12345678910111213141516171819202122232425
  1. #pragma once
  2. #include <ATen/ATen.h>
  3. #include <ATen/core/op_registration/op_registration.h>
  4. #include <torch/library.h>
  5. namespace at {
  6. // If an operator doesn't have a batching rule implemented then we fallback
  7. // to this implementation. The fallback only works on out-of-place operators
  8. // that return only tensors with new memory. (e.g., no in-place operators, no
  9. // view operations).
  10. //
  11. // The fallback effectively takes all of the BatchedTensors in `stack`, slices
  12. // them, and runs `op` on all of the corresponding slices to produce slices
  13. // of the outputs. The output slices then get `torch.stack`ed to create the
  14. // final returns.
  15. //
  16. // The performance of the fallback is not very good because it introduces an
  17. // extra copy from stacking the sliced outputs. Because of this, we prefer to
  18. // write batching rules for operators whenever possible.
  19. void batchedTensorForLoopFallback(
  20. const c10::OperatorHandle& op,
  21. torch::jit::Stack* stack);
  22. } // namespace at