_decomp_lu.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. """LU decomposition functions."""
  2. from warnings import warn
  3. from numpy import asarray, asarray_chkfinite
  4. # Local imports
  5. from ._misc import _datacopied, LinAlgWarning
  6. from .lapack import get_lapack_funcs
  7. from ._flinalg_py import get_flinalg_funcs
  8. __all__ = ['lu', 'lu_solve', 'lu_factor']
  9. def lu_factor(a, overwrite_a=False, check_finite=True):
  10. """
  11. Compute pivoted LU decomposition of a matrix.
  12. The decomposition is::
  13. A = P L U
  14. where P is a permutation matrix, L lower triangular with unit
  15. diagonal elements, and U upper triangular.
  16. Parameters
  17. ----------
  18. a : (M, N) array_like
  19. Matrix to decompose
  20. overwrite_a : bool, optional
  21. Whether to overwrite data in A (may increase performance)
  22. check_finite : bool, optional
  23. Whether to check that the input matrix contains only finite numbers.
  24. Disabling may give a performance gain, but may result in problems
  25. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  26. Returns
  27. -------
  28. lu : (M, N) ndarray
  29. Matrix containing U in its upper triangle, and L in its lower triangle.
  30. The unit diagonal elements of L are not stored.
  31. piv : (N,) ndarray
  32. Pivot indices representing the permutation matrix P:
  33. row i of matrix was interchanged with row piv[i].
  34. See Also
  35. --------
  36. lu : gives lu factorization in more user-friendly format
  37. lu_solve : solve an equation system using the LU factorization of a matrix
  38. Notes
  39. -----
  40. This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
  41. :func:`lu`, it outputs the L and U factors into a single array
  42. and returns pivot indices instead of a permutation matrix.
  43. Examples
  44. --------
  45. >>> import numpy as np
  46. >>> from scipy.linalg import lu_factor
  47. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  48. >>> lu, piv = lu_factor(A)
  49. >>> piv
  50. array([2, 2, 3, 3], dtype=int32)
  51. Convert LAPACK's ``piv`` array to NumPy index and test the permutation
  52. >>> piv_py = [2, 0, 3, 1]
  53. >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
  54. >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4)))
  55. True
  56. """
  57. if check_finite:
  58. a1 = asarray_chkfinite(a)
  59. else:
  60. a1 = asarray(a)
  61. overwrite_a = overwrite_a or (_datacopied(a1, a))
  62. getrf, = get_lapack_funcs(('getrf',), (a1,))
  63. lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
  64. if info < 0:
  65. raise ValueError('illegal value in %dth argument of '
  66. 'internal getrf (lu_factor)' % -info)
  67. if info > 0:
  68. warn("Diagonal number %d is exactly zero. Singular matrix." % info,
  69. LinAlgWarning, stacklevel=2)
  70. return lu, piv
  71. def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
  72. """Solve an equation system, a x = b, given the LU factorization of a
  73. Parameters
  74. ----------
  75. (lu, piv)
  76. Factorization of the coefficient matrix a, as given by lu_factor
  77. b : array
  78. Right-hand side
  79. trans : {0, 1, 2}, optional
  80. Type of system to solve:
  81. ===== =========
  82. trans system
  83. ===== =========
  84. 0 a x = b
  85. 1 a^T x = b
  86. 2 a^H x = b
  87. ===== =========
  88. overwrite_b : bool, optional
  89. Whether to overwrite data in b (may increase performance)
  90. check_finite : bool, optional
  91. Whether to check that the input matrices contain only finite numbers.
  92. Disabling may give a performance gain, but may result in problems
  93. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  94. Returns
  95. -------
  96. x : array
  97. Solution to the system
  98. See Also
  99. --------
  100. lu_factor : LU factorize a matrix
  101. Examples
  102. --------
  103. >>> import numpy as np
  104. >>> from scipy.linalg import lu_factor, lu_solve
  105. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  106. >>> b = np.array([1, 1, 1, 1])
  107. >>> lu, piv = lu_factor(A)
  108. >>> x = lu_solve((lu, piv), b)
  109. >>> np.allclose(A @ x - b, np.zeros((4,)))
  110. True
  111. """
  112. (lu, piv) = lu_and_piv
  113. if check_finite:
  114. b1 = asarray_chkfinite(b)
  115. else:
  116. b1 = asarray(b)
  117. overwrite_b = overwrite_b or _datacopied(b1, b)
  118. if lu.shape[0] != b1.shape[0]:
  119. raise ValueError("Shapes of lu {} and b {} are incompatible"
  120. .format(lu.shape, b1.shape))
  121. getrs, = get_lapack_funcs(('getrs',), (lu, b1))
  122. x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
  123. if info == 0:
  124. return x
  125. raise ValueError('illegal value in %dth argument of internal gesv|posv'
  126. % -info)
  127. def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
  128. """
  129. Compute pivoted LU decomposition of a matrix.
  130. The decomposition is::
  131. A = P L U
  132. where P is a permutation matrix, L lower triangular with unit
  133. diagonal elements, and U upper triangular.
  134. Parameters
  135. ----------
  136. a : (M, N) array_like
  137. Array to decompose
  138. permute_l : bool, optional
  139. Perform the multiplication P*L (Default: do not permute)
  140. overwrite_a : bool, optional
  141. Whether to overwrite data in a (may improve performance)
  142. check_finite : bool, optional
  143. Whether to check that the input matrix contains only finite numbers.
  144. Disabling may give a performance gain, but may result in problems
  145. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  146. Returns
  147. -------
  148. **(If permute_l == False)**
  149. p : (M, M) ndarray
  150. Permutation matrix
  151. l : (M, K) ndarray
  152. Lower triangular or trapezoidal matrix with unit diagonal.
  153. K = min(M, N)
  154. u : (K, N) ndarray
  155. Upper triangular or trapezoidal matrix
  156. **(If permute_l == True)**
  157. pl : (M, K) ndarray
  158. Permuted L matrix.
  159. K = min(M, N)
  160. u : (K, N) ndarray
  161. Upper triangular or trapezoidal matrix
  162. Notes
  163. -----
  164. This is a LU factorization routine written for SciPy.
  165. Examples
  166. --------
  167. >>> import numpy as np
  168. >>> from scipy.linalg import lu
  169. >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
  170. >>> p, l, u = lu(A)
  171. >>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
  172. True
  173. """
  174. if check_finite:
  175. a1 = asarray_chkfinite(a)
  176. else:
  177. a1 = asarray(a)
  178. if len(a1.shape) != 2:
  179. raise ValueError('expected matrix')
  180. overwrite_a = overwrite_a or (_datacopied(a1, a))
  181. flu, = get_flinalg_funcs(('lu',), (a1,))
  182. p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
  183. if info < 0:
  184. raise ValueError('illegal value in %dth argument of '
  185. 'internal lu.getrf' % -info)
  186. if permute_l:
  187. return l, u
  188. return p, l, u