123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- #include "ceres/canonical_views_clustering.h"
- #include <unordered_map>
- #include <unordered_set>
- #include <vector>
- #include "ceres/graph.h"
- #include "ceres/internal/export.h"
- #include "ceres/map_util.h"
- #include "glog/logging.h"
- namespace ceres::internal {
- using IntMap = std::unordered_map<int, int>;
- using IntSet = std::unordered_set<int>;
- class CERES_NO_EXPORT CanonicalViewsClustering {
- public:
-
-
-
-
-
-
-
-
- void ComputeClustering(const CanonicalViewsClusteringOptions& options,
- const WeightedGraph<int>& graph,
- std::vector<int>* centers,
- IntMap* membership);
- private:
- void FindValidViews(IntSet* valid_views) const;
- double ComputeClusteringQualityDifference(
- int candidate, const std::vector<int>& centers) const;
- void UpdateCanonicalViewAssignments(const int canonical_view);
- void ComputeClusterMembership(const std::vector<int>& centers,
- IntMap* membership) const;
- CanonicalViewsClusteringOptions options_;
- const WeightedGraph<int>* graph_;
-
-
- IntMap view_to_canonical_view_;
-
- std::unordered_map<int, double> view_to_canonical_view_similarity_;
- };
- void ComputeCanonicalViewsClustering(
- const CanonicalViewsClusteringOptions& options,
- const WeightedGraph<int>& graph,
- std::vector<int>* centers,
- IntMap* membership) {
- time_t start_time = time(nullptr);
- CanonicalViewsClustering cv;
- cv.ComputeClustering(options, graph, centers, membership);
- VLOG(2) << "Canonical views clustering time (secs): "
- << time(nullptr) - start_time;
- }
- void CanonicalViewsClustering::ComputeClustering(
- const CanonicalViewsClusteringOptions& options,
- const WeightedGraph<int>& graph,
- std::vector<int>* centers,
- IntMap* membership) {
- options_ = options;
- CHECK(centers != nullptr);
- CHECK(membership != nullptr);
- centers->clear();
- membership->clear();
- graph_ = &graph;
- IntSet valid_views;
- FindValidViews(&valid_views);
- while (!valid_views.empty()) {
-
- double best_difference = -std::numeric_limits<double>::max();
- int best_view = 0;
-
- for (const auto& view : valid_views) {
- const double difference =
- ComputeClusteringQualityDifference(view, *centers);
- if (difference > best_difference) {
- best_difference = difference;
- best_view = view;
- }
- }
- CHECK_GT(best_difference, -std::numeric_limits<double>::max());
-
-
- if ((best_difference <= 0) && (centers->size() >= options_.min_views)) {
- break;
- }
- centers->push_back(best_view);
- valid_views.erase(best_view);
- UpdateCanonicalViewAssignments(best_view);
- }
- ComputeClusterMembership(*centers, membership);
- }
- void CanonicalViewsClustering::FindValidViews(IntSet* valid_views) const {
- const IntSet& views = graph_->vertices();
- for (const auto& view : views) {
- if (graph_->VertexWeight(view) != WeightedGraph<int>::InvalidWeight()) {
- valid_views->insert(view);
- }
- }
- }
- double CanonicalViewsClustering::ComputeClusteringQualityDifference(
- const int candidate, const std::vector<int>& centers) const {
-
- double difference =
- options_.view_score_weight * graph_->VertexWeight(candidate);
-
-
-
- const IntSet& neighbors = graph_->Neighbors(candidate);
- for (const auto& neighbor : neighbors) {
- const double old_similarity =
- FindWithDefault(view_to_canonical_view_similarity_, neighbor, 0.0);
- const double new_similarity = graph_->EdgeWeight(neighbor, candidate);
- if (new_similarity > old_similarity) {
- difference += new_similarity - old_similarity;
- }
- }
-
- difference -= options_.size_penalty_weight;
-
- for (int center : centers) {
- difference -= options_.similarity_penalty_weight *
- graph_->EdgeWeight(center, candidate);
- }
- return difference;
- }
- void CanonicalViewsClustering::UpdateCanonicalViewAssignments(
- const int canonical_view) {
- const IntSet& neighbors = graph_->Neighbors(canonical_view);
- for (const auto& neighbor : neighbors) {
- const double old_similarity =
- FindWithDefault(view_to_canonical_view_similarity_, neighbor, 0.0);
- const double new_similarity = graph_->EdgeWeight(neighbor, canonical_view);
- if (new_similarity > old_similarity) {
- view_to_canonical_view_[neighbor] = canonical_view;
- view_to_canonical_view_similarity_[neighbor] = new_similarity;
- }
- }
- }
- void CanonicalViewsClustering::ComputeClusterMembership(
- const std::vector<int>& centers, IntMap* membership) const {
- CHECK(membership != nullptr);
- membership->clear();
-
- IntMap center_to_cluster_id;
- for (int i = 0; i < centers.size(); ++i) {
- center_to_cluster_id[centers[i]] = i;
- }
- static constexpr int kInvalidClusterId = -1;
- const IntSet& views = graph_->vertices();
- for (const auto& view : views) {
- auto it = view_to_canonical_view_.find(view);
- int cluster_id = kInvalidClusterId;
- if (it != view_to_canonical_view_.end()) {
- cluster_id = FindOrDie(center_to_cluster_id, it->second);
- }
- InsertOrDie(membership, view, cluster_id);
- }
- }
- }
|