ThreadLocalState.h 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #pragma once
  2. #include <stack>
  3. #include <c10/core/InferenceMode.h>
  4. #include <c10/core/impl/LocalDispatchKeySet.h>
  5. #include <c10/util/Exception.h>
  6. #include <c10/util/ThreadLocalDebugInfo.h>
  7. #include <ATen/FuncTorchTLS.h>
  8. #include <ATen/PythonTorchFunctionTLS.h>
  9. #include <ATen/SavedTensorHooks.h>
  10. #include <ATen/ThreadLocalPythonObjects.h>
  11. #include <ATen/record_function.h>
  12. #include <c10/core/impl/PythonDispatcherTLS.h>
  13. #include <c10/core/impl/TorchDispatchModeTLS.h>
  14. namespace at {
  15. // Thread local state contains values that are preserved across
  16. // thread boundaries (e.g. at::launch/JIT fork, autograd).
  17. // Note at::parallel_for doesn't preserve TLS across thread boundaries.
  18. class TORCH_API ThreadLocalState {
  19. public:
  20. // Saves the thread local variables' values and
  21. // returns them as a ThreadLocalState
  22. ThreadLocalState();
  23. // set_grad_mode - force the value of the grad mode TLS in
  24. // the current state object. This is used for example in the
  25. // autograd engine.
  26. void set_grad_mode(bool enabled);
  27. // set_multithreading_enabled - force the value of the multithreadinmaximum
  28. // threads TLS in
  29. // the current state object. This is used for example in the
  30. // autograd engine.
  31. void set_multithreading_enabled(bool enabled);
  32. // Sets thread local variables in the current thread,
  33. // according to the thread boundary specified
  34. static void setThreadLocalState(const ThreadLocalState& state);
  35. private:
  36. c10::impl::LocalDispatchKeySet dispatch_key_;
  37. // ThreadLocalDebugInfo does not change after being created
  38. // with DebugInfoGuard
  39. std::shared_ptr<c10::ThreadLocalDebugInfo> debug_info_;
  40. // RecordFunction TLS
  41. RecordFunctionTLS rf_tls_;
  42. // TLS for out-of-tree functorch
  43. // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a
  44. // pointer (spoiler alert: it's due to the indirection)
  45. // This needs to be a shared_ptr instead of a unique_ptr because
  46. // ThreadLocalState is copy-able and does indeed get copied. Maybe we can
  47. // consider adding an explicit copy constructor for ThreadLocalState in the
  48. // future but I didn't want to add one just for this.
  49. std::shared_ptr<const functorch::FuncTorchTLSBase> functorch_tls_;
  50. // TLS for AutogradModes
  51. AutogradState autograd_tls_;
  52. // TLS for enable_torch_dispatch_mode
  53. c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_;
  54. // TLS for enable_python_dispatcher
  55. c10::impl::PyInterpreter* python_dispatcher_state_;
  56. // TLS for __torch_function__ (mode and disable_torch_function)
  57. at::impl::PythonTorchFunctionTLS python_torch_function_state_;
  58. // TLS for saved tensors default hooks
  59. at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
  60. bool functionalization_reapply_views_state_;
  61. // TLS for arbitrary python objects that is registered via hooks
  62. at::impl::ThreadLocalPythonObjects saved_objects_;
  63. friend class ThreadLocalStateGuard;
  64. };
  65. // Guard to set and reset the thread local state
  66. class TORCH_API ThreadLocalStateGuard {
  67. public:
  68. explicit ThreadLocalStateGuard(const ThreadLocalState& state)
  69. : prev_state_(ThreadLocalState()) {
  70. // set the given state across the thread boundary
  71. ThreadLocalState::setThreadLocalState(state);
  72. }
  73. ~ThreadLocalStateGuard() {
  74. // restore previously set variables
  75. ThreadLocalState::setThreadLocalState(prev_state_);
  76. }
  77. private:
  78. const ThreadLocalState prev_state_;
  79. };
  80. template <typename T>
  81. auto wrapPropagateTLSState(T callback) {
  82. return [tls_state = ThreadLocalState(),
  83. callback = std::move(callback)](auto&&... args) {
  84. ThreadLocalStateGuard g(tls_state);
  85. // Propagate value returned by callback().
  86. return callback(std::forward<decltype(args)>(args)...);
  87. };
  88. }
  89. } // namespace at