canonical_views_clustering.cc 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: David Gallup (dgallup@google.com)
  30. // Sameer Agarwal (sameeragarwal@google.com)
  31. #include "ceres/canonical_views_clustering.h"
  32. #include <unordered_map>
  33. #include <unordered_set>
  34. #include <vector>
  35. #include "ceres/graph.h"
  36. #include "ceres/internal/export.h"
  37. #include "ceres/map_util.h"
  38. #include "glog/logging.h"
  39. namespace ceres::internal {
  40. using IntMap = std::unordered_map<int, int>;
  41. using IntSet = std::unordered_set<int>;
  42. class CERES_NO_EXPORT CanonicalViewsClustering {
  43. public:
  44. // Compute the canonical views clustering of the vertices of the
  45. // graph. centers will contain the vertices that are the identified
  46. // as the canonical views/cluster centers, and membership is a map
  47. // from vertices to cluster_ids. The i^th cluster center corresponds
  48. // to the i^th cluster. It is possible depending on the
  49. // configuration of the clustering algorithm that some of the
  50. // vertices may not be assigned to any cluster. In this case they
  51. // are assigned to a cluster with id = kInvalidClusterId.
  52. void ComputeClustering(const CanonicalViewsClusteringOptions& options,
  53. const WeightedGraph<int>& graph,
  54. std::vector<int>* centers,
  55. IntMap* membership);
  56. private:
  57. void FindValidViews(IntSet* valid_views) const;
  58. double ComputeClusteringQualityDifference(
  59. int candidate, const std::vector<int>& centers) const;
  60. void UpdateCanonicalViewAssignments(const int canonical_view);
  61. void ComputeClusterMembership(const std::vector<int>& centers,
  62. IntMap* membership) const;
  63. CanonicalViewsClusteringOptions options_;
  64. const WeightedGraph<int>* graph_;
  65. // Maps a view to its representative canonical view (its cluster
  66. // center).
  67. IntMap view_to_canonical_view_;
  68. // Maps a view to its similarity to its current cluster center.
  69. std::unordered_map<int, double> view_to_canonical_view_similarity_;
  70. };
  71. void ComputeCanonicalViewsClustering(
  72. const CanonicalViewsClusteringOptions& options,
  73. const WeightedGraph<int>& graph,
  74. std::vector<int>* centers,
  75. IntMap* membership) {
  76. time_t start_time = time(nullptr);
  77. CanonicalViewsClustering cv;
  78. cv.ComputeClustering(options, graph, centers, membership);
  79. VLOG(2) << "Canonical views clustering time (secs): "
  80. << time(nullptr) - start_time;
  81. }
  82. // Implementation of CanonicalViewsClustering
  83. void CanonicalViewsClustering::ComputeClustering(
  84. const CanonicalViewsClusteringOptions& options,
  85. const WeightedGraph<int>& graph,
  86. std::vector<int>* centers,
  87. IntMap* membership) {
  88. options_ = options;
  89. CHECK(centers != nullptr);
  90. CHECK(membership != nullptr);
  91. centers->clear();
  92. membership->clear();
  93. graph_ = &graph;
  94. IntSet valid_views;
  95. FindValidViews(&valid_views);
  96. while (!valid_views.empty()) {
  97. // Find the next best canonical view.
  98. double best_difference = -std::numeric_limits<double>::max();
  99. int best_view = 0;
  100. // TODO(sameeragarwal): Make this loop multi-threaded.
  101. for (const auto& view : valid_views) {
  102. const double difference =
  103. ComputeClusteringQualityDifference(view, *centers);
  104. if (difference > best_difference) {
  105. best_difference = difference;
  106. best_view = view;
  107. }
  108. }
  109. CHECK_GT(best_difference, -std::numeric_limits<double>::max());
  110. // Add canonical view if quality improves, or if minimum is not
  111. // yet met, otherwise break.
  112. if ((best_difference <= 0) && (centers->size() >= options_.min_views)) {
  113. break;
  114. }
  115. centers->push_back(best_view);
  116. valid_views.erase(best_view);
  117. UpdateCanonicalViewAssignments(best_view);
  118. }
  119. ComputeClusterMembership(*centers, membership);
  120. }
  121. // Return the set of vertices of the graph which have valid vertex
  122. // weights.
  123. void CanonicalViewsClustering::FindValidViews(IntSet* valid_views) const {
  124. const IntSet& views = graph_->vertices();
  125. for (const auto& view : views) {
  126. if (graph_->VertexWeight(view) != WeightedGraph<int>::InvalidWeight()) {
  127. valid_views->insert(view);
  128. }
  129. }
  130. }
  131. // Computes the difference in the quality score if 'candidate' were
  132. // added to the set of canonical views.
  133. double CanonicalViewsClustering::ComputeClusteringQualityDifference(
  134. const int candidate, const std::vector<int>& centers) const {
  135. // View score.
  136. double difference =
  137. options_.view_score_weight * graph_->VertexWeight(candidate);
  138. // Compute how much the quality score changes if the candidate view
  139. // was added to the list of canonical views and its nearest
  140. // neighbors became members of its cluster.
  141. const IntSet& neighbors = graph_->Neighbors(candidate);
  142. for (const auto& neighbor : neighbors) {
  143. const double old_similarity =
  144. FindWithDefault(view_to_canonical_view_similarity_, neighbor, 0.0);
  145. const double new_similarity = graph_->EdgeWeight(neighbor, candidate);
  146. if (new_similarity > old_similarity) {
  147. difference += new_similarity - old_similarity;
  148. }
  149. }
  150. // Number of views penalty.
  151. difference -= options_.size_penalty_weight;
  152. // Orthogonality.
  153. for (int center : centers) {
  154. difference -= options_.similarity_penalty_weight *
  155. graph_->EdgeWeight(center, candidate);
  156. }
  157. return difference;
  158. }
  159. // Reassign views if they're more similar to the new canonical view.
  160. void CanonicalViewsClustering::UpdateCanonicalViewAssignments(
  161. const int canonical_view) {
  162. const IntSet& neighbors = graph_->Neighbors(canonical_view);
  163. for (const auto& neighbor : neighbors) {
  164. const double old_similarity =
  165. FindWithDefault(view_to_canonical_view_similarity_, neighbor, 0.0);
  166. const double new_similarity = graph_->EdgeWeight(neighbor, canonical_view);
  167. if (new_similarity > old_similarity) {
  168. view_to_canonical_view_[neighbor] = canonical_view;
  169. view_to_canonical_view_similarity_[neighbor] = new_similarity;
  170. }
  171. }
  172. }
  173. // Assign a cluster id to each view.
  174. void CanonicalViewsClustering::ComputeClusterMembership(
  175. const std::vector<int>& centers, IntMap* membership) const {
  176. CHECK(membership != nullptr);
  177. membership->clear();
  178. // The i^th cluster has cluster id i.
  179. IntMap center_to_cluster_id;
  180. for (int i = 0; i < centers.size(); ++i) {
  181. center_to_cluster_id[centers[i]] = i;
  182. }
  183. static constexpr int kInvalidClusterId = -1;
  184. const IntSet& views = graph_->vertices();
  185. for (const auto& view : views) {
  186. auto it = view_to_canonical_view_.find(view);
  187. int cluster_id = kInvalidClusterId;
  188. if (it != view_to_canonical_view_.end()) {
  189. cluster_id = FindOrDie(center_to_cluster_id, it->second);
  190. }
  191. InsertOrDie(membership, view, cluster_id);
  192. }
  193. }
  194. } // namespace ceres::internal