CallOnce.h 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #pragma once
  2. #include <atomic>
  3. #include <mutex>
  4. #include <thread>
  5. #include <utility>
  6. #include <c10/macros/Macros.h>
  7. #include <c10/util/C++17.h>
  8. namespace c10 {
  9. // custom c10 call_once implementation to avoid the deadlock in std::call_once.
  10. // The implementation here is a simplified version from folly and likely much
  11. // much higher memory footprint.
  12. template <typename Flag, typename F, typename... Args>
  13. inline void call_once(Flag& flag, F&& f, Args&&... args) {
  14. if (C10_LIKELY(flag.test_once())) {
  15. return;
  16. }
  17. flag.call_once_slow(std::forward<F>(f), std::forward<Args>(args)...);
  18. }
  19. class once_flag {
  20. public:
  21. #ifndef _WIN32
  22. // running into build error on MSVC. Can't seem to get a repro locally so I'm
  23. // just avoiding constexpr
  24. //
  25. // C:/actions-runner/_work/pytorch/pytorch\c10/util/CallOnce.h(26): error:
  26. // defaulted default constructor cannot be constexpr because the
  27. // corresponding implicitly declared default constructor would not be
  28. // constexpr 1 error detected in the compilation of
  29. // "C:/actions-runner/_work/pytorch/pytorch/aten/src/ATen/cuda/cub.cu".
  30. constexpr
  31. #endif
  32. once_flag() noexcept = default;
  33. once_flag(const once_flag&) = delete;
  34. once_flag& operator=(const once_flag&) = delete;
  35. private:
  36. template <typename Flag, typename F, typename... Args>
  37. friend void call_once(Flag& flag, F&& f, Args&&... args);
  38. template <typename F, typename... Args>
  39. void call_once_slow(F&& f, Args&&... args) {
  40. std::lock_guard<std::mutex> guard(mutex_);
  41. if (init_.load(std::memory_order_relaxed)) {
  42. return;
  43. }
  44. c10::guts::invoke(f, std::forward<Args>(args)...);
  45. init_.store(true, std::memory_order_release);
  46. }
  47. bool test_once() {
  48. return init_.load(std::memory_order_acquire);
  49. }
  50. void reset_once() {
  51. init_.store(false, std::memory_order_release);
  52. }
  53. private:
  54. std::mutex mutex_;
  55. std::atomic<bool> init_{false};
  56. };
  57. } // namespace c10