_decomp_schur.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """Schur decomposition functions."""
  2. import numpy
  3. from numpy import asarray_chkfinite, single, asarray, array
  4. from numpy.linalg import norm
  5. # Local imports.
  6. from ._misc import LinAlgError, _datacopied
  7. from .lapack import get_lapack_funcs
  8. from ._decomp import eigvals
  9. __all__ = ['schur', 'rsf2csf']
  10. _double_precision = ['i', 'l', 'd']
  11. def schur(a, output='real', lwork=None, overwrite_a=False, sort=None,
  12. check_finite=True):
  13. """
  14. Compute Schur decomposition of a matrix.
  15. The Schur decomposition is::
  16. A = Z T Z^H
  17. where Z is unitary and T is either upper-triangular, or for real
  18. Schur decomposition (output='real'), quasi-upper triangular. In
  19. the quasi-triangular form, 2x2 blocks describing complex-valued
  20. eigenvalue pairs may extrude from the diagonal.
  21. Parameters
  22. ----------
  23. a : (M, M) array_like
  24. Matrix to decompose
  25. output : {'real', 'complex'}, optional
  26. Construct the real or complex Schur decomposition (for real matrices).
  27. lwork : int, optional
  28. Work array size. If None or -1, it is automatically computed.
  29. overwrite_a : bool, optional
  30. Whether to overwrite data in a (may improve performance).
  31. sort : {None, callable, 'lhp', 'rhp', 'iuc', 'ouc'}, optional
  32. Specifies whether the upper eigenvalues should be sorted. A callable
  33. may be passed that, given a eigenvalue, returns a boolean denoting
  34. whether the eigenvalue should be sorted to the top-left (True).
  35. Alternatively, string parameters may be used::
  36. 'lhp' Left-hand plane (x.real < 0.0)
  37. 'rhp' Right-hand plane (x.real > 0.0)
  38. 'iuc' Inside the unit circle (x*x.conjugate() <= 1.0)
  39. 'ouc' Outside the unit circle (x*x.conjugate() > 1.0)
  40. Defaults to None (no sorting).
  41. check_finite : bool, optional
  42. Whether to check that the input matrix contains only finite numbers.
  43. Disabling may give a performance gain, but may result in problems
  44. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  45. Returns
  46. -------
  47. T : (M, M) ndarray
  48. Schur form of A. It is real-valued for the real Schur decomposition.
  49. Z : (M, M) ndarray
  50. An unitary Schur transformation matrix for A.
  51. It is real-valued for the real Schur decomposition.
  52. sdim : int
  53. If and only if sorting was requested, a third return value will
  54. contain the number of eigenvalues satisfying the sort condition.
  55. Raises
  56. ------
  57. LinAlgError
  58. Error raised under three conditions:
  59. 1. The algorithm failed due to a failure of the QR algorithm to
  60. compute all eigenvalues.
  61. 2. If eigenvalue sorting was requested, the eigenvalues could not be
  62. reordered due to a failure to separate eigenvalues, usually because
  63. of poor conditioning.
  64. 3. If eigenvalue sorting was requested, roundoff errors caused the
  65. leading eigenvalues to no longer satisfy the sorting condition.
  66. See Also
  67. --------
  68. rsf2csf : Convert real Schur form to complex Schur form
  69. Examples
  70. --------
  71. >>> import numpy as np
  72. >>> from scipy.linalg import schur, eigvals
  73. >>> A = np.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]])
  74. >>> T, Z = schur(A)
  75. >>> T
  76. array([[ 2.65896708, 1.42440458, -1.92933439],
  77. [ 0. , -0.32948354, -0.49063704],
  78. [ 0. , 1.31178921, -0.32948354]])
  79. >>> Z
  80. array([[0.72711591, -0.60156188, 0.33079564],
  81. [0.52839428, 0.79801892, 0.28976765],
  82. [0.43829436, 0.03590414, -0.89811411]])
  83. >>> T2, Z2 = schur(A, output='complex')
  84. >>> T2
  85. array([[ 2.65896708, -1.22839825+1.32378589j, 0.42590089+1.51937378j],
  86. [ 0. , -0.32948354+0.80225456j, -0.59877807+0.56192146j],
  87. [ 0. , 0. , -0.32948354-0.80225456j]])
  88. >>> eigvals(T2)
  89. array([2.65896708, -0.32948354+0.80225456j, -0.32948354-0.80225456j])
  90. An arbitrary custom eig-sorting condition, having positive imaginary part,
  91. which is satisfied by only one eigenvalue
  92. >>> T3, Z3, sdim = schur(A, output='complex', sort=lambda x: x.imag > 0)
  93. >>> sdim
  94. 1
  95. """
  96. if output not in ['real', 'complex', 'r', 'c']:
  97. raise ValueError("argument must be 'real', or 'complex'")
  98. if check_finite:
  99. a1 = asarray_chkfinite(a)
  100. else:
  101. a1 = asarray(a)
  102. if len(a1.shape) != 2 or (a1.shape[0] != a1.shape[1]):
  103. raise ValueError('expected square matrix')
  104. typ = a1.dtype.char
  105. if output in ['complex', 'c'] and typ not in ['F', 'D']:
  106. if typ in _double_precision:
  107. a1 = a1.astype('D')
  108. typ = 'D'
  109. else:
  110. a1 = a1.astype('F')
  111. typ = 'F'
  112. overwrite_a = overwrite_a or (_datacopied(a1, a))
  113. gees, = get_lapack_funcs(('gees',), (a1,))
  114. if lwork is None or lwork == -1:
  115. # get optimal work array
  116. result = gees(lambda x: None, a1, lwork=-1)
  117. lwork = result[-2][0].real.astype(numpy.int_)
  118. if sort is None:
  119. sort_t = 0
  120. sfunction = lambda x: None
  121. else:
  122. sort_t = 1
  123. if callable(sort):
  124. sfunction = sort
  125. elif sort == 'lhp':
  126. sfunction = lambda x: (x.real < 0.0)
  127. elif sort == 'rhp':
  128. sfunction = lambda x: (x.real >= 0.0)
  129. elif sort == 'iuc':
  130. sfunction = lambda x: (abs(x) <= 1.0)
  131. elif sort == 'ouc':
  132. sfunction = lambda x: (abs(x) > 1.0)
  133. else:
  134. raise ValueError("'sort' parameter must either be 'None', or a "
  135. "callable, or one of ('lhp','rhp','iuc','ouc')")
  136. result = gees(sfunction, a1, lwork=lwork, overwrite_a=overwrite_a,
  137. sort_t=sort_t)
  138. info = result[-1]
  139. if info < 0:
  140. raise ValueError('illegal value in {}-th argument of internal gees'
  141. ''.format(-info))
  142. elif info == a1.shape[0] + 1:
  143. raise LinAlgError('Eigenvalues could not be separated for reordering.')
  144. elif info == a1.shape[0] + 2:
  145. raise LinAlgError('Leading eigenvalues do not satisfy sort condition.')
  146. elif info > 0:
  147. raise LinAlgError("Schur form not found. Possibly ill-conditioned.")
  148. if sort_t == 0:
  149. return result[0], result[-3]
  150. else:
  151. return result[0], result[-3], result[1]
  152. eps = numpy.finfo(float).eps
  153. feps = numpy.finfo(single).eps
  154. _array_kind = {'b': 0, 'h': 0, 'B': 0, 'i': 0, 'l': 0,
  155. 'f': 0, 'd': 0, 'F': 1, 'D': 1}
  156. _array_precision = {'i': 1, 'l': 1, 'f': 0, 'd': 1, 'F': 0, 'D': 1}
  157. _array_type = [['f', 'd'], ['F', 'D']]
  158. def _commonType(*arrays):
  159. kind = 0
  160. precision = 0
  161. for a in arrays:
  162. t = a.dtype.char
  163. kind = max(kind, _array_kind[t])
  164. precision = max(precision, _array_precision[t])
  165. return _array_type[kind][precision]
  166. def _castCopy(type, *arrays):
  167. cast_arrays = ()
  168. for a in arrays:
  169. if a.dtype.char == type:
  170. cast_arrays = cast_arrays + (a.copy(),)
  171. else:
  172. cast_arrays = cast_arrays + (a.astype(type),)
  173. if len(cast_arrays) == 1:
  174. return cast_arrays[0]
  175. else:
  176. return cast_arrays
  177. def rsf2csf(T, Z, check_finite=True):
  178. """
  179. Convert real Schur form to complex Schur form.
  180. Convert a quasi-diagonal real-valued Schur form to the upper-triangular
  181. complex-valued Schur form.
  182. Parameters
  183. ----------
  184. T : (M, M) array_like
  185. Real Schur form of the original array
  186. Z : (M, M) array_like
  187. Schur transformation matrix
  188. check_finite : bool, optional
  189. Whether to check that the input arrays contain only finite numbers.
  190. Disabling may give a performance gain, but may result in problems
  191. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  192. Returns
  193. -------
  194. T : (M, M) ndarray
  195. Complex Schur form of the original array
  196. Z : (M, M) ndarray
  197. Schur transformation matrix corresponding to the complex form
  198. See Also
  199. --------
  200. schur : Schur decomposition of an array
  201. Examples
  202. --------
  203. >>> import numpy as np
  204. >>> from scipy.linalg import schur, rsf2csf
  205. >>> A = np.array([[0, 2, 2], [0, 1, 2], [1, 0, 1]])
  206. >>> T, Z = schur(A)
  207. >>> T
  208. array([[ 2.65896708, 1.42440458, -1.92933439],
  209. [ 0. , -0.32948354, -0.49063704],
  210. [ 0. , 1.31178921, -0.32948354]])
  211. >>> Z
  212. array([[0.72711591, -0.60156188, 0.33079564],
  213. [0.52839428, 0.79801892, 0.28976765],
  214. [0.43829436, 0.03590414, -0.89811411]])
  215. >>> T2 , Z2 = rsf2csf(T, Z)
  216. >>> T2
  217. array([[2.65896708+0.j, -1.64592781+0.743164187j, -1.21516887+1.00660462j],
  218. [0.+0.j , -0.32948354+8.02254558e-01j, -0.82115218-2.77555756e-17j],
  219. [0.+0.j , 0.+0.j, -0.32948354-0.802254558j]])
  220. >>> Z2
  221. array([[0.72711591+0.j, 0.28220393-0.31385693j, 0.51319638-0.17258824j],
  222. [0.52839428+0.j, 0.24720268+0.41635578j, -0.68079517-0.15118243j],
  223. [0.43829436+0.j, -0.76618703+0.01873251j, -0.03063006+0.46857912j]])
  224. """
  225. if check_finite:
  226. Z, T = map(asarray_chkfinite, (Z, T))
  227. else:
  228. Z, T = map(asarray, (Z, T))
  229. for ind, X in enumerate([Z, T]):
  230. if X.ndim != 2 or X.shape[0] != X.shape[1]:
  231. raise ValueError("Input '{}' must be square.".format('ZT'[ind]))
  232. if T.shape[0] != Z.shape[0]:
  233. raise ValueError("Input array shapes must match: Z: {} vs. T: {}"
  234. "".format(Z.shape, T.shape))
  235. N = T.shape[0]
  236. t = _commonType(Z, T, array([3.0], 'F'))
  237. Z, T = _castCopy(t, Z, T)
  238. for m in range(N-1, 0, -1):
  239. if abs(T[m, m-1]) > eps*(abs(T[m-1, m-1]) + abs(T[m, m])):
  240. mu = eigvals(T[m-1:m+1, m-1:m+1]) - T[m, m]
  241. r = norm([mu[0], T[m, m-1]])
  242. c = mu[0] / r
  243. s = T[m, m-1] / r
  244. G = array([[c.conj(), s], [-s, c]], dtype=t)
  245. T[m-1:m+1, m-1:] = G.dot(T[m-1:m+1, m-1:])
  246. T[:m+1, m-1:m+1] = T[:m+1, m-1:m+1].dot(G.conj().T)
  247. Z[:, m-1:m+1] = Z[:, m-1:m+1].dot(G.conj().T)
  248. T[m, m-1] = 0.0
  249. return T, Z