123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- #ifndef CERES_PUBLIC_LOSS_FUNCTION_H_
- #define CERES_PUBLIC_LOSS_FUNCTION_H_
- #include <memory>
- #include "ceres/internal/disable_warnings.h"
- #include "ceres/internal/export.h"
- #include "ceres/types.h"
- #include "glog/logging.h"
- namespace ceres {
- class CERES_EXPORT LossFunction {
- public:
- virtual ~LossFunction();
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- virtual void Evaluate(double sq_norm, double out[3]) const = 0;
- };
- class CERES_EXPORT TrivialLoss final : public LossFunction {
- public:
- void Evaluate(double, double*) const override;
- };
- class CERES_EXPORT HuberLoss final : public LossFunction {
- public:
- explicit HuberLoss(double a) : a_(a), b_(a * a) {}
- void Evaluate(double, double*) const override;
- private:
- const double a_;
-
- const double b_;
- };
- class CERES_EXPORT SoftLOneLoss final : public LossFunction {
- public:
- explicit SoftLOneLoss(double a) : b_(a * a), c_(1 / b_) {}
- void Evaluate(double, double*) const override;
- private:
-
- const double b_;
-
- const double c_;
- };
- class CERES_EXPORT CauchyLoss final : public LossFunction {
- public:
- explicit CauchyLoss(double a) : b_(a * a), c_(1 / b_) {}
- void Evaluate(double, double*) const override;
- private:
-
- const double b_;
-
- const double c_;
- };
- class CERES_EXPORT ArctanLoss final : public LossFunction {
- public:
- explicit ArctanLoss(double a) : a_(a), b_(1 / (a * a)) {}
- void Evaluate(double, double*) const override;
- private:
- const double a_;
-
- const double b_;
- };
- class CERES_EXPORT TolerantLoss final : public LossFunction {
- public:
- explicit TolerantLoss(double a, double b);
- void Evaluate(double, double*) const override;
- private:
- const double a_, b_, c_;
- };
- class CERES_EXPORT TukeyLoss final : public ceres::LossFunction {
- public:
- explicit TukeyLoss(double a) : a_squared_(a * a) {}
- void Evaluate(double, double*) const override;
- private:
- const double a_squared_;
- };
- class CERES_EXPORT ComposedLoss final : public LossFunction {
- public:
- explicit ComposedLoss(const LossFunction* f,
- Ownership ownership_f,
- const LossFunction* g,
- Ownership ownership_g);
- ~ComposedLoss() override;
- void Evaluate(double, double*) const override;
- private:
- std::unique_ptr<const LossFunction> f_, g_;
- const Ownership ownership_f_, ownership_g_;
- };
- class CERES_EXPORT ScaledLoss final : public LossFunction {
- public:
-
-
-
- ScaledLoss(const LossFunction* rho, double a, Ownership ownership)
- : rho_(rho), a_(a), ownership_(ownership) {}
- ScaledLoss(const ScaledLoss&) = delete;
- void operator=(const ScaledLoss&) = delete;
- ~ScaledLoss() override {
- if (ownership_ == DO_NOT_TAKE_OWNERSHIP) {
- rho_.release();
- }
- }
- void Evaluate(double, double*) const override;
- private:
- std::unique_ptr<const LossFunction> rho_;
- const double a_;
- const Ownership ownership_;
- };
- class CERES_EXPORT LossFunctionWrapper final : public LossFunction {
- public:
- LossFunctionWrapper(LossFunction* rho, Ownership ownership)
- : rho_(rho), ownership_(ownership) {}
- LossFunctionWrapper(const LossFunctionWrapper&) = delete;
- void operator=(const LossFunctionWrapper&) = delete;
- ~LossFunctionWrapper() override {
- if (ownership_ == DO_NOT_TAKE_OWNERSHIP) {
- rho_.release();
- }
- }
- void Evaluate(double sq_norm, double out[3]) const override {
- if (rho_.get() == nullptr) {
- out[0] = sq_norm;
- out[1] = 1.0;
- out[2] = 0.0;
- } else {
- rho_->Evaluate(sq_norm, out);
- }
- }
- void Reset(LossFunction* rho, Ownership ownership) {
- if (ownership_ == DO_NOT_TAKE_OWNERSHIP) {
- rho_.release();
- }
- rho_.reset(rho);
- ownership_ = ownership;
- }
- private:
- std::unique_ptr<const LossFunction> rho_;
- Ownership ownership_;
- };
- }
- #include "ceres/internal/reenable_warnings.h"
- #endif
|