lsmr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. """
  2. Copyright (C) 2010 David Fong and Michael Saunders
  3. LSMR uses an iterative method.
  4. 07 Jun 2010: Documentation updated
  5. 03 Jun 2010: First release version in Python
  6. David Chin-lung Fong clfong@stanford.edu
  7. Institute for Computational and Mathematical Engineering
  8. Stanford University
  9. Michael Saunders saunders@stanford.edu
  10. Systems Optimization Laboratory
  11. Dept of MS&E, Stanford University.
  12. """
  13. __all__ = ['lsmr']
  14. from numpy import zeros, infty, atleast_1d, result_type
  15. from numpy.linalg import norm
  16. from math import sqrt
  17. from scipy.sparse.linalg._interface import aslinearoperator
  18. from scipy.sparse.linalg._isolve.lsqr import _sym_ortho
  19. def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8,
  20. maxiter=None, show=False, x0=None):
  21. """Iterative solver for least-squares problems.
  22. lsmr solves the system of linear equations ``Ax = b``. If the system
  23. is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``.
  24. ``A`` is a rectangular matrix of dimension m-by-n, where all cases are
  25. allowed: m = n, m > n, or m < n. ``b`` is a vector of length m.
  26. The matrix A may be dense or sparse (usually sparse).
  27. Parameters
  28. ----------
  29. A : {sparse matrix, ndarray, LinearOperator}
  30. Matrix A in the linear system.
  31. Alternatively, ``A`` can be a linear operator which can
  32. produce ``Ax`` and ``A^H x`` using, e.g.,
  33. ``scipy.sparse.linalg.LinearOperator``.
  34. b : array_like, shape (m,)
  35. Vector ``b`` in the linear system.
  36. damp : float
  37. Damping factor for regularized least-squares. `lsmr` solves
  38. the regularized least-squares problem::
  39. min ||(b) - ( A )x||
  40. ||(0) (damp*I) ||_2
  41. where damp is a scalar. If damp is None or 0, the system
  42. is solved without regularization. Default is 0.
  43. atol, btol : float, optional
  44. Stopping tolerances. `lsmr` continues iterations until a
  45. certain backward error estimate is smaller than some quantity
  46. depending on atol and btol. Let ``r = b - Ax`` be the
  47. residual vector for the current approximate solution ``x``.
  48. If ``Ax = b`` seems to be consistent, `lsmr` terminates
  49. when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``.
  50. Otherwise, `lsmr` terminates when ``norm(A^H r) <=
  51. atol * norm(A) * norm(r)``. If both tolerances are 1.0e-6 (default),
  52. the final ``norm(r)`` should be accurate to about 6
  53. digits. (The final ``x`` will usually have fewer correct digits,
  54. depending on ``cond(A)`` and the size of LAMBDA.) If `atol`
  55. or `btol` is None, a default value of 1.0e-6 will be used.
  56. Ideally, they should be estimates of the relative error in the
  57. entries of ``A`` and ``b`` respectively. For example, if the entries
  58. of ``A`` have 7 correct digits, set ``atol = 1e-7``. This prevents
  59. the algorithm from doing unnecessary work beyond the
  60. uncertainty of the input data.
  61. conlim : float, optional
  62. `lsmr` terminates if an estimate of ``cond(A)`` exceeds
  63. `conlim`. For compatible systems ``Ax = b``, conlim could be
  64. as large as 1.0e+12 (say). For least-squares problems,
  65. `conlim` should be less than 1.0e+8. If `conlim` is None, the
  66. default value is 1e+8. Maximum precision can be obtained by
  67. setting ``atol = btol = conlim = 0``, but the number of
  68. iterations may then be excessive. Default is 1e8.
  69. maxiter : int, optional
  70. `lsmr` terminates if the number of iterations reaches
  71. `maxiter`. The default is ``maxiter = min(m, n)``. For
  72. ill-conditioned systems, a larger value of `maxiter` may be
  73. needed. Default is False.
  74. show : bool, optional
  75. Print iterations logs if ``show=True``. Default is False.
  76. x0 : array_like, shape (n,), optional
  77. Initial guess of ``x``, if None zeros are used. Default is None.
  78. .. versionadded:: 1.0.0
  79. Returns
  80. -------
  81. x : ndarray of float
  82. Least-square solution returned.
  83. istop : int
  84. istop gives the reason for stopping::
  85. istop = 0 means x=0 is a solution. If x0 was given, then x=x0 is a
  86. solution.
  87. = 1 means x is an approximate solution to A@x = B,
  88. according to atol and btol.
  89. = 2 means x approximately solves the least-squares problem
  90. according to atol.
  91. = 3 means COND(A) seems to be greater than CONLIM.
  92. = 4 is the same as 1 with atol = btol = eps (machine
  93. precision)
  94. = 5 is the same as 2 with atol = eps.
  95. = 6 is the same as 3 with CONLIM = 1/eps.
  96. = 7 means ITN reached maxiter before the other stopping
  97. conditions were satisfied.
  98. itn : int
  99. Number of iterations used.
  100. normr : float
  101. ``norm(b-Ax)``
  102. normar : float
  103. ``norm(A^H (b - Ax))``
  104. norma : float
  105. ``norm(A)``
  106. conda : float
  107. Condition number of A.
  108. normx : float
  109. ``norm(x)``
  110. Notes
  111. -----
  112. .. versionadded:: 0.11.0
  113. References
  114. ----------
  115. .. [1] D. C.-L. Fong and M. A. Saunders,
  116. "LSMR: An iterative algorithm for sparse least-squares problems",
  117. SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011.
  118. :arxiv:`1006.0758`
  119. .. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/
  120. Examples
  121. --------
  122. >>> import numpy as np
  123. >>> from scipy.sparse import csc_matrix
  124. >>> from scipy.sparse.linalg import lsmr
  125. >>> A = csc_matrix([[1., 0.], [1., 1.], [0., 1.]], dtype=float)
  126. The first example has the trivial solution ``[0, 0]``
  127. >>> b = np.array([0., 0., 0.], dtype=float)
  128. >>> x, istop, itn, normr = lsmr(A, b)[:4]
  129. >>> istop
  130. 0
  131. >>> x
  132. array([0., 0.])
  133. The stopping code `istop=0` returned indicates that a vector of zeros was
  134. found as a solution. The returned solution `x` indeed contains
  135. ``[0., 0.]``. The next example has a non-trivial solution:
  136. >>> b = np.array([1., 0., -1.], dtype=float)
  137. >>> x, istop, itn, normr = lsmr(A, b)[:4]
  138. >>> istop
  139. 1
  140. >>> x
  141. array([ 1., -1.])
  142. >>> itn
  143. 1
  144. >>> normr
  145. 4.440892098500627e-16
  146. As indicated by `istop=1`, `lsmr` found a solution obeying the tolerance
  147. limits. The given solution ``[1., -1.]`` obviously solves the equation. The
  148. remaining return values include information about the number of iterations
  149. (`itn=1`) and the remaining difference of left and right side of the solved
  150. equation.
  151. The final example demonstrates the behavior in the case where there is no
  152. solution for the equation:
  153. >>> b = np.array([1., 0.01, -1.], dtype=float)
  154. >>> x, istop, itn, normr = lsmr(A, b)[:4]
  155. >>> istop
  156. 2
  157. >>> x
  158. array([ 1.00333333, -0.99666667])
  159. >>> A.dot(x)-b
  160. array([ 0.00333333, -0.00333333, 0.00333333])
  161. >>> normr
  162. 0.005773502691896255
  163. `istop` indicates that the system is inconsistent and thus `x` is rather an
  164. approximate solution to the corresponding least-squares problem. `normr`
  165. contains the minimal distance that was found.
  166. """
  167. A = aslinearoperator(A)
  168. b = atleast_1d(b)
  169. if b.ndim > 1:
  170. b = b.squeeze()
  171. msg = ('The exact solution is x = 0, or x = x0, if x0 was given ',
  172. 'Ax - b is small enough, given atol, btol ',
  173. 'The least-squares solution is good enough, given atol ',
  174. 'The estimate of cond(Abar) has exceeded conlim ',
  175. 'Ax - b is small enough for this machine ',
  176. 'The least-squares solution is good enough for this machine',
  177. 'Cond(Abar) seems to be too large for this machine ',
  178. 'The iteration limit has been reached ')
  179. hdg1 = ' itn x(1) norm r norm Ar'
  180. hdg2 = ' compatible LS norm A cond A'
  181. pfreq = 20 # print frequency (for repeating the heading)
  182. pcount = 0 # print counter
  183. m, n = A.shape
  184. # stores the num of singular values
  185. minDim = min([m, n])
  186. if maxiter is None:
  187. maxiter = minDim
  188. if x0 is None:
  189. dtype = result_type(A, b, float)
  190. else:
  191. dtype = result_type(A, b, x0, float)
  192. if show:
  193. print(' ')
  194. print('LSMR Least-squares solution of Ax = b\n')
  195. print(f'The matrix A has {m} rows and {n} columns')
  196. print('damp = %20.14e\n' % (damp))
  197. print('atol = %8.2e conlim = %8.2e\n' % (atol, conlim))
  198. print('btol = %8.2e maxiter = %8g\n' % (btol, maxiter))
  199. u = b
  200. normb = norm(b)
  201. if x0 is None:
  202. x = zeros(n, dtype)
  203. beta = normb.copy()
  204. else:
  205. x = atleast_1d(x0.copy())
  206. u = u - A.matvec(x)
  207. beta = norm(u)
  208. if beta > 0:
  209. u = (1 / beta) * u
  210. v = A.rmatvec(u)
  211. alpha = norm(v)
  212. else:
  213. v = zeros(n, dtype)
  214. alpha = 0
  215. if alpha > 0:
  216. v = (1 / alpha) * v
  217. # Initialize variables for 1st iteration.
  218. itn = 0
  219. zetabar = alpha * beta
  220. alphabar = alpha
  221. rho = 1
  222. rhobar = 1
  223. cbar = 1
  224. sbar = 0
  225. h = v.copy()
  226. hbar = zeros(n, dtype)
  227. # Initialize variables for estimation of ||r||.
  228. betadd = beta
  229. betad = 0
  230. rhodold = 1
  231. tautildeold = 0
  232. thetatilde = 0
  233. zeta = 0
  234. d = 0
  235. # Initialize variables for estimation of ||A|| and cond(A)
  236. normA2 = alpha * alpha
  237. maxrbar = 0
  238. minrbar = 1e+100
  239. normA = sqrt(normA2)
  240. condA = 1
  241. normx = 0
  242. # Items for use in stopping rules, normb set earlier
  243. istop = 0
  244. ctol = 0
  245. if conlim > 0:
  246. ctol = 1 / conlim
  247. normr = beta
  248. # Reverse the order here from the original matlab code because
  249. # there was an error on return when arnorm==0
  250. normar = alpha * beta
  251. if normar == 0:
  252. if show:
  253. print(msg[0])
  254. return x, istop, itn, normr, normar, normA, condA, normx
  255. if normb == 0:
  256. x[()] = 0
  257. return x, istop, itn, normr, normar, normA, condA, normx
  258. if show:
  259. print(' ')
  260. print(hdg1, hdg2)
  261. test1 = 1
  262. test2 = alpha / beta
  263. str1 = '%6g %12.5e' % (itn, x[0])
  264. str2 = ' %10.3e %10.3e' % (normr, normar)
  265. str3 = ' %8.1e %8.1e' % (test1, test2)
  266. print(''.join([str1, str2, str3]))
  267. # Main iteration loop.
  268. while itn < maxiter:
  269. itn = itn + 1
  270. # Perform the next step of the bidiagonalization to obtain the
  271. # next beta, u, alpha, v. These satisfy the relations
  272. # beta*u = A@v - alpha*u,
  273. # alpha*v = A'@u - beta*v.
  274. u *= -alpha
  275. u += A.matvec(v)
  276. beta = norm(u)
  277. if beta > 0:
  278. u *= (1 / beta)
  279. v *= -beta
  280. v += A.rmatvec(u)
  281. alpha = norm(v)
  282. if alpha > 0:
  283. v *= (1 / alpha)
  284. # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
  285. # Construct rotation Qhat_{k,2k+1}.
  286. chat, shat, alphahat = _sym_ortho(alphabar, damp)
  287. # Use a plane rotation (Q_i) to turn B_i to R_i
  288. rhoold = rho
  289. c, s, rho = _sym_ortho(alphahat, beta)
  290. thetanew = s*alpha
  291. alphabar = c*alpha
  292. # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
  293. rhobarold = rhobar
  294. zetaold = zeta
  295. thetabar = sbar * rho
  296. rhotemp = cbar * rho
  297. cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew)
  298. zeta = cbar * zetabar
  299. zetabar = - sbar * zetabar
  300. # Update h, h_hat, x.
  301. hbar *= - (thetabar * rho / (rhoold * rhobarold))
  302. hbar += h
  303. x += (zeta / (rho * rhobar)) * hbar
  304. h *= - (thetanew / rho)
  305. h += v
  306. # Estimate of ||r||.
  307. # Apply rotation Qhat_{k,2k+1}.
  308. betaacute = chat * betadd
  309. betacheck = -shat * betadd
  310. # Apply rotation Q_{k,k+1}.
  311. betahat = c * betaacute
  312. betadd = -s * betaacute
  313. # Apply rotation Qtilde_{k-1}.
  314. # betad = betad_{k-1} here.
  315. thetatildeold = thetatilde
  316. ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar)
  317. thetatilde = stildeold * rhobar
  318. rhodold = ctildeold * rhobar
  319. betad = - stildeold * betad + ctildeold * betahat
  320. # betad = betad_k here.
  321. # rhodold = rhod_k here.
  322. tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold
  323. taud = (zeta - thetatilde * tautildeold) / rhodold
  324. d = d + betacheck * betacheck
  325. normr = sqrt(d + (betad - taud)**2 + betadd * betadd)
  326. # Estimate ||A||.
  327. normA2 = normA2 + beta * beta
  328. normA = sqrt(normA2)
  329. normA2 = normA2 + alpha * alpha
  330. # Estimate cond(A).
  331. maxrbar = max(maxrbar, rhobarold)
  332. if itn > 1:
  333. minrbar = min(minrbar, rhobarold)
  334. condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp)
  335. # Test for convergence.
  336. # Compute norms for convergence testing.
  337. normar = abs(zetabar)
  338. normx = norm(x)
  339. # Now use these norms to estimate certain other quantities,
  340. # some of which will be small near a solution.
  341. test1 = normr / normb
  342. if (normA * normr) != 0:
  343. test2 = normar / (normA * normr)
  344. else:
  345. test2 = infty
  346. test3 = 1 / condA
  347. t1 = test1 / (1 + normA * normx / normb)
  348. rtol = btol + atol * normA * normx / normb
  349. # The following tests guard against extremely small values of
  350. # atol, btol or ctol. (The user may have set any or all of
  351. # the parameters atol, btol, conlim to 0.)
  352. # The effect is equivalent to the normAl tests using
  353. # atol = eps, btol = eps, conlim = 1/eps.
  354. if itn >= maxiter:
  355. istop = 7
  356. if 1 + test3 <= 1:
  357. istop = 6
  358. if 1 + test2 <= 1:
  359. istop = 5
  360. if 1 + t1 <= 1:
  361. istop = 4
  362. # Allow for tolerances set by the user.
  363. if test3 <= ctol:
  364. istop = 3
  365. if test2 <= atol:
  366. istop = 2
  367. if test1 <= rtol:
  368. istop = 1
  369. # See if it is time to print something.
  370. if show:
  371. if (n <= 40) or (itn <= 10) or (itn >= maxiter - 10) or \
  372. (itn % 10 == 0) or (test3 <= 1.1 * ctol) or \
  373. (test2 <= 1.1 * atol) or (test1 <= 1.1 * rtol) or \
  374. (istop != 0):
  375. if pcount >= pfreq:
  376. pcount = 0
  377. print(' ')
  378. print(hdg1, hdg2)
  379. pcount = pcount + 1
  380. str1 = '%6g %12.5e' % (itn, x[0])
  381. str2 = ' %10.3e %10.3e' % (normr, normar)
  382. str3 = ' %8.1e %8.1e' % (test1, test2)
  383. str4 = ' %8.1e %8.1e' % (normA, condA)
  384. print(''.join([str1, str2, str3, str4]))
  385. if istop > 0:
  386. break
  387. # Print the stopping condition.
  388. if show:
  389. print(' ')
  390. print('LSMR finished')
  391. print(msg[istop])
  392. print('istop =%8g normr =%8.1e' % (istop, normr))
  393. print(' normA =%8.1e normAr =%8.1e' % (normA, normar))
  394. print('itn =%8g condA =%8.1e' % (itn, condA))
  395. print(' normx =%8.1e' % (normx))
  396. print(str1, str2)
  397. print(str3, str4)
  398. return x, istop, itn, normr, normar, normA, condA, normx