parallel_for.h 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Authors: vitus@google.com (Michael Vitus),
  30. // dmitriy.korchemkin@gmail.com (Dmitriy Korchemkin)
  31. #ifndef CERES_INTERNAL_PARALLEL_FOR_H_
  32. #define CERES_INTERNAL_PARALLEL_FOR_H_
  33. #include <mutex>
  34. #include <vector>
  35. #include "ceres/context_impl.h"
  36. #include "ceres/internal/eigen.h"
  37. #include "ceres/internal/export.h"
  38. #include "ceres/parallel_invoke.h"
  39. #include "ceres/partition_range_for_parallel_for.h"
  40. #include "glog/logging.h"
  41. namespace ceres::internal {
  42. // Use a dummy mutex if num_threads = 1.
  43. inline decltype(auto) MakeConditionalLock(const int num_threads,
  44. std::mutex& m) {
  45. return (num_threads == 1) ? std::unique_lock<std::mutex>{}
  46. : std::unique_lock<std::mutex>{m};
  47. }
  48. // Execute the function for every element in the range [start, end) with at most
  49. // num_threads. It will execute all the work on the calling thread if
  50. // num_threads or (end - start) is equal to 1.
  51. // Depending on function signature, it will be supplied with either loop index
  52. // or a range of loop indicies; function can also be supplied with thread_id.
  53. // The following function signatures are supported:
  54. // - Functions accepting a single loop index:
  55. // - [](int index) { ... }
  56. // - [](int thread_id, int index) { ... }
  57. // - Functions accepting a range of loop index:
  58. // - [](std::tuple<int, int> index) { ... }
  59. // - [](int thread_id, std::tuple<int, int> index) { ... }
  60. //
  61. // When distributing workload between threads, it is assumed that each loop
  62. // iteration takes approximately equal time to complete.
  63. template <typename F>
  64. void ParallelFor(ContextImpl* context,
  65. int start,
  66. int end,
  67. int num_threads,
  68. F&& function,
  69. int min_block_size = 1) {
  70. CHECK_GT(num_threads, 0);
  71. if (start >= end) {
  72. return;
  73. }
  74. if (num_threads == 1 || end - start < min_block_size * 2) {
  75. InvokeOnSegment(0, std::make_tuple(start, end), std::forward<F>(function));
  76. return;
  77. }
  78. CHECK(context != nullptr);
  79. ParallelInvoke(context,
  80. start,
  81. end,
  82. num_threads,
  83. std::forward<F>(function),
  84. min_block_size);
  85. }
  86. // Execute function for every element in the range [start, end) with at most
  87. // num_threads, using user-provided partitions array.
  88. // When distributing workload between threads, it is assumed that each segment
  89. // bounded by adjacent elements of partitions array takes approximately equal
  90. // time to process.
  91. template <typename F>
  92. void ParallelFor(ContextImpl* context,
  93. int start,
  94. int end,
  95. int num_threads,
  96. F&& function,
  97. const std::vector<int>& partitions) {
  98. CHECK_GT(num_threads, 0);
  99. if (start >= end) {
  100. return;
  101. }
  102. CHECK_EQ(partitions.front(), start);
  103. CHECK_EQ(partitions.back(), end);
  104. if (num_threads == 1 || end - start <= num_threads) {
  105. ParallelFor(context, start, end, num_threads, std::forward<F>(function));
  106. return;
  107. }
  108. CHECK_GT(partitions.size(), 1);
  109. const int num_partitions = partitions.size() - 1;
  110. ParallelFor(context,
  111. 0,
  112. num_partitions,
  113. num_threads,
  114. [&function, &partitions](int thread_id,
  115. std::tuple<int, int> partition_ids) {
  116. // partition_ids is a range of partition indices
  117. const auto [partition_start, partition_end] = partition_ids;
  118. // Execution over several adjacent segments is equivalent
  119. // to execution over union of those segments (which is also a
  120. // contiguous segment)
  121. const int range_start = partitions[partition_start];
  122. const int range_end = partitions[partition_end];
  123. // Range of original loop indices
  124. const auto range = std::make_tuple(range_start, range_end);
  125. InvokeOnSegment(thread_id, range, function);
  126. });
  127. }
  128. // Execute function for every element in the range [start, end) with at most
  129. // num_threads, taking into account user-provided integer cumulative costs of
  130. // iterations. Cumulative costs of iteration for indices in range [0, end) are
  131. // stored in objects from cumulative_cost_data. User-provided
  132. // cumulative_cost_fun returns non-decreasing integer values corresponding to
  133. // inclusive cumulative cost of loop iterations, provided with a reference to
  134. // user-defined object. Only indices from [start, end) will be referenced. This
  135. // routine assumes that cumulative_cost_fun is non-decreasing (in other words,
  136. // all costs are non-negative);
  137. // When distributing workload between threads, input range of loop indices will
  138. // be partitioned into disjoint contiguous intervals, with the maximal cost
  139. // being minimized.
  140. // For example, with iteration costs of [1, 1, 5, 3, 1, 4] cumulative_cost_fun
  141. // should return [1, 2, 7, 10, 11, 15], and with num_threads = 4 this range
  142. // will be split into segments [0, 2) [2, 3) [3, 5) [5, 6) with costs
  143. // [2, 5, 4, 4].
  144. template <typename F, typename CumulativeCostData, typename CumulativeCostFun>
  145. void ParallelFor(ContextImpl* context,
  146. int start,
  147. int end,
  148. int num_threads,
  149. F&& function,
  150. const CumulativeCostData* cumulative_cost_data,
  151. CumulativeCostFun&& cumulative_cost_fun) {
  152. CHECK_GT(num_threads, 0);
  153. if (start >= end) {
  154. return;
  155. }
  156. if (num_threads == 1 || end - start <= num_threads) {
  157. ParallelFor(context, start, end, num_threads, std::forward<F>(function));
  158. return;
  159. }
  160. // Creating several partitions allows us to tolerate imperfections of
  161. // partitioning and user-supplied iteration costs up to a certain extent
  162. constexpr int kNumPartitionsPerThread = 4;
  163. const int kMaxPartitions = num_threads * kNumPartitionsPerThread;
  164. const auto& partitions = PartitionRangeForParallelFor(
  165. start,
  166. end,
  167. kMaxPartitions,
  168. cumulative_cost_data,
  169. std::forward<CumulativeCostFun>(cumulative_cost_fun));
  170. CHECK_GT(partitions.size(), 1);
  171. ParallelFor(
  172. context, start, end, num_threads, std::forward<F>(function), partitions);
  173. }
  174. } // namespace ceres::internal
  175. #endif // CERES_INTERNAL_PARALLEL_FOR_H_