test_decomp.py 106 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904
  1. """ Test functions for linalg.decomp module
  2. """
  3. __usage__ = """
  4. Build linalg:
  5. python setup_linalg.py build
  6. Run tests if scipy is installed:
  7. python -c 'import scipy;scipy.linalg.test()'
  8. """
  9. import itertools
  10. import platform
  11. import sys
  12. import numpy as np
  13. from numpy.testing import (assert_equal, assert_almost_equal,
  14. assert_array_almost_equal, assert_array_equal,
  15. assert_, assert_allclose)
  16. import pytest
  17. from pytest import raises as assert_raises
  18. from scipy.linalg import (eig, eigvals, lu, svd, svdvals, cholesky, qr,
  19. schur, rsf2csf, lu_solve, lu_factor, solve, diagsvd,
  20. hessenberg, rq, eig_banded, eigvals_banded, eigh,
  21. eigvalsh, qr_multiply, qz, orth, ordqz,
  22. subspace_angles, hadamard, eigvalsh_tridiagonal,
  23. eigh_tridiagonal, null_space, cdf2rdf, LinAlgError)
  24. from scipy.linalg.lapack import (dgbtrf, dgbtrs, zgbtrf, zgbtrs, dsbev,
  25. dsbevd, dsbevx, zhbevd, zhbevx,
  26. get_lapack_funcs)
  27. from scipy.linalg._misc import norm
  28. from scipy.linalg._decomp_qz import _select_function
  29. from scipy.stats import ortho_group
  30. from numpy import (array, diag, ones, full, linalg, argsort, zeros, arange,
  31. float32, complex64, ravel, sqrt, iscomplex, shape, sort,
  32. sign, asarray, isfinite, ndarray, eye, dtype, triu, tril)
  33. from numpy.random import seed, random
  34. from scipy.linalg._testutils import assert_no_overwrite
  35. from scipy.sparse._sputils import matrix
  36. from scipy._lib._testutils import check_free_memory
  37. from scipy.linalg.blas import HAS_ILP64
  38. def _random_hermitian_matrix(n, posdef=False, dtype=float):
  39. "Generate random sym/hermitian array of the given size n"
  40. if dtype in COMPLEX_DTYPES:
  41. A = np.random.rand(n, n) + np.random.rand(n, n)*1.0j
  42. A = (A + A.conj().T)/2
  43. else:
  44. A = np.random.rand(n, n)
  45. A = (A + A.T)/2
  46. if posdef:
  47. A += sqrt(2*n)*np.eye(n)
  48. return A.astype(dtype)
  49. REAL_DTYPES = [np.float32, np.float64]
  50. COMPLEX_DTYPES = [np.complex64, np.complex128]
  51. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  52. def clear_fuss(ar, fuss_binary_bits=7):
  53. """Clears trailing `fuss_binary_bits` of mantissa of a floating number"""
  54. x = np.asanyarray(ar)
  55. if np.iscomplexobj(x):
  56. return clear_fuss(x.real) + 1j * clear_fuss(x.imag)
  57. significant_binary_bits = np.finfo(x.dtype).nmant
  58. x_mant, x_exp = np.frexp(x)
  59. f = 2.0**(significant_binary_bits - fuss_binary_bits)
  60. x_mant *= f
  61. np.rint(x_mant, out=x_mant)
  62. x_mant /= f
  63. return np.ldexp(x_mant, x_exp)
  64. # XXX: This function should be available through numpy.testing
  65. def assert_dtype_equal(act, des):
  66. if isinstance(act, ndarray):
  67. act = act.dtype
  68. else:
  69. act = dtype(act)
  70. if isinstance(des, ndarray):
  71. des = des.dtype
  72. else:
  73. des = dtype(des)
  74. assert_(act == des,
  75. 'dtype mismatch: "{}" (should be "{}")'.format(act, des))
  76. # XXX: This function should not be defined here, but somewhere in
  77. # scipy.linalg namespace
  78. def symrand(dim_or_eigv):
  79. """Return a random symmetric (Hermitian) matrix.
  80. If 'dim_or_eigv' is an integer N, return a NxN matrix, with eigenvalues
  81. uniformly distributed on (-1,1).
  82. If 'dim_or_eigv' is 1-D real array 'a', return a matrix whose
  83. eigenvalues are 'a'.
  84. """
  85. if isinstance(dim_or_eigv, int):
  86. dim = dim_or_eigv
  87. d = random(dim)*2 - 1
  88. elif (isinstance(dim_or_eigv, ndarray) and
  89. len(dim_or_eigv.shape) == 1):
  90. dim = dim_or_eigv.shape[0]
  91. d = dim_or_eigv
  92. else:
  93. raise TypeError("input type not supported.")
  94. v = ortho_group.rvs(dim)
  95. h = v.T.conj() @ diag(d) @ v
  96. # to avoid roundoff errors, symmetrize the matrix (again)
  97. h = 0.5*(h.T+h)
  98. return h
  99. def _complex_symrand(dim, dtype):
  100. a1, a2 = symrand(dim), symrand(dim)
  101. # add antisymmetric matrix as imag part
  102. a = a1 + 1j*(triu(a2)-tril(a2))
  103. return a.astype(dtype)
  104. class TestEigVals:
  105. def test_simple(self):
  106. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  107. w = eigvals(a)
  108. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  109. assert_array_almost_equal(w, exact_w)
  110. def test_simple_tr(self):
  111. a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6]], 'd').T
  112. a = a.copy()
  113. a = a.T
  114. w = eigvals(a)
  115. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  116. assert_array_almost_equal(w, exact_w)
  117. def test_simple_complex(self):
  118. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]]
  119. w = eigvals(a)
  120. exact_w = [(9+1j+sqrt(92+6j))/2,
  121. 0,
  122. (9+1j-sqrt(92+6j))/2]
  123. assert_array_almost_equal(w, exact_w)
  124. def test_finite(self):
  125. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  126. w = eigvals(a, check_finite=False)
  127. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  128. assert_array_almost_equal(w, exact_w)
  129. class TestEig:
  130. def test_simple(self):
  131. a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  132. w, v = eig(a)
  133. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  134. v0 = array([1, 1, (1+sqrt(93)/3)/2])
  135. v1 = array([3., 0, -1])
  136. v2 = array([1, 1, (1-sqrt(93)/3)/2])
  137. v0 = v0 / norm(v0)
  138. v1 = v1 / norm(v1)
  139. v2 = v2 / norm(v2)
  140. assert_array_almost_equal(w, exact_w)
  141. assert_array_almost_equal(v0, v[:, 0]*sign(v[0, 0]))
  142. assert_array_almost_equal(v1, v[:, 1]*sign(v[0, 1]))
  143. assert_array_almost_equal(v2, v[:, 2]*sign(v[0, 2]))
  144. for i in range(3):
  145. assert_array_almost_equal(a @ v[:, i], w[i]*v[:, i])
  146. w, v = eig(a, left=1, right=0)
  147. for i in range(3):
  148. assert_array_almost_equal(a.T @ v[:, i], w[i]*v[:, i])
  149. def test_simple_complex_eig(self):
  150. a = array([[1, 2], [-2, 1]])
  151. w, vl, vr = eig(a, left=1, right=1)
  152. assert_array_almost_equal(w, array([1+2j, 1-2j]))
  153. for i in range(2):
  154. assert_array_almost_equal(a @ vr[:, i], w[i]*vr[:, i])
  155. for i in range(2):
  156. assert_array_almost_equal(a.conj().T @ vl[:, i],
  157. w[i].conj()*vl[:, i])
  158. def test_simple_complex(self):
  159. a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]])
  160. w, vl, vr = eig(a, left=1, right=1)
  161. for i in range(3):
  162. assert_array_almost_equal(a @ vr[:, i], w[i]*vr[:, i])
  163. for i in range(3):
  164. assert_array_almost_equal(a.conj().T @ vl[:, i],
  165. w[i].conj()*vl[:, i])
  166. def test_gh_3054(self):
  167. a = [[1]]
  168. b = [[0]]
  169. w, vr = eig(a, b, homogeneous_eigvals=True)
  170. assert_allclose(w[1, 0], 0)
  171. assert_(w[0, 0] != 0)
  172. assert_allclose(vr, 1)
  173. w, vr = eig(a, b)
  174. assert_equal(w, np.inf)
  175. assert_allclose(vr, 1)
  176. def _check_gen_eig(self, A, B):
  177. if B is not None:
  178. A, B = asarray(A), asarray(B)
  179. B0 = B
  180. else:
  181. A = asarray(A)
  182. B0 = B
  183. B = np.eye(*A.shape)
  184. msg = "\n%r\n%r" % (A, B)
  185. # Eigenvalues in homogeneous coordinates
  186. w, vr = eig(A, B0, homogeneous_eigvals=True)
  187. wt = eigvals(A, B0, homogeneous_eigvals=True)
  188. val1 = A @ vr * w[1, :]
  189. val2 = B @ vr * w[0, :]
  190. for i in range(val1.shape[1]):
  191. assert_allclose(val1[:, i], val2[:, i],
  192. rtol=1e-13, atol=1e-13, err_msg=msg)
  193. if B0 is None:
  194. assert_allclose(w[1, :], 1)
  195. assert_allclose(wt[1, :], 1)
  196. perm = np.lexsort(w)
  197. permt = np.lexsort(wt)
  198. assert_allclose(w[:, perm], wt[:, permt], atol=1e-7, rtol=1e-7,
  199. err_msg=msg)
  200. length = np.empty(len(vr))
  201. for i in range(len(vr)):
  202. length[i] = norm(vr[:, i])
  203. assert_allclose(length, np.ones(length.size), err_msg=msg,
  204. atol=1e-7, rtol=1e-7)
  205. # Convert homogeneous coordinates
  206. beta_nonzero = (w[1, :] != 0)
  207. wh = w[0, beta_nonzero] / w[1, beta_nonzero]
  208. # Eigenvalues in standard coordinates
  209. w, vr = eig(A, B0)
  210. wt = eigvals(A, B0)
  211. val1 = A @ vr
  212. val2 = B @ vr * w
  213. res = val1 - val2
  214. for i in range(res.shape[1]):
  215. if np.all(isfinite(res[:, i])):
  216. assert_allclose(res[:, i], 0,
  217. rtol=1e-13, atol=1e-13, err_msg=msg)
  218. w_fin = w[isfinite(w)]
  219. wt_fin = wt[isfinite(wt)]
  220. perm = argsort(clear_fuss(w_fin))
  221. permt = argsort(clear_fuss(wt_fin))
  222. assert_allclose(w[perm], wt[permt],
  223. atol=1e-7, rtol=1e-7, err_msg=msg)
  224. length = np.empty(len(vr))
  225. for i in range(len(vr)):
  226. length[i] = norm(vr[:, i])
  227. assert_allclose(length, np.ones(length.size), err_msg=msg)
  228. # Compare homogeneous and nonhomogeneous versions
  229. assert_allclose(sort(wh), sort(w[np.isfinite(w)]))
  230. @pytest.mark.xfail(reason="See gh-2254")
  231. def test_singular(self):
  232. # Example taken from
  233. # https://web.archive.org/web/20040903121217/http://www.cs.umu.se/research/nla/singular_pairs/guptri/matlab.html
  234. A = array([[22, 34, 31, 31, 17],
  235. [45, 45, 42, 19, 29],
  236. [39, 47, 49, 26, 34],
  237. [27, 31, 26, 21, 15],
  238. [38, 44, 44, 24, 30]])
  239. B = array([[13, 26, 25, 17, 24],
  240. [31, 46, 40, 26, 37],
  241. [26, 40, 19, 25, 25],
  242. [16, 25, 27, 14, 23],
  243. [24, 35, 18, 21, 22]])
  244. with np.errstate(all='ignore'):
  245. self._check_gen_eig(A, B)
  246. def test_falker(self):
  247. # Test matrices giving some Nan generalized eigenvalues.
  248. M = diag(array(([1, 0, 3])))
  249. K = array(([2, -1, -1], [-1, 2, -1], [-1, -1, 2]))
  250. D = array(([1, -1, 0], [-1, 1, 0], [0, 0, 0]))
  251. Z = zeros((3, 3))
  252. I3 = eye(3)
  253. A = np.block([[I3, Z], [Z, -K]])
  254. B = np.block([[Z, I3], [M, D]])
  255. with np.errstate(all='ignore'):
  256. self._check_gen_eig(A, B)
  257. def test_bad_geneig(self):
  258. # Ticket #709 (strange return values from DGGEV)
  259. def matrices(omega):
  260. c1 = -9 + omega**2
  261. c2 = 2*omega
  262. A = [[1, 0, 0, 0],
  263. [0, 1, 0, 0],
  264. [0, 0, c1, 0],
  265. [0, 0, 0, c1]]
  266. B = [[0, 0, 1, 0],
  267. [0, 0, 0, 1],
  268. [1, 0, 0, -c2],
  269. [0, 1, c2, 0]]
  270. return A, B
  271. # With a buggy LAPACK, this can fail for different omega on different
  272. # machines -- so we need to test several values
  273. with np.errstate(all='ignore'):
  274. for k in range(100):
  275. A, B = matrices(omega=k*5./100)
  276. self._check_gen_eig(A, B)
  277. def test_make_eigvals(self):
  278. # Step through all paths in _make_eigvals
  279. seed(1234)
  280. # Real eigenvalues
  281. A = symrand(3)
  282. self._check_gen_eig(A, None)
  283. B = symrand(3)
  284. self._check_gen_eig(A, B)
  285. # Complex eigenvalues
  286. A = random((3, 3)) + 1j*random((3, 3))
  287. self._check_gen_eig(A, None)
  288. B = random((3, 3)) + 1j*random((3, 3))
  289. self._check_gen_eig(A, B)
  290. def test_check_finite(self):
  291. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  292. w, v = eig(a, check_finite=False)
  293. exact_w = [(9+sqrt(93))/2, 0, (9-sqrt(93))/2]
  294. v0 = array([1, 1, (1+sqrt(93)/3)/2])
  295. v1 = array([3., 0, -1])
  296. v2 = array([1, 1, (1-sqrt(93)/3)/2])
  297. v0 = v0 / norm(v0)
  298. v1 = v1 / norm(v1)
  299. v2 = v2 / norm(v2)
  300. assert_array_almost_equal(w, exact_w)
  301. assert_array_almost_equal(v0, v[:, 0]*sign(v[0, 0]))
  302. assert_array_almost_equal(v1, v[:, 1]*sign(v[0, 1]))
  303. assert_array_almost_equal(v2, v[:, 2]*sign(v[0, 2]))
  304. for i in range(3):
  305. assert_array_almost_equal(a @ v[:, i], w[i]*v[:, i])
  306. def test_not_square_error(self):
  307. """Check that passing a non-square array raises a ValueError."""
  308. A = np.arange(6).reshape(3, 2)
  309. assert_raises(ValueError, eig, A)
  310. def test_shape_mismatch(self):
  311. """Check that passing arrays of with different shapes
  312. raises a ValueError."""
  313. A = eye(2)
  314. B = np.arange(9.0).reshape(3, 3)
  315. assert_raises(ValueError, eig, A, B)
  316. assert_raises(ValueError, eig, B, A)
  317. class TestEigBanded:
  318. def setup_method(self):
  319. self.create_bandmat()
  320. def create_bandmat(self):
  321. """Create the full matrix `self.fullmat` and
  322. the corresponding band matrix `self.bandmat`."""
  323. N = 10
  324. self.KL = 2 # number of subdiagonals (below the diagonal)
  325. self.KU = 2 # number of superdiagonals (above the diagonal)
  326. # symmetric band matrix
  327. self.sym_mat = (diag(full(N, 1.0))
  328. + diag(full(N-1, -1.0), -1) + diag(full(N-1, -1.0), 1)
  329. + diag(full(N-2, -2.0), -2) + diag(full(N-2, -2.0), 2))
  330. # hermitian band matrix
  331. self.herm_mat = (diag(full(N, -1.0))
  332. + 1j*diag(full(N-1, 1.0), -1)
  333. - 1j*diag(full(N-1, 1.0), 1)
  334. + diag(full(N-2, -2.0), -2)
  335. + diag(full(N-2, -2.0), 2))
  336. # general real band matrix
  337. self.real_mat = (diag(full(N, 1.0))
  338. + diag(full(N-1, -1.0), -1) + diag(full(N-1, -3.0), 1)
  339. + diag(full(N-2, 2.0), -2) + diag(full(N-2, -2.0), 2))
  340. # general complex band matrix
  341. self.comp_mat = (1j*diag(full(N, 1.0))
  342. + diag(full(N-1, -1.0), -1)
  343. + 1j*diag(full(N-1, -3.0), 1)
  344. + diag(full(N-2, 2.0), -2)
  345. + diag(full(N-2, -2.0), 2))
  346. # Eigenvalues and -vectors from linalg.eig
  347. ew, ev = linalg.eig(self.sym_mat)
  348. ew = ew.real
  349. args = argsort(ew)
  350. self.w_sym_lin = ew[args]
  351. self.evec_sym_lin = ev[:, args]
  352. ew, ev = linalg.eig(self.herm_mat)
  353. ew = ew.real
  354. args = argsort(ew)
  355. self.w_herm_lin = ew[args]
  356. self.evec_herm_lin = ev[:, args]
  357. # Extract upper bands from symmetric and hermitian band matrices
  358. # (for use in dsbevd, dsbevx, zhbevd, zhbevx
  359. # and their single precision versions)
  360. LDAB = self.KU + 1
  361. self.bandmat_sym = zeros((LDAB, N), dtype=float)
  362. self.bandmat_herm = zeros((LDAB, N), dtype=complex)
  363. for i in range(LDAB):
  364. self.bandmat_sym[LDAB-i-1, i:N] = diag(self.sym_mat, i)
  365. self.bandmat_herm[LDAB-i-1, i:N] = diag(self.herm_mat, i)
  366. # Extract bands from general real and complex band matrix
  367. # (for use in dgbtrf, dgbtrs and their single precision versions)
  368. LDAB = 2*self.KL + self.KU + 1
  369. self.bandmat_real = zeros((LDAB, N), dtype=float)
  370. self.bandmat_real[2*self.KL, :] = diag(self.real_mat) # diagonal
  371. for i in range(self.KL):
  372. # superdiagonals
  373. self.bandmat_real[2*self.KL-1-i, i+1:N] = diag(self.real_mat, i+1)
  374. # subdiagonals
  375. self.bandmat_real[2*self.KL+1+i, 0:N-1-i] = diag(self.real_mat,
  376. -i-1)
  377. self.bandmat_comp = zeros((LDAB, N), dtype=complex)
  378. self.bandmat_comp[2*self.KL, :] = diag(self.comp_mat) # diagonal
  379. for i in range(self.KL):
  380. # superdiagonals
  381. self.bandmat_comp[2*self.KL-1-i, i+1:N] = diag(self.comp_mat, i+1)
  382. # subdiagonals
  383. self.bandmat_comp[2*self.KL+1+i, 0:N-1-i] = diag(self.comp_mat,
  384. -i-1)
  385. # absolute value for linear equation system A*x = b
  386. self.b = 1.0*arange(N)
  387. self.bc = self.b * (1 + 1j)
  388. #####################################################################
  389. def test_dsbev(self):
  390. """Compare dsbev eigenvalues and eigenvectors with
  391. the result of linalg.eig."""
  392. w, evec, info = dsbev(self.bandmat_sym, compute_v=1)
  393. evec_ = evec[:, argsort(w)]
  394. assert_array_almost_equal(sort(w), self.w_sym_lin)
  395. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  396. def test_dsbevd(self):
  397. """Compare dsbevd eigenvalues and eigenvectors with
  398. the result of linalg.eig."""
  399. w, evec, info = dsbevd(self.bandmat_sym, compute_v=1)
  400. evec_ = evec[:, argsort(w)]
  401. assert_array_almost_equal(sort(w), self.w_sym_lin)
  402. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  403. def test_dsbevx(self):
  404. """Compare dsbevx eigenvalues and eigenvectors
  405. with the result of linalg.eig."""
  406. N, N = shape(self.sym_mat)
  407. # Achtung: Argumente 0.0,0.0,range?
  408. w, evec, num, ifail, info = dsbevx(self.bandmat_sym, 0.0, 0.0, 1, N,
  409. compute_v=1, range=2)
  410. evec_ = evec[:, argsort(w)]
  411. assert_array_almost_equal(sort(w), self.w_sym_lin)
  412. assert_array_almost_equal(abs(evec_), abs(self.evec_sym_lin))
  413. def test_zhbevd(self):
  414. """Compare zhbevd eigenvalues and eigenvectors
  415. with the result of linalg.eig."""
  416. w, evec, info = zhbevd(self.bandmat_herm, compute_v=1)
  417. evec_ = evec[:, argsort(w)]
  418. assert_array_almost_equal(sort(w), self.w_herm_lin)
  419. assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
  420. def test_zhbevx(self):
  421. """Compare zhbevx eigenvalues and eigenvectors
  422. with the result of linalg.eig."""
  423. N, N = shape(self.herm_mat)
  424. # Achtung: Argumente 0.0,0.0,range?
  425. w, evec, num, ifail, info = zhbevx(self.bandmat_herm, 0.0, 0.0, 1, N,
  426. compute_v=1, range=2)
  427. evec_ = evec[:, argsort(w)]
  428. assert_array_almost_equal(sort(w), self.w_herm_lin)
  429. assert_array_almost_equal(abs(evec_), abs(self.evec_herm_lin))
  430. def test_eigvals_banded(self):
  431. """Compare eigenvalues of eigvals_banded with those of linalg.eig."""
  432. w_sym = eigvals_banded(self.bandmat_sym)
  433. w_sym = w_sym.real
  434. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  435. w_herm = eigvals_banded(self.bandmat_herm)
  436. w_herm = w_herm.real
  437. assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
  438. # extracting eigenvalues with respect to an index range
  439. ind1 = 2
  440. ind2 = np.longlong(6)
  441. w_sym_ind = eigvals_banded(self.bandmat_sym,
  442. select='i', select_range=(ind1, ind2))
  443. assert_array_almost_equal(sort(w_sym_ind),
  444. self.w_sym_lin[ind1:ind2+1])
  445. w_herm_ind = eigvals_banded(self.bandmat_herm,
  446. select='i', select_range=(ind1, ind2))
  447. assert_array_almost_equal(sort(w_herm_ind),
  448. self.w_herm_lin[ind1:ind2+1])
  449. # extracting eigenvalues with respect to a value range
  450. v_lower = self.w_sym_lin[ind1] - 1.0e-5
  451. v_upper = self.w_sym_lin[ind2] + 1.0e-5
  452. w_sym_val = eigvals_banded(self.bandmat_sym,
  453. select='v', select_range=(v_lower, v_upper))
  454. assert_array_almost_equal(sort(w_sym_val),
  455. self.w_sym_lin[ind1:ind2+1])
  456. v_lower = self.w_herm_lin[ind1] - 1.0e-5
  457. v_upper = self.w_herm_lin[ind2] + 1.0e-5
  458. w_herm_val = eigvals_banded(self.bandmat_herm,
  459. select='v',
  460. select_range=(v_lower, v_upper))
  461. assert_array_almost_equal(sort(w_herm_val),
  462. self.w_herm_lin[ind1:ind2+1])
  463. w_sym = eigvals_banded(self.bandmat_sym, check_finite=False)
  464. w_sym = w_sym.real
  465. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  466. def test_eig_banded(self):
  467. """Compare eigenvalues and eigenvectors of eig_banded
  468. with those of linalg.eig. """
  469. w_sym, evec_sym = eig_banded(self.bandmat_sym)
  470. evec_sym_ = evec_sym[:, argsort(w_sym.real)]
  471. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  472. assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
  473. w_herm, evec_herm = eig_banded(self.bandmat_herm)
  474. evec_herm_ = evec_herm[:, argsort(w_herm.real)]
  475. assert_array_almost_equal(sort(w_herm), self.w_herm_lin)
  476. assert_array_almost_equal(abs(evec_herm_), abs(self.evec_herm_lin))
  477. # extracting eigenvalues with respect to an index range
  478. ind1 = 2
  479. ind2 = 6
  480. w_sym_ind, evec_sym_ind = eig_banded(self.bandmat_sym,
  481. select='i',
  482. select_range=(ind1, ind2))
  483. assert_array_almost_equal(sort(w_sym_ind),
  484. self.w_sym_lin[ind1:ind2+1])
  485. assert_array_almost_equal(abs(evec_sym_ind),
  486. abs(self.evec_sym_lin[:, ind1:ind2+1]))
  487. w_herm_ind, evec_herm_ind = eig_banded(self.bandmat_herm,
  488. select='i',
  489. select_range=(ind1, ind2))
  490. assert_array_almost_equal(sort(w_herm_ind),
  491. self.w_herm_lin[ind1:ind2+1])
  492. assert_array_almost_equal(abs(evec_herm_ind),
  493. abs(self.evec_herm_lin[:, ind1:ind2+1]))
  494. # extracting eigenvalues with respect to a value range
  495. v_lower = self.w_sym_lin[ind1] - 1.0e-5
  496. v_upper = self.w_sym_lin[ind2] + 1.0e-5
  497. w_sym_val, evec_sym_val = eig_banded(self.bandmat_sym,
  498. select='v',
  499. select_range=(v_lower, v_upper))
  500. assert_array_almost_equal(sort(w_sym_val),
  501. self.w_sym_lin[ind1:ind2+1])
  502. assert_array_almost_equal(abs(evec_sym_val),
  503. abs(self.evec_sym_lin[:, ind1:ind2+1]))
  504. v_lower = self.w_herm_lin[ind1] - 1.0e-5
  505. v_upper = self.w_herm_lin[ind2] + 1.0e-5
  506. w_herm_val, evec_herm_val = eig_banded(self.bandmat_herm,
  507. select='v',
  508. select_range=(v_lower, v_upper))
  509. assert_array_almost_equal(sort(w_herm_val),
  510. self.w_herm_lin[ind1:ind2+1])
  511. assert_array_almost_equal(abs(evec_herm_val),
  512. abs(self.evec_herm_lin[:, ind1:ind2+1]))
  513. w_sym, evec_sym = eig_banded(self.bandmat_sym, check_finite=False)
  514. evec_sym_ = evec_sym[:, argsort(w_sym.real)]
  515. assert_array_almost_equal(sort(w_sym), self.w_sym_lin)
  516. assert_array_almost_equal(abs(evec_sym_), abs(self.evec_sym_lin))
  517. def test_dgbtrf(self):
  518. """Compare dgbtrf LU factorisation with the LU factorisation result
  519. of linalg.lu."""
  520. M, N = shape(self.real_mat)
  521. lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
  522. # extract matrix u from lu_symm_band
  523. u = diag(lu_symm_band[2*self.KL, :])
  524. for i in range(self.KL + self.KU):
  525. u += diag(lu_symm_band[2*self.KL-1-i, i+1:N], i+1)
  526. p_lin, l_lin, u_lin = lu(self.real_mat, permute_l=0)
  527. assert_array_almost_equal(u, u_lin)
  528. def test_zgbtrf(self):
  529. """Compare zgbtrf LU factorisation with the LU factorisation result
  530. of linalg.lu."""
  531. M, N = shape(self.comp_mat)
  532. lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
  533. # extract matrix u from lu_symm_band
  534. u = diag(lu_symm_band[2*self.KL, :])
  535. for i in range(self.KL + self.KU):
  536. u += diag(lu_symm_band[2*self.KL-1-i, i+1:N], i+1)
  537. p_lin, l_lin, u_lin = lu(self.comp_mat, permute_l=0)
  538. assert_array_almost_equal(u, u_lin)
  539. def test_dgbtrs(self):
  540. """Compare dgbtrs solutions for linear equation system A*x = b
  541. with solutions of linalg.solve."""
  542. lu_symm_band, ipiv, info = dgbtrf(self.bandmat_real, self.KL, self.KU)
  543. y, info = dgbtrs(lu_symm_band, self.KL, self.KU, self.b, ipiv)
  544. y_lin = linalg.solve(self.real_mat, self.b)
  545. assert_array_almost_equal(y, y_lin)
  546. def test_zgbtrs(self):
  547. """Compare zgbtrs solutions for linear equation system A*x = b
  548. with solutions of linalg.solve."""
  549. lu_symm_band, ipiv, info = zgbtrf(self.bandmat_comp, self.KL, self.KU)
  550. y, info = zgbtrs(lu_symm_band, self.KL, self.KU, self.bc, ipiv)
  551. y_lin = linalg.solve(self.comp_mat, self.bc)
  552. assert_array_almost_equal(y, y_lin)
  553. class TestEigTridiagonal:
  554. def setup_method(self):
  555. self.create_trimat()
  556. def create_trimat(self):
  557. """Create the full matrix `self.fullmat`, `self.d`, and `self.e`."""
  558. N = 10
  559. # symmetric band matrix
  560. self.d = full(N, 1.0)
  561. self.e = full(N-1, -1.0)
  562. self.full_mat = (diag(self.d) + diag(self.e, -1) + diag(self.e, 1))
  563. ew, ev = linalg.eig(self.full_mat)
  564. ew = ew.real
  565. args = argsort(ew)
  566. self.w = ew[args]
  567. self.evec = ev[:, args]
  568. def test_degenerate(self):
  569. """Test error conditions."""
  570. # Wrong sizes
  571. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e[:-1])
  572. # Must be real
  573. assert_raises(TypeError, eigvalsh_tridiagonal, self.d, self.e * 1j)
  574. # Bad driver
  575. assert_raises(TypeError, eigvalsh_tridiagonal, self.d, self.e,
  576. lapack_driver=1.)
  577. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  578. lapack_driver='foo')
  579. # Bad bounds
  580. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  581. select='i', select_range=(0, -1))
  582. def test_eigvalsh_tridiagonal(self):
  583. """Compare eigenvalues of eigvalsh_tridiagonal with those of eig."""
  584. # can't use ?STERF with subselection
  585. for driver in ('sterf', 'stev', 'stebz', 'stemr', 'auto'):
  586. w = eigvalsh_tridiagonal(self.d, self.e, lapack_driver=driver)
  587. assert_array_almost_equal(sort(w), self.w)
  588. for driver in ('sterf', 'stev'):
  589. assert_raises(ValueError, eigvalsh_tridiagonal, self.d, self.e,
  590. lapack_driver='stev', select='i',
  591. select_range=(0, 1))
  592. for driver in ('stebz', 'stemr', 'auto'):
  593. # extracting eigenvalues with respect to the full index range
  594. w_ind = eigvalsh_tridiagonal(
  595. self.d, self.e, select='i', select_range=(0, len(self.d)-1),
  596. lapack_driver=driver)
  597. assert_array_almost_equal(sort(w_ind), self.w)
  598. # extracting eigenvalues with respect to an index range
  599. ind1 = 2
  600. ind2 = 6
  601. w_ind = eigvalsh_tridiagonal(
  602. self.d, self.e, select='i', select_range=(ind1, ind2),
  603. lapack_driver=driver)
  604. assert_array_almost_equal(sort(w_ind), self.w[ind1:ind2+1])
  605. # extracting eigenvalues with respect to a value range
  606. v_lower = self.w[ind1] - 1.0e-5
  607. v_upper = self.w[ind2] + 1.0e-5
  608. w_val = eigvalsh_tridiagonal(
  609. self.d, self.e, select='v', select_range=(v_lower, v_upper),
  610. lapack_driver=driver)
  611. assert_array_almost_equal(sort(w_val), self.w[ind1:ind2+1])
  612. def test_eigh_tridiagonal(self):
  613. """Compare eigenvalues and eigenvectors of eigh_tridiagonal
  614. with those of eig. """
  615. # can't use ?STERF when eigenvectors are requested
  616. assert_raises(ValueError, eigh_tridiagonal, self.d, self.e,
  617. lapack_driver='sterf')
  618. for driver in ('stebz', 'stev', 'stemr', 'auto'):
  619. w, evec = eigh_tridiagonal(self.d, self.e, lapack_driver=driver)
  620. evec_ = evec[:, argsort(w)]
  621. assert_array_almost_equal(sort(w), self.w)
  622. assert_array_almost_equal(abs(evec_), abs(self.evec))
  623. assert_raises(ValueError, eigh_tridiagonal, self.d, self.e,
  624. lapack_driver='stev', select='i', select_range=(0, 1))
  625. for driver in ('stebz', 'stemr', 'auto'):
  626. # extracting eigenvalues with respect to an index range
  627. ind1 = 0
  628. ind2 = len(self.d)-1
  629. w, evec = eigh_tridiagonal(
  630. self.d, self.e, select='i', select_range=(ind1, ind2),
  631. lapack_driver=driver)
  632. assert_array_almost_equal(sort(w), self.w)
  633. assert_array_almost_equal(abs(evec), abs(self.evec))
  634. ind1 = 2
  635. ind2 = 6
  636. w, evec = eigh_tridiagonal(
  637. self.d, self.e, select='i', select_range=(ind1, ind2),
  638. lapack_driver=driver)
  639. assert_array_almost_equal(sort(w), self.w[ind1:ind2+1])
  640. assert_array_almost_equal(abs(evec),
  641. abs(self.evec[:, ind1:ind2+1]))
  642. # extracting eigenvalues with respect to a value range
  643. v_lower = self.w[ind1] - 1.0e-5
  644. v_upper = self.w[ind2] + 1.0e-5
  645. w, evec = eigh_tridiagonal(
  646. self.d, self.e, select='v', select_range=(v_lower, v_upper),
  647. lapack_driver=driver)
  648. assert_array_almost_equal(sort(w), self.w[ind1:ind2+1])
  649. assert_array_almost_equal(abs(evec),
  650. abs(self.evec[:, ind1:ind2+1]))
  651. class TestEigh:
  652. def setup_class(self):
  653. seed(1234)
  654. def test_wrong_inputs(self):
  655. # Nonsquare a
  656. assert_raises(ValueError, eigh, np.ones([1, 2]))
  657. # Nonsquare b
  658. assert_raises(ValueError, eigh, np.ones([2, 2]), np.ones([2, 1]))
  659. # Incompatible a, b sizes
  660. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([2, 2]))
  661. # Wrong type parameter for generalized problem
  662. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  663. type=4)
  664. # Both value and index subsets requested
  665. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  666. subset_by_value=[1, 2], subset_by_index=[2, 4])
  667. with np.testing.suppress_warnings() as sup:
  668. sup.filter(DeprecationWarning, "Keyword argument 'eigvals")
  669. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  670. subset_by_value=[1, 2], eigvals=[2, 4])
  671. # Invalid upper index spec
  672. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  673. subset_by_index=[0, 4])
  674. with np.testing.suppress_warnings() as sup:
  675. sup.filter(DeprecationWarning, "Keyword argument 'eigvals")
  676. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  677. eigvals=[0, 4])
  678. # Invalid lower index
  679. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  680. subset_by_index=[-2, 2])
  681. with np.testing.suppress_warnings() as sup:
  682. sup.filter(DeprecationWarning, "Keyword argument 'eigvals")
  683. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  684. eigvals=[-2, 2])
  685. # Invalid index spec #2
  686. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  687. subset_by_index=[2, 0])
  688. with np.testing.suppress_warnings() as sup:
  689. sup.filter(DeprecationWarning, "Keyword argument 'eigvals")
  690. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  691. subset_by_index=[2, 0])
  692. # Invalid value spec
  693. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  694. subset_by_value=[2, 0])
  695. # Invalid driver name
  696. assert_raises(ValueError, eigh, np.ones([2, 2]), driver='wrong')
  697. # Generalized driver selection without b
  698. assert_raises(ValueError, eigh, np.ones([3, 3]), None, driver='gvx')
  699. # Standard driver with b
  700. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  701. driver='evr', turbo=False)
  702. # Subset request from invalid driver
  703. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  704. driver='gvd', subset_by_index=[1, 2], turbo=False)
  705. with np.testing.suppress_warnings() as sup:
  706. sup.filter(DeprecationWarning, "'eigh' keyword argument 'eigvals")
  707. assert_raises(ValueError, eigh, np.ones([3, 3]), np.ones([3, 3]),
  708. driver='gvd', subset_by_index=[1, 2], turbo=False)
  709. def test_nonpositive_b(self):
  710. assert_raises(LinAlgError, eigh, np.ones([3, 3]), np.ones([3, 3]))
  711. # index based subsets are done in the legacy test_eigh()
  712. def test_value_subsets(self):
  713. for ind, dt in enumerate(DTYPES):
  714. a = _random_hermitian_matrix(20, dtype=dt)
  715. w, v = eigh(a, subset_by_value=[-2, 2])
  716. assert_equal(v.shape[1], len(w))
  717. assert all((w > -2) & (w < 2))
  718. b = _random_hermitian_matrix(20, posdef=True, dtype=dt)
  719. w, v = eigh(a, b, subset_by_value=[-2, 2])
  720. assert_equal(v.shape[1], len(w))
  721. assert all((w > -2) & (w < 2))
  722. def test_eigh_integer(self):
  723. a = array([[1, 2], [2, 7]])
  724. b = array([[3, 1], [1, 5]])
  725. w, z = eigh(a)
  726. w, z = eigh(a, b)
  727. def test_eigh_of_sparse(self):
  728. # This tests the rejection of inputs that eigh cannot currently handle.
  729. import scipy.sparse
  730. a = scipy.sparse.identity(2).tocsc()
  731. b = np.atleast_2d(a)
  732. assert_raises(ValueError, eigh, a)
  733. assert_raises(ValueError, eigh, b)
  734. @pytest.mark.parametrize('dtype_', DTYPES)
  735. @pytest.mark.parametrize('driver', ("ev", "evd", "evr", "evx"))
  736. def test_various_drivers_standard(self, driver, dtype_):
  737. a = _random_hermitian_matrix(n=20, dtype=dtype_)
  738. w, v = eigh(a, driver=driver)
  739. assert_allclose(a @ v - (v * w), 0.,
  740. atol=1000*np.finfo(dtype_).eps,
  741. rtol=0.)
  742. @pytest.mark.parametrize('type', (1, 2, 3))
  743. @pytest.mark.parametrize('driver', ("gv", "gvd", "gvx"))
  744. def test_various_drivers_generalized(self, driver, type):
  745. atol = np.spacing(5000.)
  746. a = _random_hermitian_matrix(20)
  747. b = _random_hermitian_matrix(20, posdef=True)
  748. w, v = eigh(a=a, b=b, driver=driver, type=type)
  749. if type == 1:
  750. assert_allclose(a @ v - w*(b @ v), 0., atol=atol, rtol=0.)
  751. elif type == 2:
  752. assert_allclose(a @ b @ v - v * w, 0., atol=atol, rtol=0.)
  753. else:
  754. assert_allclose(b @ a @ v - v * w, 0., atol=atol, rtol=0.)
  755. def test_eigvalsh_new_args(self):
  756. a = _random_hermitian_matrix(5)
  757. w = eigvalsh(a, subset_by_index=[1, 2])
  758. assert_equal(len(w), 2)
  759. w2 = eigvalsh(a, subset_by_index=[1, 2])
  760. assert_equal(len(w2), 2)
  761. assert_allclose(w, w2)
  762. b = np.diag([1, 1.2, 1.3, 1.5, 2])
  763. w3 = eigvalsh(b, subset_by_value=[1, 1.4])
  764. assert_equal(len(w3), 2)
  765. assert_allclose(w3, np.array([1.2, 1.3]))
  766. @pytest.mark.parametrize("method", [eigh, eigvalsh])
  767. def test_deprecation_warnings(self, method):
  768. with pytest.warns(DeprecationWarning,
  769. match="Keyword argument 'turbo'"):
  770. method(np.zeros((2, 2)), turbo=True)
  771. with pytest.warns(DeprecationWarning,
  772. match="Keyword argument 'eigvals'"):
  773. method(np.zeros((2, 2)), eigvals=[0, 1])
  774. def test_deprecation_results(self):
  775. a = _random_hermitian_matrix(3)
  776. b = _random_hermitian_matrix(3, posdef=True)
  777. # check turbo gives same result as driver='gvd'
  778. with np.testing.suppress_warnings() as sup:
  779. sup.filter(DeprecationWarning, "Keyword argument 'turbo'")
  780. w_dep, v_dep = eigh(a, b, turbo=True)
  781. w, v = eigh(a, b, driver='gvd')
  782. assert_allclose(w_dep, w)
  783. assert_allclose(v_dep, v)
  784. # check eigvals gives the same result as subset_by_index
  785. with np.testing.suppress_warnings() as sup:
  786. sup.filter(DeprecationWarning, "Keyword argument 'eigvals'")
  787. w_dep, v_dep = eigh(a, eigvals=[0, 1])
  788. w, v = eigh(a, subset_by_index=[0, 1])
  789. assert_allclose(w_dep, w)
  790. assert_allclose(v_dep, v)
  791. class TestLU:
  792. def setup_method(self):
  793. self.a = array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  794. self.ca = array([[1, 2, 3], [1, 2, 3], [2, 5j, 6]])
  795. # Those matrices are more robust to detect problems in permutation
  796. # matrices than the ones above
  797. self.b = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  798. self.cb = array([[1j, 2j, 3j], [4j, 5j, 6j], [7j, 8j, 9j]])
  799. # Reectangular matrices
  800. self.hrect = array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 12, 12]])
  801. self.chrect = 1.j * array([[1, 2, 3, 4],
  802. [5, 6, 7, 8],
  803. [9, 10, 12, 12]])
  804. self.vrect = array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 12, 12]])
  805. self.cvrect = 1.j * array([[1, 2, 3],
  806. [4, 5, 6],
  807. [7, 8, 9],
  808. [10, 12, 12]])
  809. # Medium sizes matrices
  810. self.med = random((30, 40))
  811. self.cmed = random((30, 40)) + 1.j * random((30, 40))
  812. def _test_common(self, data):
  813. p, l, u = lu(data)
  814. assert_array_almost_equal(p @ l @ u, data)
  815. pl, u = lu(data, permute_l=1)
  816. assert_array_almost_equal(pl @ u, data)
  817. def _test_common_lu_factor(self, data):
  818. l_and_u1, piv1 = lu_factor(data)
  819. (getrf,) = get_lapack_funcs(("getrf",), (data,))
  820. l_and_u2, piv2, _ = getrf(data, overwrite_a=False)
  821. assert_array_equal(l_and_u1, l_and_u2)
  822. assert_array_equal(piv1, piv2)
  823. # Simple tests.
  824. # For lu_factor gives a LinAlgWarning because these matrices are singular
  825. def test_simple(self):
  826. self._test_common(self.a)
  827. def test_simple_complex(self):
  828. self._test_common(self.ca)
  829. def test_simple2(self):
  830. self._test_common(self.b)
  831. def test_simple2_complex(self):
  832. self._test_common(self.cb)
  833. # rectangular matrices tests
  834. def test_hrectangular(self):
  835. self._test_common(self.hrect)
  836. self._test_common_lu_factor(self.hrect)
  837. def test_vrectangular(self):
  838. self._test_common(self.vrect)
  839. self._test_common_lu_factor(self.vrect)
  840. def test_hrectangular_complex(self):
  841. self._test_common(self.chrect)
  842. self._test_common_lu_factor(self.chrect)
  843. def test_vrectangular_complex(self):
  844. self._test_common(self.cvrect)
  845. self._test_common_lu_factor(self.cvrect)
  846. # Bigger matrices
  847. def test_medium1(self):
  848. """Check lu decomposition on medium size, rectangular matrix."""
  849. self._test_common(self.med)
  850. self._test_common_lu_factor(self.med)
  851. def test_medium1_complex(self):
  852. """Check lu decomposition on medium size, rectangular matrix."""
  853. self._test_common(self.cmed)
  854. self._test_common_lu_factor(self.cmed)
  855. def test_check_finite(self):
  856. p, l, u = lu(self.a, check_finite=False)
  857. assert_array_almost_equal(p @ l @ u, self.a)
  858. def test_simple_known(self):
  859. # Ticket #1458
  860. for order in ['C', 'F']:
  861. A = np.array([[2, 1], [0, 1.]], order=order)
  862. LU, P = lu_factor(A)
  863. assert_array_almost_equal(LU, np.array([[2, 1], [0, 1]]))
  864. assert_array_equal(P, np.array([0, 1]))
  865. class TestLUSingle(TestLU):
  866. """LU testers for single precision, real and double"""
  867. def setup_method(self):
  868. TestLU.setup_method(self)
  869. self.a = self.a.astype(float32)
  870. self.ca = self.ca.astype(complex64)
  871. self.b = self.b.astype(float32)
  872. self.cb = self.cb.astype(complex64)
  873. self.hrect = self.hrect.astype(float32)
  874. self.chrect = self.hrect.astype(complex64)
  875. self.vrect = self.vrect.astype(float32)
  876. self.cvrect = self.vrect.astype(complex64)
  877. self.med = self.vrect.astype(float32)
  878. self.cmed = self.vrect.astype(complex64)
  879. class TestLUSolve:
  880. def setup_method(self):
  881. seed(1234)
  882. def test_lu(self):
  883. a0 = random((10, 10))
  884. b = random((10,))
  885. for order in ['C', 'F']:
  886. a = np.array(a0, order=order)
  887. x1 = solve(a, b)
  888. lu_a = lu_factor(a)
  889. x2 = lu_solve(lu_a, b)
  890. assert_array_almost_equal(x1, x2)
  891. def test_check_finite(self):
  892. a = random((10, 10))
  893. b = random((10,))
  894. x1 = solve(a, b)
  895. lu_a = lu_factor(a, check_finite=False)
  896. x2 = lu_solve(lu_a, b, check_finite=False)
  897. assert_array_almost_equal(x1, x2)
  898. class TestSVD_GESDD:
  899. def setup_method(self):
  900. self.lapack_driver = 'gesdd'
  901. seed(1234)
  902. def test_degenerate(self):
  903. assert_raises(TypeError, svd, [[1.]], lapack_driver=1.)
  904. assert_raises(ValueError, svd, [[1.]], lapack_driver='foo')
  905. def test_simple(self):
  906. a = [[1, 2, 3], [1, 20, 3], [2, 5, 6]]
  907. for full_matrices in (True, False):
  908. u, s, vh = svd(a, full_matrices=full_matrices,
  909. lapack_driver=self.lapack_driver)
  910. assert_array_almost_equal(u.T @ u, eye(3))
  911. assert_array_almost_equal(vh.T @ vh, eye(3))
  912. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  913. for i in range(len(s)):
  914. sigma[i, i] = s[i]
  915. assert_array_almost_equal(u @ sigma @ vh, a)
  916. def test_simple_singular(self):
  917. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  918. for full_matrices in (True, False):
  919. u, s, vh = svd(a, full_matrices=full_matrices,
  920. lapack_driver=self.lapack_driver)
  921. assert_array_almost_equal(u.T @ u, eye(3))
  922. assert_array_almost_equal(vh.T @ vh, eye(3))
  923. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  924. for i in range(len(s)):
  925. sigma[i, i] = s[i]
  926. assert_array_almost_equal(u @ sigma @ vh, a)
  927. def test_simple_underdet(self):
  928. a = [[1, 2, 3], [4, 5, 6]]
  929. for full_matrices in (True, False):
  930. u, s, vh = svd(a, full_matrices=full_matrices,
  931. lapack_driver=self.lapack_driver)
  932. assert_array_almost_equal(u.T @ u, eye(u.shape[0]))
  933. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  934. for i in range(len(s)):
  935. sigma[i, i] = s[i]
  936. assert_array_almost_equal(u @ sigma @ vh, a)
  937. def test_simple_overdet(self):
  938. a = [[1, 2], [4, 5], [3, 4]]
  939. for full_matrices in (True, False):
  940. u, s, vh = svd(a, full_matrices=full_matrices,
  941. lapack_driver=self.lapack_driver)
  942. assert_array_almost_equal(u.T @ u, eye(u.shape[1]))
  943. assert_array_almost_equal(vh.T @ vh, eye(2))
  944. sigma = zeros((u.shape[1], vh.shape[0]), s.dtype.char)
  945. for i in range(len(s)):
  946. sigma[i, i] = s[i]
  947. assert_array_almost_equal(u @ sigma @ vh, a)
  948. def test_random(self):
  949. n = 20
  950. m = 15
  951. for i in range(3):
  952. for a in [random([n, m]), random([m, n])]:
  953. for full_matrices in (True, False):
  954. u, s, vh = svd(a, full_matrices=full_matrices,
  955. lapack_driver=self.lapack_driver)
  956. assert_array_almost_equal(u.T @ u, eye(u.shape[1]))
  957. assert_array_almost_equal(vh @ vh.T, eye(vh.shape[0]))
  958. sigma = zeros((u.shape[1], vh.shape[0]), s.dtype.char)
  959. for i in range(len(s)):
  960. sigma[i, i] = s[i]
  961. assert_array_almost_equal(u @ sigma @ vh, a)
  962. def test_simple_complex(self):
  963. a = [[1, 2, 3], [1, 2j, 3], [2, 5, 6]]
  964. for full_matrices in (True, False):
  965. u, s, vh = svd(a, full_matrices=full_matrices,
  966. lapack_driver=self.lapack_driver)
  967. assert_array_almost_equal(u.conj().T @ u, eye(u.shape[1]))
  968. assert_array_almost_equal(vh.conj().T @ vh, eye(vh.shape[0]))
  969. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  970. for i in range(len(s)):
  971. sigma[i, i] = s[i]
  972. assert_array_almost_equal(u @ sigma @ vh, a)
  973. def test_random_complex(self):
  974. n = 20
  975. m = 15
  976. for i in range(3):
  977. for full_matrices in (True, False):
  978. for a in [random([n, m]), random([m, n])]:
  979. a = a + 1j*random(list(a.shape))
  980. u, s, vh = svd(a, full_matrices=full_matrices,
  981. lapack_driver=self.lapack_driver)
  982. assert_array_almost_equal(u.conj().T @ u,
  983. eye(u.shape[1]))
  984. # This fails when [m,n]
  985. # assert_array_almost_equal(vh.conj().T @ vh,
  986. # eye(len(vh),dtype=vh.dtype.char))
  987. sigma = zeros((u.shape[1], vh.shape[0]), s.dtype.char)
  988. for i in range(len(s)):
  989. sigma[i, i] = s[i]
  990. assert_array_almost_equal(u @ sigma @ vh, a)
  991. def test_crash_1580(self):
  992. sizes = [(13, 23), (30, 50), (60, 100)]
  993. np.random.seed(1234)
  994. for sz in sizes:
  995. for dt in [np.float32, np.float64, np.complex64, np.complex128]:
  996. a = np.random.rand(*sz).astype(dt)
  997. # should not crash
  998. svd(a, lapack_driver=self.lapack_driver)
  999. def test_check_finite(self):
  1000. a = [[1, 2, 3], [1, 20, 3], [2, 5, 6]]
  1001. u, s, vh = svd(a, check_finite=False, lapack_driver=self.lapack_driver)
  1002. assert_array_almost_equal(u.T @ u, eye(3))
  1003. assert_array_almost_equal(vh.T @ vh, eye(3))
  1004. sigma = zeros((u.shape[0], vh.shape[0]), s.dtype.char)
  1005. for i in range(len(s)):
  1006. sigma[i, i] = s[i]
  1007. assert_array_almost_equal(u @ sigma @ vh, a)
  1008. def test_gh_5039(self):
  1009. # This is a smoke test for https://github.com/scipy/scipy/issues/5039
  1010. #
  1011. # The following is reported to raise "ValueError: On entry to DGESDD
  1012. # parameter number 12 had an illegal value".
  1013. # `interp1d([1,2,3,4], [1,2,3,4], kind='cubic')`
  1014. # This is reported to only show up on LAPACK 3.0.3.
  1015. #
  1016. # The matrix below is taken from the call to
  1017. # `B = _fitpack._bsplmat(order, xk)` in interpolate._find_smoothest
  1018. b = np.array(
  1019. [[0.16666667, 0.66666667, 0.16666667, 0., 0., 0.],
  1020. [0., 0.16666667, 0.66666667, 0.16666667, 0., 0.],
  1021. [0., 0., 0.16666667, 0.66666667, 0.16666667, 0.],
  1022. [0., 0., 0., 0.16666667, 0.66666667, 0.16666667]])
  1023. svd(b, lapack_driver=self.lapack_driver)
  1024. @pytest.mark.skipif(not HAS_ILP64, reason="64-bit LAPACK required")
  1025. @pytest.mark.slow
  1026. def test_large_matrix(self):
  1027. check_free_memory(free_mb=17000)
  1028. A = np.zeros([1, 2**31], dtype=np.float32)
  1029. A[0, -1] = 1
  1030. u, s, vh = svd(A, full_matrices=False)
  1031. assert_allclose(s[0], 1.0)
  1032. assert_allclose(u[0, 0] * vh[0, -1], 1.0)
  1033. class TestSVD_GESVD(TestSVD_GESDD):
  1034. def setup_method(self):
  1035. self.lapack_driver = 'gesvd'
  1036. seed(1234)
  1037. class TestSVDVals:
  1038. def test_empty(self):
  1039. for a in [[]], np.empty((2, 0)), np.ones((0, 3)):
  1040. s = svdvals(a)
  1041. assert_equal(s, np.empty(0))
  1042. def test_simple(self):
  1043. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  1044. s = svdvals(a)
  1045. assert_(len(s) == 3)
  1046. assert_(s[0] >= s[1] >= s[2])
  1047. def test_simple_underdet(self):
  1048. a = [[1, 2, 3], [4, 5, 6]]
  1049. s = svdvals(a)
  1050. assert_(len(s) == 2)
  1051. assert_(s[0] >= s[1])
  1052. def test_simple_overdet(self):
  1053. a = [[1, 2], [4, 5], [3, 4]]
  1054. s = svdvals(a)
  1055. assert_(len(s) == 2)
  1056. assert_(s[0] >= s[1])
  1057. def test_simple_complex(self):
  1058. a = [[1, 2, 3], [1, 20, 3j], [2, 5, 6]]
  1059. s = svdvals(a)
  1060. assert_(len(s) == 3)
  1061. assert_(s[0] >= s[1] >= s[2])
  1062. def test_simple_underdet_complex(self):
  1063. a = [[1, 2, 3], [4, 5j, 6]]
  1064. s = svdvals(a)
  1065. assert_(len(s) == 2)
  1066. assert_(s[0] >= s[1])
  1067. def test_simple_overdet_complex(self):
  1068. a = [[1, 2], [4, 5], [3j, 4]]
  1069. s = svdvals(a)
  1070. assert_(len(s) == 2)
  1071. assert_(s[0] >= s[1])
  1072. def test_check_finite(self):
  1073. a = [[1, 2, 3], [1, 2, 3], [2, 5, 6]]
  1074. s = svdvals(a, check_finite=False)
  1075. assert_(len(s) == 3)
  1076. assert_(s[0] >= s[1] >= s[2])
  1077. @pytest.mark.slow
  1078. def test_crash_2609(self):
  1079. np.random.seed(1234)
  1080. a = np.random.rand(1500, 2800)
  1081. # Shouldn't crash:
  1082. svdvals(a)
  1083. class TestDiagSVD:
  1084. def test_simple(self):
  1085. assert_array_almost_equal(diagsvd([1, 0, 0], 3, 3),
  1086. [[1, 0, 0], [0, 0, 0], [0, 0, 0]])
  1087. class TestQR:
  1088. def setup_method(self):
  1089. seed(1234)
  1090. def test_simple(self):
  1091. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1092. q, r = qr(a)
  1093. assert_array_almost_equal(q.T @ q, eye(3))
  1094. assert_array_almost_equal(q @ r, a)
  1095. def test_simple_left(self):
  1096. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1097. q, r = qr(a)
  1098. c = [1, 2, 3]
  1099. qc, r2 = qr_multiply(a, c, "left")
  1100. assert_array_almost_equal(q @ c, qc)
  1101. assert_array_almost_equal(r, r2)
  1102. qc, r2 = qr_multiply(a, eye(3), "left")
  1103. assert_array_almost_equal(q, qc)
  1104. def test_simple_right(self):
  1105. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1106. q, r = qr(a)
  1107. c = [1, 2, 3]
  1108. qc, r2 = qr_multiply(a, c)
  1109. assert_array_almost_equal(c @ q, qc)
  1110. assert_array_almost_equal(r, r2)
  1111. qc, r = qr_multiply(a, eye(3))
  1112. assert_array_almost_equal(q, qc)
  1113. def test_simple_pivoting(self):
  1114. a = np.asarray([[8, 2, 3], [2, 9, 3], [5, 3, 6]])
  1115. q, r, p = qr(a, pivoting=True)
  1116. d = abs(diag(r))
  1117. assert_(np.all(d[1:] <= d[:-1]))
  1118. assert_array_almost_equal(q.T @ q, eye(3))
  1119. assert_array_almost_equal(q @ r, a[:, p])
  1120. q2, r2 = qr(a[:, p])
  1121. assert_array_almost_equal(q, q2)
  1122. assert_array_almost_equal(r, r2)
  1123. def test_simple_left_pivoting(self):
  1124. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1125. q, r, jpvt = qr(a, pivoting=True)
  1126. c = [1, 2, 3]
  1127. qc, r, jpvt = qr_multiply(a, c, "left", True)
  1128. assert_array_almost_equal(q @ c, qc)
  1129. def test_simple_right_pivoting(self):
  1130. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1131. q, r, jpvt = qr(a, pivoting=True)
  1132. c = [1, 2, 3]
  1133. qc, r, jpvt = qr_multiply(a, c, pivoting=True)
  1134. assert_array_almost_equal(c @ q, qc)
  1135. def test_simple_trap(self):
  1136. a = [[8, 2, 3], [2, 9, 3]]
  1137. q, r = qr(a)
  1138. assert_array_almost_equal(q.T @ q, eye(2))
  1139. assert_array_almost_equal(q @ r, a)
  1140. def test_simple_trap_pivoting(self):
  1141. a = np.asarray([[8, 2, 3], [2, 9, 3]])
  1142. q, r, p = qr(a, pivoting=True)
  1143. d = abs(diag(r))
  1144. assert_(np.all(d[1:] <= d[:-1]))
  1145. assert_array_almost_equal(q.T @ q, eye(2))
  1146. assert_array_almost_equal(q @ r, a[:, p])
  1147. q2, r2 = qr(a[:, p])
  1148. assert_array_almost_equal(q, q2)
  1149. assert_array_almost_equal(r, r2)
  1150. def test_simple_tall(self):
  1151. # full version
  1152. a = [[8, 2], [2, 9], [5, 3]]
  1153. q, r = qr(a)
  1154. assert_array_almost_equal(q.T @ q, eye(3))
  1155. assert_array_almost_equal(q @ r, a)
  1156. def test_simple_tall_pivoting(self):
  1157. # full version pivoting
  1158. a = np.asarray([[8, 2], [2, 9], [5, 3]])
  1159. q, r, p = qr(a, pivoting=True)
  1160. d = abs(diag(r))
  1161. assert_(np.all(d[1:] <= d[:-1]))
  1162. assert_array_almost_equal(q.T @ q, eye(3))
  1163. assert_array_almost_equal(q @ r, a[:, p])
  1164. q2, r2 = qr(a[:, p])
  1165. assert_array_almost_equal(q, q2)
  1166. assert_array_almost_equal(r, r2)
  1167. def test_simple_tall_e(self):
  1168. # economy version
  1169. a = [[8, 2], [2, 9], [5, 3]]
  1170. q, r = qr(a, mode='economic')
  1171. assert_array_almost_equal(q.T @ q, eye(2))
  1172. assert_array_almost_equal(q @ r, a)
  1173. assert_equal(q.shape, (3, 2))
  1174. assert_equal(r.shape, (2, 2))
  1175. def test_simple_tall_e_pivoting(self):
  1176. # economy version pivoting
  1177. a = np.asarray([[8, 2], [2, 9], [5, 3]])
  1178. q, r, p = qr(a, pivoting=True, mode='economic')
  1179. d = abs(diag(r))
  1180. assert_(np.all(d[1:] <= d[:-1]))
  1181. assert_array_almost_equal(q.T @ q, eye(2))
  1182. assert_array_almost_equal(q @ r, a[:, p])
  1183. q2, r2 = qr(a[:, p], mode='economic')
  1184. assert_array_almost_equal(q, q2)
  1185. assert_array_almost_equal(r, r2)
  1186. def test_simple_tall_left(self):
  1187. a = [[8, 2], [2, 9], [5, 3]]
  1188. q, r = qr(a, mode="economic")
  1189. c = [1, 2]
  1190. qc, r2 = qr_multiply(a, c, "left")
  1191. assert_array_almost_equal(q @ c, qc)
  1192. assert_array_almost_equal(r, r2)
  1193. c = array([1, 2, 0])
  1194. qc, r2 = qr_multiply(a, c, "left", overwrite_c=True)
  1195. assert_array_almost_equal(q @ c[:2], qc)
  1196. qc, r = qr_multiply(a, eye(2), "left")
  1197. assert_array_almost_equal(qc, q)
  1198. def test_simple_tall_left_pivoting(self):
  1199. a = [[8, 2], [2, 9], [5, 3]]
  1200. q, r, jpvt = qr(a, mode="economic", pivoting=True)
  1201. c = [1, 2]
  1202. qc, r, kpvt = qr_multiply(a, c, "left", True)
  1203. assert_array_equal(jpvt, kpvt)
  1204. assert_array_almost_equal(q @ c, qc)
  1205. qc, r, jpvt = qr_multiply(a, eye(2), "left", True)
  1206. assert_array_almost_equal(qc, q)
  1207. def test_simple_tall_right(self):
  1208. a = [[8, 2], [2, 9], [5, 3]]
  1209. q, r = qr(a, mode="economic")
  1210. c = [1, 2, 3]
  1211. cq, r2 = qr_multiply(a, c)
  1212. assert_array_almost_equal(c @ q, cq)
  1213. assert_array_almost_equal(r, r2)
  1214. cq, r = qr_multiply(a, eye(3))
  1215. assert_array_almost_equal(cq, q)
  1216. def test_simple_tall_right_pivoting(self):
  1217. a = [[8, 2], [2, 9], [5, 3]]
  1218. q, r, jpvt = qr(a, pivoting=True, mode="economic")
  1219. c = [1, 2, 3]
  1220. cq, r, jpvt = qr_multiply(a, c, pivoting=True)
  1221. assert_array_almost_equal(c @ q, cq)
  1222. cq, r, jpvt = qr_multiply(a, eye(3), pivoting=True)
  1223. assert_array_almost_equal(cq, q)
  1224. def test_simple_fat(self):
  1225. # full version
  1226. a = [[8, 2, 5], [2, 9, 3]]
  1227. q, r = qr(a)
  1228. assert_array_almost_equal(q.T @ q, eye(2))
  1229. assert_array_almost_equal(q @ r, a)
  1230. assert_equal(q.shape, (2, 2))
  1231. assert_equal(r.shape, (2, 3))
  1232. def test_simple_fat_pivoting(self):
  1233. # full version pivoting
  1234. a = np.asarray([[8, 2, 5], [2, 9, 3]])
  1235. q, r, p = qr(a, pivoting=True)
  1236. d = abs(diag(r))
  1237. assert_(np.all(d[1:] <= d[:-1]))
  1238. assert_array_almost_equal(q.T @ q, eye(2))
  1239. assert_array_almost_equal(q @ r, a[:, p])
  1240. assert_equal(q.shape, (2, 2))
  1241. assert_equal(r.shape, (2, 3))
  1242. q2, r2 = qr(a[:, p])
  1243. assert_array_almost_equal(q, q2)
  1244. assert_array_almost_equal(r, r2)
  1245. def test_simple_fat_e(self):
  1246. # economy version
  1247. a = [[8, 2, 3], [2, 9, 5]]
  1248. q, r = qr(a, mode='economic')
  1249. assert_array_almost_equal(q.T @ q, eye(2))
  1250. assert_array_almost_equal(q @ r, a)
  1251. assert_equal(q.shape, (2, 2))
  1252. assert_equal(r.shape, (2, 3))
  1253. def test_simple_fat_e_pivoting(self):
  1254. # economy version pivoting
  1255. a = np.asarray([[8, 2, 3], [2, 9, 5]])
  1256. q, r, p = qr(a, pivoting=True, mode='economic')
  1257. d = abs(diag(r))
  1258. assert_(np.all(d[1:] <= d[:-1]))
  1259. assert_array_almost_equal(q.T @ q, eye(2))
  1260. assert_array_almost_equal(q @ r, a[:, p])
  1261. assert_equal(q.shape, (2, 2))
  1262. assert_equal(r.shape, (2, 3))
  1263. q2, r2 = qr(a[:, p], mode='economic')
  1264. assert_array_almost_equal(q, q2)
  1265. assert_array_almost_equal(r, r2)
  1266. def test_simple_fat_left(self):
  1267. a = [[8, 2, 3], [2, 9, 5]]
  1268. q, r = qr(a, mode="economic")
  1269. c = [1, 2]
  1270. qc, r2 = qr_multiply(a, c, "left")
  1271. assert_array_almost_equal(q @ c, qc)
  1272. assert_array_almost_equal(r, r2)
  1273. qc, r = qr_multiply(a, eye(2), "left")
  1274. assert_array_almost_equal(qc, q)
  1275. def test_simple_fat_left_pivoting(self):
  1276. a = [[8, 2, 3], [2, 9, 5]]
  1277. q, r, jpvt = qr(a, mode="economic", pivoting=True)
  1278. c = [1, 2]
  1279. qc, r, jpvt = qr_multiply(a, c, "left", True)
  1280. assert_array_almost_equal(q @ c, qc)
  1281. qc, r, jpvt = qr_multiply(a, eye(2), "left", True)
  1282. assert_array_almost_equal(qc, q)
  1283. def test_simple_fat_right(self):
  1284. a = [[8, 2, 3], [2, 9, 5]]
  1285. q, r = qr(a, mode="economic")
  1286. c = [1, 2]
  1287. cq, r2 = qr_multiply(a, c)
  1288. assert_array_almost_equal(c @ q, cq)
  1289. assert_array_almost_equal(r, r2)
  1290. cq, r = qr_multiply(a, eye(2))
  1291. assert_array_almost_equal(cq, q)
  1292. def test_simple_fat_right_pivoting(self):
  1293. a = [[8, 2, 3], [2, 9, 5]]
  1294. q, r, jpvt = qr(a, pivoting=True, mode="economic")
  1295. c = [1, 2]
  1296. cq, r, jpvt = qr_multiply(a, c, pivoting=True)
  1297. assert_array_almost_equal(c @ q, cq)
  1298. cq, r, jpvt = qr_multiply(a, eye(2), pivoting=True)
  1299. assert_array_almost_equal(cq, q)
  1300. def test_simple_complex(self):
  1301. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1302. q, r = qr(a)
  1303. assert_array_almost_equal(q.conj().T @ q, eye(3))
  1304. assert_array_almost_equal(q @ r, a)
  1305. def test_simple_complex_left(self):
  1306. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1307. q, r = qr(a)
  1308. c = [1, 2, 3+4j]
  1309. qc, r = qr_multiply(a, c, "left")
  1310. assert_array_almost_equal(q @ c, qc)
  1311. qc, r = qr_multiply(a, eye(3), "left")
  1312. assert_array_almost_equal(q, qc)
  1313. def test_simple_complex_right(self):
  1314. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1315. q, r = qr(a)
  1316. c = [1, 2, 3+4j]
  1317. qc, r = qr_multiply(a, c)
  1318. assert_array_almost_equal(c @ q, qc)
  1319. qc, r = qr_multiply(a, eye(3))
  1320. assert_array_almost_equal(q, qc)
  1321. def test_simple_tall_complex_left(self):
  1322. a = [[8, 2+3j], [2, 9], [5+7j, 3]]
  1323. q, r = qr(a, mode="economic")
  1324. c = [1, 2+2j]
  1325. qc, r2 = qr_multiply(a, c, "left")
  1326. assert_array_almost_equal(q @ c, qc)
  1327. assert_array_almost_equal(r, r2)
  1328. c = array([1, 2, 0])
  1329. qc, r2 = qr_multiply(a, c, "left", overwrite_c=True)
  1330. assert_array_almost_equal(q @ c[:2], qc)
  1331. qc, r = qr_multiply(a, eye(2), "left")
  1332. assert_array_almost_equal(qc, q)
  1333. def test_simple_complex_left_conjugate(self):
  1334. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1335. q, r = qr(a)
  1336. c = [1, 2, 3+4j]
  1337. qc, r = qr_multiply(a, c, "left", conjugate=True)
  1338. assert_array_almost_equal(q.conj() @ c, qc)
  1339. def test_simple_complex_tall_left_conjugate(self):
  1340. a = [[3, 3+4j], [5, 2+2j], [3, 2]]
  1341. q, r = qr(a, mode='economic')
  1342. c = [1, 3+4j]
  1343. qc, r = qr_multiply(a, c, "left", conjugate=True)
  1344. assert_array_almost_equal(q.conj() @ c, qc)
  1345. def test_simple_complex_right_conjugate(self):
  1346. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1347. q, r = qr(a)
  1348. c = np.array([1, 2, 3+4j])
  1349. qc, r = qr_multiply(a, c, conjugate=True)
  1350. assert_array_almost_equal(c @ q.conj(), qc)
  1351. def test_simple_complex_pivoting(self):
  1352. a = array([[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]])
  1353. q, r, p = qr(a, pivoting=True)
  1354. d = abs(diag(r))
  1355. assert_(np.all(d[1:] <= d[:-1]))
  1356. assert_array_almost_equal(q.conj().T @ q, eye(3))
  1357. assert_array_almost_equal(q @ r, a[:, p])
  1358. q2, r2 = qr(a[:, p])
  1359. assert_array_almost_equal(q, q2)
  1360. assert_array_almost_equal(r, r2)
  1361. def test_simple_complex_left_pivoting(self):
  1362. a = array([[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]])
  1363. q, r, jpvt = qr(a, pivoting=True)
  1364. c = [1, 2, 3+4j]
  1365. qc, r, jpvt = qr_multiply(a, c, "left", True)
  1366. assert_array_almost_equal(q @ c, qc)
  1367. def test_simple_complex_right_pivoting(self):
  1368. a = array([[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]])
  1369. q, r, jpvt = qr(a, pivoting=True)
  1370. c = [1, 2, 3+4j]
  1371. qc, r, jpvt = qr_multiply(a, c, pivoting=True)
  1372. assert_array_almost_equal(c @ q, qc)
  1373. def test_random(self):
  1374. n = 20
  1375. for k in range(2):
  1376. a = random([n, n])
  1377. q, r = qr(a)
  1378. assert_array_almost_equal(q.T @ q, eye(n))
  1379. assert_array_almost_equal(q @ r, a)
  1380. def test_random_left(self):
  1381. n = 20
  1382. for k in range(2):
  1383. a = random([n, n])
  1384. q, r = qr(a)
  1385. c = random([n])
  1386. qc, r = qr_multiply(a, c, "left")
  1387. assert_array_almost_equal(q @ c, qc)
  1388. qc, r = qr_multiply(a, eye(n), "left")
  1389. assert_array_almost_equal(q, qc)
  1390. def test_random_right(self):
  1391. n = 20
  1392. for k in range(2):
  1393. a = random([n, n])
  1394. q, r = qr(a)
  1395. c = random([n])
  1396. cq, r = qr_multiply(a, c)
  1397. assert_array_almost_equal(c @ q, cq)
  1398. cq, r = qr_multiply(a, eye(n))
  1399. assert_array_almost_equal(q, cq)
  1400. def test_random_pivoting(self):
  1401. n = 20
  1402. for k in range(2):
  1403. a = random([n, n])
  1404. q, r, p = qr(a, pivoting=True)
  1405. d = abs(diag(r))
  1406. assert_(np.all(d[1:] <= d[:-1]))
  1407. assert_array_almost_equal(q.T @ q, eye(n))
  1408. assert_array_almost_equal(q @ r, a[:, p])
  1409. q2, r2 = qr(a[:, p])
  1410. assert_array_almost_equal(q, q2)
  1411. assert_array_almost_equal(r, r2)
  1412. def test_random_tall(self):
  1413. # full version
  1414. m = 200
  1415. n = 100
  1416. for k in range(2):
  1417. a = random([m, n])
  1418. q, r = qr(a)
  1419. assert_array_almost_equal(q.T @ q, eye(m))
  1420. assert_array_almost_equal(q @ r, a)
  1421. def test_random_tall_left(self):
  1422. # full version
  1423. m = 200
  1424. n = 100
  1425. for k in range(2):
  1426. a = random([m, n])
  1427. q, r = qr(a, mode="economic")
  1428. c = random([n])
  1429. qc, r = qr_multiply(a, c, "left")
  1430. assert_array_almost_equal(q @ c, qc)
  1431. qc, r = qr_multiply(a, eye(n), "left")
  1432. assert_array_almost_equal(qc, q)
  1433. def test_random_tall_right(self):
  1434. # full version
  1435. m = 200
  1436. n = 100
  1437. for k in range(2):
  1438. a = random([m, n])
  1439. q, r = qr(a, mode="economic")
  1440. c = random([m])
  1441. cq, r = qr_multiply(a, c)
  1442. assert_array_almost_equal(c @ q, cq)
  1443. cq, r = qr_multiply(a, eye(m))
  1444. assert_array_almost_equal(cq, q)
  1445. def test_random_tall_pivoting(self):
  1446. # full version pivoting
  1447. m = 200
  1448. n = 100
  1449. for k in range(2):
  1450. a = random([m, n])
  1451. q, r, p = qr(a, pivoting=True)
  1452. d = abs(diag(r))
  1453. assert_(np.all(d[1:] <= d[:-1]))
  1454. assert_array_almost_equal(q.T @ q, eye(m))
  1455. assert_array_almost_equal(q @ r, a[:, p])
  1456. q2, r2 = qr(a[:, p])
  1457. assert_array_almost_equal(q, q2)
  1458. assert_array_almost_equal(r, r2)
  1459. def test_random_tall_e(self):
  1460. # economy version
  1461. m = 200
  1462. n = 100
  1463. for k in range(2):
  1464. a = random([m, n])
  1465. q, r = qr(a, mode='economic')
  1466. assert_array_almost_equal(q.T @ q, eye(n))
  1467. assert_array_almost_equal(q @ r, a)
  1468. assert_equal(q.shape, (m, n))
  1469. assert_equal(r.shape, (n, n))
  1470. def test_random_tall_e_pivoting(self):
  1471. # economy version pivoting
  1472. m = 200
  1473. n = 100
  1474. for k in range(2):
  1475. a = random([m, n])
  1476. q, r, p = qr(a, pivoting=True, mode='economic')
  1477. d = abs(diag(r))
  1478. assert_(np.all(d[1:] <= d[:-1]))
  1479. assert_array_almost_equal(q.T @ q, eye(n))
  1480. assert_array_almost_equal(q @ r, a[:, p])
  1481. assert_equal(q.shape, (m, n))
  1482. assert_equal(r.shape, (n, n))
  1483. q2, r2 = qr(a[:, p], mode='economic')
  1484. assert_array_almost_equal(q, q2)
  1485. assert_array_almost_equal(r, r2)
  1486. def test_random_trap(self):
  1487. m = 100
  1488. n = 200
  1489. for k in range(2):
  1490. a = random([m, n])
  1491. q, r = qr(a)
  1492. assert_array_almost_equal(q.T @ q, eye(m))
  1493. assert_array_almost_equal(q @ r, a)
  1494. def test_random_trap_pivoting(self):
  1495. m = 100
  1496. n = 200
  1497. for k in range(2):
  1498. a = random([m, n])
  1499. q, r, p = qr(a, pivoting=True)
  1500. d = abs(diag(r))
  1501. assert_(np.all(d[1:] <= d[:-1]))
  1502. assert_array_almost_equal(q.T @ q, eye(m))
  1503. assert_array_almost_equal(q @ r, a[:, p])
  1504. q2, r2 = qr(a[:, p])
  1505. assert_array_almost_equal(q, q2)
  1506. assert_array_almost_equal(r, r2)
  1507. def test_random_complex(self):
  1508. n = 20
  1509. for k in range(2):
  1510. a = random([n, n])+1j*random([n, n])
  1511. q, r = qr(a)
  1512. assert_array_almost_equal(q.conj().T @ q, eye(n))
  1513. assert_array_almost_equal(q @ r, a)
  1514. def test_random_complex_left(self):
  1515. n = 20
  1516. for k in range(2):
  1517. a = random([n, n])+1j*random([n, n])
  1518. q, r = qr(a)
  1519. c = random([n])+1j*random([n])
  1520. qc, r = qr_multiply(a, c, "left")
  1521. assert_array_almost_equal(q @ c, qc)
  1522. qc, r = qr_multiply(a, eye(n), "left")
  1523. assert_array_almost_equal(q, qc)
  1524. def test_random_complex_right(self):
  1525. n = 20
  1526. for k in range(2):
  1527. a = random([n, n])+1j*random([n, n])
  1528. q, r = qr(a)
  1529. c = random([n])+1j*random([n])
  1530. cq, r = qr_multiply(a, c)
  1531. assert_array_almost_equal(c @ q, cq)
  1532. cq, r = qr_multiply(a, eye(n))
  1533. assert_array_almost_equal(q, cq)
  1534. def test_random_complex_pivoting(self):
  1535. n = 20
  1536. for k in range(2):
  1537. a = random([n, n])+1j*random([n, n])
  1538. q, r, p = qr(a, pivoting=True)
  1539. d = abs(diag(r))
  1540. assert_(np.all(d[1:] <= d[:-1]))
  1541. assert_array_almost_equal(q.conj().T @ q, eye(n))
  1542. assert_array_almost_equal(q @ r, a[:, p])
  1543. q2, r2 = qr(a[:, p])
  1544. assert_array_almost_equal(q, q2)
  1545. assert_array_almost_equal(r, r2)
  1546. def test_check_finite(self):
  1547. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1548. q, r = qr(a, check_finite=False)
  1549. assert_array_almost_equal(q.T @ q, eye(3))
  1550. assert_array_almost_equal(q @ r, a)
  1551. def test_lwork(self):
  1552. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1553. # Get comparison values
  1554. q, r = qr(a, lwork=None)
  1555. # Test against minimum valid lwork
  1556. q2, r2 = qr(a, lwork=3)
  1557. assert_array_almost_equal(q2, q)
  1558. assert_array_almost_equal(r2, r)
  1559. # Test against larger lwork
  1560. q3, r3 = qr(a, lwork=10)
  1561. assert_array_almost_equal(q3, q)
  1562. assert_array_almost_equal(r3, r)
  1563. # Test against explicit lwork=-1
  1564. q4, r4 = qr(a, lwork=-1)
  1565. assert_array_almost_equal(q4, q)
  1566. assert_array_almost_equal(r4, r)
  1567. # Test against invalid lwork
  1568. assert_raises(Exception, qr, (a,), {'lwork': 0})
  1569. assert_raises(Exception, qr, (a,), {'lwork': 2})
  1570. class TestRQ:
  1571. def setup_method(self):
  1572. seed(1234)
  1573. def test_simple(self):
  1574. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1575. r, q = rq(a)
  1576. assert_array_almost_equal(q @ q.T, eye(3))
  1577. assert_array_almost_equal(r @ q, a)
  1578. def test_r(self):
  1579. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1580. r, q = rq(a)
  1581. r2 = rq(a, mode='r')
  1582. assert_array_almost_equal(r, r2)
  1583. def test_random(self):
  1584. n = 20
  1585. for k in range(2):
  1586. a = random([n, n])
  1587. r, q = rq(a)
  1588. assert_array_almost_equal(q @ q.T, eye(n))
  1589. assert_array_almost_equal(r @ q, a)
  1590. def test_simple_trap(self):
  1591. a = [[8, 2, 3], [2, 9, 3]]
  1592. r, q = rq(a)
  1593. assert_array_almost_equal(q.T @ q, eye(3))
  1594. assert_array_almost_equal(r @ q, a)
  1595. def test_simple_tall(self):
  1596. a = [[8, 2], [2, 9], [5, 3]]
  1597. r, q = rq(a)
  1598. assert_array_almost_equal(q.T @ q, eye(2))
  1599. assert_array_almost_equal(r @ q, a)
  1600. def test_simple_fat(self):
  1601. a = [[8, 2, 5], [2, 9, 3]]
  1602. r, q = rq(a)
  1603. assert_array_almost_equal(q @ q.T, eye(3))
  1604. assert_array_almost_equal(r @ q, a)
  1605. def test_simple_complex(self):
  1606. a = [[3, 3+4j, 5], [5, 2, 2+7j], [3, 2, 7]]
  1607. r, q = rq(a)
  1608. assert_array_almost_equal(q @ q.conj().T, eye(3))
  1609. assert_array_almost_equal(r @ q, a)
  1610. def test_random_tall(self):
  1611. m = 200
  1612. n = 100
  1613. for k in range(2):
  1614. a = random([m, n])
  1615. r, q = rq(a)
  1616. assert_array_almost_equal(q @ q.T, eye(n))
  1617. assert_array_almost_equal(r @ q, a)
  1618. def test_random_trap(self):
  1619. m = 100
  1620. n = 200
  1621. for k in range(2):
  1622. a = random([m, n])
  1623. r, q = rq(a)
  1624. assert_array_almost_equal(q @ q.T, eye(n))
  1625. assert_array_almost_equal(r @ q, a)
  1626. def test_random_trap_economic(self):
  1627. m = 100
  1628. n = 200
  1629. for k in range(2):
  1630. a = random([m, n])
  1631. r, q = rq(a, mode='economic')
  1632. assert_array_almost_equal(q @ q.T, eye(m))
  1633. assert_array_almost_equal(r @ q, a)
  1634. assert_equal(q.shape, (m, n))
  1635. assert_equal(r.shape, (m, m))
  1636. def test_random_complex(self):
  1637. n = 20
  1638. for k in range(2):
  1639. a = random([n, n])+1j*random([n, n])
  1640. r, q = rq(a)
  1641. assert_array_almost_equal(q @ q.conj().T, eye(n))
  1642. assert_array_almost_equal(r @ q, a)
  1643. def test_random_complex_economic(self):
  1644. m = 100
  1645. n = 200
  1646. for k in range(2):
  1647. a = random([m, n])+1j*random([m, n])
  1648. r, q = rq(a, mode='economic')
  1649. assert_array_almost_equal(q @ q.conj().T, eye(m))
  1650. assert_array_almost_equal(r @ q, a)
  1651. assert_equal(q.shape, (m, n))
  1652. assert_equal(r.shape, (m, m))
  1653. def test_check_finite(self):
  1654. a = [[8, 2, 3], [2, 9, 3], [5, 3, 6]]
  1655. r, q = rq(a, check_finite=False)
  1656. assert_array_almost_equal(q @ q.T, eye(3))
  1657. assert_array_almost_equal(r @ q, a)
  1658. class TestSchur:
  1659. def check_schur(self, a, t, u, rtol, atol):
  1660. # Check that the Schur decomposition is correct.
  1661. assert_allclose(u @ t @ u.conj().T, a, rtol=rtol, atol=atol,
  1662. err_msg="Schur decomposition does not match 'a'")
  1663. # The expected value of u @ u.H - I is all zeros, so test
  1664. # with absolute tolerance only.
  1665. assert_allclose(u @ u.conj().T - np.eye(len(u)), 0, rtol=0, atol=atol,
  1666. err_msg="u is not unitary")
  1667. def test_simple(self):
  1668. a = [[8, 12, 3], [2, 9, 3], [10, 3, 6]]
  1669. t, z = schur(a)
  1670. self.check_schur(a, t, z, rtol=1e-14, atol=5e-15)
  1671. tc, zc = schur(a, 'complex')
  1672. assert_(np.any(ravel(iscomplex(zc))) and np.any(ravel(iscomplex(tc))))
  1673. self.check_schur(a, tc, zc, rtol=1e-14, atol=5e-15)
  1674. tc2, zc2 = rsf2csf(tc, zc)
  1675. self.check_schur(a, tc2, zc2, rtol=1e-14, atol=5e-15)
  1676. @pytest.mark.parametrize(
  1677. 'sort, expected_diag',
  1678. [('lhp', [-np.sqrt(2), -0.5, np.sqrt(2), 0.5]),
  1679. ('rhp', [np.sqrt(2), 0.5, -np.sqrt(2), -0.5]),
  1680. ('iuc', [-0.5, 0.5, np.sqrt(2), -np.sqrt(2)]),
  1681. ('ouc', [np.sqrt(2), -np.sqrt(2), -0.5, 0.5]),
  1682. (lambda x: x >= 0.0, [np.sqrt(2), 0.5, -np.sqrt(2), -0.5])]
  1683. )
  1684. def test_sort(self, sort, expected_diag):
  1685. # The exact eigenvalues of this matrix are
  1686. # -sqrt(2), sqrt(2), -1/2, 1/2.
  1687. a = [[4., 3., 1., -1.],
  1688. [-4.5, -3.5, -1., 1.],
  1689. [9., 6., -4., 4.5],
  1690. [6., 4., -3., 3.5]]
  1691. t, u, sdim = schur(a, sort=sort)
  1692. self.check_schur(a, t, u, rtol=1e-14, atol=5e-15)
  1693. assert_allclose(np.diag(t), expected_diag, rtol=1e-12)
  1694. assert_equal(2, sdim)
  1695. def test_sort_errors(self):
  1696. a = [[4., 3., 1., -1.],
  1697. [-4.5, -3.5, -1., 1.],
  1698. [9., 6., -4., 4.5],
  1699. [6., 4., -3., 3.5]]
  1700. assert_raises(ValueError, schur, a, sort='unsupported')
  1701. assert_raises(ValueError, schur, a, sort=1)
  1702. def test_check_finite(self):
  1703. a = [[8, 12, 3], [2, 9, 3], [10, 3, 6]]
  1704. t, z = schur(a, check_finite=False)
  1705. assert_array_almost_equal(z @ t @ z.conj().T, a)
  1706. class TestHessenberg:
  1707. def test_simple(self):
  1708. a = [[-149, -50, -154],
  1709. [537, 180, 546],
  1710. [-27, -9, -25]]
  1711. h1 = [[-149.0000, 42.2037, -156.3165],
  1712. [-537.6783, 152.5511, -554.9272],
  1713. [0, 0.0728, 2.4489]]
  1714. h, q = hessenberg(a, calc_q=1)
  1715. assert_array_almost_equal(q.T @ a @ q, h)
  1716. assert_array_almost_equal(h, h1, decimal=4)
  1717. def test_simple_complex(self):
  1718. a = [[-149, -50, -154],
  1719. [537, 180j, 546],
  1720. [-27j, -9, -25]]
  1721. h, q = hessenberg(a, calc_q=1)
  1722. assert_array_almost_equal(q.conj().T @ a @ q, h)
  1723. def test_simple2(self):
  1724. a = [[1, 2, 3, 4, 5, 6, 7],
  1725. [0, 2, 3, 4, 6, 7, 2],
  1726. [0, 2, 2, 3, 0, 3, 2],
  1727. [0, 0, 2, 8, 0, 0, 2],
  1728. [0, 3, 1, 2, 0, 1, 2],
  1729. [0, 1, 2, 3, 0, 1, 0],
  1730. [0, 0, 0, 0, 0, 1, 2]]
  1731. h, q = hessenberg(a, calc_q=1)
  1732. assert_array_almost_equal(q.T @ a @ q, h)
  1733. def test_simple3(self):
  1734. a = np.eye(3)
  1735. a[-1, 0] = 2
  1736. h, q = hessenberg(a, calc_q=1)
  1737. assert_array_almost_equal(q.T @ a @ q, h)
  1738. def test_random(self):
  1739. n = 20
  1740. for k in range(2):
  1741. a = random([n, n])
  1742. h, q = hessenberg(a, calc_q=1)
  1743. assert_array_almost_equal(q.T @ a @ q, h)
  1744. def test_random_complex(self):
  1745. n = 20
  1746. for k in range(2):
  1747. a = random([n, n])+1j*random([n, n])
  1748. h, q = hessenberg(a, calc_q=1)
  1749. assert_array_almost_equal(q.conj().T @ a @ q, h)
  1750. def test_check_finite(self):
  1751. a = [[-149, -50, -154],
  1752. [537, 180, 546],
  1753. [-27, -9, -25]]
  1754. h1 = [[-149.0000, 42.2037, -156.3165],
  1755. [-537.6783, 152.5511, -554.9272],
  1756. [0, 0.0728, 2.4489]]
  1757. h, q = hessenberg(a, calc_q=1, check_finite=False)
  1758. assert_array_almost_equal(q.T @ a @ q, h)
  1759. assert_array_almost_equal(h, h1, decimal=4)
  1760. def test_2x2(self):
  1761. a = [[2, 1], [7, 12]]
  1762. h, q = hessenberg(a, calc_q=1)
  1763. assert_array_almost_equal(q, np.eye(2))
  1764. assert_array_almost_equal(h, a)
  1765. b = [[2-7j, 1+2j], [7+3j, 12-2j]]
  1766. h2, q2 = hessenberg(b, calc_q=1)
  1767. assert_array_almost_equal(q2, np.eye(2))
  1768. assert_array_almost_equal(h2, b)
  1769. class TestQZ:
  1770. def setup_method(self):
  1771. seed(12345)
  1772. @pytest.mark.xfail(sys.platform == 'darwin',
  1773. reason="gges[float32] broken for OpenBLAS on macOS, see gh-16949")
  1774. def test_qz_single(self):
  1775. n = 5
  1776. A = random([n, n]).astype(float32)
  1777. B = random([n, n]).astype(float32)
  1778. AA, BB, Q, Z = qz(A, B)
  1779. assert_array_almost_equal(Q @ AA @ Z.T, A, decimal=5)
  1780. assert_array_almost_equal(Q @ BB @ Z.T, B, decimal=5)
  1781. assert_array_almost_equal(Q @ Q.T, eye(n), decimal=5)
  1782. assert_array_almost_equal(Z @ Z.T, eye(n), decimal=5)
  1783. assert_(np.all(diag(BB) >= 0))
  1784. def test_qz_double(self):
  1785. n = 5
  1786. A = random([n, n])
  1787. B = random([n, n])
  1788. AA, BB, Q, Z = qz(A, B)
  1789. assert_array_almost_equal(Q @ AA @ Z.T, A)
  1790. assert_array_almost_equal(Q @ BB @ Z.T, B)
  1791. assert_array_almost_equal(Q @ Q.T, eye(n))
  1792. assert_array_almost_equal(Z @ Z.T, eye(n))
  1793. assert_(np.all(diag(BB) >= 0))
  1794. def test_qz_complex(self):
  1795. n = 5
  1796. A = random([n, n]) + 1j*random([n, n])
  1797. B = random([n, n]) + 1j*random([n, n])
  1798. AA, BB, Q, Z = qz(A, B)
  1799. assert_array_almost_equal(Q @ AA @ Z.conj().T, A)
  1800. assert_array_almost_equal(Q @ BB @ Z.conj().T, B)
  1801. assert_array_almost_equal(Q @ Q.conj().T, eye(n))
  1802. assert_array_almost_equal(Z @ Z.conj().T, eye(n))
  1803. assert_(np.all(diag(BB) >= 0))
  1804. assert_(np.all(diag(BB).imag == 0))
  1805. def test_qz_complex64(self):
  1806. n = 5
  1807. A = (random([n, n]) + 1j*random([n, n])).astype(complex64)
  1808. B = (random([n, n]) + 1j*random([n, n])).astype(complex64)
  1809. AA, BB, Q, Z = qz(A, B)
  1810. assert_array_almost_equal(Q @ AA @ Z.conj().T, A, decimal=5)
  1811. assert_array_almost_equal(Q @ BB @ Z.conj().T, B, decimal=5)
  1812. assert_array_almost_equal(Q @ Q.conj().T, eye(n), decimal=5)
  1813. assert_array_almost_equal(Z @ Z.conj().T, eye(n), decimal=5)
  1814. assert_(np.all(diag(BB) >= 0))
  1815. assert_(np.all(diag(BB).imag == 0))
  1816. def test_qz_double_complex(self):
  1817. n = 5
  1818. A = random([n, n])
  1819. B = random([n, n])
  1820. AA, BB, Q, Z = qz(A, B, output='complex')
  1821. aa = Q @ AA @ Z.conj().T
  1822. assert_array_almost_equal(aa.real, A)
  1823. assert_array_almost_equal(aa.imag, 0)
  1824. bb = Q @ BB @ Z.conj().T
  1825. assert_array_almost_equal(bb.real, B)
  1826. assert_array_almost_equal(bb.imag, 0)
  1827. assert_array_almost_equal(Q @ Q.conj().T, eye(n))
  1828. assert_array_almost_equal(Z @ Z.conj().T, eye(n))
  1829. assert_(np.all(diag(BB) >= 0))
  1830. def test_qz_double_sort(self):
  1831. # from https://www.nag.com/lapack-ex/node119.html
  1832. # NOTE: These matrices may be ill-conditioned and lead to a
  1833. # seg fault on certain python versions when compiled with
  1834. # sse2 or sse3 older ATLAS/LAPACK binaries for windows
  1835. # A = np.array([[3.9, 12.5, -34.5, -0.5],
  1836. # [ 4.3, 21.5, -47.5, 7.5],
  1837. # [ 4.3, 21.5, -43.5, 3.5],
  1838. # [ 4.4, 26.0, -46.0, 6.0 ]])
  1839. # B = np.array([[ 1.0, 2.0, -3.0, 1.0],
  1840. # [1.0, 3.0, -5.0, 4.0],
  1841. # [1.0, 3.0, -4.0, 3.0],
  1842. # [1.0, 3.0, -4.0, 4.0]])
  1843. A = np.array([[3.9, 12.5, -34.5, 2.5],
  1844. [4.3, 21.5, -47.5, 7.5],
  1845. [4.3, 1.5, -43.5, 3.5],
  1846. [4.4, 6.0, -46.0, 6.0]])
  1847. B = np.array([[1.0, 1.0, -3.0, 1.0],
  1848. [1.0, 3.0, -5.0, 4.4],
  1849. [1.0, 2.0, -4.0, 1.0],
  1850. [1.2, 3.0, -4.0, 4.0]])
  1851. assert_raises(ValueError, qz, A, B, sort=lambda ar, ai, beta: ai == 0)
  1852. if False:
  1853. AA, BB, Q, Z, sdim = qz(A, B, sort=lambda ar, ai, beta: ai == 0)
  1854. # assert_(sdim == 2)
  1855. assert_(sdim == 4)
  1856. assert_array_almost_equal(Q @ AA @ Z.T, A)
  1857. assert_array_almost_equal(Q @ BB @ Z.T, B)
  1858. # test absolute values bc the sign is ambiguous and
  1859. # might be platform dependent
  1860. assert_array_almost_equal(np.abs(AA), np.abs(np.array(
  1861. [[35.7864, -80.9061, -12.0629, -9.498],
  1862. [0., 2.7638, -2.3505, 7.3256],
  1863. [0., 0., 0.6258, -0.0398],
  1864. [0., 0., 0., -12.8217]])), 4)
  1865. assert_array_almost_equal(np.abs(BB), np.abs(np.array(
  1866. [[4.5324, -8.7878, 3.2357, -3.5526],
  1867. [0., 1.4314, -2.1894, 0.9709],
  1868. [0., 0., 1.3126, -0.3468],
  1869. [0., 0., 0., 0.559]])), 4)
  1870. assert_array_almost_equal(np.abs(Q), np.abs(np.array(
  1871. [[-0.4193, -0.605, -0.1894, -0.6498],
  1872. [-0.5495, 0.6987, 0.2654, -0.3734],
  1873. [-0.4973, -0.3682, 0.6194, 0.4832],
  1874. [-0.5243, 0.1008, -0.7142, 0.4526]])), 4)
  1875. assert_array_almost_equal(np.abs(Z), np.abs(np.array(
  1876. [[-0.9471, -0.2971, -0.1217, 0.0055],
  1877. [-0.0367, 0.1209, 0.0358, 0.9913],
  1878. [0.3171, -0.9041, -0.2547, 0.1312],
  1879. [0.0346, 0.2824, -0.9587, 0.0014]])), 4)
  1880. # test absolute values bc the sign is ambiguous and might be platform
  1881. # dependent
  1882. # assert_array_almost_equal(abs(AA), abs(np.array([
  1883. # [3.8009, -69.4505, 50.3135, -43.2884],
  1884. # [0.0000, 9.2033, -0.2001, 5.9881],
  1885. # [0.0000, 0.0000, 1.4279, 4.4453],
  1886. # [0.0000, 0.0000, 0.9019, -1.1962]])), 4)
  1887. # assert_array_almost_equal(abs(BB), abs(np.array([
  1888. # [1.9005, -10.2285, 0.8658, -5.2134],
  1889. # [0.0000, 2.3008, 0.7915, 0.4262],
  1890. # [0.0000, 0.0000, 0.8101, 0.0000],
  1891. # [0.0000, 0.0000, 0.0000, -0.2823]])), 4)
  1892. # assert_array_almost_equal(abs(Q), abs(np.array([
  1893. # [0.4642, 0.7886, 0.2915, -0.2786],
  1894. # [0.5002, -0.5986, 0.5638, -0.2713],
  1895. # [0.5002, 0.0154, -0.0107, 0.8657],
  1896. # [0.5331, -0.1395, -0.7727, -0.3151]])), 4)
  1897. # assert_array_almost_equal(dot(Q,Q.T), eye(4))
  1898. # assert_array_almost_equal(abs(Z), abs(np.array([
  1899. # [0.9961, -0.0014, 0.0887, -0.0026],
  1900. # [0.0057, -0.0404, -0.0938, -0.9948],
  1901. # [0.0626, 0.7194, -0.6908, 0.0363],
  1902. # [0.0626, -0.6934, -0.7114, 0.0956]])), 4)
  1903. # assert_array_almost_equal(dot(Z,Z.T), eye(4))
  1904. # def test_qz_complex_sort(self):
  1905. # cA = np.array([
  1906. # [-21.10+22.50*1j, 53.50+-50.50*1j, -34.50+127.50*1j, 7.50+ 0.50*1j],
  1907. # [-0.46+ -7.78*1j, -3.50+-37.50*1j, -15.50+ 58.50*1j,-10.50+ -1.50*1j],
  1908. # [ 4.30+ -5.50*1j, 39.70+-17.10*1j, -68.50+ 12.50*1j, -7.50+ -3.50*1j],
  1909. # [ 5.50+ 4.40*1j, 14.40+ 43.30*1j, -32.50+-46.00*1j,-19.00+-32.50*1j]])
  1910. # cB = np.array([
  1911. # [1.00+ -5.00*1j, 1.60+ 1.20*1j,-3.00+ 0.00*1j, 0.00+ -1.00*1j],
  1912. # [0.80+ -0.60*1j, 3.00+ -5.00*1j,-4.00+ 3.00*1j,-2.40+ -3.20*1j],
  1913. # [1.00+ 0.00*1j, 2.40+ 1.80*1j,-4.00+ -5.00*1j, 0.00+ -3.00*1j],
  1914. # [0.00+ 1.00*1j,-1.80+ 2.40*1j, 0.00+ -4.00*1j, 4.00+ -5.00*1j]])
  1915. # AAS,BBS,QS,ZS,sdim = qz(cA,cB,sort='lhp')
  1916. # eigenvalues = diag(AAS)/diag(BBS)
  1917. # assert_(np.all(np.real(eigenvalues[:sdim] < 0)))
  1918. # assert_(np.all(np.real(eigenvalues[sdim:] > 0)))
  1919. def test_check_finite(self):
  1920. n = 5
  1921. A = random([n, n])
  1922. B = random([n, n])
  1923. AA, BB, Q, Z = qz(A, B, check_finite=False)
  1924. assert_array_almost_equal(Q @ AA @ Z.T, A)
  1925. assert_array_almost_equal(Q @ BB @ Z.T, B)
  1926. assert_array_almost_equal(Q @ Q.T, eye(n))
  1927. assert_array_almost_equal(Z @ Z.T, eye(n))
  1928. assert_(np.all(diag(BB) >= 0))
  1929. def _make_pos(X):
  1930. # the decompositions can have different signs than verified results
  1931. return np.sign(X)*X
  1932. class TestOrdQZ:
  1933. @classmethod
  1934. def setup_class(cls):
  1935. # https://www.nag.com/lapack-ex/node119.html
  1936. A1 = np.array([[-21.10 - 22.50j, 53.5 - 50.5j, -34.5 + 127.5j,
  1937. 7.5 + 0.5j],
  1938. [-0.46 - 7.78j, -3.5 - 37.5j, -15.5 + 58.5j,
  1939. -10.5 - 1.5j],
  1940. [4.30 - 5.50j, 39.7 - 17.1j, -68.5 + 12.5j,
  1941. -7.5 - 3.5j],
  1942. [5.50 + 4.40j, 14.4 + 43.3j, -32.5 - 46.0j,
  1943. -19.0 - 32.5j]])
  1944. B1 = np.array([[1.0 - 5.0j, 1.6 + 1.2j, -3 + 0j, 0.0 - 1.0j],
  1945. [0.8 - 0.6j, .0 - 5.0j, -4 + 3j, -2.4 - 3.2j],
  1946. [1.0 + 0.0j, 2.4 + 1.8j, -4 - 5j, 0.0 - 3.0j],
  1947. [0.0 + 1.0j, -1.8 + 2.4j, 0 - 4j, 4.0 - 5.0j]])
  1948. # https://www.nag.com/numeric/fl/nagdoc_fl23/xhtml/F08/f08yuf.xml
  1949. A2 = np.array([[3.9, 12.5, -34.5, -0.5],
  1950. [4.3, 21.5, -47.5, 7.5],
  1951. [4.3, 21.5, -43.5, 3.5],
  1952. [4.4, 26.0, -46.0, 6.0]])
  1953. B2 = np.array([[1, 2, -3, 1],
  1954. [1, 3, -5, 4],
  1955. [1, 3, -4, 3],
  1956. [1, 3, -4, 4]])
  1957. # example with the eigenvalues
  1958. # -0.33891648, 1.61217396+0.74013521j, 1.61217396-0.74013521j,
  1959. # 0.61244091
  1960. # thus featuring:
  1961. # * one complex conjugate eigenvalue pair,
  1962. # * one eigenvalue in the lhp
  1963. # * 2 eigenvalues in the unit circle
  1964. # * 2 non-real eigenvalues
  1965. A3 = np.array([[5., 1., 3., 3.],
  1966. [4., 4., 2., 7.],
  1967. [7., 4., 1., 3.],
  1968. [0., 4., 8., 7.]])
  1969. B3 = np.array([[8., 10., 6., 10.],
  1970. [7., 7., 2., 9.],
  1971. [9., 1., 6., 6.],
  1972. [5., 1., 4., 7.]])
  1973. # example with infinite eigenvalues
  1974. A4 = np.eye(2)
  1975. B4 = np.diag([0, 1])
  1976. # example with (alpha, beta) = (0, 0)
  1977. A5 = np.diag([1, 0])
  1978. cls.A = [A1, A2, A3, A4, A5]
  1979. cls.B = [B1, B2, B3, B4, A5]
  1980. def qz_decomp(self, sort):
  1981. with np.errstate(all='raise'):
  1982. ret = [ordqz(Ai, Bi, sort=sort) for Ai, Bi in zip(self.A, self.B)]
  1983. return tuple(ret)
  1984. def check(self, A, B, sort, AA, BB, alpha, beta, Q, Z):
  1985. Id = np.eye(*A.shape)
  1986. # make sure Q and Z are orthogonal
  1987. assert_array_almost_equal(Q @ Q.T.conj(), Id)
  1988. assert_array_almost_equal(Z @ Z.T.conj(), Id)
  1989. # check factorization
  1990. assert_array_almost_equal(Q @ AA, A @ Z)
  1991. assert_array_almost_equal(Q @ BB, B @ Z)
  1992. # check shape of AA and BB
  1993. assert_array_equal(np.tril(AA, -2), np.zeros(AA.shape))
  1994. assert_array_equal(np.tril(BB, -1), np.zeros(BB.shape))
  1995. # check eigenvalues
  1996. for i in range(A.shape[0]):
  1997. # does the current diagonal element belong to a 2-by-2 block
  1998. # that was already checked?
  1999. if i > 0 and A[i, i - 1] != 0:
  2000. continue
  2001. # take care of 2-by-2 blocks
  2002. if i < AA.shape[0] - 1 and AA[i + 1, i] != 0:
  2003. evals, _ = eig(AA[i:i + 2, i:i + 2], BB[i:i + 2, i:i + 2])
  2004. # make sure the pair of complex conjugate eigenvalues
  2005. # is ordered consistently (positive imaginary part first)
  2006. if evals[0].imag < 0:
  2007. evals = evals[[1, 0]]
  2008. tmp = alpha[i:i + 2]/beta[i:i + 2]
  2009. if tmp[0].imag < 0:
  2010. tmp = tmp[[1, 0]]
  2011. assert_array_almost_equal(evals, tmp)
  2012. else:
  2013. if alpha[i] == 0 and beta[i] == 0:
  2014. assert_equal(AA[i, i], 0)
  2015. assert_equal(BB[i, i], 0)
  2016. elif beta[i] == 0:
  2017. assert_equal(BB[i, i], 0)
  2018. else:
  2019. assert_almost_equal(AA[i, i]/BB[i, i], alpha[i]/beta[i])
  2020. sortfun = _select_function(sort)
  2021. lastsort = True
  2022. for i in range(A.shape[0]):
  2023. cursort = sortfun(np.array([alpha[i]]), np.array([beta[i]]))
  2024. # once the sorting criterion was not matched all subsequent
  2025. # eigenvalues also shouldn't match
  2026. if not lastsort:
  2027. assert not cursort
  2028. lastsort = cursort
  2029. def check_all(self, sort):
  2030. ret = self.qz_decomp(sort)
  2031. for reti, Ai, Bi in zip(ret, self.A, self.B):
  2032. self.check(Ai, Bi, sort, *reti)
  2033. def test_lhp(self):
  2034. self.check_all('lhp')
  2035. def test_rhp(self):
  2036. self.check_all('rhp')
  2037. def test_iuc(self):
  2038. self.check_all('iuc')
  2039. def test_ouc(self):
  2040. self.check_all('ouc')
  2041. def test_ref(self):
  2042. # real eigenvalues first (top-left corner)
  2043. def sort(x, y):
  2044. out = np.empty_like(x, dtype=bool)
  2045. nonzero = (y != 0)
  2046. out[~nonzero] = False
  2047. out[nonzero] = (x[nonzero]/y[nonzero]).imag == 0
  2048. return out
  2049. self.check_all(sort)
  2050. def test_cef(self):
  2051. # complex eigenvalues first (top-left corner)
  2052. def sort(x, y):
  2053. out = np.empty_like(x, dtype=bool)
  2054. nonzero = (y != 0)
  2055. out[~nonzero] = False
  2056. out[nonzero] = (x[nonzero]/y[nonzero]).imag != 0
  2057. return out
  2058. self.check_all(sort)
  2059. def test_diff_input_types(self):
  2060. ret = ordqz(self.A[1], self.B[2], sort='lhp')
  2061. self.check(self.A[1], self.B[2], 'lhp', *ret)
  2062. ret = ordqz(self.B[2], self.A[1], sort='lhp')
  2063. self.check(self.B[2], self.A[1], 'lhp', *ret)
  2064. def test_sort_explicit(self):
  2065. # Test order of the eigenvalues in the 2 x 2 case where we can
  2066. # explicitly compute the solution
  2067. A1 = np.eye(2)
  2068. B1 = np.diag([-2, 0.5])
  2069. expected1 = [('lhp', [-0.5, 2]),
  2070. ('rhp', [2, -0.5]),
  2071. ('iuc', [-0.5, 2]),
  2072. ('ouc', [2, -0.5])]
  2073. A2 = np.eye(2)
  2074. B2 = np.diag([-2 + 1j, 0.5 + 0.5j])
  2075. expected2 = [('lhp', [1/(-2 + 1j), 1/(0.5 + 0.5j)]),
  2076. ('rhp', [1/(0.5 + 0.5j), 1/(-2 + 1j)]),
  2077. ('iuc', [1/(-2 + 1j), 1/(0.5 + 0.5j)]),
  2078. ('ouc', [1/(0.5 + 0.5j), 1/(-2 + 1j)])]
  2079. # 'lhp' is ambiguous so don't test it
  2080. A3 = np.eye(2)
  2081. B3 = np.diag([2, 0])
  2082. expected3 = [('rhp', [0.5, np.inf]),
  2083. ('iuc', [0.5, np.inf]),
  2084. ('ouc', [np.inf, 0.5])]
  2085. # 'rhp' is ambiguous so don't test it
  2086. A4 = np.eye(2)
  2087. B4 = np.diag([-2, 0])
  2088. expected4 = [('lhp', [-0.5, np.inf]),
  2089. ('iuc', [-0.5, np.inf]),
  2090. ('ouc', [np.inf, -0.5])]
  2091. A5 = np.diag([0, 1])
  2092. B5 = np.diag([0, 0.5])
  2093. # 'lhp' and 'iuc' are ambiguous so don't test them
  2094. expected5 = [('rhp', [2, np.nan]),
  2095. ('ouc', [2, np.nan])]
  2096. A = [A1, A2, A3, A4, A5]
  2097. B = [B1, B2, B3, B4, B5]
  2098. expected = [expected1, expected2, expected3, expected4, expected5]
  2099. for Ai, Bi, expectedi in zip(A, B, expected):
  2100. for sortstr, expected_eigvals in expectedi:
  2101. _, _, alpha, beta, _, _ = ordqz(Ai, Bi, sort=sortstr)
  2102. azero = (alpha == 0)
  2103. bzero = (beta == 0)
  2104. x = np.empty_like(alpha)
  2105. x[azero & bzero] = np.nan
  2106. x[~azero & bzero] = np.inf
  2107. x[~bzero] = alpha[~bzero]/beta[~bzero]
  2108. assert_allclose(expected_eigvals, x)
  2109. class TestOrdQZWorkspaceSize:
  2110. def setup_method(self):
  2111. seed(12345)
  2112. def test_decompose(self):
  2113. N = 202
  2114. # raises error if lwork parameter to dtrsen is too small
  2115. for ddtype in [np.float32, np.float64]:
  2116. A = random((N, N)).astype(ddtype)
  2117. B = random((N, N)).astype(ddtype)
  2118. # sort = lambda ar, ai, b: ar**2 + ai**2 < b**2
  2119. _ = ordqz(A, B, sort=lambda alpha, beta: alpha < beta,
  2120. output='real')
  2121. for ddtype in [np.complex128, np.complex64]:
  2122. A = random((N, N)).astype(ddtype)
  2123. B = random((N, N)).astype(ddtype)
  2124. _ = ordqz(A, B, sort=lambda alpha, beta: alpha < beta,
  2125. output='complex')
  2126. @pytest.mark.slow
  2127. def test_decompose_ouc(self):
  2128. N = 202
  2129. # segfaults if lwork parameter to dtrsen is too small
  2130. for ddtype in [np.float32, np.float64, np.complex128, np.complex64]:
  2131. A = random((N, N)).astype(ddtype)
  2132. B = random((N, N)).astype(ddtype)
  2133. S, T, alpha, beta, U, V = ordqz(A, B, sort='ouc')
  2134. class TestDatacopied:
  2135. def test_datacopied(self):
  2136. from scipy.linalg._decomp import _datacopied
  2137. M = matrix([[0, 1], [2, 3]])
  2138. A = asarray(M)
  2139. L = M.tolist()
  2140. M2 = M.copy()
  2141. class Fake1:
  2142. def __array__(self):
  2143. return A
  2144. class Fake2:
  2145. __array_interface__ = A.__array_interface__
  2146. F1 = Fake1()
  2147. F2 = Fake2()
  2148. for item, status in [(M, False), (A, False), (L, True),
  2149. (M2, False), (F1, False), (F2, False)]:
  2150. arr = asarray(item)
  2151. assert_equal(_datacopied(arr, item), status,
  2152. err_msg=repr(item))
  2153. def test_aligned_mem_float():
  2154. """Check linalg works with non-aligned memory (float32)"""
  2155. # Allocate 402 bytes of memory (allocated on boundary)
  2156. a = arange(402, dtype=np.uint8)
  2157. # Create an array with boundary offset 4
  2158. z = np.frombuffer(a.data, offset=2, count=100, dtype=float32)
  2159. z.shape = 10, 10
  2160. eig(z, overwrite_a=True)
  2161. eig(z.T, overwrite_a=True)
  2162. @pytest.mark.skipif(platform.machine() == 'ppc64le',
  2163. reason="crashes on ppc64le")
  2164. def test_aligned_mem():
  2165. """Check linalg works with non-aligned memory (float64)"""
  2166. # Allocate 804 bytes of memory (allocated on boundary)
  2167. a = arange(804, dtype=np.uint8)
  2168. # Create an array with boundary offset 4
  2169. z = np.frombuffer(a.data, offset=4, count=100, dtype=float)
  2170. z.shape = 10, 10
  2171. eig(z, overwrite_a=True)
  2172. eig(z.T, overwrite_a=True)
  2173. def test_aligned_mem_complex():
  2174. """Check that complex objects don't need to be completely aligned"""
  2175. # Allocate 1608 bytes of memory (allocated on boundary)
  2176. a = zeros(1608, dtype=np.uint8)
  2177. # Create an array with boundary offset 8
  2178. z = np.frombuffer(a.data, offset=8, count=100, dtype=complex)
  2179. z.shape = 10, 10
  2180. eig(z, overwrite_a=True)
  2181. # This does not need special handling
  2182. eig(z.T, overwrite_a=True)
  2183. def check_lapack_misaligned(func, args, kwargs):
  2184. args = list(args)
  2185. for i in range(len(args)):
  2186. a = args[:]
  2187. if isinstance(a[i], np.ndarray):
  2188. # Try misaligning a[i]
  2189. aa = np.zeros(a[i].size*a[i].dtype.itemsize+8, dtype=np.uint8)
  2190. aa = np.frombuffer(aa.data, offset=4, count=a[i].size,
  2191. dtype=a[i].dtype)
  2192. aa.shape = a[i].shape
  2193. aa[...] = a[i]
  2194. a[i] = aa
  2195. func(*a, **kwargs)
  2196. if len(a[i].shape) > 1:
  2197. a[i] = a[i].T
  2198. func(*a, **kwargs)
  2199. @pytest.mark.xfail(run=False,
  2200. reason="Ticket #1152, triggers a segfault in rare cases.")
  2201. def test_lapack_misaligned():
  2202. M = np.eye(10, dtype=float)
  2203. R = np.arange(100)
  2204. R.shape = 10, 10
  2205. S = np.arange(20000, dtype=np.uint8)
  2206. S = np.frombuffer(S.data, offset=4, count=100, dtype=float)
  2207. S.shape = 10, 10
  2208. b = np.ones(10)
  2209. LU, piv = lu_factor(S)
  2210. for (func, args, kwargs) in [
  2211. (eig, (S,), dict(overwrite_a=True)), # crash
  2212. (eigvals, (S,), dict(overwrite_a=True)), # no crash
  2213. (lu, (S,), dict(overwrite_a=True)), # no crash
  2214. (lu_factor, (S,), dict(overwrite_a=True)), # no crash
  2215. (lu_solve, ((LU, piv), b), dict(overwrite_b=True)),
  2216. (solve, (S, b), dict(overwrite_a=True, overwrite_b=True)),
  2217. (svd, (M,), dict(overwrite_a=True)), # no crash
  2218. (svd, (R,), dict(overwrite_a=True)), # no crash
  2219. (svd, (S,), dict(overwrite_a=True)), # crash
  2220. (svdvals, (S,), dict()), # no crash
  2221. (svdvals, (S,), dict(overwrite_a=True)), # crash
  2222. (cholesky, (M,), dict(overwrite_a=True)), # no crash
  2223. (qr, (S,), dict(overwrite_a=True)), # crash
  2224. (rq, (S,), dict(overwrite_a=True)), # crash
  2225. (hessenberg, (S,), dict(overwrite_a=True)), # crash
  2226. (schur, (S,), dict(overwrite_a=True)), # crash
  2227. ]:
  2228. check_lapack_misaligned(func, args, kwargs)
  2229. # not properly tested
  2230. # cholesky, rsf2csf, lu_solve, solve, eig_banded, eigvals_banded, eigh, diagsvd
  2231. class TestOverwrite:
  2232. def test_eig(self):
  2233. assert_no_overwrite(eig, [(3, 3)])
  2234. assert_no_overwrite(eig, [(3, 3), (3, 3)])
  2235. def test_eigh(self):
  2236. assert_no_overwrite(eigh, [(3, 3)])
  2237. assert_no_overwrite(eigh, [(3, 3), (3, 3)])
  2238. def test_eig_banded(self):
  2239. assert_no_overwrite(eig_banded, [(3, 2)])
  2240. def test_eigvals(self):
  2241. assert_no_overwrite(eigvals, [(3, 3)])
  2242. def test_eigvalsh(self):
  2243. assert_no_overwrite(eigvalsh, [(3, 3)])
  2244. def test_eigvals_banded(self):
  2245. assert_no_overwrite(eigvals_banded, [(3, 2)])
  2246. def test_hessenberg(self):
  2247. assert_no_overwrite(hessenberg, [(3, 3)])
  2248. def test_lu_factor(self):
  2249. assert_no_overwrite(lu_factor, [(3, 3)])
  2250. def test_lu_solve(self):
  2251. x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 8]])
  2252. xlu = lu_factor(x)
  2253. assert_no_overwrite(lambda b: lu_solve(xlu, b), [(3,)])
  2254. def test_lu(self):
  2255. assert_no_overwrite(lu, [(3, 3)])
  2256. def test_qr(self):
  2257. assert_no_overwrite(qr, [(3, 3)])
  2258. def test_rq(self):
  2259. assert_no_overwrite(rq, [(3, 3)])
  2260. def test_schur(self):
  2261. assert_no_overwrite(schur, [(3, 3)])
  2262. def test_schur_complex(self):
  2263. assert_no_overwrite(lambda a: schur(a, 'complex'), [(3, 3)],
  2264. dtypes=[np.float32, np.float64])
  2265. def test_svd(self):
  2266. assert_no_overwrite(svd, [(3, 3)])
  2267. assert_no_overwrite(lambda a: svd(a, lapack_driver='gesvd'), [(3, 3)])
  2268. def test_svdvals(self):
  2269. assert_no_overwrite(svdvals, [(3, 3)])
  2270. def _check_orth(n, dtype, skip_big=False):
  2271. X = np.ones((n, 2), dtype=float).astype(dtype)
  2272. eps = np.finfo(dtype).eps
  2273. tol = 1000 * eps
  2274. Y = orth(X)
  2275. assert_equal(Y.shape, (n, 1))
  2276. assert_allclose(Y, Y.mean(), atol=tol)
  2277. Y = orth(X.T)
  2278. assert_equal(Y.shape, (2, 1))
  2279. assert_allclose(Y, Y.mean(), atol=tol)
  2280. if n > 5 and not skip_big:
  2281. np.random.seed(1)
  2282. X = np.random.rand(n, 5) @ np.random.rand(5, n)
  2283. X = X + 1e-4 * np.random.rand(n, 1) @ np.random.rand(1, n)
  2284. X = X.astype(dtype)
  2285. Y = orth(X, rcond=1e-3)
  2286. assert_equal(Y.shape, (n, 5))
  2287. Y = orth(X, rcond=1e-6)
  2288. assert_equal(Y.shape, (n, 5 + 1))
  2289. @pytest.mark.slow
  2290. @pytest.mark.skipif(np.dtype(np.intp).itemsize < 8,
  2291. reason="test only on 64-bit, else too slow")
  2292. def test_orth_memory_efficiency():
  2293. # Pick n so that 16*n bytes is reasonable but 8*n*n bytes is unreasonable.
  2294. # Keep in mind that @pytest.mark.slow tests are likely to be running
  2295. # under configurations that support 4Gb+ memory for tests related to
  2296. # 32 bit overflow.
  2297. n = 10*1000*1000
  2298. try:
  2299. _check_orth(n, np.float64, skip_big=True)
  2300. except MemoryError as e:
  2301. raise AssertionError(
  2302. 'memory error perhaps caused by orth regression'
  2303. ) from e
  2304. def test_orth():
  2305. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  2306. sizes = [1, 2, 3, 10, 100]
  2307. for dt, n in itertools.product(dtypes, sizes):
  2308. _check_orth(n, dt)
  2309. def test_null_space():
  2310. np.random.seed(1)
  2311. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  2312. sizes = [1, 2, 3, 10, 100]
  2313. for dt, n in itertools.product(dtypes, sizes):
  2314. X = np.ones((2, n), dtype=dt)
  2315. eps = np.finfo(dt).eps
  2316. tol = 1000 * eps
  2317. Y = null_space(X)
  2318. assert_equal(Y.shape, (n, n-1))
  2319. assert_allclose(X @ Y, 0, atol=tol)
  2320. Y = null_space(X.T)
  2321. assert_equal(Y.shape, (2, 1))
  2322. assert_allclose(X.T @ Y, 0, atol=tol)
  2323. X = np.random.randn(1 + n//2, n)
  2324. Y = null_space(X)
  2325. assert_equal(Y.shape, (n, n - 1 - n//2))
  2326. assert_allclose(X @ Y, 0, atol=tol)
  2327. if n > 5:
  2328. np.random.seed(1)
  2329. X = np.random.rand(n, 5) @ np.random.rand(5, n)
  2330. X = X + 1e-4 * np.random.rand(n, 1) @ np.random.rand(1, n)
  2331. X = X.astype(dt)
  2332. Y = null_space(X, rcond=1e-3)
  2333. assert_equal(Y.shape, (n, n - 5))
  2334. Y = null_space(X, rcond=1e-6)
  2335. assert_equal(Y.shape, (n, n - 6))
  2336. def test_subspace_angles():
  2337. H = hadamard(8, float)
  2338. A = H[:, :3]
  2339. B = H[:, 3:]
  2340. assert_allclose(subspace_angles(A, B), [np.pi / 2.] * 3, atol=1e-14)
  2341. assert_allclose(subspace_angles(B, A), [np.pi / 2.] * 3, atol=1e-14)
  2342. for x in (A, B):
  2343. assert_allclose(subspace_angles(x, x), np.zeros(x.shape[1]),
  2344. atol=1e-14)
  2345. # From MATLAB function "subspace", which effectively only returns the
  2346. # last value that we calculate
  2347. x = np.array(
  2348. [[0.537667139546100, 0.318765239858981, 3.578396939725760, 0.725404224946106], # noqa: E501
  2349. [1.833885014595086, -1.307688296305273, 2.769437029884877, -0.063054873189656], # noqa: E501
  2350. [-2.258846861003648, -0.433592022305684, -1.349886940156521, 0.714742903826096], # noqa: E501
  2351. [0.862173320368121, 0.342624466538650, 3.034923466331855, -0.204966058299775]]) # noqa: E501
  2352. expected = 1.481454682101605
  2353. assert_allclose(subspace_angles(x[:, :2], x[:, 2:])[0], expected,
  2354. rtol=1e-12)
  2355. assert_allclose(subspace_angles(x[:, 2:], x[:, :2])[0], expected,
  2356. rtol=1e-12)
  2357. expected = 0.746361174247302
  2358. assert_allclose(subspace_angles(x[:, :2], x[:, [2]]), expected, rtol=1e-12)
  2359. assert_allclose(subspace_angles(x[:, [2]], x[:, :2]), expected, rtol=1e-12)
  2360. expected = 0.487163718534313
  2361. assert_allclose(subspace_angles(x[:, :3], x[:, [3]]), expected, rtol=1e-12)
  2362. assert_allclose(subspace_angles(x[:, [3]], x[:, :3]), expected, rtol=1e-12)
  2363. expected = 0.328950515907756
  2364. assert_allclose(subspace_angles(x[:, :2], x[:, 1:]), [expected, 0],
  2365. atol=1e-12)
  2366. # Degenerate conditions
  2367. assert_raises(ValueError, subspace_angles, x[0], x)
  2368. assert_raises(ValueError, subspace_angles, x, x[0])
  2369. assert_raises(ValueError, subspace_angles, x[:-1], x)
  2370. # Test branch if mask.any is True:
  2371. A = np.array([[1, 0, 0],
  2372. [0, 1, 0],
  2373. [0, 0, 1],
  2374. [0, 0, 0],
  2375. [0, 0, 0]])
  2376. B = np.array([[1, 0, 0],
  2377. [0, 1, 0],
  2378. [0, 0, 0],
  2379. [0, 0, 0],
  2380. [0, 0, 1]])
  2381. expected = np.array([np.pi/2, 0, 0])
  2382. assert_allclose(subspace_angles(A, B), expected, rtol=1e-12)
  2383. # Complex
  2384. # second column in "b" does not affect result, just there so that
  2385. # b can have more cols than a, and vice-versa (both conditional code paths)
  2386. a = [[1 + 1j], [0]]
  2387. b = [[1 - 1j, 0], [0, 1]]
  2388. assert_allclose(subspace_angles(a, b), 0., atol=1e-14)
  2389. assert_allclose(subspace_angles(b, a), 0., atol=1e-14)
  2390. class TestCDF2RDF:
  2391. def matmul(self, a, b):
  2392. return np.einsum('...ij,...jk->...ik', a, b)
  2393. def assert_eig_valid(self, w, v, x):
  2394. assert_array_almost_equal(
  2395. self.matmul(v, w),
  2396. self.matmul(x, v)
  2397. )
  2398. def test_single_array0x0real(self):
  2399. # eig doesn't support 0x0 in old versions of numpy
  2400. X = np.empty((0, 0))
  2401. w, v = np.empty(0), np.empty((0, 0))
  2402. wr, vr = cdf2rdf(w, v)
  2403. self.assert_eig_valid(wr, vr, X)
  2404. def test_single_array2x2_real(self):
  2405. X = np.array([[1, 2], [3, -1]])
  2406. w, v = np.linalg.eig(X)
  2407. wr, vr = cdf2rdf(w, v)
  2408. self.assert_eig_valid(wr, vr, X)
  2409. def test_single_array2x2_complex(self):
  2410. X = np.array([[1, 2], [-2, 1]])
  2411. w, v = np.linalg.eig(X)
  2412. wr, vr = cdf2rdf(w, v)
  2413. self.assert_eig_valid(wr, vr, X)
  2414. def test_single_array3x3_real(self):
  2415. X = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6]])
  2416. w, v = np.linalg.eig(X)
  2417. wr, vr = cdf2rdf(w, v)
  2418. self.assert_eig_valid(wr, vr, X)
  2419. def test_single_array3x3_complex(self):
  2420. X = np.array([[1, 2, 3], [0, 4, 5], [0, -5, 4]])
  2421. w, v = np.linalg.eig(X)
  2422. wr, vr = cdf2rdf(w, v)
  2423. self.assert_eig_valid(wr, vr, X)
  2424. def test_random_1d_stacked_arrays(self):
  2425. # cannot test M == 0 due to bug in old numpy
  2426. for M in range(1, 7):
  2427. np.random.seed(999999999)
  2428. X = np.random.rand(100, M, M)
  2429. w, v = np.linalg.eig(X)
  2430. wr, vr = cdf2rdf(w, v)
  2431. self.assert_eig_valid(wr, vr, X)
  2432. def test_random_2d_stacked_arrays(self):
  2433. # cannot test M == 0 due to bug in old numpy
  2434. for M in range(1, 7):
  2435. X = np.random.rand(10, 10, M, M)
  2436. w, v = np.linalg.eig(X)
  2437. wr, vr = cdf2rdf(w, v)
  2438. self.assert_eig_valid(wr, vr, X)
  2439. def test_low_dimensionality_error(self):
  2440. w, v = np.empty(()), np.array((2,))
  2441. assert_raises(ValueError, cdf2rdf, w, v)
  2442. def test_not_square_error(self):
  2443. # Check that passing a non-square array raises a ValueError.
  2444. w, v = np.arange(3), np.arange(6).reshape(3, 2)
  2445. assert_raises(ValueError, cdf2rdf, w, v)
  2446. def test_swapped_v_w_error(self):
  2447. # Check that exchanging places of w and v raises ValueError.
  2448. X = np.array([[1, 2, 3], [0, 4, 5], [0, -5, 4]])
  2449. w, v = np.linalg.eig(X)
  2450. assert_raises(ValueError, cdf2rdf, v, w)
  2451. def test_non_associated_error(self):
  2452. # Check that passing non-associated eigenvectors raises a ValueError.
  2453. w, v = np.arange(3), np.arange(16).reshape(4, 4)
  2454. assert_raises(ValueError, cdf2rdf, w, v)
  2455. def test_not_conjugate_pairs(self):
  2456. # Check that passing non-conjugate pairs raises a ValueError.
  2457. X = np.array([[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]])
  2458. w, v = np.linalg.eig(X)
  2459. assert_raises(ValueError, cdf2rdf, w, v)
  2460. # different arrays in the stack, so not conjugate
  2461. X = np.array([
  2462. [[1, 2, 3], [1, 2, 3], [2, 5, 6+1j]],
  2463. [[1, 2, 3], [1, 2, 3], [2, 5, 6-1j]],
  2464. ])
  2465. w, v = np.linalg.eig(X)
  2466. assert_raises(ValueError, cdf2rdf, w, v)