_decomp_qr.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. """QR decomposition functions."""
  2. import numpy
  3. # Local imports
  4. from .lapack import get_lapack_funcs
  5. from ._misc import _datacopied
  6. __all__ = ['qr', 'qr_multiply', 'rq']
  7. def safecall(f, name, *args, **kwargs):
  8. """Call a LAPACK routine, determining lwork automatically and handling
  9. error return values"""
  10. lwork = kwargs.get("lwork", None)
  11. if lwork in (None, -1):
  12. kwargs['lwork'] = -1
  13. ret = f(*args, **kwargs)
  14. kwargs['lwork'] = ret[-2][0].real.astype(numpy.int_)
  15. ret = f(*args, **kwargs)
  16. if ret[-1] < 0:
  17. raise ValueError("illegal value in %dth argument of internal %s"
  18. % (-ret[-1], name))
  19. return ret[:-2]
  20. def qr(a, overwrite_a=False, lwork=None, mode='full', pivoting=False,
  21. check_finite=True):
  22. """
  23. Compute QR decomposition of a matrix.
  24. Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
  25. and R upper triangular.
  26. Parameters
  27. ----------
  28. a : (M, N) array_like
  29. Matrix to be decomposed
  30. overwrite_a : bool, optional
  31. Whether data in `a` is overwritten (may improve performance if
  32. `overwrite_a` is set to True by reusing the existing input data
  33. structure rather than creating a new one.)
  34. lwork : int, optional
  35. Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
  36. is computed.
  37. mode : {'full', 'r', 'economic', 'raw'}, optional
  38. Determines what information is to be returned: either both Q and R
  39. ('full', default), only R ('r') or both Q and R but computed in
  40. economy-size ('economic', see Notes). The final option 'raw'
  41. (added in SciPy 0.11) makes the function return two matrices
  42. (Q, TAU) in the internal format used by LAPACK.
  43. pivoting : bool, optional
  44. Whether or not factorization should include pivoting for rank-revealing
  45. qr decomposition. If pivoting, compute the decomposition
  46. ``A P = Q R`` as above, but where P is chosen such that the diagonal
  47. of R is non-increasing.
  48. check_finite : bool, optional
  49. Whether to check that the input matrix contains only finite numbers.
  50. Disabling may give a performance gain, but may result in problems
  51. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  52. Returns
  53. -------
  54. Q : float or complex ndarray
  55. Of shape (M, M), or (M, K) for ``mode='economic'``. Not returned
  56. if ``mode='r'``.
  57. R : float or complex ndarray
  58. Of shape (M, N), or (K, N) for ``mode='economic'``. ``K = min(M, N)``.
  59. P : int ndarray
  60. Of shape (N,) for ``pivoting=True``. Not returned if
  61. ``pivoting=False``.
  62. Raises
  63. ------
  64. LinAlgError
  65. Raised if decomposition fails
  66. Notes
  67. -----
  68. This is an interface to the LAPACK routines dgeqrf, zgeqrf,
  69. dorgqr, zungqr, dgeqp3, and zgeqp3.
  70. If ``mode=economic``, the shapes of Q and R are (M, K) and (K, N) instead
  71. of (M,M) and (M,N), with ``K=min(M,N)``.
  72. Examples
  73. --------
  74. >>> import numpy as np
  75. >>> from scipy import linalg
  76. >>> rng = np.random.default_rng()
  77. >>> a = rng.standard_normal((9, 6))
  78. >>> q, r = linalg.qr(a)
  79. >>> np.allclose(a, np.dot(q, r))
  80. True
  81. >>> q.shape, r.shape
  82. ((9, 9), (9, 6))
  83. >>> r2 = linalg.qr(a, mode='r')
  84. >>> np.allclose(r, r2)
  85. True
  86. >>> q3, r3 = linalg.qr(a, mode='economic')
  87. >>> q3.shape, r3.shape
  88. ((9, 6), (6, 6))
  89. >>> q4, r4, p4 = linalg.qr(a, pivoting=True)
  90. >>> d = np.abs(np.diag(r4))
  91. >>> np.all(d[1:] <= d[:-1])
  92. True
  93. >>> np.allclose(a[:, p4], np.dot(q4, r4))
  94. True
  95. >>> q4.shape, r4.shape, p4.shape
  96. ((9, 9), (9, 6), (6,))
  97. >>> q5, r5, p5 = linalg.qr(a, mode='economic', pivoting=True)
  98. >>> q5.shape, r5.shape, p5.shape
  99. ((9, 6), (6, 6), (6,))
  100. """
  101. # 'qr' was the old default, equivalent to 'full'. Neither 'full' nor
  102. # 'qr' are used below.
  103. # 'raw' is used internally by qr_multiply
  104. if mode not in ['full', 'qr', 'r', 'economic', 'raw']:
  105. raise ValueError("Mode argument should be one of ['full', 'r',"
  106. "'economic', 'raw']")
  107. if check_finite:
  108. a1 = numpy.asarray_chkfinite(a)
  109. else:
  110. a1 = numpy.asarray(a)
  111. if len(a1.shape) != 2:
  112. raise ValueError("expected a 2-D array")
  113. M, N = a1.shape
  114. overwrite_a = overwrite_a or (_datacopied(a1, a))
  115. if pivoting:
  116. geqp3, = get_lapack_funcs(('geqp3',), (a1,))
  117. qr, jpvt, tau = safecall(geqp3, "geqp3", a1, overwrite_a=overwrite_a)
  118. jpvt -= 1 # geqp3 returns a 1-based index array, so subtract 1
  119. else:
  120. geqrf, = get_lapack_funcs(('geqrf',), (a1,))
  121. qr, tau = safecall(geqrf, "geqrf", a1, lwork=lwork,
  122. overwrite_a=overwrite_a)
  123. if mode not in ['economic', 'raw'] or M < N:
  124. R = numpy.triu(qr)
  125. else:
  126. R = numpy.triu(qr[:N, :])
  127. if pivoting:
  128. Rj = R, jpvt
  129. else:
  130. Rj = R,
  131. if mode == 'r':
  132. return Rj
  133. elif mode == 'raw':
  134. return ((qr, tau),) + Rj
  135. gor_un_gqr, = get_lapack_funcs(('orgqr',), (qr,))
  136. if M < N:
  137. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr[:, :M], tau,
  138. lwork=lwork, overwrite_a=1)
  139. elif mode == 'economic':
  140. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qr, tau, lwork=lwork,
  141. overwrite_a=1)
  142. else:
  143. t = qr.dtype.char
  144. qqr = numpy.empty((M, M), dtype=t)
  145. qqr[:, :N] = qr
  146. Q, = safecall(gor_un_gqr, "gorgqr/gungqr", qqr, tau, lwork=lwork,
  147. overwrite_a=1)
  148. return (Q,) + Rj
  149. def qr_multiply(a, c, mode='right', pivoting=False, conjugate=False,
  150. overwrite_a=False, overwrite_c=False):
  151. """
  152. Calculate the QR decomposition and multiply Q with a matrix.
  153. Calculate the decomposition ``A = Q R`` where Q is unitary/orthogonal
  154. and R upper triangular. Multiply Q with a vector or a matrix c.
  155. Parameters
  156. ----------
  157. a : (M, N), array_like
  158. Input array
  159. c : array_like
  160. Input array to be multiplied by ``q``.
  161. mode : {'left', 'right'}, optional
  162. ``Q @ c`` is returned if mode is 'left', ``c @ Q`` is returned if
  163. mode is 'right'.
  164. The shape of c must be appropriate for the matrix multiplications,
  165. if mode is 'left', ``min(a.shape) == c.shape[0]``,
  166. if mode is 'right', ``a.shape[0] == c.shape[1]``.
  167. pivoting : bool, optional
  168. Whether or not factorization should include pivoting for rank-revealing
  169. qr decomposition, see the documentation of qr.
  170. conjugate : bool, optional
  171. Whether Q should be complex-conjugated. This might be faster
  172. than explicit conjugation.
  173. overwrite_a : bool, optional
  174. Whether data in a is overwritten (may improve performance)
  175. overwrite_c : bool, optional
  176. Whether data in c is overwritten (may improve performance).
  177. If this is used, c must be big enough to keep the result,
  178. i.e. ``c.shape[0]`` = ``a.shape[0]`` if mode is 'left'.
  179. Returns
  180. -------
  181. CQ : ndarray
  182. The product of ``Q`` and ``c``.
  183. R : (K, N), ndarray
  184. R array of the resulting QR factorization where ``K = min(M, N)``.
  185. P : (N,) ndarray
  186. Integer pivot array. Only returned when ``pivoting=True``.
  187. Raises
  188. ------
  189. LinAlgError
  190. Raised if QR decomposition fails.
  191. Notes
  192. -----
  193. This is an interface to the LAPACK routines ``?GEQRF``, ``?ORMQR``,
  194. ``?UNMQR``, and ``?GEQP3``.
  195. .. versionadded:: 0.11.0
  196. Examples
  197. --------
  198. >>> import numpy as np
  199. >>> from scipy.linalg import qr_multiply, qr
  200. >>> A = np.array([[1, 3, 3], [2, 3, 2], [2, 3, 3], [1, 3, 2]])
  201. >>> qc, r1, piv1 = qr_multiply(A, 2*np.eye(4), pivoting=1)
  202. >>> qc
  203. array([[-1., 1., -1.],
  204. [-1., -1., 1.],
  205. [-1., -1., -1.],
  206. [-1., 1., 1.]])
  207. >>> r1
  208. array([[-6., -3., -5. ],
  209. [ 0., -1., -1.11022302e-16],
  210. [ 0., 0., -1. ]])
  211. >>> piv1
  212. array([1, 0, 2], dtype=int32)
  213. >>> q2, r2, piv2 = qr(A, mode='economic', pivoting=1)
  214. >>> np.allclose(2*q2 - qc, np.zeros((4, 3)))
  215. True
  216. """
  217. if mode not in ['left', 'right']:
  218. raise ValueError("Mode argument can only be 'left' or 'right' but "
  219. "not '{}'".format(mode))
  220. c = numpy.asarray_chkfinite(c)
  221. if c.ndim < 2:
  222. onedim = True
  223. c = numpy.atleast_2d(c)
  224. if mode == "left":
  225. c = c.T
  226. else:
  227. onedim = False
  228. a = numpy.atleast_2d(numpy.asarray(a)) # chkfinite done in qr
  229. M, N = a.shape
  230. if mode == 'left':
  231. if c.shape[0] != min(M, N + overwrite_c*(M-N)):
  232. raise ValueError('Array shapes are not compatible for Q @ c'
  233. ' operation: {} vs {}'.format(a.shape, c.shape))
  234. else:
  235. if M != c.shape[1]:
  236. raise ValueError('Array shapes are not compatible for c @ Q'
  237. ' operation: {} vs {}'.format(c.shape, a.shape))
  238. raw = qr(a, overwrite_a, None, "raw", pivoting)
  239. Q, tau = raw[0]
  240. gor_un_mqr, = get_lapack_funcs(('ormqr',), (Q,))
  241. if gor_un_mqr.typecode in ('s', 'd'):
  242. trans = "T"
  243. else:
  244. trans = "C"
  245. Q = Q[:, :min(M, N)]
  246. if M > N and mode == "left" and not overwrite_c:
  247. if conjugate:
  248. cc = numpy.zeros((c.shape[1], M), dtype=c.dtype, order="F")
  249. cc[:, :N] = c.T
  250. else:
  251. cc = numpy.zeros((M, c.shape[1]), dtype=c.dtype, order="F")
  252. cc[:N, :] = c
  253. trans = "N"
  254. if conjugate:
  255. lr = "R"
  256. else:
  257. lr = "L"
  258. overwrite_c = True
  259. elif c.flags["C_CONTIGUOUS"] and trans == "T" or conjugate:
  260. cc = c.T
  261. if mode == "left":
  262. lr = "R"
  263. else:
  264. lr = "L"
  265. else:
  266. trans = "N"
  267. cc = c
  268. if mode == "left":
  269. lr = "L"
  270. else:
  271. lr = "R"
  272. cQ, = safecall(gor_un_mqr, "gormqr/gunmqr", lr, trans, Q, tau, cc,
  273. overwrite_c=overwrite_c)
  274. if trans != "N":
  275. cQ = cQ.T
  276. if mode == "right":
  277. cQ = cQ[:, :min(M, N)]
  278. if onedim:
  279. cQ = cQ.ravel()
  280. return (cQ,) + raw[1:]
  281. def rq(a, overwrite_a=False, lwork=None, mode='full', check_finite=True):
  282. """
  283. Compute RQ decomposition of a matrix.
  284. Calculate the decomposition ``A = R Q`` where Q is unitary/orthogonal
  285. and R upper triangular.
  286. Parameters
  287. ----------
  288. a : (M, N) array_like
  289. Matrix to be decomposed
  290. overwrite_a : bool, optional
  291. Whether data in a is overwritten (may improve performance)
  292. lwork : int, optional
  293. Work array size, lwork >= a.shape[1]. If None or -1, an optimal size
  294. is computed.
  295. mode : {'full', 'r', 'economic'}, optional
  296. Determines what information is to be returned: either both Q and R
  297. ('full', default), only R ('r') or both Q and R but computed in
  298. economy-size ('economic', see Notes).
  299. check_finite : bool, optional
  300. Whether to check that the input matrix contains only finite numbers.
  301. Disabling may give a performance gain, but may result in problems
  302. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  303. Returns
  304. -------
  305. R : float or complex ndarray
  306. Of shape (M, N) or (M, K) for ``mode='economic'``. ``K = min(M, N)``.
  307. Q : float or complex ndarray
  308. Of shape (N, N) or (K, N) for ``mode='economic'``. Not returned
  309. if ``mode='r'``.
  310. Raises
  311. ------
  312. LinAlgError
  313. If decomposition fails.
  314. Notes
  315. -----
  316. This is an interface to the LAPACK routines sgerqf, dgerqf, cgerqf, zgerqf,
  317. sorgrq, dorgrq, cungrq and zungrq.
  318. If ``mode=economic``, the shapes of Q and R are (K, N) and (M, K) instead
  319. of (N,N) and (M,N), with ``K=min(M,N)``.
  320. Examples
  321. --------
  322. >>> import numpy as np
  323. >>> from scipy import linalg
  324. >>> rng = np.random.default_rng()
  325. >>> a = rng.standard_normal((6, 9))
  326. >>> r, q = linalg.rq(a)
  327. >>> np.allclose(a, r @ q)
  328. True
  329. >>> r.shape, q.shape
  330. ((6, 9), (9, 9))
  331. >>> r2 = linalg.rq(a, mode='r')
  332. >>> np.allclose(r, r2)
  333. True
  334. >>> r3, q3 = linalg.rq(a, mode='economic')
  335. >>> r3.shape, q3.shape
  336. ((6, 6), (6, 9))
  337. """
  338. if mode not in ['full', 'r', 'economic']:
  339. raise ValueError(
  340. "Mode argument should be one of ['full', 'r', 'economic']")
  341. if check_finite:
  342. a1 = numpy.asarray_chkfinite(a)
  343. else:
  344. a1 = numpy.asarray(a)
  345. if len(a1.shape) != 2:
  346. raise ValueError('expected matrix')
  347. M, N = a1.shape
  348. overwrite_a = overwrite_a or (_datacopied(a1, a))
  349. gerqf, = get_lapack_funcs(('gerqf',), (a1,))
  350. rq, tau = safecall(gerqf, 'gerqf', a1, lwork=lwork,
  351. overwrite_a=overwrite_a)
  352. if not mode == 'economic' or N < M:
  353. R = numpy.triu(rq, N-M)
  354. else:
  355. R = numpy.triu(rq[-M:, -M:])
  356. if mode == 'r':
  357. return R
  358. gor_un_grq, = get_lapack_funcs(('orgrq',), (rq,))
  359. if N < M:
  360. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq[-N:], tau, lwork=lwork,
  361. overwrite_a=1)
  362. elif mode == 'economic':
  363. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq, tau, lwork=lwork,
  364. overwrite_a=1)
  365. else:
  366. rq1 = numpy.empty((N, N), dtype=rq.dtype)
  367. rq1[-M:] = rq
  368. Q, = safecall(gor_un_grq, "gorgrq/gungrq", rq1, tau, lwork=lwork,
  369. overwrite_a=1)
  370. return R, Q