FuncTorchTLS.h 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. #pragma once
  2. #include <c10/macros/Macros.h>
  3. #include <memory>
  4. namespace at {
  5. namespace functorch {
  6. // NOTE [functorch TLS in pytorch/pytorch]
  7. //
  8. // functorch lives out-of-tree. However, it has some TLS that needs to be
  9. // propagated. The solution for that is we store a pointer to the TLS
  10. // inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to
  11. // include whatever functorch needs.
  12. //
  13. // We need to store a pointer due to the indirection:
  14. // inside functorch, we will create a subclass of FunctorchTLSBase called
  15. // FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack.
  16. // FuncTorchTLSBase doesn't have any metadata because it hasn't been defined
  17. // yet.
  18. //
  19. // Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside
  20. // functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*.
  21. // We can't directly pass around FunctorchTLSBase (without a pointer) because
  22. // FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having
  23. // more elements.
  24. struct TORCH_API FuncTorchTLSBase {
  25. virtual ~FuncTorchTLSBase() = default;
  26. virtual std::unique_ptr<FuncTorchTLSBase> deepcopy() const = 0;
  27. virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0;
  28. virtual void checkSupportsInplaceRequiresGrad() const = 0;
  29. virtual void checkSupportsRetainGrad() const = 0;
  30. };
  31. // returns deepcopy of the functorch tls
  32. TORCH_API std::unique_ptr<FuncTorchTLSBase> getCopyOfFuncTorchTLS();
  33. // sets the functorch tls. always does a deep copy.
  34. TORCH_API void setFuncTorchTLS(
  35. const std::shared_ptr<const FuncTorchTLSBase>& state);
  36. // get a mutable reference to the functorch tls
  37. TORCH_API std::unique_ptr<FuncTorchTLSBase>& functorchTLSAccessor();
  38. } // namespace functorch
  39. } // namespace at