parallel_invoke.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Authors: vitus@google.com (Michael Vitus),
  30. // dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
  31. #ifndef CERES_INTERNAL_PARALLEL_INVOKE_H_
  32. #define CERES_INTERNAL_PARALLEL_INVOKE_H_
  33. #include <atomic>
  34. #include <condition_variable>
  35. #include <memory>
  36. #include <mutex>
  37. #include <tuple>
  38. #include <type_traits>
  39. namespace ceres::internal {
  40. // InvokeWithThreadId handles passing thread_id to the function
  41. template <typename F, typename... Args>
  42. void InvokeWithThreadId(int thread_id, F&& function, Args&&... args) {
  43. constexpr bool kPassThreadId = std::is_invocable_v<F, int, Args...>;
  44. if constexpr (kPassThreadId) {
  45. function(thread_id, std::forward<Args>(args)...);
  46. } else {
  47. function(std::forward<Args>(args)...);
  48. }
  49. }
  50. // InvokeOnSegment either runs a loop over segment indices or passes it to the
  51. // function
  52. template <typename F>
  53. void InvokeOnSegment(int thread_id, std::tuple<int, int> range, F&& function) {
  54. constexpr bool kExplicitLoop =
  55. std::is_invocable_v<F, int> || std::is_invocable_v<F, int, int>;
  56. if constexpr (kExplicitLoop) {
  57. const auto [start, end] = range;
  58. for (int i = start; i != end; ++i) {
  59. InvokeWithThreadId(thread_id, std::forward<F>(function), i);
  60. }
  61. } else {
  62. InvokeWithThreadId(thread_id, std::forward<F>(function), range);
  63. }
  64. }
  65. // This class creates a thread safe barrier which will block until a
  66. // pre-specified number of threads call Finished. This allows us to block the
  67. // main thread until all the parallel threads are finished processing all the
  68. // work.
  69. class BlockUntilFinished {
  70. public:
  71. explicit BlockUntilFinished(int num_total_jobs);
  72. // Increment the number of jobs that have been processed by the number of
  73. // jobs processed by caller and signal the blocking thread if all jobs
  74. // have finished.
  75. void Finished(int num_jobs_finished);
  76. // Block until receiving confirmation of all jobs being finished.
  77. void Block();
  78. private:
  79. std::mutex mutex_;
  80. std::condition_variable condition_;
  81. int num_total_jobs_finished_;
  82. const int num_total_jobs_;
  83. };
  84. // Shared state between the parallel tasks. Each thread will use this
  85. // information to get the next block of work to be performed.
  86. struct ParallelInvokeState {
  87. // The entire range [start, end) is split into num_work_blocks contiguous
  88. // disjoint intervals (blocks), which are as equal as possible given
  89. // total index count and requested number of blocks.
  90. //
  91. // Those num_work_blocks blocks are then processed in parallel.
  92. //
  93. // Total number of integer indices in interval [start, end) is
  94. // end - start, and when splitting them into num_work_blocks blocks
  95. // we can either
  96. // - Split into equal blocks when (end - start) is divisible by
  97. // num_work_blocks
  98. // - Split into blocks with size difference at most 1:
  99. // - Size of the smallest block(s) is (end - start) / num_work_blocks
  100. // - (end - start) % num_work_blocks will need to be 1 index larger
  101. //
  102. // Note that this splitting is optimal in the sense of maximal difference
  103. // between block sizes, since splitting into equal blocks is possible
  104. // if and only if number of indices is divisible by number of blocks.
  105. ParallelInvokeState(int start, int end, int num_work_blocks);
  106. // The start and end index of the for loop.
  107. const int start;
  108. const int end;
  109. // The number of blocks that need to be processed.
  110. const int num_work_blocks;
  111. // Size of the smallest block
  112. const int base_block_size;
  113. // Number of blocks of size base_block_size + 1
  114. const int num_base_p1_sized_blocks;
  115. // The next block of work to be assigned to a worker. The parallel for loop
  116. // range is split into num_work_blocks blocks of work, with a single block of
  117. // work being of size
  118. // - base_block_size + 1 for the first num_base_p1_sized_blocks blocks
  119. // - base_block_size for the rest of the blocks
  120. // blocks of indices are contiguous and disjoint
  121. std::atomic<int> block_id;
  122. // Provides a unique thread ID among all active threads
  123. // We do not schedule more than num_threads threads via thread pool
  124. // and caller thread might steal one ID
  125. std::atomic<int> thread_id;
  126. // Used to signal when all the work has been completed. Thread safe.
  127. BlockUntilFinished block_until_finished;
  128. };
  129. // This implementation uses a fixed size max worker pool with a shared task
  130. // queue. The problem of executing the function for the interval of [start, end)
  131. // is broken up into at most num_threads * kWorkBlocksPerThread blocks (each of
  132. // size at least min_block_size) and added to the thread pool. To avoid
  133. // deadlocks, the calling thread is allowed to steal work from the worker pool.
  134. // This is implemented via a shared state between the tasks. In order for
  135. // the calling thread or thread pool to get a block of work, it will query the
  136. // shared state for the next block of work to be done. If there is nothing left,
  137. // it will return. We will exit the ParallelFor call when all of the work has
  138. // been done, not when all of the tasks have been popped off the task queue.
  139. //
  140. // A unique thread ID among all active tasks will be acquired once for each
  141. // block of work. This avoids the significant performance penalty for acquiring
  142. // it on every iteration of the for loop. The thread ID is guaranteed to be in
  143. // [0, num_threads).
  144. //
  145. // A performance analysis has shown this implementation is on par with OpenMP
  146. // and TBB.
  147. template <typename F>
  148. void ParallelInvoke(ContextImpl* context,
  149. int start,
  150. int end,
  151. int num_threads,
  152. F&& function,
  153. int min_block_size) {
  154. CHECK(context != nullptr);
  155. // Maximal number of work items scheduled for a single thread
  156. // - Lower number of work items results in larger runtimes on unequal tasks
  157. // - Higher number of work items results in larger losses for synchronization
  158. constexpr int kWorkBlocksPerThread = 4;
  159. // Interval [start, end) is being split into
  160. // num_threads * kWorkBlocksPerThread contiguous disjoint blocks.
  161. //
  162. // In order to avoid creating empty blocks of work, we need to limit
  163. // number of work blocks by a total number of indices.
  164. const int num_work_blocks = std::min((end - start) / min_block_size,
  165. num_threads * kWorkBlocksPerThread);
  166. // We use a std::shared_ptr because the main thread can finish all
  167. // the work before the tasks have been popped off the queue. So the
  168. // shared state needs to exist for the duration of all the tasks.
  169. auto shared_state =
  170. std::make_shared<ParallelInvokeState>(start, end, num_work_blocks);
  171. // A function which tries to schedule another task in the thread pool and
  172. // perform several chunks of work. Function expects itself as the argument in
  173. // order to schedule next task in the thread pool.
  174. auto task = [context, shared_state, num_threads, &function](auto& task_copy) {
  175. int num_jobs_finished = 0;
  176. const int thread_id = shared_state->thread_id.fetch_add(1);
  177. // In order to avoid dead-locks in nested parallel for loops, task() will be
  178. // invoked num_threads + 1 times:
  179. // - num_threads times via enqueueing task into thread pool
  180. // - one more time in the main thread
  181. // Tasks enqueued to thread pool might take some time before execution, and
  182. // the last task being executed will be terminated here in order to avoid
  183. // having more than num_threads active threads
  184. if (thread_id >= num_threads) return;
  185. const int num_work_blocks = shared_state->num_work_blocks;
  186. if (thread_id + 1 < num_threads &&
  187. shared_state->block_id < num_work_blocks) {
  188. // Add another thread to the thread pool.
  189. // Note we are taking the task as value so the copy of shared_state shared
  190. // pointer (captured by value at declaration of task lambda-function) is
  191. // copied and the ref count is increased. This is to prevent it from being
  192. // deleted when the main thread finishes all the work and exits before the
  193. // threads finish.
  194. context->thread_pool.AddTask([task_copy]() { task_copy(task_copy); });
  195. }
  196. const int start = shared_state->start;
  197. const int base_block_size = shared_state->base_block_size;
  198. const int num_base_p1_sized_blocks = shared_state->num_base_p1_sized_blocks;
  199. while (true) {
  200. // Get the next available chunk of work to be performed. If there is no
  201. // work, return.
  202. int block_id = shared_state->block_id.fetch_add(1);
  203. if (block_id >= num_work_blocks) {
  204. break;
  205. }
  206. ++num_jobs_finished;
  207. // For-loop interval [start, end) was split into num_work_blocks,
  208. // with num_base_p1_sized_blocks of size base_block_size + 1 and remaining
  209. // num_work_blocks - num_base_p1_sized_blocks of size base_block_size
  210. //
  211. // Then, start index of the block #block_id is given by a total
  212. // length of preceeding blocks:
  213. // * Total length of preceeding blocks of size base_block_size + 1:
  214. // min(block_id, num_base_p1_sized_blocks) * (base_block_size + 1)
  215. //
  216. // * Total length of preceeding blocks of size base_block_size:
  217. // (block_id - min(block_id, num_base_p1_sized_blocks)) *
  218. // base_block_size
  219. //
  220. // Simplifying sum of those quantities yields a following
  221. // expression for start index of the block #block_id
  222. const int curr_start = start + block_id * base_block_size +
  223. std::min(block_id, num_base_p1_sized_blocks);
  224. // First num_base_p1_sized_blocks have size base_block_size + 1
  225. //
  226. // Note that it is guaranteed that all blocks are within
  227. // [start, end) interval
  228. const int curr_end = curr_start + base_block_size +
  229. (block_id < num_base_p1_sized_blocks ? 1 : 0);
  230. // Perform each task in current block
  231. const auto range = std::make_tuple(curr_start, curr_end);
  232. InvokeOnSegment(thread_id, range, function);
  233. }
  234. shared_state->block_until_finished.Finished(num_jobs_finished);
  235. };
  236. // Start scheduling threads and doing work. We might end up with less threads
  237. // scheduled than expected, if scheduling overhead is larger than the amount
  238. // of work to be done.
  239. task(task);
  240. // Wait until all tasks have finished.
  241. shared_state->block_until_finished.Block();
  242. }
  243. } // namespace ceres::internal
  244. #endif