123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- #ifndef CERES_INTERNAL_PARALLEL_INVOKE_H_
- #define CERES_INTERNAL_PARALLEL_INVOKE_H_
- #include <atomic>
- #include <condition_variable>
- #include <memory>
- #include <mutex>
- #include <tuple>
- #include <type_traits>
- namespace ceres::internal {
- template <typename F, typename... Args>
- void InvokeWithThreadId(int thread_id, F&& function, Args&&... args) {
- constexpr bool kPassThreadId = std::is_invocable_v<F, int, Args...>;
- if constexpr (kPassThreadId) {
- function(thread_id, std::forward<Args>(args)...);
- } else {
- function(std::forward<Args>(args)...);
- }
- }
- template <typename F>
- void InvokeOnSegment(int thread_id, std::tuple<int, int> range, F&& function) {
- constexpr bool kExplicitLoop =
- std::is_invocable_v<F, int> || std::is_invocable_v<F, int, int>;
- if constexpr (kExplicitLoop) {
- const auto [start, end] = range;
- for (int i = start; i != end; ++i) {
- InvokeWithThreadId(thread_id, std::forward<F>(function), i);
- }
- } else {
- InvokeWithThreadId(thread_id, std::forward<F>(function), range);
- }
- }
- class BlockUntilFinished {
- public:
- explicit BlockUntilFinished(int num_total_jobs);
-
-
-
- void Finished(int num_jobs_finished);
-
- void Block();
- private:
- std::mutex mutex_;
- std::condition_variable condition_;
- int num_total_jobs_finished_;
- const int num_total_jobs_;
- };
- struct ParallelInvokeState {
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- ParallelInvokeState(int start, int end, int num_work_blocks);
-
- const int start;
- const int end;
-
- const int num_work_blocks;
-
- const int base_block_size;
-
- const int num_base_p1_sized_blocks;
-
-
-
-
-
-
- std::atomic<int> block_id;
-
-
-
- std::atomic<int> thread_id;
-
- BlockUntilFinished block_until_finished;
- };
- template <typename F>
- void ParallelInvoke(ContextImpl* context,
- int start,
- int end,
- int num_threads,
- F&& function,
- int min_block_size) {
- CHECK(context != nullptr);
-
-
-
- constexpr int kWorkBlocksPerThread = 4;
-
-
-
-
-
- const int num_work_blocks = std::min((end - start) / min_block_size,
- num_threads * kWorkBlocksPerThread);
-
-
-
- auto shared_state =
- std::make_shared<ParallelInvokeState>(start, end, num_work_blocks);
-
-
-
- auto task = [context, shared_state, num_threads, &function](auto& task_copy) {
- int num_jobs_finished = 0;
- const int thread_id = shared_state->thread_id.fetch_add(1);
-
-
-
-
-
-
-
- if (thread_id >= num_threads) return;
- const int num_work_blocks = shared_state->num_work_blocks;
- if (thread_id + 1 < num_threads &&
- shared_state->block_id < num_work_blocks) {
-
-
-
-
-
-
- context->thread_pool.AddTask([task_copy]() { task_copy(task_copy); });
- }
- const int start = shared_state->start;
- const int base_block_size = shared_state->base_block_size;
- const int num_base_p1_sized_blocks = shared_state->num_base_p1_sized_blocks;
- while (true) {
-
-
- int block_id = shared_state->block_id.fetch_add(1);
- if (block_id >= num_work_blocks) {
- break;
- }
- ++num_jobs_finished;
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- const int curr_start = start + block_id * base_block_size +
- std::min(block_id, num_base_p1_sized_blocks);
-
-
-
-
- const int curr_end = curr_start + base_block_size +
- (block_id < num_base_p1_sized_blocks ? 1 : 0);
-
- const auto range = std::make_tuple(curr_start, curr_end);
- InvokeOnSegment(thread_id, range, function);
- }
- shared_state->block_until_finished.Finished(num_jobs_finished);
- };
-
-
-
- task(task);
-
- shared_state->block_until_finished.Block();
- }
- }
- #endif
|