SavedTensorHooks.h 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. #pragma once
  2. #include <c10/macros/Export.h>
  3. #include <c10/util/Optional.h>
  4. #include <c10/util/python_stub.h>
  5. #include <stack>
  6. #include <string>
  7. #include <utility>
  8. namespace at {
  9. namespace impl {
  10. struct TORCH_API SavedTensorDefaultHooksTLS {
  11. // PyObject is defined in c10/util/python_stub.h
  12. std::stack<std::pair<PyObject*, PyObject*>> stack;
  13. // See NOTE: [Disabling SavedTensorDefaultHooks] for context
  14. // NOTE: [disabled_error_message invariant]
  15. // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
  16. // We did this for efficiency (so we didn't have to keep a separate bool
  17. // around)
  18. c10::optional<std::string> disabled_error_message;
  19. };
  20. } // namespace impl
  21. struct TORCH_API SavedTensorDefaultHooks {
  22. static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
  23. static void pop_hooks();
  24. static std::pair<PyObject*, PyObject*> get_hooks();
  25. static void lazy_initialize();
  26. static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
  27. static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
  28. static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
  29. static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
  30. // NOTE: [Disabling SavedTensorDefaultHooks]
  31. // A developer of a PyTorch feature may choose to disable SavedTensorDefault
  32. // hooks, especially if their feature does not work with it. If they are
  33. // disabled, then the following will raise an error:
  34. // - Attempting to push_hooks
  35. // - calling disable(message) with a non-zero stack (from get_stack) size
  36. static void disable(const std::string& error_message);
  37. static void enable();
  38. static bool is_enabled();
  39. static const c10::optional<std::string>& get_disabled_error_message();
  40. };
  41. } // namespace at