123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- """LU decomposition functions."""
- from warnings import warn
- from numpy import asarray, asarray_chkfinite
- # Local imports
- from ._misc import _datacopied, LinAlgWarning
- from .lapack import get_lapack_funcs
- from ._flinalg_py import get_flinalg_funcs
- __all__ = ['lu', 'lu_solve', 'lu_factor']
- def lu_factor(a, overwrite_a=False, check_finite=True):
- """
- Compute pivoted LU decomposition of a matrix.
- The decomposition is::
- A = P L U
- where P is a permutation matrix, L lower triangular with unit
- diagonal elements, and U upper triangular.
- Parameters
- ----------
- a : (M, N) array_like
- Matrix to decompose
- overwrite_a : bool, optional
- Whether to overwrite data in A (may increase performance)
- check_finite : bool, optional
- Whether to check that the input matrix contains only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- Returns
- -------
- lu : (M, N) ndarray
- Matrix containing U in its upper triangle, and L in its lower triangle.
- The unit diagonal elements of L are not stored.
- piv : (N,) ndarray
- Pivot indices representing the permutation matrix P:
- row i of matrix was interchanged with row piv[i].
- See Also
- --------
- lu : gives lu factorization in more user-friendly format
- lu_solve : solve an equation system using the LU factorization of a matrix
- Notes
- -----
- This is a wrapper to the ``*GETRF`` routines from LAPACK. Unlike
- :func:`lu`, it outputs the L and U factors into a single array
- and returns pivot indices instead of a permutation matrix.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import lu_factor
- >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
- >>> lu, piv = lu_factor(A)
- >>> piv
- array([2, 2, 3, 3], dtype=int32)
- Convert LAPACK's ``piv`` array to NumPy index and test the permutation
- >>> piv_py = [2, 0, 3, 1]
- >>> L, U = np.tril(lu, k=-1) + np.eye(4), np.triu(lu)
- >>> np.allclose(A[piv_py] - L @ U, np.zeros((4, 4)))
- True
- """
- if check_finite:
- a1 = asarray_chkfinite(a)
- else:
- a1 = asarray(a)
- overwrite_a = overwrite_a or (_datacopied(a1, a))
- getrf, = get_lapack_funcs(('getrf',), (a1,))
- lu, piv, info = getrf(a1, overwrite_a=overwrite_a)
- if info < 0:
- raise ValueError('illegal value in %dth argument of '
- 'internal getrf (lu_factor)' % -info)
- if info > 0:
- warn("Diagonal number %d is exactly zero. Singular matrix." % info,
- LinAlgWarning, stacklevel=2)
- return lu, piv
- def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
- """Solve an equation system, a x = b, given the LU factorization of a
- Parameters
- ----------
- (lu, piv)
- Factorization of the coefficient matrix a, as given by lu_factor
- b : array
- Right-hand side
- trans : {0, 1, 2}, optional
- Type of system to solve:
- ===== =========
- trans system
- ===== =========
- 0 a x = b
- 1 a^T x = b
- 2 a^H x = b
- ===== =========
- overwrite_b : bool, optional
- Whether to overwrite data in b (may increase performance)
- check_finite : bool, optional
- Whether to check that the input matrices contain only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- Returns
- -------
- x : array
- Solution to the system
- See Also
- --------
- lu_factor : LU factorize a matrix
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import lu_factor, lu_solve
- >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
- >>> b = np.array([1, 1, 1, 1])
- >>> lu, piv = lu_factor(A)
- >>> x = lu_solve((lu, piv), b)
- >>> np.allclose(A @ x - b, np.zeros((4,)))
- True
- """
- (lu, piv) = lu_and_piv
- if check_finite:
- b1 = asarray_chkfinite(b)
- else:
- b1 = asarray(b)
- overwrite_b = overwrite_b or _datacopied(b1, b)
- if lu.shape[0] != b1.shape[0]:
- raise ValueError("Shapes of lu {} and b {} are incompatible"
- .format(lu.shape, b1.shape))
- getrs, = get_lapack_funcs(('getrs',), (lu, b1))
- x, info = getrs(lu, piv, b1, trans=trans, overwrite_b=overwrite_b)
- if info == 0:
- return x
- raise ValueError('illegal value in %dth argument of internal gesv|posv'
- % -info)
- def lu(a, permute_l=False, overwrite_a=False, check_finite=True):
- """
- Compute pivoted LU decomposition of a matrix.
- The decomposition is::
- A = P L U
- where P is a permutation matrix, L lower triangular with unit
- diagonal elements, and U upper triangular.
- Parameters
- ----------
- a : (M, N) array_like
- Array to decompose
- permute_l : bool, optional
- Perform the multiplication P*L (Default: do not permute)
- overwrite_a : bool, optional
- Whether to overwrite data in a (may improve performance)
- check_finite : bool, optional
- Whether to check that the input matrix contains only finite numbers.
- Disabling may give a performance gain, but may result in problems
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
- Returns
- -------
- **(If permute_l == False)**
- p : (M, M) ndarray
- Permutation matrix
- l : (M, K) ndarray
- Lower triangular or trapezoidal matrix with unit diagonal.
- K = min(M, N)
- u : (K, N) ndarray
- Upper triangular or trapezoidal matrix
- **(If permute_l == True)**
- pl : (M, K) ndarray
- Permuted L matrix.
- K = min(M, N)
- u : (K, N) ndarray
- Upper triangular or trapezoidal matrix
- Notes
- -----
- This is a LU factorization routine written for SciPy.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import lu
- >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
- >>> p, l, u = lu(A)
- >>> np.allclose(A - p @ l @ u, np.zeros((4, 4)))
- True
- """
- if check_finite:
- a1 = asarray_chkfinite(a)
- else:
- a1 = asarray(a)
- if len(a1.shape) != 2:
- raise ValueError('expected matrix')
- overwrite_a = overwrite_a or (_datacopied(a1, a))
- flu, = get_flinalg_funcs(('lu',), (a1,))
- p, l, u, info = flu(a1, permute_l=permute_l, overwrite_a=overwrite_a)
- if info < 0:
- raise ValueError('illegal value in %dth argument of '
- 'internal lu.getrf' % -info)
- if permute_l:
- return l, u
- return p, l, u
|