_matfuncs_sqrtm.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """
  2. Matrix square root for general matrices and for upper triangular matrices.
  3. This module exists to avoid cyclic imports.
  4. """
  5. __all__ = ['sqrtm']
  6. import numpy as np
  7. from scipy._lib._util import _asarray_validated
  8. # Local imports
  9. from ._misc import norm
  10. from .lapack import ztrsyl, dtrsyl
  11. from ._decomp_schur import schur, rsf2csf
  12. class SqrtmError(np.linalg.LinAlgError):
  13. pass
  14. from ._matfuncs_sqrtm_triu import within_block_loop
  15. def _sqrtm_triu(T, blocksize=64):
  16. """
  17. Matrix square root of an upper triangular matrix.
  18. This is a helper function for `sqrtm` and `logm`.
  19. Parameters
  20. ----------
  21. T : (N, N) array_like upper triangular
  22. Matrix whose square root to evaluate
  23. blocksize : int, optional
  24. If the blocksize is not degenerate with respect to the
  25. size of the input array, then use a blocked algorithm. (Default: 64)
  26. Returns
  27. -------
  28. sqrtm : (N, N) ndarray
  29. Value of the sqrt function at `T`
  30. References
  31. ----------
  32. .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
  33. "Blocked Schur Algorithms for Computing the Matrix Square Root,
  34. Lecture Notes in Computer Science, 7782. pp. 171-182.
  35. """
  36. T_diag = np.diag(T)
  37. keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0
  38. # Cast to complex as necessary + ensure double precision
  39. if not keep_it_real:
  40. T = np.asarray(T, dtype=np.complex128, order="C")
  41. T_diag = np.asarray(T_diag, dtype=np.complex128)
  42. else:
  43. T = np.asarray(T, dtype=np.float64, order="C")
  44. T_diag = np.asarray(T_diag, dtype=np.float64)
  45. R = np.diag(np.sqrt(T_diag))
  46. # Compute the number of blocks to use; use at least one block.
  47. n, n = T.shape
  48. nblocks = max(n // blocksize, 1)
  49. # Compute the smaller of the two sizes of blocks that
  50. # we will actually use, and compute the number of large blocks.
  51. bsmall, nlarge = divmod(n, nblocks)
  52. blarge = bsmall + 1
  53. nsmall = nblocks - nlarge
  54. if nsmall * bsmall + nlarge * blarge != n:
  55. raise Exception('internal inconsistency')
  56. # Define the index range covered by each block.
  57. start_stop_pairs = []
  58. start = 0
  59. for count, size in ((nsmall, bsmall), (nlarge, blarge)):
  60. for i in range(count):
  61. start_stop_pairs.append((start, start + size))
  62. start += size
  63. # Within-block interactions (Cythonized)
  64. try:
  65. within_block_loop(R, T, start_stop_pairs, nblocks)
  66. except RuntimeError as e:
  67. raise SqrtmError(*e.args) from e
  68. # Between-block interactions (Cython would give no significant speedup)
  69. for j in range(nblocks):
  70. jstart, jstop = start_stop_pairs[j]
  71. for i in range(j-1, -1, -1):
  72. istart, istop = start_stop_pairs[i]
  73. S = T[istart:istop, jstart:jstop]
  74. if j - i > 1:
  75. S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
  76. jstart:jstop])
  77. # Invoke LAPACK.
  78. # For more details, see the solve_sylvester implemention
  79. # and the fortran dtrsyl and ztrsyl docs.
  80. Rii = R[istart:istop, istart:istop]
  81. Rjj = R[jstart:jstop, jstart:jstop]
  82. if keep_it_real:
  83. x, scale, info = dtrsyl(Rii, Rjj, S)
  84. else:
  85. x, scale, info = ztrsyl(Rii, Rjj, S)
  86. R[istart:istop, jstart:jstop] = x * scale
  87. # Return the matrix square root.
  88. return R
  89. def sqrtm(A, disp=True, blocksize=64):
  90. """
  91. Matrix square root.
  92. Parameters
  93. ----------
  94. A : (N, N) array_like
  95. Matrix whose square root to evaluate
  96. disp : bool, optional
  97. Print warning if error in the result is estimated large
  98. instead of returning estimated error. (Default: True)
  99. blocksize : integer, optional
  100. If the blocksize is not degenerate with respect to the
  101. size of the input array, then use a blocked algorithm. (Default: 64)
  102. Returns
  103. -------
  104. sqrtm : (N, N) ndarray
  105. Value of the sqrt function at `A`. The dtype is float or complex.
  106. The precision (data size) is determined based on the precision of
  107. input `A`. When the dtype is float, the precision is same as `A`.
  108. When the dtype is complex, the precition is double as `A`. The
  109. precision might be cliped by each dtype precision range.
  110. errest : float
  111. (if disp == False)
  112. Frobenius norm of the estimated error, ||err||_F / ||A||_F
  113. References
  114. ----------
  115. .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
  116. "Blocked Schur Algorithms for Computing the Matrix Square Root,
  117. Lecture Notes in Computer Science, 7782. pp. 171-182.
  118. Examples
  119. --------
  120. >>> import numpy as np
  121. >>> from scipy.linalg import sqrtm
  122. >>> a = np.array([[1.0, 3.0], [1.0, 4.0]])
  123. >>> r = sqrtm(a)
  124. >>> r
  125. array([[ 0.75592895, 1.13389342],
  126. [ 0.37796447, 1.88982237]])
  127. >>> r.dot(r)
  128. array([[ 1., 3.],
  129. [ 1., 4.]])
  130. """
  131. byte_size = np.asarray(A).dtype.itemsize
  132. A = _asarray_validated(A, check_finite=True, as_inexact=True)
  133. if len(A.shape) != 2:
  134. raise ValueError("Non-matrix input to matrix function.")
  135. if blocksize < 1:
  136. raise ValueError("The blocksize should be at least 1.")
  137. keep_it_real = np.isrealobj(A)
  138. if keep_it_real:
  139. T, Z = schur(A)
  140. if not np.array_equal(T, np.triu(T)):
  141. T, Z = rsf2csf(T, Z)
  142. else:
  143. T, Z = schur(A, output='complex')
  144. failflag = False
  145. try:
  146. R = _sqrtm_triu(T, blocksize=blocksize)
  147. ZH = np.conjugate(Z).T
  148. X = Z.dot(R).dot(ZH)
  149. if not np.iscomplexobj(X):
  150. # float byte size range: f2 ~ f16
  151. X = X.astype(f"f{np.clip(byte_size, 2, 16)}", copy=False)
  152. else:
  153. # complex byte size range: c8 ~ c32.
  154. # c32(complex256) might not be supported in some environments.
  155. if hasattr(np, 'complex256'):
  156. X = X.astype(f"c{np.clip(byte_size*2, 8, 32)}", copy=False)
  157. else:
  158. X = X.astype(f"c{np.clip(byte_size*2, 8, 16)}", copy=False)
  159. except SqrtmError:
  160. failflag = True
  161. X = np.empty_like(A)
  162. X.fill(np.nan)
  163. if disp:
  164. if failflag:
  165. print("Failed to find a square root.")
  166. return X
  167. else:
  168. try:
  169. arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')
  170. except ValueError:
  171. # NaNs in matrix
  172. arg2 = np.inf
  173. return X, arg2