BatchedFallback.h 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the BSD-style license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <ATen/ATen.h>
  8. #include <ATen/core/op_registration/op_registration.h>
  9. #include <torch/library.h>
  10. namespace at {
  11. namespace functorch {
  12. // This file contains code for the vmap fallback (also known as the
  13. // BatchedTensor fallback or the Batched fallback). This code runs
  14. // when an operation doesn't have a batching rule implemented.
  15. // If an operator doesn't have a batching rule implemented then we fallback
  16. // to this implementation. The fallback doesn't work on out= variants or
  17. // view operations; that is, it works for out-of-place operations and
  18. // in-place non-view operations.
  19. //
  20. // For out-of-place operations, the fallback effectively takes all of the
  21. // BatchedTensors in `stack`, slices them, and runs `op` on all of the
  22. // corresponding slices to produce slices of the outputs. The output slices
  23. // then get `torch.stack`ed to create the
  24. // final returns.
  25. //
  26. // The performance of the fallback is not very good because it introduces an
  27. // extra copy from stacking the sliced outputs. Because of this, we prefer to
  28. // write batching rules for operators whenever possible.
  29. void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  30. void vmapErrorFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
  31. // The vmap fallback emits a warning by default, but it may be disabled if
  32. // the user finds it to be too annoying.
  33. TORCH_API bool isVmapFallbackWarningEnabled();
  34. TORCH_API void setVmapFallbackWarningEnabled(bool enabled);
  35. // Used for testing. The vmap fallback is enabled by default. When it is disabled,
  36. // it raises an error.
  37. TORCH_API bool isVmapFallbackEnabled();
  38. TORCH_API void setVmapFallbackEnabled(bool enabled);
  39. template <typename A> A vector_to_result(const std::vector<IValue>& buffer) {
  40. return buffer[0].to<A>();
  41. }
  42. template <typename A, typename B> std::tuple<A, B> vector_to_result(const std::vector<IValue>& buffer) {
  43. return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>());
  44. }
  45. template <typename A, typename B, typename C> std::tuple<A, B, C> vector_to_result(const std::vector<IValue>& buffer) {
  46. return std::make_tuple(buffer[0].to<A>(), buffer[1].to<B>(), buffer[2].to<B>());
  47. }
  48. // slow_fallback is a way to call the vmap fallback inside some boxed kernel.
  49. // There is probably some better way to metaprogram this.
  50. template <typename Ret>
  51. Ret slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
  52. std::vector<IValue> stack(args.begin(), args.end());
  53. batchedTensorForLoopFallback(op, &stack);
  54. return vector_to_result<Ret>(stack);
  55. }
  56. template <typename A, typename B>
  57. std::tuple<A, B> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
  58. std::vector<IValue> stack(args.begin(), args.end());
  59. batchedTensorForLoopFallback(op, &stack);
  60. return vector_to_result<A, B>(stack);
  61. }
  62. template <typename A, typename B, typename C>
  63. std::tuple<A, B, C> slow_fallback(const c10::OperatorHandle& op, ArrayRef<IValue> args) {
  64. std::vector<IValue> stack(args.begin(), args.end());
  65. batchedTensorForLoopFallback(op, &stack);
  66. return vector_to_result<A, B, C>(stack);
  67. }
  68. }
  69. } // namespace at