123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- """
- Matrix square root for general matrices and for upper triangular matrices.
- This module exists to avoid cyclic imports.
- """
- __all__ = ['sqrtm']
- import numpy as np
- from scipy._lib._util import _asarray_validated
- # Local imports
- from ._misc import norm
- from .lapack import ztrsyl, dtrsyl
- from ._decomp_schur import schur, rsf2csf
- class SqrtmError(np.linalg.LinAlgError):
- pass
- from ._matfuncs_sqrtm_triu import within_block_loop
- def _sqrtm_triu(T, blocksize=64):
- """
- Matrix square root of an upper triangular matrix.
- This is a helper function for `sqrtm` and `logm`.
- Parameters
- ----------
- T : (N, N) array_like upper triangular
- Matrix whose square root to evaluate
- blocksize : int, optional
- If the blocksize is not degenerate with respect to the
- size of the input array, then use a blocked algorithm. (Default: 64)
- Returns
- -------
- sqrtm : (N, N) ndarray
- Value of the sqrt function at `T`
- References
- ----------
- .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
- "Blocked Schur Algorithms for Computing the Matrix Square Root,
- Lecture Notes in Computer Science, 7782. pp. 171-182.
- """
- T_diag = np.diag(T)
- keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0
- # Cast to complex as necessary + ensure double precision
- if not keep_it_real:
- T = np.asarray(T, dtype=np.complex128, order="C")
- T_diag = np.asarray(T_diag, dtype=np.complex128)
- else:
- T = np.asarray(T, dtype=np.float64, order="C")
- T_diag = np.asarray(T_diag, dtype=np.float64)
- R = np.diag(np.sqrt(T_diag))
- # Compute the number of blocks to use; use at least one block.
- n, n = T.shape
- nblocks = max(n // blocksize, 1)
- # Compute the smaller of the two sizes of blocks that
- # we will actually use, and compute the number of large blocks.
- bsmall, nlarge = divmod(n, nblocks)
- blarge = bsmall + 1
- nsmall = nblocks - nlarge
- if nsmall * bsmall + nlarge * blarge != n:
- raise Exception('internal inconsistency')
- # Define the index range covered by each block.
- start_stop_pairs = []
- start = 0
- for count, size in ((nsmall, bsmall), (nlarge, blarge)):
- for i in range(count):
- start_stop_pairs.append((start, start + size))
- start += size
- # Within-block interactions (Cythonized)
- try:
- within_block_loop(R, T, start_stop_pairs, nblocks)
- except RuntimeError as e:
- raise SqrtmError(*e.args) from e
- # Between-block interactions (Cython would give no significant speedup)
- for j in range(nblocks):
- jstart, jstop = start_stop_pairs[j]
- for i in range(j-1, -1, -1):
- istart, istop = start_stop_pairs[i]
- S = T[istart:istop, jstart:jstop]
- if j - i > 1:
- S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
- jstart:jstop])
- # Invoke LAPACK.
- # For more details, see the solve_sylvester implemention
- # and the fortran dtrsyl and ztrsyl docs.
- Rii = R[istart:istop, istart:istop]
- Rjj = R[jstart:jstop, jstart:jstop]
- if keep_it_real:
- x, scale, info = dtrsyl(Rii, Rjj, S)
- else:
- x, scale, info = ztrsyl(Rii, Rjj, S)
- R[istart:istop, jstart:jstop] = x * scale
- # Return the matrix square root.
- return R
- def sqrtm(A, disp=True, blocksize=64):
- """
- Matrix square root.
- Parameters
- ----------
- A : (N, N) array_like
- Matrix whose square root to evaluate
- disp : bool, optional
- Print warning if error in the result is estimated large
- instead of returning estimated error. (Default: True)
- blocksize : integer, optional
- If the blocksize is not degenerate with respect to the
- size of the input array, then use a blocked algorithm. (Default: 64)
- Returns
- -------
- sqrtm : (N, N) ndarray
- Value of the sqrt function at `A`. The dtype is float or complex.
- The precision (data size) is determined based on the precision of
- input `A`. When the dtype is float, the precision is same as `A`.
- When the dtype is complex, the precition is double as `A`. The
- precision might be cliped by each dtype precision range.
- errest : float
- (if disp == False)
- Frobenius norm of the estimated error, ||err||_F / ||A||_F
- References
- ----------
- .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
- "Blocked Schur Algorithms for Computing the Matrix Square Root,
- Lecture Notes in Computer Science, 7782. pp. 171-182.
- Examples
- --------
- >>> import numpy as np
- >>> from scipy.linalg import sqrtm
- >>> a = np.array([[1.0, 3.0], [1.0, 4.0]])
- >>> r = sqrtm(a)
- >>> r
- array([[ 0.75592895, 1.13389342],
- [ 0.37796447, 1.88982237]])
- >>> r.dot(r)
- array([[ 1., 3.],
- [ 1., 4.]])
- """
- byte_size = np.asarray(A).dtype.itemsize
- A = _asarray_validated(A, check_finite=True, as_inexact=True)
- if len(A.shape) != 2:
- raise ValueError("Non-matrix input to matrix function.")
- if blocksize < 1:
- raise ValueError("The blocksize should be at least 1.")
- keep_it_real = np.isrealobj(A)
- if keep_it_real:
- T, Z = schur(A)
- if not np.array_equal(T, np.triu(T)):
- T, Z = rsf2csf(T, Z)
- else:
- T, Z = schur(A, output='complex')
- failflag = False
- try:
- R = _sqrtm_triu(T, blocksize=blocksize)
- ZH = np.conjugate(Z).T
- X = Z.dot(R).dot(ZH)
- if not np.iscomplexobj(X):
- # float byte size range: f2 ~ f16
- X = X.astype(f"f{np.clip(byte_size, 2, 16)}", copy=False)
- else:
- # complex byte size range: c8 ~ c32.
- # c32(complex256) might not be supported in some environments.
- if hasattr(np, 'complex256'):
- X = X.astype(f"c{np.clip(byte_size*2, 8, 32)}", copy=False)
- else:
- X = X.astype(f"c{np.clip(byte_size*2, 8, 16)}", copy=False)
- except SqrtmError:
- failflag = True
- X = np.empty_like(A)
- X.fill(np.nan)
- if disp:
- if failflag:
- print("Failed to find a square root.")
- return X
- else:
- try:
- arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')
- except ValueError:
- # NaNs in matrix
- arg2 = np.inf
- return X, arg2
|