123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- #ifndef CERES_EXAMPLES_FIELDS_OF_EXPERTS_H_
- #define CERES_EXAMPLES_FIELDS_OF_EXPERTS_H_
- #include <iostream>
- #include <vector>
- #include "ceres/cost_function.h"
- #include "ceres/loss_function.h"
- #include "ceres/sized_cost_function.h"
- #include "pgm_image.h"
- namespace ceres::examples {
- class FieldsOfExpertsCost : public ceres::CostFunction {
- public:
- explicit FieldsOfExpertsCost(const std::vector<double>& filter);
-
-
- bool Evaluate(double const* const* parameters,
- double* residuals,
- double** jacobians) const override;
- private:
- const std::vector<double>& filter_;
- };
- class FieldsOfExpertsLoss : public ceres::LossFunction {
- public:
- explicit FieldsOfExpertsLoss(double alpha) : alpha_(alpha) {}
- void Evaluate(double, double*) const override;
- private:
- const double alpha_;
- };
- class FieldsOfExperts {
- public:
-
- FieldsOfExperts();
-
-
- bool LoadFromFile(const std::string& filename);
-
- int Size() const { return size_; }
-
- int NumVariables() const { return size_ * size_; }
-
- int NumFilters() const { return num_filters_; }
-
-
- ceres::CostFunction* NewCostFunction(int alpha_index) const;
-
-
- ceres::LossFunction* NewLossFunction(int alpha_index) const;
-
- const std::vector<int>& GetXDeltaIndices() const { return x_delta_indices_; }
- const std::vector<int>& GetYDeltaIndices() const { return y_delta_indices_; }
- private:
-
- int size_;
-
- int num_filters_;
-
- std::vector<int> x_delta_indices_, y_delta_indices_;
-
- std::vector<double> alpha_;
-
- std::vector<std::vector<double>> filters_;
- };
- }
- #endif
|