DynamicLayer.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. // Copyright (c) Facebook, Inc. and its affiliates.
  2. // All rights reserved.
  3. //
  4. // This source code is licensed under the BSD-style license found in the
  5. // LICENSE file in the root directory of this source tree.
  6. #pragma once
  7. #include <ATen/functorch/Macros.h>
  8. #include <c10/core/DispatchKey.h>
  9. #include <ATen/core/function_schema.h>
  10. #include <c10/util/Optional.h>
  11. #include <c10/util/variant.h>
  12. #include <unordered_map>
  13. #include <mutex>
  14. #include <c10/core/impl/LocalDispatchKeySet.h>
  15. #include <ATen/functorch/Interpreter.h>
  16. #include <ATen/functorch/VmapInterpreter.h>
  17. #include <ATen/functorch/ADInterpreters.h>
  18. #include <ATen/functorch/FunctionalizeInterpreter.h>
  19. // Forward declared
  20. namespace c10 { struct AutogradMetaInterface; }
  21. namespace at {
  22. namespace functorch {
  23. // This file contains the implementation of functorch's interpreter stack.
  24. // See NOTE: [functorch interpreter stack] first before reading on.
  25. //
  26. // NB: the functorch interpreter stack is also referred to as:
  27. // - the "dynamic layer stack" -- an older name for "interpreter" was
  28. // "dynamic layer".
  29. // - the "functorch mode stack". You can think of each functorch transform as a
  30. // "mode" (in the same sense as torch_dispatch mode or torch_function mode),
  31. // and functorch being an implementation of a "mode stack" where the modes
  32. // may be arbitrary composed.
  33. // DynamicLayer is basically the same thing as an Interpreter.
  34. // It represents a functorch transform and it holds an Interpreter,
  35. // which contains metadata related to the transform and instructions on
  36. // how to perform the transform.
  37. //
  38. // TODO: we can excise DynamicLayer in favor of Interpreter,
  39. // But I am going to leave it for now as a compatiblity shim to avoid
  40. // needing to refactor a lot of callsites...
  41. struct TORCH_API DynamicLayer {
  42. explicit DynamicLayer(
  43. TransformType transform_type,
  44. int64_t layerId,
  45. optional<int64_t> batchSize = nullopt,
  46. optional<RandomnessType> randomness = nullopt,
  47. optional<bool> prev_grad_mode = nullopt,
  48. optional<bool> pre_fwd_grad_mode = nullopt,
  49. optional<bool> functionalize_add_back_views = nullopt);
  50. TransformType key() const;
  51. int64_t layerId() const;
  52. const Interpreter& interpreter() const { return interpreter_; }
  53. Interpreter& interpreter() { return interpreter_; }
  54. // Only valid for vmap
  55. int64_t batchSize() const;
  56. RandomnessType randomness() const;
  57. private:
  58. Interpreter interpreter_;
  59. };
  60. TORCH_API int64_t initAndPushDynamicLayer(
  61. TransformType transform_type,
  62. optional<int64_t> batch_size = nullopt,
  63. optional<RandomnessType> randomness = nullopt,
  64. optional<bool> prev_grad_mode = nullopt,
  65. optional<bool> prev_fwd_grad_mode = nullopt,
  66. optional<bool> functionalize_add_back_views = nullopt);
  67. TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata();
  68. TORCH_API c10::optional<DynamicLayer> maybeCurrentDynamicLayer();
  69. TORCH_API const std::vector<DynamicLayer>& getDynamicLayerStack();
  70. TORCH_API void setDynamicLayerStack(const std::vector<DynamicLayer>& stack);
  71. TORCH_API void setDynamicLayerFrontBackKeysIncluded(bool included);
  72. // NOTE: [Life handles and lexically scoped transforms]
  73. // functorch transforms are lexically scoped.
  74. // Given a level, we store a "life handle" that is a boolean that tells us if the
  75. // transform with that level is active or not.
  76. //
  77. // functorch's TensorWrapper (for grad transforms) stores a life handle.
  78. // If a TensorWrapper escapes from the scope of the transform, then somehow
  79. // it must know it escaped; it can tell by querying the life handle.
  80. TORCH_API const std::shared_ptr<bool>& getLifeHandleForLevel(int64_t level);
  81. // Returns if an operator is in-place. An operator is inplace if:
  82. // 1. The first argument is a Tensor and it is being written to
  83. // 2. The first argument is being returned
  84. // 3. No other arguments are aliased
  85. // Here is an example of an in-place operator:
  86. // add_(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
  87. TORCH_API bool isInplaceOp(const c10::FunctionSchema& schema);
  88. // Given the indices of unwrapped inputs and the schema, this returns the indices of any outputs that should remain unwrapped
  89. TORCH_API c10::optional<size_t> findAliasedOutput(const FunctionSchema& schema, const int64_t immutable_input);
  90. TORCH_API Tensor unwrapIfDead(const Tensor& tensor);
  91. TORCH_API bool isDeadTensorWrapper(const Tensor& tensor);
  92. // Pretty printers
  93. TORCH_API std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
  94. TORCH_API std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
  95. // While a functorch transform is active, torch.autograd.function._SingleLevelFunction
  96. // is disabled by default. The following two APIs are APIs for enabling
  97. // it. These are not user-facing APIs. We can delete this in the future, but
  98. // it is useful for debugging when something goes wrong with the
  99. // autograd.Function <> functorch interaction, which uses _SingleLevelFunction,
  100. // because it leads to loud errors if something is incorrect.
  101. TORCH_API void setSingleLevelAutogradFunctionAllowed(bool allowed);
  102. TORCH_API bool getSingleLevelAutogradFunctionAllowed();
  103. // While a functorch grad transform is active, Tensor.requires_grad_() gets
  104. // disabled. These two functions are the mechanism to controlling that.
  105. TORCH_API void setInplaceRequiresGradAllowed(bool allowed);
  106. TORCH_API bool getInplaceRequiresGradAllowed();
  107. TORCH_API DynamicLayer popDynamicLayer();
  108. TORCH_API int64_t pushDynamicLayer(DynamicLayer&& layer);
  109. }
  110. } // namespace at