dense_solvers.cpp 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #include <iostream>
  2. #include "BenchTimer.h"
  3. #include <Eigen/Dense>
  4. #include <map>
  5. #include <vector>
  6. #include <string>
  7. #include <sstream>
  8. using namespace Eigen;
  9. std::map<std::string,Array<float,1,8,DontAlign|RowMajor> > results;
  10. std::vector<std::string> labels;
  11. std::vector<Array2i> sizes;
  12. template<typename Solver,typename MatrixType>
  13. EIGEN_DONT_INLINE
  14. void compute_norm_equation(Solver &solver, const MatrixType &A) {
  15. if(A.rows()!=A.cols())
  16. solver.compute(A.transpose()*A);
  17. else
  18. solver.compute(A);
  19. }
  20. template<typename Solver,typename MatrixType>
  21. EIGEN_DONT_INLINE
  22. void compute(Solver &solver, const MatrixType &A) {
  23. solver.compute(A);
  24. }
  25. template<typename Scalar,int Size>
  26. void bench(int id, int rows, int size = Size)
  27. {
  28. typedef Matrix<Scalar,Dynamic,Size> Mat;
  29. typedef Matrix<Scalar,Dynamic,Dynamic> MatDyn;
  30. typedef Matrix<Scalar,Size,Size> MatSquare;
  31. Mat A(rows,size);
  32. A.setRandom();
  33. if(rows==size)
  34. A = A*A.adjoint();
  35. BenchTimer t_llt, t_ldlt, t_lu, t_fplu, t_qr, t_cpqr, t_cod, t_fpqr, t_jsvd, t_bdcsvd;
  36. int svd_opt = ComputeThinU|ComputeThinV;
  37. int tries = 5;
  38. int rep = 1000/size;
  39. if(rep==0) rep = 1;
  40. // rep = rep*rep;
  41. LLT<MatSquare> llt(size);
  42. LDLT<MatSquare> ldlt(size);
  43. PartialPivLU<MatSquare> lu(size);
  44. FullPivLU<MatSquare> fplu(size,size);
  45. HouseholderQR<Mat> qr(A.rows(),A.cols());
  46. ColPivHouseholderQR<Mat> cpqr(A.rows(),A.cols());
  47. CompleteOrthogonalDecomposition<Mat> cod(A.rows(),A.cols());
  48. FullPivHouseholderQR<Mat> fpqr(A.rows(),A.cols());
  49. JacobiSVD<MatDyn> jsvd(A.rows(),A.cols());
  50. BDCSVD<MatDyn> bdcsvd(A.rows(),A.cols());
  51. BENCH(t_llt, tries, rep, compute_norm_equation(llt,A));
  52. BENCH(t_ldlt, tries, rep, compute_norm_equation(ldlt,A));
  53. BENCH(t_lu, tries, rep, compute_norm_equation(lu,A));
  54. if(size<=1000)
  55. BENCH(t_fplu, tries, rep, compute_norm_equation(fplu,A));
  56. BENCH(t_qr, tries, rep, compute(qr,A));
  57. BENCH(t_cpqr, tries, rep, compute(cpqr,A));
  58. BENCH(t_cod, tries, rep, compute(cod,A));
  59. if(size*rows<=10000000)
  60. BENCH(t_fpqr, tries, rep, compute(fpqr,A));
  61. if(size<500) // JacobiSVD is really too slow for too large matrices
  62. BENCH(t_jsvd, tries, rep, jsvd.compute(A,svd_opt));
  63. // if(size*rows<=20000000)
  64. BENCH(t_bdcsvd, tries, rep, bdcsvd.compute(A,svd_opt));
  65. results["LLT"][id] = t_llt.best();
  66. results["LDLT"][id] = t_ldlt.best();
  67. results["PartialPivLU"][id] = t_lu.best();
  68. results["FullPivLU"][id] = t_fplu.best();
  69. results["HouseholderQR"][id] = t_qr.best();
  70. results["ColPivHouseholderQR"][id] = t_cpqr.best();
  71. results["CompleteOrthogonalDecomposition"][id] = t_cod.best();
  72. results["FullPivHouseholderQR"][id] = t_fpqr.best();
  73. results["JacobiSVD"][id] = t_jsvd.best();
  74. results["BDCSVD"][id] = t_bdcsvd.best();
  75. }
  76. int main()
  77. {
  78. labels.push_back("LLT");
  79. labels.push_back("LDLT");
  80. labels.push_back("PartialPivLU");
  81. labels.push_back("FullPivLU");
  82. labels.push_back("HouseholderQR");
  83. labels.push_back("ColPivHouseholderQR");
  84. labels.push_back("CompleteOrthogonalDecomposition");
  85. labels.push_back("FullPivHouseholderQR");
  86. labels.push_back("JacobiSVD");
  87. labels.push_back("BDCSVD");
  88. for(int i=0; i<labels.size(); ++i)
  89. results[labels[i]].fill(-1);
  90. const int small = 8;
  91. sizes.push_back(Array2i(small,small));
  92. sizes.push_back(Array2i(100,100));
  93. sizes.push_back(Array2i(1000,1000));
  94. sizes.push_back(Array2i(4000,4000));
  95. sizes.push_back(Array2i(10000,small));
  96. sizes.push_back(Array2i(10000,100));
  97. sizes.push_back(Array2i(10000,1000));
  98. sizes.push_back(Array2i(10000,4000));
  99. using namespace std;
  100. for(int k=0; k<sizes.size(); ++k)
  101. {
  102. cout << sizes[k](0) << "x" << sizes[k](1) << "...\n";
  103. bench<float,Dynamic>(k,sizes[k](0),sizes[k](1));
  104. }
  105. cout.width(32);
  106. cout << "solver/size";
  107. cout << " ";
  108. for(int k=0; k<sizes.size(); ++k)
  109. {
  110. std::stringstream ss;
  111. ss << sizes[k](0) << "x" << sizes[k](1);
  112. cout.width(10); cout << ss.str(); cout << " ";
  113. }
  114. cout << endl;
  115. for(int i=0; i<labels.size(); ++i)
  116. {
  117. cout.width(32); cout << labels[i]; cout << " ";
  118. ArrayXf r = (results[labels[i]]*100000.f).floor()/100.f;
  119. for(int k=0; k<sizes.size(); ++k)
  120. {
  121. cout.width(10);
  122. if(r(k)>=1e6) cout << "-";
  123. else cout << r(k);
  124. cout << " ";
  125. }
  126. cout << endl;
  127. }
  128. // HTML output
  129. cout << "<table class=\"manual\">" << endl;
  130. cout << "<tr><th>solver/size</th>" << endl;
  131. for(int k=0; k<sizes.size(); ++k)
  132. cout << " <th>" << sizes[k](0) << "x" << sizes[k](1) << "</th>";
  133. cout << "</tr>" << endl;
  134. for(int i=0; i<labels.size(); ++i)
  135. {
  136. cout << "<tr";
  137. if(i%2==1) cout << " class=\"alt\"";
  138. cout << "><td>" << labels[i] << "</td>";
  139. ArrayXf r = (results[labels[i]]*100000.f).floor()/100.f;
  140. for(int k=0; k<sizes.size(); ++k)
  141. {
  142. if(r(k)>=1e6) cout << "<td>-</td>";
  143. else
  144. {
  145. cout << "<td>" << r(k);
  146. if(i>0)
  147. cout << " (x" << numext::round(10.f*results[labels[i]](k)/results["LLT"](k))/10.f << ")";
  148. if(i<4 && sizes[k](0)!=sizes[k](1))
  149. cout << " <sup><a href=\"#note_ls\">*</a></sup>";
  150. cout << "</td>";
  151. }
  152. }
  153. cout << "</tr>" << endl;
  154. }
  155. cout << "</table>" << endl;
  156. // cout << "LLT (ms) " << (results["LLT"]*1000.).format(fmt) << "\n";
  157. // cout << "LDLT (%) " << (results["LDLT"]/results["LLT"]).format(fmt) << "\n";
  158. // cout << "PartialPivLU (%) " << (results["PartialPivLU"]/results["LLT"]).format(fmt) << "\n";
  159. // cout << "FullPivLU (%) " << (results["FullPivLU"]/results["LLT"]).format(fmt) << "\n";
  160. // cout << "HouseholderQR (%) " << (results["HouseholderQR"]/results["LLT"]).format(fmt) << "\n";
  161. // cout << "ColPivHouseholderQR (%) " << (results["ColPivHouseholderQR"]/results["LLT"]).format(fmt) << "\n";
  162. // cout << "CompleteOrthogonalDecomposition (%) " << (results["CompleteOrthogonalDecomposition"]/results["LLT"]).format(fmt) << "\n";
  163. // cout << "FullPivHouseholderQR (%) " << (results["FullPivHouseholderQR"]/results["LLT"]).format(fmt) << "\n";
  164. // cout << "JacobiSVD (%) " << (results["JacobiSVD"]/results["LLT"]).format(fmt) << "\n";
  165. // cout << "BDCSVD (%) " << (results["BDCSVD"]/results["LLT"]).format(fmt) << "\n";
  166. }