123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- #include "ceres/thread_pool.h"
- #include <cmath>
- #include <limits>
- #include "ceres/internal/config.h"
- namespace ceres::internal {
- namespace {
- int GetNumAllowedThreads(int requested_num_threads) {
- return std::min(requested_num_threads, ThreadPool::MaxNumThreadsAvailable());
- }
- }
- int ThreadPool::MaxNumThreadsAvailable() {
- const int num_hardware_threads = std::thread::hardware_concurrency();
-
-
- return num_hardware_threads == 0 ? std::numeric_limits<int>::max()
- : num_hardware_threads;
- }
- ThreadPool::ThreadPool() = default;
- ThreadPool::ThreadPool(int num_threads) { Resize(num_threads); }
- ThreadPool::~ThreadPool() {
- std::lock_guard<std::mutex> lock(thread_pool_mutex_);
-
-
- Stop();
- for (std::thread& thread : thread_pool_) {
- thread.join();
- }
- }
- void ThreadPool::Resize(int num_threads) {
- std::lock_guard<std::mutex> lock(thread_pool_mutex_);
- const int num_current_threads = thread_pool_.size();
- if (num_current_threads >= num_threads) {
- return;
- }
- const int create_num_threads =
- GetNumAllowedThreads(num_threads) - num_current_threads;
- for (int i = 0; i < create_num_threads; ++i) {
- thread_pool_.emplace_back(&ThreadPool::ThreadMainLoop, this);
- }
- }
- void ThreadPool::AddTask(const std::function<void()>& func) {
- task_queue_.Push(func);
- }
- int ThreadPool::Size() {
- std::lock_guard<std::mutex> lock(thread_pool_mutex_);
- return thread_pool_.size();
- }
- void ThreadPool::ThreadMainLoop() {
- std::function<void()> task;
- while (task_queue_.Wait(&task)) {
- task();
- }
- }
- void ThreadPool::Stop() { task_queue_.StopWaiters(); }
- }
|