dynamic_sparsity_test.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. // Ceres Solver - A fast non-linear least squares minimizer
  2. // Copyright 2023 Google Inc. All rights reserved.
  3. // http://ceres-solver.org/
  4. //
  5. // Redistribution and use in source and binary forms, with or without
  6. // modification, are permitted provided that the following conditions are met:
  7. //
  8. // * Redistributions of source code must retain the above copyright notice,
  9. // this list of conditions and the following disclaimer.
  10. // * Redistributions in binary form must reproduce the above copyright notice,
  11. // this list of conditions and the following disclaimer in the documentation
  12. // and/or other materials provided with the distribution.
  13. // * Neither the name of Google Inc. nor the names of its contributors may be
  14. // used to endorse or promote products derived from this software without
  15. // specific prior written permission.
  16. //
  17. // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  18. // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  19. // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  20. // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
  21. // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  22. // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  23. // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  24. // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  25. // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  26. // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  27. // POSSIBILITY OF SUCH DAMAGE.
  28. //
  29. // Author: richie.stebbing@gmail.com (Richard Stebbing)
  30. // sameeragarwal@google.com (Sameer Agarwal)
  31. //
  32. // Based on examples/ellipse_approximation.cc
  33. #include <cmath>
  34. #include <utility>
  35. #include <vector>
  36. #include "ceres/ceres.h"
  37. #include "glog/logging.h"
  38. #include "gtest/gtest.h"
  39. namespace ceres::internal {
  40. // Data generated with the following Python code.
  41. // import numpy as np
  42. // np.random.seed(1337)
  43. // t = np.linspace(0.0, 2.0 * np.pi, 212, endpoint=False)
  44. // t += 2.0 * np.pi * 0.01 * np.random.randn(t.size)
  45. // theta = np.deg2rad(15)
  46. // a, b = np.cos(theta), np.sin(theta)
  47. // R = np.array([[a, -b],
  48. // [b, a]])
  49. // Y = np.dot(np.c_[4.0 * np.cos(t), np.sin(t)], R.T)
  50. const int kYRows = 212;
  51. const int kYCols = 2;
  52. // clang-format off
  53. const double kYData[kYRows * kYCols] = {
  54. +3.871364e+00, +9.916027e-01,
  55. +3.864003e+00, +1.034148e+00,
  56. +3.850651e+00, +1.072202e+00,
  57. +3.868350e+00, +1.014408e+00,
  58. +3.796381e+00, +1.153021e+00,
  59. +3.857138e+00, +1.056102e+00,
  60. +3.787532e+00, +1.162215e+00,
  61. +3.704477e+00, +1.227272e+00,
  62. +3.564711e+00, +1.294959e+00,
  63. +3.754363e+00, +1.191948e+00,
  64. +3.482098e+00, +1.322725e+00,
  65. +3.602777e+00, +1.279658e+00,
  66. +3.585433e+00, +1.286858e+00,
  67. +3.347505e+00, +1.356415e+00,
  68. +3.220855e+00, +1.378914e+00,
  69. +3.558808e+00, +1.297174e+00,
  70. +3.403618e+00, +1.343809e+00,
  71. +3.179828e+00, +1.384721e+00,
  72. +3.054789e+00, +1.398759e+00,
  73. +3.294153e+00, +1.366808e+00,
  74. +3.247312e+00, +1.374813e+00,
  75. +2.988547e+00, +1.404247e+00,
  76. +3.114508e+00, +1.392698e+00,
  77. +2.899226e+00, +1.409802e+00,
  78. +2.533256e+00, +1.414778e+00,
  79. +2.654773e+00, +1.415909e+00,
  80. +2.565100e+00, +1.415313e+00,
  81. +2.976456e+00, +1.405118e+00,
  82. +2.484200e+00, +1.413640e+00,
  83. +2.324751e+00, +1.407476e+00,
  84. +1.930468e+00, +1.378221e+00,
  85. +2.329017e+00, +1.407688e+00,
  86. +1.760640e+00, +1.360319e+00,
  87. +2.147375e+00, +1.396603e+00,
  88. +1.741989e+00, +1.358178e+00,
  89. +1.743859e+00, +1.358394e+00,
  90. +1.557372e+00, +1.335208e+00,
  91. +1.280551e+00, +1.295087e+00,
  92. +1.429880e+00, +1.317546e+00,
  93. +1.213485e+00, +1.284400e+00,
  94. +9.168172e-01, +1.232870e+00,
  95. +1.311141e+00, +1.299839e+00,
  96. +1.231969e+00, +1.287382e+00,
  97. +7.453773e-01, +1.200049e+00,
  98. +6.151587e-01, +1.173683e+00,
  99. +5.935666e-01, +1.169193e+00,
  100. +2.538707e-01, +1.094227e+00,
  101. +6.806136e-01, +1.187089e+00,
  102. +2.805447e-01, +1.100405e+00,
  103. +6.184807e-01, +1.174371e+00,
  104. +1.170550e-01, +1.061762e+00,
  105. +2.890507e-01, +1.102365e+00,
  106. +3.834234e-01, +1.123772e+00,
  107. +3.980161e-04, +1.033061e+00,
  108. -3.651680e-01, +9.370367e-01,
  109. -8.386351e-01, +7.987201e-01,
  110. -8.105704e-01, +8.073702e-01,
  111. -8.735139e-01, +7.878886e-01,
  112. -9.913836e-01, +7.506100e-01,
  113. -8.784011e-01, +7.863636e-01,
  114. -1.181440e+00, +6.882566e-01,
  115. -1.229556e+00, +6.720191e-01,
  116. -1.035839e+00, +7.362765e-01,
  117. -8.031520e-01, +8.096470e-01,
  118. -1.539136e+00, +5.629549e-01,
  119. -1.755423e+00, +4.817306e-01,
  120. -1.337589e+00, +6.348763e-01,
  121. -1.836966e+00, +4.499485e-01,
  122. -1.913367e+00, +4.195617e-01,
  123. -2.126467e+00, +3.314900e-01,
  124. -1.927625e+00, +4.138238e-01,
  125. -2.339862e+00, +2.379074e-01,
  126. -1.881736e+00, +4.322152e-01,
  127. -2.116753e+00, +3.356163e-01,
  128. -2.255733e+00, +2.754930e-01,
  129. -2.555834e+00, +1.368473e-01,
  130. -2.770277e+00, +2.895711e-02,
  131. -2.563376e+00, +1.331890e-01,
  132. -2.826715e+00, -9.000818e-04,
  133. -2.978191e+00, -8.457804e-02,
  134. -3.115855e+00, -1.658786e-01,
  135. -2.982049e+00, -8.678322e-02,
  136. -3.307892e+00, -2.902083e-01,
  137. -3.038346e+00, -1.194222e-01,
  138. -3.190057e+00, -2.122060e-01,
  139. -3.279086e+00, -2.705777e-01,
  140. -3.322028e+00, -2.999889e-01,
  141. -3.122576e+00, -1.699965e-01,
  142. -3.551973e+00, -4.768674e-01,
  143. -3.581866e+00, -5.032175e-01,
  144. -3.497799e+00, -4.315203e-01,
  145. -3.565384e+00, -4.885602e-01,
  146. -3.699493e+00, -6.199815e-01,
  147. -3.585166e+00, -5.061925e-01,
  148. -3.758914e+00, -6.918275e-01,
  149. -3.741104e+00, -6.689131e-01,
  150. -3.688331e+00, -6.077239e-01,
  151. -3.810425e+00, -7.689015e-01,
  152. -3.791829e+00, -7.386911e-01,
  153. -3.789951e+00, -7.358189e-01,
  154. -3.823100e+00, -7.918398e-01,
  155. -3.857021e+00, -8.727074e-01,
  156. -3.858250e+00, -8.767645e-01,
  157. -3.872100e+00, -9.563174e-01,
  158. -3.864397e+00, -1.032630e+00,
  159. -3.846230e+00, -1.081669e+00,
  160. -3.834799e+00, -1.102536e+00,
  161. -3.866684e+00, -1.022901e+00,
  162. -3.808643e+00, -1.139084e+00,
  163. -3.868840e+00, -1.011569e+00,
  164. -3.791071e+00, -1.158615e+00,
  165. -3.797999e+00, -1.151267e+00,
  166. -3.696278e+00, -1.232314e+00,
  167. -3.779007e+00, -1.170504e+00,
  168. -3.622855e+00, -1.270793e+00,
  169. -3.647249e+00, -1.259166e+00,
  170. -3.655412e+00, -1.255042e+00,
  171. -3.573218e+00, -1.291696e+00,
  172. -3.638019e+00, -1.263684e+00,
  173. -3.498409e+00, -1.317750e+00,
  174. -3.304143e+00, -1.364970e+00,
  175. -3.183001e+00, -1.384295e+00,
  176. -3.202456e+00, -1.381599e+00,
  177. -3.244063e+00, -1.375332e+00,
  178. -3.233308e+00, -1.377019e+00,
  179. -3.060112e+00, -1.398264e+00,
  180. -3.078187e+00, -1.396517e+00,
  181. -2.689594e+00, -1.415761e+00,
  182. -2.947662e+00, -1.407039e+00,
  183. -2.854490e+00, -1.411860e+00,
  184. -2.660499e+00, -1.415900e+00,
  185. -2.875955e+00, -1.410930e+00,
  186. -2.675385e+00, -1.415848e+00,
  187. -2.813155e+00, -1.413363e+00,
  188. -2.417673e+00, -1.411512e+00,
  189. -2.725461e+00, -1.415373e+00,
  190. -2.148334e+00, -1.396672e+00,
  191. -2.108972e+00, -1.393738e+00,
  192. -2.029905e+00, -1.387302e+00,
  193. -2.046214e+00, -1.388687e+00,
  194. -2.057402e+00, -1.389621e+00,
  195. -1.650250e+00, -1.347160e+00,
  196. -1.806764e+00, -1.365469e+00,
  197. -1.206973e+00, -1.283343e+00,
  198. -8.029259e-01, -1.211308e+00,
  199. -1.229551e+00, -1.286993e+00,
  200. -1.101507e+00, -1.265754e+00,
  201. -9.110645e-01, -1.231804e+00,
  202. -1.110046e+00, -1.267211e+00,
  203. -8.465274e-01, -1.219677e+00,
  204. -7.594163e-01, -1.202818e+00,
  205. -8.023823e-01, -1.211203e+00,
  206. -3.732519e-01, -1.121494e+00,
  207. -1.918373e-01, -1.079668e+00,
  208. -4.671988e-01, -1.142253e+00,
  209. -4.033645e-01, -1.128215e+00,
  210. -1.920740e-01, -1.079724e+00,
  211. -3.022157e-01, -1.105389e+00,
  212. -1.652831e-01, -1.073354e+00,
  213. +4.671625e-01, -9.085886e-01,
  214. +5.940178e-01, -8.721832e-01,
  215. +3.147557e-01, -9.508290e-01,
  216. +6.383631e-01, -8.591867e-01,
  217. +9.888923e-01, -7.514088e-01,
  218. +7.076339e-01, -8.386023e-01,
  219. +1.326682e+00, -6.386698e-01,
  220. +1.149834e+00, -6.988221e-01,
  221. +1.257742e+00, -6.624207e-01,
  222. +1.492352e+00, -5.799632e-01,
  223. +1.595574e+00, -5.421766e-01,
  224. +1.240173e+00, -6.684113e-01,
  225. +1.706612e+00, -5.004442e-01,
  226. +1.873984e+00, -4.353002e-01,
  227. +1.985633e+00, -3.902561e-01,
  228. +1.722880e+00, -4.942329e-01,
  229. +2.095182e+00, -3.447402e-01,
  230. +2.018118e+00, -3.768991e-01,
  231. +2.422702e+00, -1.999563e-01,
  232. +2.370611e+00, -2.239326e-01,
  233. +2.152154e+00, -3.205250e-01,
  234. +2.525121e+00, -1.516499e-01,
  235. +2.422116e+00, -2.002280e-01,
  236. +2.842806e+00, +9.536372e-03,
  237. +3.030128e+00, +1.146027e-01,
  238. +2.888424e+00, +3.433444e-02,
  239. +2.991609e+00, +9.226409e-02,
  240. +2.924807e+00, +5.445844e-02,
  241. +3.007772e+00, +1.015875e-01,
  242. +2.781973e+00, -2.282382e-02,
  243. +3.164737e+00, +1.961781e-01,
  244. +3.237671e+00, +2.430139e-01,
  245. +3.046123e+00, +1.240014e-01,
  246. +3.414834e+00, +3.669060e-01,
  247. +3.436591e+00, +3.833600e-01,
  248. +3.626207e+00, +5.444311e-01,
  249. +3.223325e+00, +2.336361e-01,
  250. +3.511963e+00, +4.431060e-01,
  251. +3.698380e+00, +6.187442e-01,
  252. +3.670244e+00, +5.884943e-01,
  253. +3.558833e+00, +4.828230e-01,
  254. +3.661807e+00, +5.797689e-01,
  255. +3.767261e+00, +7.030893e-01,
  256. +3.801065e+00, +7.532650e-01,
  257. +3.828523e+00, +8.024454e-01,
  258. +3.840719e+00, +8.287032e-01,
  259. +3.848748e+00, +8.485921e-01,
  260. +3.865801e+00, +9.066551e-01,
  261. +3.870983e+00, +9.404873e-01,
  262. +3.870263e+00, +1.001884e+00,
  263. +3.864462e+00, +1.032374e+00,
  264. +3.870542e+00, +9.996121e-01,
  265. +3.865424e+00, +1.028474e+00
  266. };
  267. // clang-format on
  268. ConstMatrixRef kY(kYData, kYRows, kYCols);
  269. class PointToLineSegmentContourCostFunction : public CostFunction {
  270. public:
  271. // This class needs to have an Eigen aligned operator new as it contains
  272. // fixed-size Eigen types.
  273. EIGEN_MAKE_ALIGNED_OPERATOR_NEW
  274. PointToLineSegmentContourCostFunction(const int num_segments,
  275. Eigen::Vector2d y)
  276. : num_segments_(num_segments), y_(std::move(y)) {
  277. // The first parameter is the preimage position.
  278. mutable_parameter_block_sizes()->push_back(1);
  279. // The next parameters are the control points for the line segment contour.
  280. for (int i = 0; i < num_segments_; ++i) {
  281. mutable_parameter_block_sizes()->push_back(2);
  282. }
  283. set_num_residuals(2);
  284. }
  285. bool Evaluate(const double* const* x,
  286. double* residuals,
  287. double** jacobians) const final {
  288. // Convert the preimage position `t` into a segment index `i0` and the
  289. // line segment interpolation parameter `u`. `i1` is the index of the next
  290. // control point.
  291. const double t = ModuloNumSegments(*x[0]);
  292. CHECK_GE(t, 0.0);
  293. CHECK_LT(t, num_segments_);
  294. const int i0 = floor(t), i1 = (i0 + 1) % num_segments_;
  295. const double u = t - i0;
  296. // Linearly interpolate between control points `i0` and `i1`.
  297. residuals[0] = y_[0] - ((1.0 - u) * x[1 + i0][0] + u * x[1 + i1][0]);
  298. residuals[1] = y_[1] - ((1.0 - u) * x[1 + i0][1] + u * x[1 + i1][1]);
  299. if (jacobians == nullptr) {
  300. return true;
  301. }
  302. if (jacobians[0] != nullptr) {
  303. jacobians[0][0] = x[1 + i0][0] - x[1 + i1][0];
  304. jacobians[0][1] = x[1 + i0][1] - x[1 + i1][1];
  305. }
  306. for (int i = 0; i < num_segments_; ++i) {
  307. if (jacobians[i + 1] != nullptr) {
  308. MatrixRef(jacobians[i + 1], 2, 2).setZero();
  309. if (i == i0) {
  310. jacobians[i + 1][0] = -(1.0 - u);
  311. jacobians[i + 1][3] = -(1.0 - u);
  312. } else if (i == i1) {
  313. jacobians[i + 1][0] = -u;
  314. jacobians[i + 1][3] = -u;
  315. }
  316. }
  317. }
  318. return true;
  319. }
  320. static CostFunction* Create(const int num_segments,
  321. const Eigen::Vector2d& y) {
  322. return new PointToLineSegmentContourCostFunction(num_segments, y);
  323. }
  324. private:
  325. inline double ModuloNumSegments(const double t) const {
  326. return t - num_segments_ * floor(t / num_segments_);
  327. }
  328. const int num_segments_;
  329. const Eigen::Vector2d y_;
  330. };
  331. class EuclideanDistanceFunctor {
  332. public:
  333. explicit EuclideanDistanceFunctor(const double sqrt_weight)
  334. : sqrt_weight_(sqrt_weight) {}
  335. template <typename T>
  336. bool operator()(const T* x0, const T* x1, T* residuals) const {
  337. residuals[0] = sqrt_weight_ * (x0[0] - x1[0]);
  338. residuals[1] = sqrt_weight_ * (x0[1] - x1[1]);
  339. return true;
  340. }
  341. static CostFunction* Create(const double sqrt_weight) {
  342. return new AutoDiffCostFunction<EuclideanDistanceFunctor, 2, 2, 2>(
  343. new EuclideanDistanceFunctor(sqrt_weight));
  344. }
  345. private:
  346. const double sqrt_weight_;
  347. };
  348. TEST(DynamicSparsity, StaticAndDynamicSparsityProduceSameSolution) {
  349. // Skip test if there is no sparse linear algebra library that
  350. // supports dynamic sparsity.
  351. if (!IsSparseLinearAlgebraLibraryTypeAvailable(SUITE_SPARSE) &&
  352. !IsSparseLinearAlgebraLibraryTypeAvailable(EIGEN_SPARSE)) {
  353. return;
  354. }
  355. // Problem configuration.
  356. const int num_segments = 151;
  357. const double regularization_weight = 1e-2;
  358. // `X` is the matrix of control points which make up the contour of line
  359. // segments. The number of control points is equal to the number of line
  360. // segments because the contour is closed.
  361. //
  362. // Initialize `X` to points on the unit circle.
  363. Vector w(num_segments + 1);
  364. w.setLinSpaced(num_segments + 1, 0.0, 2.0 * constants::pi);
  365. w.conservativeResize(num_segments);
  366. Matrix X(num_segments, 2);
  367. X.col(0) = w.array().cos();
  368. X.col(1) = w.array().sin();
  369. // Each data point has an associated preimage position on the line segment
  370. // contour. For each data point we initialize the preimage positions to
  371. // the index of the closest control point.
  372. const int num_observations = kY.rows();
  373. Vector t(num_observations);
  374. for (int i = 0; i < num_observations; ++i) {
  375. (X.rowwise() - kY.row(i)).rowwise().squaredNorm().minCoeff(&t[i]);
  376. }
  377. Problem problem;
  378. // For each data point add a residual which measures its distance to its
  379. // corresponding position on the line segment contour.
  380. std::vector<double*> parameter_blocks(1 + num_segments);
  381. parameter_blocks[0] = nullptr;
  382. for (int i = 0; i < num_segments; ++i) {
  383. parameter_blocks[i + 1] = X.data() + 2 * i;
  384. }
  385. for (int i = 0; i < num_observations; ++i) {
  386. parameter_blocks[0] = &t[i];
  387. problem.AddResidualBlock(
  388. PointToLineSegmentContourCostFunction::Create(num_segments, kY.row(i)),
  389. nullptr,
  390. parameter_blocks);
  391. }
  392. // Add regularization to minimize the length of the line segment contour.
  393. for (int i = 0; i < num_segments; ++i) {
  394. problem.AddResidualBlock(
  395. EuclideanDistanceFunctor::Create(sqrt(regularization_weight)),
  396. nullptr,
  397. X.data() + 2 * i,
  398. X.data() + 2 * ((i + 1) % num_segments));
  399. }
  400. Solver::Options options;
  401. options.max_num_iterations = 100;
  402. options.linear_solver_type = SPARSE_NORMAL_CHOLESKY;
  403. // Only SuiteSparse & EigenSparse currently support dynamic sparsity.
  404. options.sparse_linear_algebra_library_type =
  405. #if !defined(CERES_NO_SUITESPARSE)
  406. ceres::SUITE_SPARSE;
  407. #elif defined(CERES_USE_EIGEN_SPARSE)
  408. ceres::EIGEN_SPARSE;
  409. #endif
  410. // First, solve `X` and `t` jointly with dynamic_sparsity = true.
  411. Matrix X0 = X;
  412. Vector t0 = t;
  413. options.dynamic_sparsity = false;
  414. Solver::Summary static_summary;
  415. Solve(options, &problem, &static_summary);
  416. EXPECT_EQ(static_summary.termination_type, CONVERGENCE)
  417. << static_summary.FullReport();
  418. X = X0;
  419. t = t0;
  420. options.dynamic_sparsity = true;
  421. Solver::Summary dynamic_summary;
  422. Solve(options, &problem, &dynamic_summary);
  423. EXPECT_EQ(dynamic_summary.termination_type, CONVERGENCE)
  424. << dynamic_summary.FullReport();
  425. EXPECT_NEAR(static_summary.final_cost,
  426. dynamic_summary.final_cost,
  427. std::numeric_limits<double>::epsilon())
  428. << "Static: \n"
  429. << static_summary.FullReport() << "\nDynamic: \n"
  430. << dynamic_summary.FullReport();
  431. }
  432. } // namespace ceres::internal