_decomp_cossin.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # -*- coding: utf-8 -*-
  2. from collections.abc import Iterable
  3. import numpy as np
  4. from scipy._lib._util import _asarray_validated
  5. from scipy.linalg import block_diag, LinAlgError
  6. from .lapack import _compute_lwork, get_lapack_funcs
  7. __all__ = ['cossin']
  8. def cossin(X, p=None, q=None, separate=False,
  9. swap_sign=False, compute_u=True, compute_vh=True):
  10. """
  11. Compute the cosine-sine (CS) decomposition of an orthogonal/unitary matrix.
  12. X is an ``(m, m)`` orthogonal/unitary matrix, partitioned as the following
  13. where upper left block has the shape of ``(p, q)``::
  14. ┌ ┐
  15. │ I 0 0 │ 0 0 0 │
  16. ┌ ┐ ┌ ┐│ 0 C 0 │ 0 -S 0 │┌ ┐*
  17. │ X11 │ X12 │ │ U1 │ ││ 0 0 0 │ 0 0 -I ││ V1 │ │
  18. │ ────┼──── │ = │────┼────││─────────┼─────────││────┼────│
  19. │ X21 │ X22 │ │ │ U2 ││ 0 0 0 │ I 0 0 ││ │ V2 │
  20. └ ┘ └ ┘│ 0 S 0 │ 0 C 0 │└ ┘
  21. │ 0 0 I │ 0 0 0 │
  22. └ ┘
  23. ``U1``, ``U2``, ``V1``, ``V2`` are square orthogonal/unitary matrices of
  24. dimensions ``(p,p)``, ``(m-p,m-p)``, ``(q,q)``, and ``(m-q,m-q)``
  25. respectively, and ``C`` and ``S`` are ``(r, r)`` nonnegative diagonal
  26. matrices satisfying ``C^2 + S^2 = I`` where ``r = min(p, m-p, q, m-q)``.
  27. Moreover, the rank of the identity matrices are ``min(p, q) - r``,
  28. ``min(p, m - q) - r``, ``min(m - p, q) - r``, and ``min(m - p, m - q) - r``
  29. respectively.
  30. X can be supplied either by itself and block specifications p, q or its
  31. subblocks in an iterable from which the shapes would be derived. See the
  32. examples below.
  33. Parameters
  34. ----------
  35. X : array_like, iterable
  36. complex unitary or real orthogonal matrix to be decomposed, or iterable
  37. of subblocks ``X11``, ``X12``, ``X21``, ``X22``, when ``p``, ``q`` are
  38. omitted.
  39. p : int, optional
  40. Number of rows of the upper left block ``X11``, used only when X is
  41. given as an array.
  42. q : int, optional
  43. Number of columns of the upper left block ``X11``, used only when X is
  44. given as an array.
  45. separate : bool, optional
  46. if ``True``, the low level components are returned instead of the
  47. matrix factors, i.e. ``(u1,u2)``, ``theta``, ``(v1h,v2h)`` instead of
  48. ``u``, ``cs``, ``vh``.
  49. swap_sign : bool, optional
  50. if ``True``, the ``-S``, ``-I`` block will be the bottom left,
  51. otherwise (by default) they will be in the upper right block.
  52. compute_u : bool, optional
  53. if ``False``, ``u`` won't be computed and an empty array is returned.
  54. compute_vh : bool, optional
  55. if ``False``, ``vh`` won't be computed and an empty array is returned.
  56. Returns
  57. -------
  58. u : ndarray
  59. When ``compute_u=True``, contains the block diagonal orthogonal/unitary
  60. matrix consisting of the blocks ``U1`` (``p`` x ``p``) and ``U2``
  61. (``m-p`` x ``m-p``) orthogonal/unitary matrices. If ``separate=True``,
  62. this contains the tuple of ``(U1, U2)``.
  63. cs : ndarray
  64. The cosine-sine factor with the structure described above.
  65. If ``separate=True``, this contains the ``theta`` array containing the
  66. angles in radians.
  67. vh : ndarray
  68. When ``compute_vh=True`, contains the block diagonal orthogonal/unitary
  69. matrix consisting of the blocks ``V1H`` (``q`` x ``q``) and ``V2H``
  70. (``m-q`` x ``m-q``) orthogonal/unitary matrices. If ``separate=True``,
  71. this contains the tuple of ``(V1H, V2H)``.
  72. References
  73. ----------
  74. .. [1] Brian D. Sutton. Computing the complete CS decomposition. Numer.
  75. Algorithms, 50(1):33-65, 2009.
  76. Examples
  77. --------
  78. >>> import numpy as np
  79. >>> from scipy.linalg import cossin
  80. >>> from scipy.stats import unitary_group
  81. >>> x = unitary_group.rvs(4)
  82. >>> u, cs, vdh = cossin(x, p=2, q=2)
  83. >>> np.allclose(x, u @ cs @ vdh)
  84. True
  85. Same can be entered via subblocks without the need of ``p`` and ``q``. Also
  86. let's skip the computation of ``u``
  87. >>> ue, cs, vdh = cossin((x[:2, :2], x[:2, 2:], x[2:, :2], x[2:, 2:]),
  88. ... compute_u=False)
  89. >>> print(ue)
  90. []
  91. >>> np.allclose(x, u @ cs @ vdh)
  92. True
  93. """
  94. if p or q:
  95. p = 1 if p is None else int(p)
  96. q = 1 if q is None else int(q)
  97. X = _asarray_validated(X, check_finite=True)
  98. if not np.equal(*X.shape):
  99. raise ValueError("Cosine Sine decomposition only supports square"
  100. " matrices, got {}".format(X.shape))
  101. m = X.shape[0]
  102. if p >= m or p <= 0:
  103. raise ValueError("invalid p={}, 0<p<{} must hold"
  104. .format(p, X.shape[0]))
  105. if q >= m or q <= 0:
  106. raise ValueError("invalid q={}, 0<q<{} must hold"
  107. .format(q, X.shape[0]))
  108. x11, x12, x21, x22 = X[:p, :q], X[:p, q:], X[p:, :q], X[p:, q:]
  109. elif not isinstance(X, Iterable):
  110. raise ValueError("When p and q are None, X must be an Iterable"
  111. " containing the subblocks of X")
  112. else:
  113. if len(X) != 4:
  114. raise ValueError("When p and q are None, exactly four arrays"
  115. " should be in X, got {}".format(len(X)))
  116. x11, x12, x21, x22 = [np.atleast_2d(x) for x in X]
  117. for name, block in zip(["x11", "x12", "x21", "x22"],
  118. [x11, x12, x21, x22]):
  119. if block.shape[1] == 0:
  120. raise ValueError("{} can't be empty".format(name))
  121. p, q = x11.shape
  122. mmp, mmq = x22.shape
  123. if x12.shape != (p, mmq):
  124. raise ValueError("Invalid x12 dimensions: desired {}, "
  125. "got {}".format((p, mmq), x12.shape))
  126. if x21.shape != (mmp, q):
  127. raise ValueError("Invalid x21 dimensions: desired {}, "
  128. "got {}".format((mmp, q), x21.shape))
  129. if p + mmp != q + mmq:
  130. raise ValueError("The subblocks have compatible sizes but "
  131. "don't form a square array (instead they form a"
  132. " {}x{} array). This might be due to missing "
  133. "p, q arguments.".format(p + mmp, q + mmq))
  134. m = p + mmp
  135. cplx = any([np.iscomplexobj(x) for x in [x11, x12, x21, x22]])
  136. driver = "uncsd" if cplx else "orcsd"
  137. csd, csd_lwork = get_lapack_funcs([driver, driver + "_lwork"],
  138. [x11, x12, x21, x22])
  139. lwork = _compute_lwork(csd_lwork, m=m, p=p, q=q)
  140. lwork_args = ({'lwork': lwork[0], 'lrwork': lwork[1]} if cplx else
  141. {'lwork': lwork})
  142. *_, theta, u1, u2, v1h, v2h, info = csd(x11=x11, x12=x12, x21=x21, x22=x22,
  143. compute_u1=compute_u,
  144. compute_u2=compute_u,
  145. compute_v1t=compute_vh,
  146. compute_v2t=compute_vh,
  147. trans=False, signs=swap_sign,
  148. **lwork_args)
  149. method_name = csd.typecode + driver
  150. if info < 0:
  151. raise ValueError('illegal value in argument {} of internal {}'
  152. .format(-info, method_name))
  153. if info > 0:
  154. raise LinAlgError("{} did not converge: {}".format(method_name, info))
  155. if separate:
  156. return (u1, u2), theta, (v1h, v2h)
  157. U = block_diag(u1, u2)
  158. VDH = block_diag(v1h, v2h)
  159. # Construct the middle factor CS
  160. c = np.diag(np.cos(theta))
  161. s = np.diag(np.sin(theta))
  162. r = min(p, q, m - p, m - q)
  163. n11 = min(p, q) - r
  164. n12 = min(p, m - q) - r
  165. n21 = min(m - p, q) - r
  166. n22 = min(m - p, m - q) - r
  167. Id = np.eye(np.max([n11, n12, n21, n22, r]), dtype=theta.dtype)
  168. CS = np.zeros((m, m), dtype=theta.dtype)
  169. CS[:n11, :n11] = Id[:n11, :n11]
  170. xs = n11 + r
  171. xe = n11 + r + n12
  172. ys = n11 + n21 + n22 + 2 * r
  173. ye = n11 + n21 + n22 + 2 * r + n12
  174. CS[xs: xe, ys:ye] = Id[:n12, :n12] if swap_sign else -Id[:n12, :n12]
  175. xs = p + n22 + r
  176. xe = p + n22 + r + + n21
  177. ys = n11 + r
  178. ye = n11 + r + n21
  179. CS[xs:xe, ys:ye] = -Id[:n21, :n21] if swap_sign else Id[:n21, :n21]
  180. CS[p:p + n22, q:q + n22] = Id[:n22, :n22]
  181. CS[n11:n11 + r, n11:n11 + r] = c
  182. CS[p + n22:p + n22 + r, r + n21 + n22:2 * r + n21 + n22] = c
  183. xs = n11
  184. xe = n11 + r
  185. ys = n11 + n21 + n22 + r
  186. ye = n11 + n21 + n22 + 2 * r
  187. CS[xs:xe, ys:ye] = s if swap_sign else -s
  188. CS[p + n22:p + n22 + r, n11:n11 + r] = -s if swap_sign else s
  189. return U, CS, VDH