cxx11_eventcount.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. // This file is part of Eigen, a lightweight C++ template library
  2. // for linear algebra.
  3. //
  4. // Copyright (C) 2016 Dmitry Vyukov <dvyukov@google.com>
  5. // Copyright (C) 2016 Benoit Steiner <benoit.steiner.goog@gmail.com>
  6. //
  7. // This Source Code Form is subject to the terms of the Mozilla
  8. // Public License v. 2.0. If a copy of the MPL was not distributed
  9. // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
  10. #define EIGEN_USE_THREADS
  11. #include "main.h"
  12. #include <Eigen/CXX11/ThreadPool>
  13. // Visual studio doesn't implement a rand_r() function since its
  14. // implementation of rand() is already thread safe
  15. int rand_reentrant(unsigned int* s) {
  16. #ifdef EIGEN_COMP_MSVC_STRICT
  17. EIGEN_UNUSED_VARIABLE(s);
  18. return rand();
  19. #else
  20. return rand_r(s);
  21. #endif
  22. }
  23. static void test_basic_eventcount()
  24. {
  25. MaxSizeVector<EventCount::Waiter> waiters(1);
  26. waiters.resize(1);
  27. EventCount ec(waiters);
  28. EventCount::Waiter& w = waiters[0];
  29. ec.Notify(false);
  30. ec.Prewait();
  31. ec.Notify(true);
  32. ec.CommitWait(&w);
  33. ec.Prewait();
  34. ec.CancelWait();
  35. }
  36. // Fake bounded counter-based queue.
  37. struct TestQueue {
  38. std::atomic<int> val_;
  39. static const int kQueueSize = 10;
  40. TestQueue() : val_() {}
  41. ~TestQueue() { VERIFY_IS_EQUAL(val_.load(), 0); }
  42. bool Push() {
  43. int val = val_.load(std::memory_order_relaxed);
  44. for (;;) {
  45. VERIFY_GE(val, 0);
  46. VERIFY_LE(val, kQueueSize);
  47. if (val == kQueueSize) return false;
  48. if (val_.compare_exchange_weak(val, val + 1, std::memory_order_relaxed))
  49. return true;
  50. }
  51. }
  52. bool Pop() {
  53. int val = val_.load(std::memory_order_relaxed);
  54. for (;;) {
  55. VERIFY_GE(val, 0);
  56. VERIFY_LE(val, kQueueSize);
  57. if (val == 0) return false;
  58. if (val_.compare_exchange_weak(val, val - 1, std::memory_order_relaxed))
  59. return true;
  60. }
  61. }
  62. bool Empty() { return val_.load(std::memory_order_relaxed) == 0; }
  63. };
  64. const int TestQueue::kQueueSize;
  65. // A number of producers send messages to a set of consumers using a set of
  66. // fake queues. Ensure that it does not crash, consumers don't deadlock and
  67. // number of blocked and unblocked threads match.
  68. static void test_stress_eventcount()
  69. {
  70. const int kThreads = std::thread::hardware_concurrency();
  71. static const int kEvents = 1 << 16;
  72. static const int kQueues = 10;
  73. MaxSizeVector<EventCount::Waiter> waiters(kThreads);
  74. waiters.resize(kThreads);
  75. EventCount ec(waiters);
  76. TestQueue queues[kQueues];
  77. std::vector<std::unique_ptr<std::thread>> producers;
  78. for (int i = 0; i < kThreads; i++) {
  79. producers.emplace_back(new std::thread([&ec, &queues]() {
  80. unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
  81. for (int j = 0; j < kEvents; j++) {
  82. unsigned idx = rand_reentrant(&rnd) % kQueues;
  83. if (queues[idx].Push()) {
  84. ec.Notify(false);
  85. continue;
  86. }
  87. EIGEN_THREAD_YIELD();
  88. j--;
  89. }
  90. }));
  91. }
  92. std::vector<std::unique_ptr<std::thread>> consumers;
  93. for (int i = 0; i < kThreads; i++) {
  94. consumers.emplace_back(new std::thread([&ec, &queues, &waiters, i]() {
  95. EventCount::Waiter& w = waiters[i];
  96. unsigned int rnd = static_cast<unsigned int>(std::hash<std::thread::id>()(std::this_thread::get_id()));
  97. for (int j = 0; j < kEvents; j++) {
  98. unsigned idx = rand_reentrant(&rnd) % kQueues;
  99. if (queues[idx].Pop()) continue;
  100. j--;
  101. ec.Prewait();
  102. bool empty = true;
  103. for (int q = 0; q < kQueues; q++) {
  104. if (!queues[q].Empty()) {
  105. empty = false;
  106. break;
  107. }
  108. }
  109. if (!empty) {
  110. ec.CancelWait();
  111. continue;
  112. }
  113. ec.CommitWait(&w);
  114. }
  115. }));
  116. }
  117. for (int i = 0; i < kThreads; i++) {
  118. producers[i]->join();
  119. consumers[i]->join();
  120. }
  121. }
  122. EIGEN_DECLARE_TEST(cxx11_eventcount)
  123. {
  124. CALL_SUBTEST(test_basic_eventcount());
  125. CALL_SUBTEST(test_stress_eventcount());
  126. }