linsolve.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. from warnings import warn
  2. import numpy as np
  3. from numpy import asarray
  4. from scipy.sparse import (isspmatrix_csc, isspmatrix_csr, isspmatrix,
  5. SparseEfficiencyWarning, csc_matrix, csr_matrix)
  6. from scipy.sparse._sputils import is_pydata_spmatrix
  7. from scipy.linalg import LinAlgError
  8. import copy
  9. from . import _superlu
  10. noScikit = False
  11. try:
  12. import scikits.umfpack as umfpack
  13. except ImportError:
  14. noScikit = True
  15. useUmfpack = not noScikit
  16. __all__ = ['use_solver', 'spsolve', 'splu', 'spilu', 'factorized',
  17. 'MatrixRankWarning', 'spsolve_triangular']
  18. class MatrixRankWarning(UserWarning):
  19. pass
  20. def use_solver(**kwargs):
  21. """
  22. Select default sparse direct solver to be used.
  23. Parameters
  24. ----------
  25. useUmfpack : bool, optional
  26. Use UMFPACK [1]_, [2]_, [3]_, [4]_. over SuperLU. Has effect only
  27. if ``scikits.umfpack`` is installed. Default: True
  28. assumeSortedIndices : bool, optional
  29. Allow UMFPACK to skip the step of sorting indices for a CSR/CSC matrix.
  30. Has effect only if useUmfpack is True and ``scikits.umfpack`` is
  31. installed. Default: False
  32. Notes
  33. -----
  34. The default sparse solver is UMFPACK when available
  35. (``scikits.umfpack`` is installed). This can be changed by passing
  36. useUmfpack = False, which then causes the always present SuperLU
  37. based solver to be used.
  38. UMFPACK requires a CSR/CSC matrix to have sorted column/row indices. If
  39. sure that the matrix fulfills this, pass ``assumeSortedIndices=True``
  40. to gain some speed.
  41. References
  42. ----------
  43. .. [1] T. A. Davis, Algorithm 832: UMFPACK - an unsymmetric-pattern
  44. multifrontal method with a column pre-ordering strategy, ACM
  45. Trans. on Mathematical Software, 30(2), 2004, pp. 196--199.
  46. https://dl.acm.org/doi/abs/10.1145/992200.992206
  47. .. [2] T. A. Davis, A column pre-ordering strategy for the
  48. unsymmetric-pattern multifrontal method, ACM Trans.
  49. on Mathematical Software, 30(2), 2004, pp. 165--195.
  50. https://dl.acm.org/doi/abs/10.1145/992200.992205
  51. .. [3] T. A. Davis and I. S. Duff, A combined unifrontal/multifrontal
  52. method for unsymmetric sparse matrices, ACM Trans. on
  53. Mathematical Software, 25(1), 1999, pp. 1--19.
  54. https://doi.org/10.1145/305658.287640
  55. .. [4] T. A. Davis and I. S. Duff, An unsymmetric-pattern multifrontal
  56. method for sparse LU factorization, SIAM J. Matrix Analysis and
  57. Computations, 18(1), 1997, pp. 140--158.
  58. https://doi.org/10.1137/S0895479894246905T.
  59. Examples
  60. --------
  61. >>> import numpy as np
  62. >>> from scipy.sparse.linalg import use_solver, spsolve
  63. >>> from scipy.sparse import csc_matrix
  64. >>> R = np.random.randn(5, 5)
  65. >>> A = csc_matrix(R)
  66. >>> b = np.random.randn(5)
  67. >>> use_solver(useUmfpack=False) # enforce superLU over UMFPACK
  68. >>> x = spsolve(A, b)
  69. >>> np.allclose(A.dot(x), b)
  70. True
  71. >>> use_solver(useUmfpack=True) # reset umfPack usage to default
  72. """
  73. if 'useUmfpack' in kwargs:
  74. globals()['useUmfpack'] = kwargs['useUmfpack']
  75. if useUmfpack and 'assumeSortedIndices' in kwargs:
  76. umfpack.configure(assumeSortedIndices=kwargs['assumeSortedIndices'])
  77. def _get_umf_family(A):
  78. """Get umfpack family string given the sparse matrix dtype."""
  79. _families = {
  80. (np.float64, np.int32): 'di',
  81. (np.complex128, np.int32): 'zi',
  82. (np.float64, np.int64): 'dl',
  83. (np.complex128, np.int64): 'zl'
  84. }
  85. f_type = np.sctypeDict[A.dtype.name]
  86. i_type = np.sctypeDict[A.indices.dtype.name]
  87. try:
  88. family = _families[(f_type, i_type)]
  89. except KeyError as e:
  90. msg = 'only float64 or complex128 matrices with int32 or int64' \
  91. ' indices are supported! (got: matrix: %s, indices: %s)' \
  92. % (f_type, i_type)
  93. raise ValueError(msg) from e
  94. # See gh-8278. Considered converting only if
  95. # A.shape[0]*A.shape[1] > np.iinfo(np.int32).max,
  96. # but that didn't always fix the issue.
  97. family = family[0] + "l"
  98. A_new = copy.copy(A)
  99. A_new.indptr = np.array(A.indptr, copy=False, dtype=np.int64)
  100. A_new.indices = np.array(A.indices, copy=False, dtype=np.int64)
  101. return family, A_new
  102. def spsolve(A, b, permc_spec=None, use_umfpack=True):
  103. """Solve the sparse linear system Ax=b, where b may be a vector or a matrix.
  104. Parameters
  105. ----------
  106. A : ndarray or sparse matrix
  107. The square matrix A will be converted into CSC or CSR form
  108. b : ndarray or sparse matrix
  109. The matrix or vector representing the right hand side of the equation.
  110. If a vector, b.shape must be (n,) or (n, 1).
  111. permc_spec : str, optional
  112. How to permute the columns of the matrix for sparsity preservation.
  113. (default: 'COLAMD')
  114. - ``NATURAL``: natural ordering.
  115. - ``MMD_ATA``: minimum degree ordering on the structure of A^T A.
  116. - ``MMD_AT_PLUS_A``: minimum degree ordering on the structure of A^T+A.
  117. - ``COLAMD``: approximate minimum degree column ordering [1]_, [2]_.
  118. use_umfpack : bool, optional
  119. if True (default) then use UMFPACK for the solution [3]_, [4]_, [5]_,
  120. [6]_ . This is only referenced if b is a vector and
  121. ``scikits.umfpack`` is installed.
  122. Returns
  123. -------
  124. x : ndarray or sparse matrix
  125. the solution of the sparse linear equation.
  126. If b is a vector, then x is a vector of size A.shape[1]
  127. If b is a matrix, then x is a matrix of size (A.shape[1], b.shape[1])
  128. Notes
  129. -----
  130. For solving the matrix expression AX = B, this solver assumes the resulting
  131. matrix X is sparse, as is often the case for very sparse inputs. If the
  132. resulting X is dense, the construction of this sparse result will be
  133. relatively expensive. In that case, consider converting A to a dense
  134. matrix and using scipy.linalg.solve or its variants.
  135. References
  136. ----------
  137. .. [1] T. A. Davis, J. R. Gilbert, S. Larimore, E. Ng, Algorithm 836:
  138. COLAMD, an approximate column minimum degree ordering algorithm,
  139. ACM Trans. on Mathematical Software, 30(3), 2004, pp. 377--380.
  140. :doi:`10.1145/1024074.1024080`
  141. .. [2] T. A. Davis, J. R. Gilbert, S. Larimore, E. Ng, A column approximate
  142. minimum degree ordering algorithm, ACM Trans. on Mathematical
  143. Software, 30(3), 2004, pp. 353--376. :doi:`10.1145/1024074.1024079`
  144. .. [3] T. A. Davis, Algorithm 832: UMFPACK - an unsymmetric-pattern
  145. multifrontal method with a column pre-ordering strategy, ACM
  146. Trans. on Mathematical Software, 30(2), 2004, pp. 196--199.
  147. https://dl.acm.org/doi/abs/10.1145/992200.992206
  148. .. [4] T. A. Davis, A column pre-ordering strategy for the
  149. unsymmetric-pattern multifrontal method, ACM Trans.
  150. on Mathematical Software, 30(2), 2004, pp. 165--195.
  151. https://dl.acm.org/doi/abs/10.1145/992200.992205
  152. .. [5] T. A. Davis and I. S. Duff, A combined unifrontal/multifrontal
  153. method for unsymmetric sparse matrices, ACM Trans. on
  154. Mathematical Software, 25(1), 1999, pp. 1--19.
  155. https://doi.org/10.1145/305658.287640
  156. .. [6] T. A. Davis and I. S. Duff, An unsymmetric-pattern multifrontal
  157. method for sparse LU factorization, SIAM J. Matrix Analysis and
  158. Computations, 18(1), 1997, pp. 140--158.
  159. https://doi.org/10.1137/S0895479894246905T.
  160. Examples
  161. --------
  162. >>> import numpy as np
  163. >>> from scipy.sparse import csc_matrix
  164. >>> from scipy.sparse.linalg import spsolve
  165. >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
  166. >>> B = csc_matrix([[2, 0], [-1, 0], [2, 0]], dtype=float)
  167. >>> x = spsolve(A, B)
  168. >>> np.allclose(A.dot(x).toarray(), B.toarray())
  169. True
  170. """
  171. if is_pydata_spmatrix(A):
  172. A = A.to_scipy_sparse().tocsc()
  173. if not (isspmatrix_csc(A) or isspmatrix_csr(A)):
  174. A = csc_matrix(A)
  175. warn('spsolve requires A be CSC or CSR matrix format',
  176. SparseEfficiencyWarning)
  177. # b is a vector only if b have shape (n,) or (n, 1)
  178. b_is_sparse = isspmatrix(b) or is_pydata_spmatrix(b)
  179. if not b_is_sparse:
  180. b = asarray(b)
  181. b_is_vector = ((b.ndim == 1) or (b.ndim == 2 and b.shape[1] == 1))
  182. # sum duplicates for non-canonical format
  183. A.sum_duplicates()
  184. A = A.asfptype() # upcast to a floating point format
  185. result_dtype = np.promote_types(A.dtype, b.dtype)
  186. if A.dtype != result_dtype:
  187. A = A.astype(result_dtype)
  188. if b.dtype != result_dtype:
  189. b = b.astype(result_dtype)
  190. # validate input shapes
  191. M, N = A.shape
  192. if (M != N):
  193. raise ValueError("matrix must be square (has shape %s)" % ((M, N),))
  194. if M != b.shape[0]:
  195. raise ValueError("matrix - rhs dimension mismatch (%s - %s)"
  196. % (A.shape, b.shape[0]))
  197. use_umfpack = use_umfpack and useUmfpack
  198. if b_is_vector and use_umfpack:
  199. if b_is_sparse:
  200. b_vec = b.toarray()
  201. else:
  202. b_vec = b
  203. b_vec = asarray(b_vec, dtype=A.dtype).ravel()
  204. if noScikit:
  205. raise RuntimeError('Scikits.umfpack not installed.')
  206. if A.dtype.char not in 'dD':
  207. raise ValueError("convert matrix data to double, please, using"
  208. " .astype(), or set linsolve.useUmfpack = False")
  209. umf_family, A = _get_umf_family(A)
  210. umf = umfpack.UmfpackContext(umf_family)
  211. x = umf.linsolve(umfpack.UMFPACK_A, A, b_vec,
  212. autoTranspose=True)
  213. else:
  214. if b_is_vector and b_is_sparse:
  215. b = b.toarray()
  216. b_is_sparse = False
  217. if not b_is_sparse:
  218. if isspmatrix_csc(A):
  219. flag = 1 # CSC format
  220. else:
  221. flag = 0 # CSR format
  222. options = dict(ColPerm=permc_spec)
  223. x, info = _superlu.gssv(N, A.nnz, A.data, A.indices, A.indptr,
  224. b, flag, options=options)
  225. if info != 0:
  226. warn("Matrix is exactly singular", MatrixRankWarning)
  227. x.fill(np.nan)
  228. if b_is_vector:
  229. x = x.ravel()
  230. else:
  231. # b is sparse
  232. Afactsolve = factorized(A)
  233. if not (isspmatrix_csc(b) or is_pydata_spmatrix(b)):
  234. warn('spsolve is more efficient when sparse b '
  235. 'is in the CSC matrix format', SparseEfficiencyWarning)
  236. b = csc_matrix(b)
  237. # Create a sparse output matrix by repeatedly applying
  238. # the sparse factorization to solve columns of b.
  239. data_segs = []
  240. row_segs = []
  241. col_segs = []
  242. for j in range(b.shape[1]):
  243. # TODO: replace this with
  244. # bj = b[:, j].toarray().ravel()
  245. # once 1D sparse arrays are supported.
  246. # That is a slightly faster code path.
  247. bj = b[:, [j]].toarray().ravel()
  248. xj = Afactsolve(bj)
  249. w = np.flatnonzero(xj)
  250. segment_length = w.shape[0]
  251. row_segs.append(w)
  252. col_segs.append(np.full(segment_length, j, dtype=int))
  253. data_segs.append(np.asarray(xj[w], dtype=A.dtype))
  254. sparse_data = np.concatenate(data_segs)
  255. sparse_row = np.concatenate(row_segs)
  256. sparse_col = np.concatenate(col_segs)
  257. x = A.__class__((sparse_data, (sparse_row, sparse_col)),
  258. shape=b.shape, dtype=A.dtype)
  259. if is_pydata_spmatrix(b):
  260. x = b.__class__(x)
  261. return x
  262. def splu(A, permc_spec=None, diag_pivot_thresh=None,
  263. relax=None, panel_size=None, options=dict()):
  264. """
  265. Compute the LU decomposition of a sparse, square matrix.
  266. Parameters
  267. ----------
  268. A : sparse matrix
  269. Sparse matrix to factorize. Most efficient when provided in CSC
  270. format. Other formats will be converted to CSC before factorization.
  271. permc_spec : str, optional
  272. How to permute the columns of the matrix for sparsity preservation.
  273. (default: 'COLAMD')
  274. - ``NATURAL``: natural ordering.
  275. - ``MMD_ATA``: minimum degree ordering on the structure of A^T A.
  276. - ``MMD_AT_PLUS_A``: minimum degree ordering on the structure of A^T+A.
  277. - ``COLAMD``: approximate minimum degree column ordering
  278. diag_pivot_thresh : float, optional
  279. Threshold used for a diagonal entry to be an acceptable pivot.
  280. See SuperLU user's guide for details [1]_
  281. relax : int, optional
  282. Expert option for customizing the degree of relaxing supernodes.
  283. See SuperLU user's guide for details [1]_
  284. panel_size : int, optional
  285. Expert option for customizing the panel size.
  286. See SuperLU user's guide for details [1]_
  287. options : dict, optional
  288. Dictionary containing additional expert options to SuperLU.
  289. See SuperLU user guide [1]_ (section 2.4 on the 'Options' argument)
  290. for more details. For example, you can specify
  291. ``options=dict(Equil=False, IterRefine='SINGLE'))``
  292. to turn equilibration off and perform a single iterative refinement.
  293. Returns
  294. -------
  295. invA : scipy.sparse.linalg.SuperLU
  296. Object, which has a ``solve`` method.
  297. See also
  298. --------
  299. spilu : incomplete LU decomposition
  300. Notes
  301. -----
  302. This function uses the SuperLU library.
  303. References
  304. ----------
  305. .. [1] SuperLU https://portal.nersc.gov/project/sparse/superlu/
  306. Examples
  307. --------
  308. >>> import numpy as np
  309. >>> from scipy.sparse import csc_matrix
  310. >>> from scipy.sparse.linalg import splu
  311. >>> A = csc_matrix([[1., 0., 0.], [5., 0., 2.], [0., -1., 0.]], dtype=float)
  312. >>> B = splu(A)
  313. >>> x = np.array([1., 2., 3.], dtype=float)
  314. >>> B.solve(x)
  315. array([ 1. , -3. , -1.5])
  316. >>> A.dot(B.solve(x))
  317. array([ 1., 2., 3.])
  318. >>> B.solve(A.dot(x))
  319. array([ 1., 2., 3.])
  320. """
  321. if is_pydata_spmatrix(A):
  322. csc_construct_func = lambda *a, cls=type(A): cls(csc_matrix(*a))
  323. A = A.to_scipy_sparse().tocsc()
  324. else:
  325. csc_construct_func = csc_matrix
  326. if not isspmatrix_csc(A):
  327. A = csc_matrix(A)
  328. warn('splu converted its input to CSC format', SparseEfficiencyWarning)
  329. # sum duplicates for non-canonical format
  330. A.sum_duplicates()
  331. A = A.asfptype() # upcast to a floating point format
  332. M, N = A.shape
  333. if (M != N):
  334. raise ValueError("can only factor square matrices") # is this true?
  335. _options = dict(DiagPivotThresh=diag_pivot_thresh, ColPerm=permc_spec,
  336. PanelSize=panel_size, Relax=relax)
  337. if options is not None:
  338. _options.update(options)
  339. # Ensure that no column permutations are applied
  340. if (_options["ColPerm"] == "NATURAL"):
  341. _options["SymmetricMode"] = True
  342. return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
  343. csc_construct_func=csc_construct_func,
  344. ilu=False, options=_options)
  345. def spilu(A, drop_tol=None, fill_factor=None, drop_rule=None, permc_spec=None,
  346. diag_pivot_thresh=None, relax=None, panel_size=None, options=None):
  347. """
  348. Compute an incomplete LU decomposition for a sparse, square matrix.
  349. The resulting object is an approximation to the inverse of `A`.
  350. Parameters
  351. ----------
  352. A : (N, N) array_like
  353. Sparse matrix to factorize. Most efficient when provided in CSC format.
  354. Other formats will be converted to CSC before factorization.
  355. drop_tol : float, optional
  356. Drop tolerance (0 <= tol <= 1) for an incomplete LU decomposition.
  357. (default: 1e-4)
  358. fill_factor : float, optional
  359. Specifies the fill ratio upper bound (>= 1.0) for ILU. (default: 10)
  360. drop_rule : str, optional
  361. Comma-separated string of drop rules to use.
  362. Available rules: ``basic``, ``prows``, ``column``, ``area``,
  363. ``secondary``, ``dynamic``, ``interp``. (Default: ``basic,area``)
  364. See SuperLU documentation for details.
  365. Remaining other options
  366. Same as for `splu`
  367. Returns
  368. -------
  369. invA_approx : scipy.sparse.linalg.SuperLU
  370. Object, which has a ``solve`` method.
  371. See also
  372. --------
  373. splu : complete LU decomposition
  374. Notes
  375. -----
  376. To improve the better approximation to the inverse, you may need to
  377. increase `fill_factor` AND decrease `drop_tol`.
  378. This function uses the SuperLU library.
  379. Examples
  380. --------
  381. >>> import numpy as np
  382. >>> from scipy.sparse import csc_matrix
  383. >>> from scipy.sparse.linalg import spilu
  384. >>> A = csc_matrix([[1., 0., 0.], [5., 0., 2.], [0., -1., 0.]], dtype=float)
  385. >>> B = spilu(A)
  386. >>> x = np.array([1., 2., 3.], dtype=float)
  387. >>> B.solve(x)
  388. array([ 1. , -3. , -1.5])
  389. >>> A.dot(B.solve(x))
  390. array([ 1., 2., 3.])
  391. >>> B.solve(A.dot(x))
  392. array([ 1., 2., 3.])
  393. """
  394. if is_pydata_spmatrix(A):
  395. csc_construct_func = lambda *a, cls=type(A): cls(csc_matrix(*a))
  396. A = A.to_scipy_sparse().tocsc()
  397. else:
  398. csc_construct_func = csc_matrix
  399. if not isspmatrix_csc(A):
  400. A = csc_matrix(A)
  401. warn('spilu converted its input to CSC format',
  402. SparseEfficiencyWarning)
  403. # sum duplicates for non-canonical format
  404. A.sum_duplicates()
  405. A = A.asfptype() # upcast to a floating point format
  406. M, N = A.shape
  407. if (M != N):
  408. raise ValueError("can only factor square matrices") # is this true?
  409. _options = dict(ILU_DropRule=drop_rule, ILU_DropTol=drop_tol,
  410. ILU_FillFactor=fill_factor,
  411. DiagPivotThresh=diag_pivot_thresh, ColPerm=permc_spec,
  412. PanelSize=panel_size, Relax=relax)
  413. if options is not None:
  414. _options.update(options)
  415. # Ensure that no column permutations are applied
  416. if (_options["ColPerm"] == "NATURAL"):
  417. _options["SymmetricMode"] = True
  418. return _superlu.gstrf(N, A.nnz, A.data, A.indices, A.indptr,
  419. csc_construct_func=csc_construct_func,
  420. ilu=True, options=_options)
  421. def factorized(A):
  422. """
  423. Return a function for solving a sparse linear system, with A pre-factorized.
  424. Parameters
  425. ----------
  426. A : (N, N) array_like
  427. Input. A in CSC format is most efficient. A CSR format matrix will
  428. be converted to CSC before factorization.
  429. Returns
  430. -------
  431. solve : callable
  432. To solve the linear system of equations given in `A`, the `solve`
  433. callable should be passed an ndarray of shape (N,).
  434. Examples
  435. --------
  436. >>> import numpy as np
  437. >>> from scipy.sparse.linalg import factorized
  438. >>> A = np.array([[ 3. , 2. , -1. ],
  439. ... [ 2. , -2. , 4. ],
  440. ... [-1. , 0.5, -1. ]])
  441. >>> solve = factorized(A) # Makes LU decomposition.
  442. >>> rhs1 = np.array([1, -2, 0])
  443. >>> solve(rhs1) # Uses the LU factors.
  444. array([ 1., -2., -2.])
  445. """
  446. if is_pydata_spmatrix(A):
  447. A = A.to_scipy_sparse().tocsc()
  448. if useUmfpack:
  449. if noScikit:
  450. raise RuntimeError('Scikits.umfpack not installed.')
  451. if not isspmatrix_csc(A):
  452. A = csc_matrix(A)
  453. warn('splu converted its input to CSC format',
  454. SparseEfficiencyWarning)
  455. A = A.asfptype() # upcast to a floating point format
  456. if A.dtype.char not in 'dD':
  457. raise ValueError("convert matrix data to double, please, using"
  458. " .astype(), or set linsolve.useUmfpack = False")
  459. umf_family, A = _get_umf_family(A)
  460. umf = umfpack.UmfpackContext(umf_family)
  461. # Make LU decomposition.
  462. umf.numeric(A)
  463. def solve(b):
  464. with np.errstate(divide="ignore", invalid="ignore"):
  465. # Ignoring warnings with numpy >= 1.23.0, see gh-16523
  466. result = umf.solve(umfpack.UMFPACK_A, A, b, autoTranspose=True)
  467. return result
  468. return solve
  469. else:
  470. return splu(A).solve
  471. def spsolve_triangular(A, b, lower=True, overwrite_A=False, overwrite_b=False,
  472. unit_diagonal=False):
  473. """
  474. Solve the equation ``A x = b`` for `x`, assuming A is a triangular matrix.
  475. Parameters
  476. ----------
  477. A : (M, M) sparse matrix
  478. A sparse square triangular matrix. Should be in CSR format.
  479. b : (M,) or (M, N) array_like
  480. Right-hand side matrix in ``A x = b``
  481. lower : bool, optional
  482. Whether `A` is a lower or upper triangular matrix.
  483. Default is lower triangular matrix.
  484. overwrite_A : bool, optional
  485. Allow changing `A`. The indices of `A` are going to be sorted and zero
  486. entries are going to be removed.
  487. Enabling gives a performance gain. Default is False.
  488. overwrite_b : bool, optional
  489. Allow overwriting data in `b`.
  490. Enabling gives a performance gain. Default is False.
  491. If `overwrite_b` is True, it should be ensured that
  492. `b` has an appropriate dtype to be able to store the result.
  493. unit_diagonal : bool, optional
  494. If True, diagonal elements of `a` are assumed to be 1 and will not be
  495. referenced.
  496. .. versionadded:: 1.4.0
  497. Returns
  498. -------
  499. x : (M,) or (M, N) ndarray
  500. Solution to the system ``A x = b``. Shape of return matches shape
  501. of `b`.
  502. Raises
  503. ------
  504. LinAlgError
  505. If `A` is singular or not triangular.
  506. ValueError
  507. If shape of `A` or shape of `b` do not match the requirements.
  508. Notes
  509. -----
  510. .. versionadded:: 0.19.0
  511. Examples
  512. --------
  513. >>> import numpy as np
  514. >>> from scipy.sparse import csr_matrix
  515. >>> from scipy.sparse.linalg import spsolve_triangular
  516. >>> A = csr_matrix([[3, 0, 0], [1, -1, 0], [2, 0, 1]], dtype=float)
  517. >>> B = np.array([[2, 0], [-1, 0], [2, 0]], dtype=float)
  518. >>> x = spsolve_triangular(A, B)
  519. >>> np.allclose(A.dot(x), B)
  520. True
  521. """
  522. if is_pydata_spmatrix(A):
  523. A = A.to_scipy_sparse().tocsr()
  524. # Check the input for correct type and format.
  525. if not isspmatrix_csr(A):
  526. warn('CSR matrix format is required. Converting to CSR matrix.',
  527. SparseEfficiencyWarning)
  528. A = csr_matrix(A)
  529. elif not overwrite_A:
  530. A = A.copy()
  531. if A.shape[0] != A.shape[1]:
  532. raise ValueError(
  533. 'A must be a square matrix but its shape is {}.'.format(A.shape))
  534. # sum duplicates for non-canonical format
  535. A.sum_duplicates()
  536. b = np.asanyarray(b)
  537. if b.ndim not in [1, 2]:
  538. raise ValueError(
  539. 'b must have 1 or 2 dims but its shape is {}.'.format(b.shape))
  540. if A.shape[0] != b.shape[0]:
  541. raise ValueError(
  542. 'The size of the dimensions of A must be equal to '
  543. 'the size of the first dimension of b but the shape of A is '
  544. '{} and the shape of b is {}.'.format(A.shape, b.shape))
  545. # Init x as (a copy of) b.
  546. x_dtype = np.result_type(A.data, b, np.float64)
  547. if overwrite_b:
  548. if np.can_cast(b.dtype, x_dtype, casting='same_kind'):
  549. x = b
  550. else:
  551. raise ValueError(
  552. 'Cannot overwrite b (dtype {}) with result '
  553. 'of type {}.'.format(b.dtype, x_dtype))
  554. else:
  555. x = b.astype(x_dtype, copy=True)
  556. # Choose forward or backward order.
  557. if lower:
  558. row_indices = range(len(b))
  559. else:
  560. row_indices = range(len(b) - 1, -1, -1)
  561. # Fill x iteratively.
  562. for i in row_indices:
  563. # Get indices for i-th row.
  564. indptr_start = A.indptr[i]
  565. indptr_stop = A.indptr[i + 1]
  566. if lower:
  567. A_diagonal_index_row_i = indptr_stop - 1
  568. A_off_diagonal_indices_row_i = slice(indptr_start, indptr_stop - 1)
  569. else:
  570. A_diagonal_index_row_i = indptr_start
  571. A_off_diagonal_indices_row_i = slice(indptr_start + 1, indptr_stop)
  572. # Check regularity and triangularity of A.
  573. if not unit_diagonal and (indptr_stop <= indptr_start
  574. or A.indices[A_diagonal_index_row_i] < i):
  575. raise LinAlgError(
  576. 'A is singular: diagonal {} is zero.'.format(i))
  577. if not unit_diagonal and A.indices[A_diagonal_index_row_i] > i:
  578. raise LinAlgError(
  579. 'A is not triangular: A[{}, {}] is nonzero.'
  580. ''.format(i, A.indices[A_diagonal_index_row_i]))
  581. # Incorporate off-diagonal entries.
  582. A_column_indices_in_row_i = A.indices[A_off_diagonal_indices_row_i]
  583. A_values_in_row_i = A.data[A_off_diagonal_indices_row_i]
  584. x[i] -= np.dot(x[A_column_indices_in_row_i].T, A_values_in_row_i)
  585. # Compute i-th entry of x.
  586. if not unit_diagonal:
  587. x[i] /= A.data[A_diagonal_index_row_i]
  588. return x