123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308 |
- #ifndef CERES_INTERNAL_DENSE_CHOLESKY_H_
- #define CERES_INTERNAL_DENSE_CHOLESKY_H_
- #include "ceres/internal/config.h"
- #include <memory>
- #include <vector>
- #include "Eigen/Dense"
- #include "ceres/context_impl.h"
- #include "ceres/cuda_buffer.h"
- #include "ceres/linear_solver.h"
- #include "glog/logging.h"
- #ifndef CERES_NO_CUDA
- #include "ceres/context_impl.h"
- #include "cuda_runtime.h"
- #include "cusolverDn.h"
- #endif
- namespace ceres::internal {
- class CERES_NO_EXPORT DenseCholesky {
- public:
- static std::unique_ptr<DenseCholesky> Create(
- const LinearSolver::Options& options);
- virtual ~DenseCholesky();
-
-
-
-
-
-
-
-
-
-
- virtual LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) = 0;
-
-
-
-
-
-
-
- virtual LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) = 0;
-
-
-
-
-
-
-
- LinearSolverTerminationType FactorAndSolve(int num_cols,
- double* lhs,
- const double* rhs,
- double* solution,
- std::string* message);
- };
- class CERES_NO_EXPORT EigenDenseCholesky final : public DenseCholesky {
- public:
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- using LLTType = Eigen::LLT<Eigen::Ref<Eigen::MatrixXd>, Eigen::Lower>;
- std::unique_ptr<LLTType> llt_;
- };
- class CERES_NO_EXPORT FloatEigenDenseCholesky final : public DenseCholesky {
- public:
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- Eigen::MatrixXf lhs_;
- Eigen::VectorXf rhs_;
- Eigen::VectorXf solution_;
- using LLTType = Eigen::LLT<Eigen::MatrixXf, Eigen::Lower>;
- std::unique_ptr<LLTType> llt_;
- };
- #ifndef CERES_NO_LAPACK
- class CERES_NO_EXPORT LAPACKDenseCholesky final : public DenseCholesky {
- public:
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- double* lhs_ = nullptr;
- int num_cols_ = -1;
- LinearSolverTerminationType termination_type_ =
- LinearSolverTerminationType::FATAL_ERROR;
- };
- class CERES_NO_EXPORT FloatLAPACKDenseCholesky final : public DenseCholesky {
- public:
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- Eigen::MatrixXf lhs_;
- Eigen::VectorXf rhs_and_solution_;
- int num_cols_ = -1;
- LinearSolverTerminationType termination_type_ =
- LinearSolverTerminationType::FATAL_ERROR;
- };
- #endif
- class DenseIterativeRefiner;
- class CERES_NO_EXPORT RefinedDenseCholesky final : public DenseCholesky {
- public:
- RefinedDenseCholesky(
- std::unique_ptr<DenseCholesky> dense_cholesky,
- std::unique_ptr<DenseIterativeRefiner> iterative_refiner);
- ~RefinedDenseCholesky() override;
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- std::unique_ptr<DenseCholesky> dense_cholesky_;
- std::unique_ptr<DenseIterativeRefiner> iterative_refiner_;
- double* lhs_ = nullptr;
- int num_cols_;
- };
- #ifndef CERES_NO_CUDA
- class CERES_NO_EXPORT CUDADenseCholesky final : public DenseCholesky {
- public:
- static std::unique_ptr<CUDADenseCholesky> Create(
- const LinearSolver::Options& options);
- CUDADenseCholesky(const CUDADenseCholesky&) = delete;
- CUDADenseCholesky& operator=(const CUDADenseCholesky&) = delete;
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- explicit CUDADenseCholesky(ContextImpl* context);
- ContextImpl* context_ = nullptr;
-
-
- size_t num_cols_ = 0;
-
- CudaBuffer<double> lhs_;
-
- CudaBuffer<double> rhs_;
-
- CudaBuffer<double> device_workspace_;
-
- CudaBuffer<int> error_;
-
-
- LinearSolverTerminationType factorize_result_ =
- LinearSolverTerminationType::FATAL_ERROR;
- };
- class CERES_NO_EXPORT CUDADenseCholeskyMixedPrecision final
- : public DenseCholesky {
- public:
- static std::unique_ptr<CUDADenseCholeskyMixedPrecision> Create(
- const LinearSolver::Options& options);
- CUDADenseCholeskyMixedPrecision(const CUDADenseCholeskyMixedPrecision&) =
- delete;
- CUDADenseCholeskyMixedPrecision& operator=(
- const CUDADenseCholeskyMixedPrecision&) = delete;
- LinearSolverTerminationType Factorize(int num_cols,
- double* lhs,
- std::string* message) override;
- LinearSolverTerminationType Solve(const double* rhs,
- double* solution,
- std::string* message) override;
- private:
- CUDADenseCholeskyMixedPrecision(ContextImpl* context,
- int max_num_refinement_iterations);
-
- LinearSolverTerminationType CudaCholeskyFactorize(std::string* message);
-
- LinearSolverTerminationType CudaCholeskySolve(std::string* message);
-
-
-
-
- bool Init(const LinearSolver::Options& options, std::string* message);
- ContextImpl* context_ = nullptr;
-
-
- size_t num_cols_ = 0;
- CudaBuffer<double> lhs_fp64_;
- CudaBuffer<double> rhs_fp64_;
- CudaBuffer<float> lhs_fp32_;
-
- CudaBuffer<float> device_workspace_;
-
- CudaBuffer<int> error_;
-
- CudaBuffer<double> x_fp64_;
-
- CudaBuffer<float> correction_fp32_;
-
- CudaBuffer<float> residual_fp32_;
- CudaBuffer<double> residual_fp64_;
-
- int max_num_refinement_iterations_ = 0;
-
-
- LinearSolverTerminationType factorize_result_ =
- LinearSolverTerminationType::FATAL_ERROR;
- };
- #endif
- }
- #endif
|