ThreadLocalDebugInfo.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #pragma once
  2. #include <c10/macros/Export.h>
  3. #include <memory>
  4. #include <string>
  5. namespace c10 {
  6. enum class C10_API_ENUM DebugInfoKind : uint8_t {
  7. PRODUCER_INFO = 0,
  8. MOBILE_RUNTIME_INFO,
  9. PROFILER_STATE,
  10. INFERENCE_CONTEXT, // for inference usage
  11. PARAM_COMMS_INFO,
  12. TEST_INFO, // used only in tests
  13. TEST_INFO_2, // used only in tests
  14. };
  15. class C10_API DebugInfoBase {
  16. public:
  17. DebugInfoBase() = default;
  18. virtual ~DebugInfoBase() = default;
  19. };
  20. // Thread local debug information is propagated across the forward
  21. // (including async fork tasks) and backward passes and is supposed
  22. // to be utilized by the user's code to pass extra information from
  23. // the higher layers (e.g. model id) down to the lower levels
  24. // (e.g. to the operator observers used for debugging, logging,
  25. // profiling, etc)
  26. class C10_API ThreadLocalDebugInfo {
  27. public:
  28. static DebugInfoBase* get(DebugInfoKind kind);
  29. // Get current ThreadLocalDebugInfo
  30. static std::shared_ptr<ThreadLocalDebugInfo> current();
  31. // Internal, use DebugInfoGuard/ThreadLocalStateGuard
  32. static void _forceCurrentDebugInfo(
  33. std::shared_ptr<ThreadLocalDebugInfo> info);
  34. // Push debug info struct of a given kind
  35. static void _push(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
  36. // Pop debug info, throws in case the last pushed
  37. // debug info is not of a given kind
  38. static std::shared_ptr<DebugInfoBase> _pop(DebugInfoKind kind);
  39. // Peek debug info, throws in case the last pushed debug info is not of the
  40. // given kind
  41. static std::shared_ptr<DebugInfoBase> _peek(DebugInfoKind kind);
  42. private:
  43. std::shared_ptr<DebugInfoBase> info_;
  44. DebugInfoKind kind_;
  45. std::shared_ptr<ThreadLocalDebugInfo> parent_info_;
  46. friend class DebugInfoGuard;
  47. };
  48. // DebugInfoGuard is used to set debug information,
  49. // ThreadLocalDebugInfo is semantically immutable, the values are set
  50. // through the scope-based guard object.
  51. // Nested DebugInfoGuard adds/overrides existing values in the scope,
  52. // restoring the original values after exiting the scope.
  53. // Users can access the values through the ThreadLocalDebugInfo::get() call;
  54. class C10_API DebugInfoGuard {
  55. public:
  56. DebugInfoGuard(DebugInfoKind kind, std::shared_ptr<DebugInfoBase> info);
  57. explicit DebugInfoGuard(std::shared_ptr<ThreadLocalDebugInfo> info);
  58. ~DebugInfoGuard();
  59. DebugInfoGuard(const DebugInfoGuard&) = delete;
  60. DebugInfoGuard(DebugInfoGuard&&) = delete;
  61. private:
  62. bool active_ = false;
  63. std::shared_ptr<ThreadLocalDebugInfo> prev_info_ = nullptr;
  64. };
  65. } // namespace c10