test_lapack.py 122 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282
  1. #
  2. # Created by: Pearu Peterson, September 2002
  3. #
  4. import sys
  5. from functools import reduce
  6. from numpy.testing import (assert_equal, assert_array_almost_equal, assert_,
  7. assert_allclose, assert_almost_equal,
  8. assert_array_equal)
  9. import pytest
  10. from pytest import raises as assert_raises
  11. import numpy as np
  12. from numpy import (eye, ones, zeros, zeros_like, triu, tril, tril_indices,
  13. triu_indices)
  14. from numpy.random import rand, randint, seed
  15. from scipy.linalg import (_flapack as flapack, lapack, inv, svd, cholesky,
  16. solve, ldl, norm, block_diag, qr, eigh)
  17. from scipy.linalg.lapack import _compute_lwork
  18. from scipy.stats import ortho_group, unitary_group
  19. import scipy.sparse as sps
  20. try:
  21. from scipy.linalg import _clapack as clapack
  22. except ImportError:
  23. clapack = None
  24. from scipy.linalg.lapack import get_lapack_funcs
  25. from scipy.linalg.blas import get_blas_funcs
  26. REAL_DTYPES = [np.float32, np.float64]
  27. COMPLEX_DTYPES = [np.complex64, np.complex128]
  28. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  29. def generate_random_dtype_array(shape, dtype):
  30. # generates a random matrix of desired data type of shape
  31. if dtype in COMPLEX_DTYPES:
  32. return (np.random.rand(*shape)
  33. + np.random.rand(*shape)*1.0j).astype(dtype)
  34. return np.random.rand(*shape).astype(dtype)
  35. def test_lapack_documented():
  36. """Test that all entries are in the doc."""
  37. if lapack.__doc__ is None: # just in case there is a python -OO
  38. pytest.skip('lapack.__doc__ is None')
  39. names = set(lapack.__doc__.split())
  40. ignore_list = set([
  41. 'absolute_import', 'clapack', 'division', 'find_best_lapack_type',
  42. 'flapack', 'print_function', 'HAS_ILP64',
  43. ])
  44. missing = list()
  45. for name in dir(lapack):
  46. if (not name.startswith('_') and name not in ignore_list and
  47. name not in names):
  48. missing.append(name)
  49. assert missing == [], 'Name(s) missing from lapack.__doc__ or ignore_list'
  50. class TestFlapackSimple:
  51. def test_gebal(self):
  52. a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  53. a1 = [[1, 0, 0, 3e-4],
  54. [4, 0, 0, 2e-3],
  55. [7, 1, 0, 0],
  56. [0, 1, 0, 0]]
  57. for p in 'sdzc':
  58. f = getattr(flapack, p+'gebal', None)
  59. if f is None:
  60. continue
  61. ba, lo, hi, pivscale, info = f(a)
  62. assert_(not info, repr(info))
  63. assert_array_almost_equal(ba, a)
  64. assert_equal((lo, hi), (0, len(a[0])-1))
  65. assert_array_almost_equal(pivscale, np.ones(len(a)))
  66. ba, lo, hi, pivscale, info = f(a1, permute=1, scale=1)
  67. assert_(not info, repr(info))
  68. # print(a1)
  69. # print(ba, lo, hi, pivscale)
  70. def test_gehrd(self):
  71. a = [[-149, -50, -154],
  72. [537, 180, 546],
  73. [-27, -9, -25]]
  74. for p in 'd':
  75. f = getattr(flapack, p+'gehrd', None)
  76. if f is None:
  77. continue
  78. ht, tau, info = f(a)
  79. assert_(not info, repr(info))
  80. def test_trsyl(self):
  81. a = np.array([[1, 2], [0, 4]])
  82. b = np.array([[5, 6], [0, 8]])
  83. c = np.array([[9, 10], [11, 12]])
  84. trans = 'T'
  85. # Test single and double implementations, including most
  86. # of the options
  87. for dtype in 'fdFD':
  88. a1, b1, c1 = a.astype(dtype), b.astype(dtype), c.astype(dtype)
  89. trsyl, = get_lapack_funcs(('trsyl',), (a1,))
  90. if dtype.isupper(): # is complex dtype
  91. a1[0] += 1j
  92. trans = 'C'
  93. x, scale, info = trsyl(a1, b1, c1)
  94. assert_array_almost_equal(np.dot(a1, x) + np.dot(x, b1),
  95. scale * c1)
  96. x, scale, info = trsyl(a1, b1, c1, trana=trans, tranb=trans)
  97. assert_array_almost_equal(
  98. np.dot(a1.conjugate().T, x) + np.dot(x, b1.conjugate().T),
  99. scale * c1, decimal=4)
  100. x, scale, info = trsyl(a1, b1, c1, isgn=-1)
  101. assert_array_almost_equal(np.dot(a1, x) - np.dot(x, b1),
  102. scale * c1, decimal=4)
  103. def test_lange(self):
  104. a = np.array([
  105. [-149, -50, -154],
  106. [537, 180, 546],
  107. [-27, -9, -25]])
  108. for dtype in 'fdFD':
  109. for norm_str in 'Mm1OoIiFfEe':
  110. a1 = a.astype(dtype)
  111. if dtype.isupper():
  112. # is complex dtype
  113. a1[0, 0] += 1j
  114. lange, = get_lapack_funcs(('lange',), (a1,))
  115. value = lange(norm_str, a1)
  116. if norm_str in 'FfEe':
  117. if dtype in 'Ff':
  118. decimal = 3
  119. else:
  120. decimal = 7
  121. ref = np.sqrt(np.sum(np.square(np.abs(a1))))
  122. assert_almost_equal(value, ref, decimal)
  123. else:
  124. if norm_str in 'Mm':
  125. ref = np.max(np.abs(a1))
  126. elif norm_str in '1Oo':
  127. ref = np.max(np.sum(np.abs(a1), axis=0))
  128. elif norm_str in 'Ii':
  129. ref = np.max(np.sum(np.abs(a1), axis=1))
  130. assert_equal(value, ref)
  131. class TestLapack:
  132. def test_flapack(self):
  133. if hasattr(flapack, 'empty_module'):
  134. # flapack module is empty
  135. pass
  136. def test_clapack(self):
  137. if hasattr(clapack, 'empty_module'):
  138. # clapack module is empty
  139. pass
  140. class TestLeastSquaresSolvers:
  141. def test_gels(self):
  142. seed(1234)
  143. # Test fat/tall matrix argument handling - gh-issue #8329
  144. for ind, dtype in enumerate(DTYPES):
  145. m = 10
  146. n = 20
  147. nrhs = 1
  148. a1 = rand(m, n).astype(dtype)
  149. b1 = rand(n).astype(dtype)
  150. gls, glslw = get_lapack_funcs(('gels', 'gels_lwork'), dtype=dtype)
  151. # Request of sizes
  152. lwork = _compute_lwork(glslw, m, n, nrhs)
  153. _, _, info = gls(a1, b1, lwork=lwork)
  154. assert_(info >= 0)
  155. _, _, info = gls(a1, b1, trans='TTCC'[ind], lwork=lwork)
  156. assert_(info >= 0)
  157. for dtype in REAL_DTYPES:
  158. a1 = np.array([[1.0, 2.0],
  159. [4.0, 5.0],
  160. [7.0, 8.0]], dtype=dtype)
  161. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  162. gels, gels_lwork, geqrf = get_lapack_funcs(
  163. ('gels', 'gels_lwork', 'geqrf'), (a1, b1))
  164. m, n = a1.shape
  165. if len(b1.shape) == 2:
  166. nrhs = b1.shape[1]
  167. else:
  168. nrhs = 1
  169. # Request of sizes
  170. lwork = _compute_lwork(gels_lwork, m, n, nrhs)
  171. lqr, x, info = gels(a1, b1, lwork=lwork)
  172. assert_allclose(x[:-1], np.array([-14.333333333333323,
  173. 14.999999999999991],
  174. dtype=dtype),
  175. rtol=25*np.finfo(dtype).eps)
  176. lqr_truth, _, _, _ = geqrf(a1)
  177. assert_array_equal(lqr, lqr_truth)
  178. for dtype in COMPLEX_DTYPES:
  179. a1 = np.array([[1.0+4.0j, 2.0],
  180. [4.0+0.5j, 5.0-3.0j],
  181. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  182. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  183. gels, gels_lwork, geqrf = get_lapack_funcs(
  184. ('gels', 'gels_lwork', 'geqrf'), (a1, b1))
  185. m, n = a1.shape
  186. if len(b1.shape) == 2:
  187. nrhs = b1.shape[1]
  188. else:
  189. nrhs = 1
  190. # Request of sizes
  191. lwork = _compute_lwork(gels_lwork, m, n, nrhs)
  192. lqr, x, info = gels(a1, b1, lwork=lwork)
  193. assert_allclose(x[:-1],
  194. np.array([1.161753632288328-1.901075709391912j,
  195. 1.735882340522193+1.521240901196909j],
  196. dtype=dtype), rtol=25*np.finfo(dtype).eps)
  197. lqr_truth, _, _, _ = geqrf(a1)
  198. assert_array_equal(lqr, lqr_truth)
  199. def test_gelsd(self):
  200. for dtype in REAL_DTYPES:
  201. a1 = np.array([[1.0, 2.0],
  202. [4.0, 5.0],
  203. [7.0, 8.0]], dtype=dtype)
  204. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  205. gelsd, gelsd_lwork = get_lapack_funcs(('gelsd', 'gelsd_lwork'),
  206. (a1, b1))
  207. m, n = a1.shape
  208. if len(b1.shape) == 2:
  209. nrhs = b1.shape[1]
  210. else:
  211. nrhs = 1
  212. # Request of sizes
  213. work, iwork, info = gelsd_lwork(m, n, nrhs, -1)
  214. lwork = int(np.real(work))
  215. iwork_size = iwork
  216. x, s, rank, info = gelsd(a1, b1, lwork, iwork_size,
  217. -1, False, False)
  218. assert_allclose(x[:-1], np.array([-14.333333333333323,
  219. 14.999999999999991],
  220. dtype=dtype),
  221. rtol=25*np.finfo(dtype).eps)
  222. assert_allclose(s, np.array([12.596017180511966,
  223. 0.583396253199685], dtype=dtype),
  224. rtol=25*np.finfo(dtype).eps)
  225. for dtype in COMPLEX_DTYPES:
  226. a1 = np.array([[1.0+4.0j, 2.0],
  227. [4.0+0.5j, 5.0-3.0j],
  228. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  229. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  230. gelsd, gelsd_lwork = get_lapack_funcs(('gelsd', 'gelsd_lwork'),
  231. (a1, b1))
  232. m, n = a1.shape
  233. if len(b1.shape) == 2:
  234. nrhs = b1.shape[1]
  235. else:
  236. nrhs = 1
  237. # Request of sizes
  238. work, rwork, iwork, info = gelsd_lwork(m, n, nrhs, -1)
  239. lwork = int(np.real(work))
  240. rwork_size = int(rwork)
  241. iwork_size = iwork
  242. x, s, rank, info = gelsd(a1, b1, lwork, rwork_size, iwork_size,
  243. -1, False, False)
  244. assert_allclose(x[:-1],
  245. np.array([1.161753632288328-1.901075709391912j,
  246. 1.735882340522193+1.521240901196909j],
  247. dtype=dtype), rtol=25*np.finfo(dtype).eps)
  248. assert_allclose(s,
  249. np.array([13.035514762572043, 4.337666985231382],
  250. dtype=dtype), rtol=25*np.finfo(dtype).eps)
  251. def test_gelss(self):
  252. for dtype in REAL_DTYPES:
  253. a1 = np.array([[1.0, 2.0],
  254. [4.0, 5.0],
  255. [7.0, 8.0]], dtype=dtype)
  256. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  257. gelss, gelss_lwork = get_lapack_funcs(('gelss', 'gelss_lwork'),
  258. (a1, b1))
  259. m, n = a1.shape
  260. if len(b1.shape) == 2:
  261. nrhs = b1.shape[1]
  262. else:
  263. nrhs = 1
  264. # Request of sizes
  265. work, info = gelss_lwork(m, n, nrhs, -1)
  266. lwork = int(np.real(work))
  267. v, x, s, rank, work, info = gelss(a1, b1, -1, lwork, False, False)
  268. assert_allclose(x[:-1], np.array([-14.333333333333323,
  269. 14.999999999999991],
  270. dtype=dtype),
  271. rtol=25*np.finfo(dtype).eps)
  272. assert_allclose(s, np.array([12.596017180511966,
  273. 0.583396253199685], dtype=dtype),
  274. rtol=25*np.finfo(dtype).eps)
  275. for dtype in COMPLEX_DTYPES:
  276. a1 = np.array([[1.0+4.0j, 2.0],
  277. [4.0+0.5j, 5.0-3.0j],
  278. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  279. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  280. gelss, gelss_lwork = get_lapack_funcs(('gelss', 'gelss_lwork'),
  281. (a1, b1))
  282. m, n = a1.shape
  283. if len(b1.shape) == 2:
  284. nrhs = b1.shape[1]
  285. else:
  286. nrhs = 1
  287. # Request of sizes
  288. work, info = gelss_lwork(m, n, nrhs, -1)
  289. lwork = int(np.real(work))
  290. v, x, s, rank, work, info = gelss(a1, b1, -1, lwork, False, False)
  291. assert_allclose(x[:-1],
  292. np.array([1.161753632288328-1.901075709391912j,
  293. 1.735882340522193+1.521240901196909j],
  294. dtype=dtype),
  295. rtol=25*np.finfo(dtype).eps)
  296. assert_allclose(s, np.array([13.035514762572043,
  297. 4.337666985231382], dtype=dtype),
  298. rtol=25*np.finfo(dtype).eps)
  299. def test_gelsy(self):
  300. for dtype in REAL_DTYPES:
  301. a1 = np.array([[1.0, 2.0],
  302. [4.0, 5.0],
  303. [7.0, 8.0]], dtype=dtype)
  304. b1 = np.array([16.0, 17.0, 20.0], dtype=dtype)
  305. gelsy, gelsy_lwork = get_lapack_funcs(('gelsy', 'gelss_lwork'),
  306. (a1, b1))
  307. m, n = a1.shape
  308. if len(b1.shape) == 2:
  309. nrhs = b1.shape[1]
  310. else:
  311. nrhs = 1
  312. # Request of sizes
  313. work, info = gelsy_lwork(m, n, nrhs, 10*np.finfo(dtype).eps)
  314. lwork = int(np.real(work))
  315. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  316. v, x, j, rank, info = gelsy(a1, b1, jptv, np.finfo(dtype).eps,
  317. lwork, False, False)
  318. assert_allclose(x[:-1], np.array([-14.333333333333323,
  319. 14.999999999999991],
  320. dtype=dtype),
  321. rtol=25*np.finfo(dtype).eps)
  322. for dtype in COMPLEX_DTYPES:
  323. a1 = np.array([[1.0+4.0j, 2.0],
  324. [4.0+0.5j, 5.0-3.0j],
  325. [7.0-2.0j, 8.0+0.7j]], dtype=dtype)
  326. b1 = np.array([16.0, 17.0+2.0j, 20.0-4.0j], dtype=dtype)
  327. gelsy, gelsy_lwork = get_lapack_funcs(('gelsy', 'gelss_lwork'),
  328. (a1, b1))
  329. m, n = a1.shape
  330. if len(b1.shape) == 2:
  331. nrhs = b1.shape[1]
  332. else:
  333. nrhs = 1
  334. # Request of sizes
  335. work, info = gelsy_lwork(m, n, nrhs, 10*np.finfo(dtype).eps)
  336. lwork = int(np.real(work))
  337. jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
  338. v, x, j, rank, info = gelsy(a1, b1, jptv, np.finfo(dtype).eps,
  339. lwork, False, False)
  340. assert_allclose(x[:-1],
  341. np.array([1.161753632288328-1.901075709391912j,
  342. 1.735882340522193+1.521240901196909j],
  343. dtype=dtype),
  344. rtol=25*np.finfo(dtype).eps)
  345. @pytest.mark.parametrize('dtype', DTYPES)
  346. @pytest.mark.parametrize('shape', [(3, 4), (5, 2), (2**18, 2**18)])
  347. def test_geqrf_lwork(dtype, shape):
  348. geqrf_lwork = get_lapack_funcs(('geqrf_lwork'), dtype=dtype)
  349. m, n = shape
  350. lwork, info = geqrf_lwork(m=m, n=n)
  351. assert_equal(info, 0)
  352. class TestRegression:
  353. def test_ticket_1645(self):
  354. # Check that RQ routines have correct lwork
  355. for dtype in DTYPES:
  356. a = np.zeros((300, 2), dtype=dtype)
  357. gerqf, = get_lapack_funcs(['gerqf'], [a])
  358. assert_raises(Exception, gerqf, a, lwork=2)
  359. rq, tau, work, info = gerqf(a)
  360. if dtype in REAL_DTYPES:
  361. orgrq, = get_lapack_funcs(['orgrq'], [a])
  362. assert_raises(Exception, orgrq, rq[-2:], tau, lwork=1)
  363. orgrq(rq[-2:], tau, lwork=2)
  364. elif dtype in COMPLEX_DTYPES:
  365. ungrq, = get_lapack_funcs(['ungrq'], [a])
  366. assert_raises(Exception, ungrq, rq[-2:], tau, lwork=1)
  367. ungrq(rq[-2:], tau, lwork=2)
  368. class TestDpotr:
  369. def test_gh_2691(self):
  370. # 'lower' argument of dportf/dpotri
  371. for lower in [True, False]:
  372. for clean in [True, False]:
  373. np.random.seed(42)
  374. x = np.random.normal(size=(3, 3))
  375. a = x.dot(x.T)
  376. dpotrf, dpotri = get_lapack_funcs(("potrf", "potri"), (a, ))
  377. c, info = dpotrf(a, lower, clean=clean)
  378. dpt = dpotri(c, lower)[0]
  379. if lower:
  380. assert_allclose(np.tril(dpt), np.tril(inv(a)))
  381. else:
  382. assert_allclose(np.triu(dpt), np.triu(inv(a)))
  383. class TestDlasd4:
  384. def test_sing_val_update(self):
  385. sigmas = np.array([4., 3., 2., 0])
  386. m_vec = np.array([3.12, 5.7, -4.8, -2.2])
  387. M = np.hstack((np.vstack((np.diag(sigmas[0:-1]),
  388. np.zeros((1, len(m_vec) - 1)))),
  389. m_vec[:, np.newaxis]))
  390. SM = svd(M, full_matrices=False, compute_uv=False, overwrite_a=False,
  391. check_finite=False)
  392. it_len = len(sigmas)
  393. sgm = np.concatenate((sigmas[::-1], [sigmas[0] + it_len*norm(m_vec)]))
  394. mvc = np.concatenate((m_vec[::-1], (0,)))
  395. lasd4 = get_lapack_funcs('lasd4', (sigmas,))
  396. roots = []
  397. for i in range(0, it_len):
  398. res = lasd4(i, sgm, mvc)
  399. roots.append(res[1])
  400. assert_((res[3] <= 0), "LAPACK root finding dlasd4 failed to find \
  401. the singular value %i" % i)
  402. roots = np.array(roots)[::-1]
  403. assert_((not np.any(np.isnan(roots)), "There are NaN roots"))
  404. assert_allclose(SM, roots, atol=100*np.finfo(np.float64).eps,
  405. rtol=100*np.finfo(np.float64).eps)
  406. class TestTbtrs:
  407. @pytest.mark.parametrize('dtype', DTYPES)
  408. def test_nag_example_f07vef_f07vsf(self, dtype):
  409. """Test real (f07vef) and complex (f07vsf) examples from NAG
  410. Examples available from:
  411. * https://www.nag.com/numeric/fl/nagdoc_latest/html/f07/f07vef.html
  412. * https://www.nag.com/numeric/fl/nagdoc_latest/html/f07/f07vsf.html
  413. """
  414. if dtype in REAL_DTYPES:
  415. ab = np.array([[-4.16, 4.78, 6.32, 0.16],
  416. [-2.25, 5.86, -4.82, 0]],
  417. dtype=dtype)
  418. b = np.array([[-16.64, -4.16],
  419. [-13.78, -16.59],
  420. [13.10, -4.94],
  421. [-14.14, -9.96]],
  422. dtype=dtype)
  423. x_out = np.array([[4, 1],
  424. [-1, -3],
  425. [3, 2],
  426. [2, -2]],
  427. dtype=dtype)
  428. elif dtype in COMPLEX_DTYPES:
  429. ab = np.array([[-1.94+4.43j, 4.12-4.27j, 0.43-2.66j, 0.44+0.1j],
  430. [-3.39+3.44j, -1.84+5.52j, 1.74 - 0.04j, 0],
  431. [1.62+3.68j, -2.77-1.93j, 0, 0]],
  432. dtype=dtype)
  433. b = np.array([[-8.86 - 3.88j, -24.09 - 5.27j],
  434. [-15.57 - 23.41j, -57.97 + 8.14j],
  435. [-7.63 + 22.78j, 19.09 - 29.51j],
  436. [-14.74 - 2.40j, 19.17 + 21.33j]],
  437. dtype=dtype)
  438. x_out = np.array([[2j, 1 + 5j],
  439. [1 - 3j, -7 - 2j],
  440. [-4.001887 - 4.988417j, 3.026830 + 4.003182j],
  441. [1.996158 - 1.045105j, -6.103357 - 8.986653j]],
  442. dtype=dtype)
  443. else:
  444. raise ValueError(f"Datatype {dtype} not understood.")
  445. tbtrs = get_lapack_funcs(('tbtrs'), dtype=dtype)
  446. x, info = tbtrs(ab=ab, b=b, uplo='L')
  447. assert_equal(info, 0)
  448. assert_allclose(x, x_out, rtol=0, atol=1e-5)
  449. @pytest.mark.parametrize('dtype,trans',
  450. [(dtype, trans)
  451. for dtype in DTYPES for trans in ['N', 'T', 'C']
  452. if not (trans == 'C' and dtype in REAL_DTYPES)])
  453. @pytest.mark.parametrize('uplo', ['U', 'L'])
  454. @pytest.mark.parametrize('diag', ['N', 'U'])
  455. def test_random_matrices(self, dtype, trans, uplo, diag):
  456. seed(1724)
  457. # n, nrhs, kd are used to specify A and b.
  458. # A is of shape n x n with kd super/sub-diagonals
  459. # b is of shape n x nrhs matrix
  460. n, nrhs, kd = 4, 3, 2
  461. tbtrs = get_lapack_funcs('tbtrs', dtype=dtype)
  462. is_upper = (uplo == 'U')
  463. ku = kd * is_upper
  464. kl = kd - ku
  465. # Construct the diagonal and kd super/sub diagonals of A with
  466. # the corresponding offsets.
  467. band_offsets = range(ku, -kl - 1, -1)
  468. band_widths = [n - abs(x) for x in band_offsets]
  469. bands = [generate_random_dtype_array((width,), dtype)
  470. for width in band_widths]
  471. if diag == 'U': # A must be unit triangular
  472. bands[ku] = np.ones(n, dtype=dtype)
  473. # Construct the diagonal banded matrix A from the bands and offsets.
  474. a = sps.diags(bands, band_offsets, format='dia')
  475. # Convert A into banded storage form
  476. ab = np.zeros((kd + 1, n), dtype)
  477. for row, k in enumerate(band_offsets):
  478. ab[row, max(k, 0):min(n+k, n)] = a.diagonal(k)
  479. # The RHS values.
  480. b = generate_random_dtype_array((n, nrhs), dtype)
  481. x, info = tbtrs(ab=ab, b=b, uplo=uplo, trans=trans, diag=diag)
  482. assert_equal(info, 0)
  483. if trans == 'N':
  484. assert_allclose(a @ x, b, rtol=5e-5)
  485. elif trans == 'T':
  486. assert_allclose(a.T @ x, b, rtol=5e-5)
  487. elif trans == 'C':
  488. assert_allclose(a.H @ x, b, rtol=5e-5)
  489. else:
  490. raise ValueError('Invalid trans argument')
  491. @pytest.mark.parametrize('uplo,trans,diag',
  492. [['U', 'N', 'Invalid'],
  493. ['U', 'Invalid', 'N'],
  494. ['Invalid', 'N', 'N']])
  495. def test_invalid_argument_raises_exception(self, uplo, trans, diag):
  496. """Test if invalid values of uplo, trans and diag raise exceptions"""
  497. # Argument checks occur independently of used datatype.
  498. # This mean we must not parameterize all available datatypes.
  499. tbtrs = get_lapack_funcs('tbtrs', dtype=np.float64)
  500. ab = rand(4, 2)
  501. b = rand(2, 4)
  502. assert_raises(Exception, tbtrs, ab, b, uplo, trans, diag)
  503. def test_zero_element_in_diagonal(self):
  504. """Test if a matrix with a zero diagonal element is singular
  505. If the i-th diagonal of A is zero, ?tbtrs should return `i` in `info`
  506. indicating the provided matrix is singular.
  507. Note that ?tbtrs requires the matrix A to be stored in banded form.
  508. In this form the diagonal corresponds to the last row."""
  509. ab = np.ones((3, 4), dtype=float)
  510. b = np.ones(4, dtype=float)
  511. tbtrs = get_lapack_funcs('tbtrs', dtype=float)
  512. ab[-1, 3] = 0
  513. _, info = tbtrs(ab=ab, b=b, uplo='U')
  514. assert_equal(info, 4)
  515. @pytest.mark.parametrize('ldab,n,ldb,nrhs', [
  516. (5, 5, 0, 5),
  517. (5, 5, 3, 5)
  518. ])
  519. def test_invalid_matrix_shapes(self, ldab, n, ldb, nrhs):
  520. """Test ?tbtrs fails correctly if shapes are invalid."""
  521. ab = np.ones((ldab, n), dtype=float)
  522. b = np.ones((ldb, nrhs), dtype=float)
  523. tbtrs = get_lapack_funcs('tbtrs', dtype=float)
  524. assert_raises(Exception, tbtrs, ab, b)
  525. def test_lartg():
  526. for dtype in 'fdFD':
  527. lartg = get_lapack_funcs('lartg', dtype=dtype)
  528. f = np.array(3, dtype)
  529. g = np.array(4, dtype)
  530. if np.iscomplexobj(g):
  531. g *= 1j
  532. cs, sn, r = lartg(f, g)
  533. assert_allclose(cs, 3.0/5.0)
  534. assert_allclose(r, 5.0)
  535. if np.iscomplexobj(g):
  536. assert_allclose(sn, -4.0j/5.0)
  537. assert_(type(r) == complex)
  538. assert_(type(cs) == float)
  539. else:
  540. assert_allclose(sn, 4.0/5.0)
  541. def test_rot():
  542. # srot, drot from blas and crot and zrot from lapack.
  543. for dtype in 'fdFD':
  544. c = 0.6
  545. s = 0.8
  546. u = np.full(4, 3, dtype)
  547. v = np.full(4, 4, dtype)
  548. atol = 10**-(np.finfo(dtype).precision-1)
  549. if dtype in 'fd':
  550. rot = get_blas_funcs('rot', dtype=dtype)
  551. f = 4
  552. else:
  553. rot = get_lapack_funcs('rot', dtype=dtype)
  554. s *= -1j
  555. v *= 1j
  556. f = 4j
  557. assert_allclose(rot(u, v, c, s), [[5, 5, 5, 5],
  558. [0, 0, 0, 0]], atol=atol)
  559. assert_allclose(rot(u, v, c, s, n=2), [[5, 5, 3, 3],
  560. [0, 0, f, f]], atol=atol)
  561. assert_allclose(rot(u, v, c, s, offx=2, offy=2),
  562. [[3, 3, 5, 5], [f, f, 0, 0]], atol=atol)
  563. assert_allclose(rot(u, v, c, s, incx=2, offy=2, n=2),
  564. [[5, 3, 5, 3], [f, f, 0, 0]], atol=atol)
  565. assert_allclose(rot(u, v, c, s, offx=2, incy=2, n=2),
  566. [[3, 3, 5, 5], [0, f, 0, f]], atol=atol)
  567. assert_allclose(rot(u, v, c, s, offx=2, incx=2, offy=2, incy=2, n=1),
  568. [[3, 3, 5, 3], [f, f, 0, f]], atol=atol)
  569. assert_allclose(rot(u, v, c, s, incx=-2, incy=-2, n=2),
  570. [[5, 3, 5, 3], [0, f, 0, f]], atol=atol)
  571. a, b = rot(u, v, c, s, overwrite_x=1, overwrite_y=1)
  572. assert_(a is u)
  573. assert_(b is v)
  574. assert_allclose(a, [5, 5, 5, 5], atol=atol)
  575. assert_allclose(b, [0, 0, 0, 0], atol=atol)
  576. def test_larfg_larf():
  577. np.random.seed(1234)
  578. a0 = np.random.random((4, 4))
  579. a0 = a0.T.dot(a0)
  580. a0j = np.random.random((4, 4)) + 1j*np.random.random((4, 4))
  581. a0j = a0j.T.conj().dot(a0j)
  582. # our test here will be to do one step of reducing a hermetian matrix to
  583. # tridiagonal form using householder transforms.
  584. for dtype in 'fdFD':
  585. larfg, larf = get_lapack_funcs(['larfg', 'larf'], dtype=dtype)
  586. if dtype in 'FD':
  587. a = a0j.copy()
  588. else:
  589. a = a0.copy()
  590. # generate a householder transform to clear a[2:,0]
  591. alpha, x, tau = larfg(a.shape[0]-1, a[1, 0], a[2:, 0])
  592. # create expected output
  593. expected = np.zeros_like(a[:, 0])
  594. expected[0] = a[0, 0]
  595. expected[1] = alpha
  596. # assemble householder vector
  597. v = np.zeros_like(a[1:, 0])
  598. v[0] = 1.0
  599. v[1:] = x
  600. # apply transform from the left
  601. a[1:, :] = larf(v, tau.conjugate(), a[1:, :], np.zeros(a.shape[1]))
  602. # apply transform from the right
  603. a[:, 1:] = larf(v, tau, a[:, 1:], np.zeros(a.shape[0]), side='R')
  604. assert_allclose(a[:, 0], expected, atol=1e-5)
  605. assert_allclose(a[0, :], expected, atol=1e-5)
  606. def test_sgesdd_lwork_bug_workaround():
  607. # Test that SGESDD lwork is sufficiently large for LAPACK.
  608. #
  609. # This checks that _compute_lwork() correctly works around a bug in
  610. # LAPACK versions older than 3.10.1.
  611. sgesdd_lwork = get_lapack_funcs('gesdd_lwork', dtype=np.float32,
  612. ilp64='preferred')
  613. n = 9537
  614. lwork = _compute_lwork(sgesdd_lwork, n, n,
  615. compute_uv=True, full_matrices=True)
  616. # If we called the Fortran function SGESDD directly with IWORK=-1, the
  617. # LAPACK bug would result in lwork being 272929856, which was too small.
  618. # (The result was returned in a single precision float, which does not
  619. # have sufficient precision to represent the exact integer value that it
  620. # computed internally.) The work-around implemented in _compute_lwork()
  621. # will convert that to 272929888. If we are using LAPACK 3.10.1 or later
  622. # (such as in OpenBLAS 0.3.21 or later), the work-around will return
  623. # 272929920, because it does not know which version of LAPACK is being
  624. # used, so it always applies the correction to whatever it is given. We
  625. # will accept either 272929888 or 272929920.
  626. # Note that the acceptable values are a LAPACK implementation detail.
  627. # If a future version of LAPACK changes how SGESDD works, and therefore
  628. # changes the required LWORK size, the acceptable values might have to
  629. # be updated.
  630. assert lwork == 272929888 or lwork == 272929920
  631. class TestSytrd:
  632. @pytest.mark.parametrize('dtype', REAL_DTYPES)
  633. def test_sytrd_with_zero_dim_array(self, dtype):
  634. # Assert that a 0x0 matrix raises an error
  635. A = np.zeros((0, 0), dtype=dtype)
  636. sytrd = get_lapack_funcs('sytrd', (A,))
  637. assert_raises(ValueError, sytrd, A)
  638. @pytest.mark.parametrize('dtype', REAL_DTYPES)
  639. @pytest.mark.parametrize('n', (1, 3))
  640. def test_sytrd(self, dtype, n):
  641. A = np.zeros((n, n), dtype=dtype)
  642. sytrd, sytrd_lwork = \
  643. get_lapack_funcs(('sytrd', 'sytrd_lwork'), (A,))
  644. # some upper triangular array
  645. A[np.triu_indices_from(A)] = \
  646. np.arange(1, n*(n+1)//2+1, dtype=dtype)
  647. # query lwork
  648. lwork, info = sytrd_lwork(n)
  649. assert_equal(info, 0)
  650. # check lower=1 behavior (shouldn't do much since the matrix is
  651. # upper triangular)
  652. data, d, e, tau, info = sytrd(A, lower=1, lwork=lwork)
  653. assert_equal(info, 0)
  654. assert_allclose(data, A, atol=5*np.finfo(dtype).eps, rtol=1.0)
  655. assert_allclose(d, np.diag(A))
  656. assert_allclose(e, 0.0)
  657. assert_allclose(tau, 0.0)
  658. # and now for the proper test (lower=0 is the default)
  659. data, d, e, tau, info = sytrd(A, lwork=lwork)
  660. assert_equal(info, 0)
  661. # assert Q^T*A*Q = tridiag(e, d, e)
  662. # build tridiagonal matrix
  663. T = np.zeros_like(A, dtype=dtype)
  664. k = np.arange(A.shape[0])
  665. T[k, k] = d
  666. k2 = np.arange(A.shape[0]-1)
  667. T[k2+1, k2] = e
  668. T[k2, k2+1] = e
  669. # build Q
  670. Q = np.eye(n, n, dtype=dtype)
  671. for i in range(n-1):
  672. v = np.zeros(n, dtype=dtype)
  673. v[:i] = data[:i, i+1]
  674. v[i] = 1.0
  675. H = np.eye(n, n, dtype=dtype) - tau[i] * np.outer(v, v)
  676. Q = np.dot(H, Q)
  677. # Make matrix fully symmetric
  678. i_lower = np.tril_indices(n, -1)
  679. A[i_lower] = A.T[i_lower]
  680. QTAQ = np.dot(Q.T, np.dot(A, Q))
  681. # disable rtol here since some values in QTAQ and T are very close
  682. # to 0.
  683. assert_allclose(QTAQ, T, atol=5*np.finfo(dtype).eps, rtol=1.0)
  684. class TestHetrd:
  685. @pytest.mark.parametrize('complex_dtype', COMPLEX_DTYPES)
  686. def test_hetrd_with_zero_dim_array(self, complex_dtype):
  687. # Assert that a 0x0 matrix raises an error
  688. A = np.zeros((0, 0), dtype=complex_dtype)
  689. hetrd = get_lapack_funcs('hetrd', (A,))
  690. assert_raises(ValueError, hetrd, A)
  691. @pytest.mark.parametrize('real_dtype,complex_dtype',
  692. zip(REAL_DTYPES, COMPLEX_DTYPES))
  693. @pytest.mark.parametrize('n', (1, 3))
  694. def test_hetrd(self, n, real_dtype, complex_dtype):
  695. A = np.zeros((n, n), dtype=complex_dtype)
  696. hetrd, hetrd_lwork = \
  697. get_lapack_funcs(('hetrd', 'hetrd_lwork'), (A,))
  698. # some upper triangular array
  699. A[np.triu_indices_from(A)] = (
  700. np.arange(1, n*(n+1)//2+1, dtype=real_dtype)
  701. + 1j * np.arange(1, n*(n+1)//2+1, dtype=real_dtype)
  702. )
  703. np.fill_diagonal(A, np.real(np.diag(A)))
  704. # test query lwork
  705. for x in [0, 1]:
  706. _, info = hetrd_lwork(n, lower=x)
  707. assert_equal(info, 0)
  708. # lwork returns complex which segfaults hetrd call (gh-10388)
  709. # use the safe and recommended option
  710. lwork = _compute_lwork(hetrd_lwork, n)
  711. # check lower=1 behavior (shouldn't do much since the matrix is
  712. # upper triangular)
  713. data, d, e, tau, info = hetrd(A, lower=1, lwork=lwork)
  714. assert_equal(info, 0)
  715. assert_allclose(data, A, atol=5*np.finfo(real_dtype).eps, rtol=1.0)
  716. assert_allclose(d, np.real(np.diag(A)))
  717. assert_allclose(e, 0.0)
  718. assert_allclose(tau, 0.0)
  719. # and now for the proper test (lower=0 is the default)
  720. data, d, e, tau, info = hetrd(A, lwork=lwork)
  721. assert_equal(info, 0)
  722. # assert Q^T*A*Q = tridiag(e, d, e)
  723. # build tridiagonal matrix
  724. T = np.zeros_like(A, dtype=real_dtype)
  725. k = np.arange(A.shape[0], dtype=int)
  726. T[k, k] = d
  727. k2 = np.arange(A.shape[0]-1, dtype=int)
  728. T[k2+1, k2] = e
  729. T[k2, k2+1] = e
  730. # build Q
  731. Q = np.eye(n, n, dtype=complex_dtype)
  732. for i in range(n-1):
  733. v = np.zeros(n, dtype=complex_dtype)
  734. v[:i] = data[:i, i+1]
  735. v[i] = 1.0
  736. H = np.eye(n, n, dtype=complex_dtype) \
  737. - tau[i] * np.outer(v, np.conj(v))
  738. Q = np.dot(H, Q)
  739. # Make matrix fully Hermitian
  740. i_lower = np.tril_indices(n, -1)
  741. A[i_lower] = np.conj(A.T[i_lower])
  742. QHAQ = np.dot(np.conj(Q.T), np.dot(A, Q))
  743. # disable rtol here since some values in QTAQ and T are very close
  744. # to 0.
  745. assert_allclose(
  746. QHAQ, T, atol=10*np.finfo(real_dtype).eps, rtol=1.0
  747. )
  748. def test_gglse():
  749. # Example data taken from NAG manual
  750. for ind, dtype in enumerate(DTYPES):
  751. # DTYPES = <s,d,c,z> gglse
  752. func, func_lwork = get_lapack_funcs(('gglse', 'gglse_lwork'),
  753. dtype=dtype)
  754. lwork = _compute_lwork(func_lwork, m=6, n=4, p=2)
  755. # For <s,d>gglse
  756. if ind < 2:
  757. a = np.array([[-0.57, -1.28, -0.39, 0.25],
  758. [-1.93, 1.08, -0.31, -2.14],
  759. [2.30, 0.24, 0.40, -0.35],
  760. [-1.93, 0.64, -0.66, 0.08],
  761. [0.15, 0.30, 0.15, -2.13],
  762. [-0.02, 1.03, -1.43, 0.50]], dtype=dtype)
  763. c = np.array([-1.50, -2.14, 1.23, -0.54, -1.68, 0.82], dtype=dtype)
  764. d = np.array([0., 0.], dtype=dtype)
  765. # For <s,d>gglse
  766. else:
  767. a = np.array([[0.96-0.81j, -0.03+0.96j, -0.91+2.06j, -0.05+0.41j],
  768. [-0.98+1.98j, -1.20+0.19j, -0.66+0.42j, -0.81+0.56j],
  769. [0.62-0.46j, 1.01+0.02j, 0.63-0.17j, -1.11+0.60j],
  770. [0.37+0.38j, 0.19-0.54j, -0.98-0.36j, 0.22-0.20j],
  771. [0.83+0.51j, 0.20+0.01j, -0.17-0.46j, 1.47+1.59j],
  772. [1.08-0.28j, 0.20-0.12j, -0.07+1.23j, 0.26+0.26j]])
  773. c = np.array([[-2.54+0.09j],
  774. [1.65-2.26j],
  775. [-2.11-3.96j],
  776. [1.82+3.30j],
  777. [-6.41+3.77j],
  778. [2.07+0.66j]])
  779. d = np.zeros(2, dtype=dtype)
  780. b = np.array([[1., 0., -1., 0.], [0., 1., 0., -1.]], dtype=dtype)
  781. _, _, _, result, _ = func(a, b, c, d, lwork=lwork)
  782. if ind < 2:
  783. expected = np.array([0.48904455,
  784. 0.99754786,
  785. 0.48904455,
  786. 0.99754786])
  787. else:
  788. expected = np.array([1.08742917-1.96205783j,
  789. -0.74093902+3.72973919j,
  790. 1.08742917-1.96205759j,
  791. -0.74093896+3.72973895j])
  792. assert_array_almost_equal(result, expected, decimal=4)
  793. def test_sycon_hecon():
  794. seed(1234)
  795. for ind, dtype in enumerate(DTYPES+COMPLEX_DTYPES):
  796. # DTYPES + COMPLEX DTYPES = <s,d,c,z> sycon + <c,z>hecon
  797. n = 10
  798. # For <s,d,c,z>sycon
  799. if ind < 4:
  800. func_lwork = get_lapack_funcs('sytrf_lwork', dtype=dtype)
  801. funcon, functrf = get_lapack_funcs(('sycon', 'sytrf'), dtype=dtype)
  802. A = (rand(n, n)).astype(dtype)
  803. # For <c,z>hecon
  804. else:
  805. func_lwork = get_lapack_funcs('hetrf_lwork', dtype=dtype)
  806. funcon, functrf = get_lapack_funcs(('hecon', 'hetrf'), dtype=dtype)
  807. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  808. # Since sycon only refers to upper/lower part, conj() is safe here.
  809. A = (A + A.conj().T)/2 + 2*np.eye(n, dtype=dtype)
  810. anorm = norm(A, 1)
  811. lwork = _compute_lwork(func_lwork, n)
  812. ldu, ipiv, _ = functrf(A, lwork=lwork, lower=1)
  813. rcond, _ = funcon(a=ldu, ipiv=ipiv, anorm=anorm, lower=1)
  814. # The error is at most 1-fold
  815. assert_(abs(1/rcond - np.linalg.cond(A, p=1))*rcond < 1)
  816. def test_sygst():
  817. seed(1234)
  818. for ind, dtype in enumerate(REAL_DTYPES):
  819. # DTYPES = <s,d> sygst
  820. n = 10
  821. potrf, sygst, syevd, sygvd = get_lapack_funcs(('potrf', 'sygst',
  822. 'syevd', 'sygvd'),
  823. dtype=dtype)
  824. A = rand(n, n).astype(dtype)
  825. A = (A + A.T)/2
  826. # B must be positive definite
  827. B = rand(n, n).astype(dtype)
  828. B = (B + B.T)/2 + 2 * np.eye(n, dtype=dtype)
  829. # Perform eig (sygvd)
  830. eig_gvd, _, info = sygvd(A, B)
  831. assert_(info == 0)
  832. # Convert to std problem potrf
  833. b, info = potrf(B)
  834. assert_(info == 0)
  835. a, info = sygst(A, b)
  836. assert_(info == 0)
  837. eig, _, info = syevd(a)
  838. assert_(info == 0)
  839. assert_allclose(eig, eig_gvd, rtol=1e-4)
  840. def test_hegst():
  841. seed(1234)
  842. for ind, dtype in enumerate(COMPLEX_DTYPES):
  843. # DTYPES = <c,z> hegst
  844. n = 10
  845. potrf, hegst, heevd, hegvd = get_lapack_funcs(('potrf', 'hegst',
  846. 'heevd', 'hegvd'),
  847. dtype=dtype)
  848. A = rand(n, n).astype(dtype) + 1j * rand(n, n).astype(dtype)
  849. A = (A + A.conj().T)/2
  850. # B must be positive definite
  851. B = rand(n, n).astype(dtype) + 1j * rand(n, n).astype(dtype)
  852. B = (B + B.conj().T)/2 + 2 * np.eye(n, dtype=dtype)
  853. # Perform eig (hegvd)
  854. eig_gvd, _, info = hegvd(A, B)
  855. assert_(info == 0)
  856. # Convert to std problem potrf
  857. b, info = potrf(B)
  858. assert_(info == 0)
  859. a, info = hegst(A, b)
  860. assert_(info == 0)
  861. eig, _, info = heevd(a)
  862. assert_(info == 0)
  863. assert_allclose(eig, eig_gvd, rtol=1e-4)
  864. def test_tzrzf():
  865. """
  866. This test performs an RZ decomposition in which an m x n upper trapezoidal
  867. array M (m <= n) is factorized as M = [R 0] * Z where R is upper triangular
  868. and Z is unitary.
  869. """
  870. seed(1234)
  871. m, n = 10, 15
  872. for ind, dtype in enumerate(DTYPES):
  873. tzrzf, tzrzf_lw = get_lapack_funcs(('tzrzf', 'tzrzf_lwork'),
  874. dtype=dtype)
  875. lwork = _compute_lwork(tzrzf_lw, m, n)
  876. if ind < 2:
  877. A = triu(rand(m, n).astype(dtype))
  878. else:
  879. A = triu((rand(m, n) + rand(m, n)*1j).astype(dtype))
  880. # assert wrong shape arg, f2py returns generic error
  881. assert_raises(Exception, tzrzf, A.T)
  882. rz, tau, info = tzrzf(A, lwork=lwork)
  883. # Check success
  884. assert_(info == 0)
  885. # Get Z manually for comparison
  886. R = np.hstack((rz[:, :m], np.zeros((m, n-m), dtype=dtype)))
  887. V = np.hstack((np.eye(m, dtype=dtype), rz[:, m:]))
  888. Id = np.eye(n, dtype=dtype)
  889. ref = [Id-tau[x]*V[[x], :].T.dot(V[[x], :].conj()) for x in range(m)]
  890. Z = reduce(np.dot, ref)
  891. assert_allclose(R.dot(Z) - A, zeros_like(A, dtype=dtype),
  892. atol=10*np.spacing(dtype(1.0).real), rtol=0.)
  893. def test_tfsm():
  894. """
  895. Test for solving a linear system with the coefficient matrix is a
  896. triangular array stored in Full Packed (RFP) format.
  897. """
  898. seed(1234)
  899. for ind, dtype in enumerate(DTYPES):
  900. n = 20
  901. if ind > 1:
  902. A = triu(rand(n, n) + rand(n, n)*1j + eye(n)).astype(dtype)
  903. trans = 'C'
  904. else:
  905. A = triu(rand(n, n) + eye(n)).astype(dtype)
  906. trans = 'T'
  907. trttf, tfttr, tfsm = get_lapack_funcs(('trttf', 'tfttr', 'tfsm'),
  908. dtype=dtype)
  909. Afp, _ = trttf(A)
  910. B = rand(n, 2).astype(dtype)
  911. soln = tfsm(-1, Afp, B)
  912. assert_array_almost_equal(soln, solve(-A, B),
  913. decimal=4 if ind % 2 == 0 else 6)
  914. soln = tfsm(-1, Afp, B, trans=trans)
  915. assert_array_almost_equal(soln, solve(-A.conj().T, B),
  916. decimal=4 if ind % 2 == 0 else 6)
  917. # Make A, unit diagonal
  918. A[np.arange(n), np.arange(n)] = dtype(1.)
  919. soln = tfsm(-1, Afp, B, trans=trans, diag='U')
  920. assert_array_almost_equal(soln, solve(-A.conj().T, B),
  921. decimal=4 if ind % 2 == 0 else 6)
  922. # Change side
  923. B2 = rand(3, n).astype(dtype)
  924. soln = tfsm(-1, Afp, B2, trans=trans, diag='U', side='R')
  925. assert_array_almost_equal(soln, solve(-A, B2.T).conj().T,
  926. decimal=4 if ind % 2 == 0 else 6)
  927. def test_ormrz_unmrz():
  928. """
  929. This test performs a matrix multiplication with an arbitrary m x n matric C
  930. and a unitary matrix Q without explicitly forming the array. The array data
  931. is encoded in the rectangular part of A which is obtained from ?TZRZF. Q
  932. size is inferred by m, n, side keywords.
  933. """
  934. seed(1234)
  935. qm, qn, cn = 10, 15, 15
  936. for ind, dtype in enumerate(DTYPES):
  937. tzrzf, tzrzf_lw = get_lapack_funcs(('tzrzf', 'tzrzf_lwork'),
  938. dtype=dtype)
  939. lwork_rz = _compute_lwork(tzrzf_lw, qm, qn)
  940. if ind < 2:
  941. A = triu(rand(qm, qn).astype(dtype))
  942. C = rand(cn, cn).astype(dtype)
  943. orun_mrz, orun_mrz_lw = get_lapack_funcs(('ormrz', 'ormrz_lwork'),
  944. dtype=dtype)
  945. else:
  946. A = triu((rand(qm, qn) + rand(qm, qn)*1j).astype(dtype))
  947. C = (rand(cn, cn) + rand(cn, cn)*1j).astype(dtype)
  948. orun_mrz, orun_mrz_lw = get_lapack_funcs(('unmrz', 'unmrz_lwork'),
  949. dtype=dtype)
  950. lwork_mrz = _compute_lwork(orun_mrz_lw, cn, cn)
  951. rz, tau, info = tzrzf(A, lwork=lwork_rz)
  952. # Get Q manually for comparison
  953. V = np.hstack((np.eye(qm, dtype=dtype), rz[:, qm:]))
  954. Id = np.eye(qn, dtype=dtype)
  955. ref = [Id-tau[x]*V[[x], :].T.dot(V[[x], :].conj()) for x in range(qm)]
  956. Q = reduce(np.dot, ref)
  957. # Now that we have Q, we can test whether lapack results agree with
  958. # each case of CQ, CQ^H, QC, and QC^H
  959. trans = 'T' if ind < 2 else 'C'
  960. tol = 10*np.spacing(dtype(1.0).real)
  961. cq, info = orun_mrz(rz, tau, C, lwork=lwork_mrz)
  962. assert_(info == 0)
  963. assert_allclose(cq - Q.dot(C), zeros_like(C), atol=tol, rtol=0.)
  964. cq, info = orun_mrz(rz, tau, C, trans=trans, lwork=lwork_mrz)
  965. assert_(info == 0)
  966. assert_allclose(cq - Q.conj().T.dot(C), zeros_like(C), atol=tol,
  967. rtol=0.)
  968. cq, info = orun_mrz(rz, tau, C, side='R', lwork=lwork_mrz)
  969. assert_(info == 0)
  970. assert_allclose(cq - C.dot(Q), zeros_like(C), atol=tol, rtol=0.)
  971. cq, info = orun_mrz(rz, tau, C, side='R', trans=trans, lwork=lwork_mrz)
  972. assert_(info == 0)
  973. assert_allclose(cq - C.dot(Q.conj().T), zeros_like(C), atol=tol,
  974. rtol=0.)
  975. def test_tfttr_trttf():
  976. """
  977. Test conversion routines between the Rectengular Full Packed (RFP) format
  978. and Standard Triangular Array (TR)
  979. """
  980. seed(1234)
  981. for ind, dtype in enumerate(DTYPES):
  982. n = 20
  983. if ind > 1:
  984. A_full = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  985. transr = 'C'
  986. else:
  987. A_full = (rand(n, n)).astype(dtype)
  988. transr = 'T'
  989. trttf, tfttr = get_lapack_funcs(('trttf', 'tfttr'), dtype=dtype)
  990. A_tf_U, info = trttf(A_full)
  991. assert_(info == 0)
  992. A_tf_L, info = trttf(A_full, uplo='L')
  993. assert_(info == 0)
  994. A_tf_U_T, info = trttf(A_full, transr=transr, uplo='U')
  995. assert_(info == 0)
  996. A_tf_L_T, info = trttf(A_full, transr=transr, uplo='L')
  997. assert_(info == 0)
  998. # Create the RFP array manually (n is even!)
  999. A_tf_U_m = zeros((n+1, n//2), dtype=dtype)
  1000. A_tf_U_m[:-1, :] = triu(A_full)[:, n//2:]
  1001. A_tf_U_m[n//2+1:, :] += triu(A_full)[:n//2, :n//2].conj().T
  1002. A_tf_L_m = zeros((n+1, n//2), dtype=dtype)
  1003. A_tf_L_m[1:, :] = tril(A_full)[:, :n//2]
  1004. A_tf_L_m[:n//2, :] += tril(A_full)[n//2:, n//2:].conj().T
  1005. assert_array_almost_equal(A_tf_U, A_tf_U_m.reshape(-1, order='F'))
  1006. assert_array_almost_equal(A_tf_U_T,
  1007. A_tf_U_m.conj().T.reshape(-1, order='F'))
  1008. assert_array_almost_equal(A_tf_L, A_tf_L_m.reshape(-1, order='F'))
  1009. assert_array_almost_equal(A_tf_L_T,
  1010. A_tf_L_m.conj().T.reshape(-1, order='F'))
  1011. # Get the original array from RFP
  1012. A_tr_U, info = tfttr(n, A_tf_U)
  1013. assert_(info == 0)
  1014. A_tr_L, info = tfttr(n, A_tf_L, uplo='L')
  1015. assert_(info == 0)
  1016. A_tr_U_T, info = tfttr(n, A_tf_U_T, transr=transr, uplo='U')
  1017. assert_(info == 0)
  1018. A_tr_L_T, info = tfttr(n, A_tf_L_T, transr=transr, uplo='L')
  1019. assert_(info == 0)
  1020. assert_array_almost_equal(A_tr_U, triu(A_full))
  1021. assert_array_almost_equal(A_tr_U_T, triu(A_full))
  1022. assert_array_almost_equal(A_tr_L, tril(A_full))
  1023. assert_array_almost_equal(A_tr_L_T, tril(A_full))
  1024. def test_tpttr_trttp():
  1025. """
  1026. Test conversion routines between the Rectengular Full Packed (RFP) format
  1027. and Standard Triangular Array (TR)
  1028. """
  1029. seed(1234)
  1030. for ind, dtype in enumerate(DTYPES):
  1031. n = 20
  1032. if ind > 1:
  1033. A_full = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1034. else:
  1035. A_full = (rand(n, n)).astype(dtype)
  1036. trttp, tpttr = get_lapack_funcs(('trttp', 'tpttr'), dtype=dtype)
  1037. A_tp_U, info = trttp(A_full)
  1038. assert_(info == 0)
  1039. A_tp_L, info = trttp(A_full, uplo='L')
  1040. assert_(info == 0)
  1041. # Create the TP array manually
  1042. inds = tril_indices(n)
  1043. A_tp_U_m = zeros(n*(n+1)//2, dtype=dtype)
  1044. A_tp_U_m[:] = (triu(A_full).T)[inds]
  1045. inds = triu_indices(n)
  1046. A_tp_L_m = zeros(n*(n+1)//2, dtype=dtype)
  1047. A_tp_L_m[:] = (tril(A_full).T)[inds]
  1048. assert_array_almost_equal(A_tp_U, A_tp_U_m)
  1049. assert_array_almost_equal(A_tp_L, A_tp_L_m)
  1050. # Get the original array from TP
  1051. A_tr_U, info = tpttr(n, A_tp_U)
  1052. assert_(info == 0)
  1053. A_tr_L, info = tpttr(n, A_tp_L, uplo='L')
  1054. assert_(info == 0)
  1055. assert_array_almost_equal(A_tr_U, triu(A_full))
  1056. assert_array_almost_equal(A_tr_L, tril(A_full))
  1057. def test_pftrf():
  1058. """
  1059. Test Cholesky factorization of a positive definite Rectengular Full
  1060. Packed (RFP) format array
  1061. """
  1062. seed(1234)
  1063. for ind, dtype in enumerate(DTYPES):
  1064. n = 20
  1065. if ind > 1:
  1066. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1067. A = A + A.conj().T + n*eye(n)
  1068. else:
  1069. A = (rand(n, n)).astype(dtype)
  1070. A = A + A.T + n*eye(n)
  1071. pftrf, trttf, tfttr = get_lapack_funcs(('pftrf', 'trttf', 'tfttr'),
  1072. dtype=dtype)
  1073. # Get the original array from TP
  1074. Afp, info = trttf(A)
  1075. Achol_rfp, info = pftrf(n, Afp)
  1076. assert_(info == 0)
  1077. A_chol_r, _ = tfttr(n, Achol_rfp)
  1078. Achol = cholesky(A)
  1079. assert_array_almost_equal(A_chol_r, Achol)
  1080. def test_pftri():
  1081. """
  1082. Test Cholesky factorization of a positive definite Rectengular Full
  1083. Packed (RFP) format array to find its inverse
  1084. """
  1085. seed(1234)
  1086. for ind, dtype in enumerate(DTYPES):
  1087. n = 20
  1088. if ind > 1:
  1089. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1090. A = A + A.conj().T + n*eye(n)
  1091. else:
  1092. A = (rand(n, n)).astype(dtype)
  1093. A = A + A.T + n*eye(n)
  1094. pftri, pftrf, trttf, tfttr = get_lapack_funcs(('pftri',
  1095. 'pftrf',
  1096. 'trttf',
  1097. 'tfttr'),
  1098. dtype=dtype)
  1099. # Get the original array from TP
  1100. Afp, info = trttf(A)
  1101. A_chol_rfp, info = pftrf(n, Afp)
  1102. A_inv_rfp, info = pftri(n, A_chol_rfp)
  1103. assert_(info == 0)
  1104. A_inv_r, _ = tfttr(n, A_inv_rfp)
  1105. Ainv = inv(A)
  1106. assert_array_almost_equal(A_inv_r, triu(Ainv),
  1107. decimal=4 if ind % 2 == 0 else 6)
  1108. def test_pftrs():
  1109. """
  1110. Test Cholesky factorization of a positive definite Rectengular Full
  1111. Packed (RFP) format array and solve a linear system
  1112. """
  1113. seed(1234)
  1114. for ind, dtype in enumerate(DTYPES):
  1115. n = 20
  1116. if ind > 1:
  1117. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1118. A = A + A.conj().T + n*eye(n)
  1119. else:
  1120. A = (rand(n, n)).astype(dtype)
  1121. A = A + A.T + n*eye(n)
  1122. B = ones((n, 3), dtype=dtype)
  1123. Bf1 = ones((n+2, 3), dtype=dtype)
  1124. Bf2 = ones((n-2, 3), dtype=dtype)
  1125. pftrs, pftrf, trttf, tfttr = get_lapack_funcs(('pftrs',
  1126. 'pftrf',
  1127. 'trttf',
  1128. 'tfttr'),
  1129. dtype=dtype)
  1130. # Get the original array from TP
  1131. Afp, info = trttf(A)
  1132. A_chol_rfp, info = pftrf(n, Afp)
  1133. # larger B arrays shouldn't segfault
  1134. soln, info = pftrs(n, A_chol_rfp, Bf1)
  1135. assert_(info == 0)
  1136. assert_raises(Exception, pftrs, n, A_chol_rfp, Bf2)
  1137. soln, info = pftrs(n, A_chol_rfp, B)
  1138. assert_(info == 0)
  1139. assert_array_almost_equal(solve(A, B), soln,
  1140. decimal=4 if ind % 2 == 0 else 6)
  1141. def test_sfrk_hfrk():
  1142. """
  1143. Test for performing a symmetric rank-k operation for matrix in RFP format.
  1144. """
  1145. seed(1234)
  1146. for ind, dtype in enumerate(DTYPES):
  1147. n = 20
  1148. if ind > 1:
  1149. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1150. A = A + A.conj().T + n*eye(n)
  1151. else:
  1152. A = (rand(n, n)).astype(dtype)
  1153. A = A + A.T + n*eye(n)
  1154. prefix = 's'if ind < 2 else 'h'
  1155. trttf, tfttr, shfrk = get_lapack_funcs(('trttf', 'tfttr', '{}frk'
  1156. ''.format(prefix)),
  1157. dtype=dtype)
  1158. Afp, _ = trttf(A)
  1159. C = np.random.rand(n, 2).astype(dtype)
  1160. Afp_out = shfrk(n, 2, -1, C, 2, Afp)
  1161. A_out, _ = tfttr(n, Afp_out)
  1162. assert_array_almost_equal(A_out, triu(-C.dot(C.conj().T) + 2*A),
  1163. decimal=4 if ind % 2 == 0 else 6)
  1164. def test_syconv():
  1165. """
  1166. Test for going back and forth between the returned format of he/sytrf to
  1167. L and D factors/permutations.
  1168. """
  1169. seed(1234)
  1170. for ind, dtype in enumerate(DTYPES):
  1171. n = 10
  1172. if ind > 1:
  1173. A = (randint(-30, 30, (n, n)) +
  1174. randint(-30, 30, (n, n))*1j).astype(dtype)
  1175. A = A + A.conj().T
  1176. else:
  1177. A = randint(-30, 30, (n, n)).astype(dtype)
  1178. A = A + A.T + n*eye(n)
  1179. tol = 100*np.spacing(dtype(1.0).real)
  1180. syconv, trf, trf_lwork = get_lapack_funcs(('syconv', 'sytrf',
  1181. 'sytrf_lwork'), dtype=dtype)
  1182. lw = _compute_lwork(trf_lwork, n, lower=1)
  1183. L, D, perm = ldl(A, lower=1, hermitian=False)
  1184. lw = _compute_lwork(trf_lwork, n, lower=1)
  1185. ldu, ipiv, info = trf(A, lower=1, lwork=lw)
  1186. a, e, info = syconv(ldu, ipiv, lower=1)
  1187. assert_allclose(tril(a, -1,), tril(L[perm, :], -1), atol=tol, rtol=0.)
  1188. # Test also upper
  1189. U, D, perm = ldl(A, lower=0, hermitian=False)
  1190. ldu, ipiv, info = trf(A, lower=0)
  1191. a, e, info = syconv(ldu, ipiv, lower=0)
  1192. assert_allclose(triu(a, 1), triu(U[perm, :], 1), atol=tol, rtol=0.)
  1193. class TestBlockedQR:
  1194. """
  1195. Tests for the blocked QR factorization, namely through geqrt, gemqrt, tpqrt
  1196. and tpmqr.
  1197. """
  1198. def test_geqrt_gemqrt(self):
  1199. seed(1234)
  1200. for ind, dtype in enumerate(DTYPES):
  1201. n = 20
  1202. if ind > 1:
  1203. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1204. else:
  1205. A = (rand(n, n)).astype(dtype)
  1206. tol = 100*np.spacing(dtype(1.0).real)
  1207. geqrt, gemqrt = get_lapack_funcs(('geqrt', 'gemqrt'), dtype=dtype)
  1208. a, t, info = geqrt(n, A)
  1209. assert info == 0
  1210. # Extract elementary reflectors from lower triangle, adding the
  1211. # main diagonal of ones.
  1212. v = np.tril(a, -1) + np.eye(n, dtype=dtype)
  1213. # Generate the block Householder transform I - VTV^H
  1214. Q = np.eye(n, dtype=dtype) - v @ t @ v.T.conj()
  1215. R = np.triu(a)
  1216. # Test columns of Q are orthogonal
  1217. assert_allclose(Q.T.conj() @ Q, np.eye(n, dtype=dtype), atol=tol,
  1218. rtol=0.)
  1219. assert_allclose(Q @ R, A, atol=tol, rtol=0.)
  1220. if ind > 1:
  1221. C = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1222. transpose = 'C'
  1223. else:
  1224. C = (rand(n, n)).astype(dtype)
  1225. transpose = 'T'
  1226. for side in ('L', 'R'):
  1227. for trans in ('N', transpose):
  1228. c, info = gemqrt(a, t, C, side=side, trans=trans)
  1229. assert info == 0
  1230. if trans == transpose:
  1231. q = Q.T.conj()
  1232. else:
  1233. q = Q
  1234. if side == 'L':
  1235. qC = q @ C
  1236. else:
  1237. qC = C @ q
  1238. assert_allclose(c, qC, atol=tol, rtol=0.)
  1239. # Test default arguments
  1240. if (side, trans) == ('L', 'N'):
  1241. c_default, info = gemqrt(a, t, C)
  1242. assert info == 0
  1243. assert_equal(c_default, c)
  1244. # Test invalid side/trans
  1245. assert_raises(Exception, gemqrt, a, t, C, side='A')
  1246. assert_raises(Exception, gemqrt, a, t, C, trans='A')
  1247. def test_tpqrt_tpmqrt(self):
  1248. seed(1234)
  1249. for ind, dtype in enumerate(DTYPES):
  1250. n = 20
  1251. if ind > 1:
  1252. A = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1253. B = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1254. else:
  1255. A = (rand(n, n)).astype(dtype)
  1256. B = (rand(n, n)).astype(dtype)
  1257. tol = 100*np.spacing(dtype(1.0).real)
  1258. tpqrt, tpmqrt = get_lapack_funcs(('tpqrt', 'tpmqrt'), dtype=dtype)
  1259. # Test for the range of pentagonal B, from square to upper
  1260. # triangular
  1261. for l in (0, n // 2, n):
  1262. a, b, t, info = tpqrt(l, n, A, B)
  1263. assert info == 0
  1264. # Check that lower triangular part of A has not been modified
  1265. assert_equal(np.tril(a, -1), np.tril(A, -1))
  1266. # Check that elements not part of the pentagonal portion of B
  1267. # have not been modified.
  1268. assert_equal(np.tril(b, l - n - 1), np.tril(B, l - n - 1))
  1269. # Extract pentagonal portion of B
  1270. B_pent, b_pent = np.triu(B, l - n), np.triu(b, l - n)
  1271. # Generate elementary reflectors
  1272. v = np.concatenate((np.eye(n, dtype=dtype), b_pent))
  1273. # Generate the block Householder transform I - VTV^H
  1274. Q = np.eye(2 * n, dtype=dtype) - v @ t @ v.T.conj()
  1275. R = np.concatenate((np.triu(a), np.zeros_like(a)))
  1276. # Test columns of Q are orthogonal
  1277. assert_allclose(Q.T.conj() @ Q, np.eye(2 * n, dtype=dtype),
  1278. atol=tol, rtol=0.)
  1279. assert_allclose(Q @ R, np.concatenate((np.triu(A), B_pent)),
  1280. atol=tol, rtol=0.)
  1281. if ind > 1:
  1282. C = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1283. D = (rand(n, n) + rand(n, n)*1j).astype(dtype)
  1284. transpose = 'C'
  1285. else:
  1286. C = (rand(n, n)).astype(dtype)
  1287. D = (rand(n, n)).astype(dtype)
  1288. transpose = 'T'
  1289. for side in ('L', 'R'):
  1290. for trans in ('N', transpose):
  1291. c, d, info = tpmqrt(l, b, t, C, D, side=side,
  1292. trans=trans)
  1293. assert info == 0
  1294. if trans == transpose:
  1295. q = Q.T.conj()
  1296. else:
  1297. q = Q
  1298. if side == 'L':
  1299. cd = np.concatenate((c, d), axis=0)
  1300. CD = np.concatenate((C, D), axis=0)
  1301. qCD = q @ CD
  1302. else:
  1303. cd = np.concatenate((c, d), axis=1)
  1304. CD = np.concatenate((C, D), axis=1)
  1305. qCD = CD @ q
  1306. assert_allclose(cd, qCD, atol=tol, rtol=0.)
  1307. if (side, trans) == ('L', 'N'):
  1308. c_default, d_default, info = tpmqrt(l, b, t, C, D)
  1309. assert info == 0
  1310. assert_equal(c_default, c)
  1311. assert_equal(d_default, d)
  1312. # Test invalid side/trans
  1313. assert_raises(Exception, tpmqrt, l, b, t, C, D, side='A')
  1314. assert_raises(Exception, tpmqrt, l, b, t, C, D, trans='A')
  1315. def test_pstrf():
  1316. seed(1234)
  1317. for ind, dtype in enumerate(DTYPES):
  1318. # DTYPES = <s, d, c, z> pstrf
  1319. n = 10
  1320. r = 2
  1321. pstrf = get_lapack_funcs('pstrf', dtype=dtype)
  1322. # Create positive semidefinite A
  1323. if ind > 1:
  1324. A = rand(n, n-r).astype(dtype) + 1j * rand(n, n-r).astype(dtype)
  1325. A = A @ A.conj().T
  1326. else:
  1327. A = rand(n, n-r).astype(dtype)
  1328. A = A @ A.T
  1329. c, piv, r_c, info = pstrf(A)
  1330. U = triu(c)
  1331. U[r_c - n:, r_c - n:] = 0.
  1332. assert_equal(info, 1)
  1333. # python-dbg 3.5.2 runs cause trouble with the following assertion.
  1334. # assert_equal(r_c, n - r)
  1335. single_atol = 1000 * np.finfo(np.float32).eps
  1336. double_atol = 1000 * np.finfo(np.float64).eps
  1337. atol = single_atol if ind in [0, 2] else double_atol
  1338. assert_allclose(A[piv-1][:, piv-1], U.conj().T @ U, rtol=0., atol=atol)
  1339. c, piv, r_c, info = pstrf(A, lower=1)
  1340. L = tril(c)
  1341. L[r_c - n:, r_c - n:] = 0.
  1342. assert_equal(info, 1)
  1343. # assert_equal(r_c, n - r)
  1344. single_atol = 1000 * np.finfo(np.float32).eps
  1345. double_atol = 1000 * np.finfo(np.float64).eps
  1346. atol = single_atol if ind in [0, 2] else double_atol
  1347. assert_allclose(A[piv-1][:, piv-1], L @ L.conj().T, rtol=0., atol=atol)
  1348. def test_pstf2():
  1349. seed(1234)
  1350. for ind, dtype in enumerate(DTYPES):
  1351. # DTYPES = <s, d, c, z> pstf2
  1352. n = 10
  1353. r = 2
  1354. pstf2 = get_lapack_funcs('pstf2', dtype=dtype)
  1355. # Create positive semidefinite A
  1356. if ind > 1:
  1357. A = rand(n, n-r).astype(dtype) + 1j * rand(n, n-r).astype(dtype)
  1358. A = A @ A.conj().T
  1359. else:
  1360. A = rand(n, n-r).astype(dtype)
  1361. A = A @ A.T
  1362. c, piv, r_c, info = pstf2(A)
  1363. U = triu(c)
  1364. U[r_c - n:, r_c - n:] = 0.
  1365. assert_equal(info, 1)
  1366. # python-dbg 3.5.2 runs cause trouble with the commented assertions.
  1367. # assert_equal(r_c, n - r)
  1368. single_atol = 1000 * np.finfo(np.float32).eps
  1369. double_atol = 1000 * np.finfo(np.float64).eps
  1370. atol = single_atol if ind in [0, 2] else double_atol
  1371. assert_allclose(A[piv-1][:, piv-1], U.conj().T @ U, rtol=0., atol=atol)
  1372. c, piv, r_c, info = pstf2(A, lower=1)
  1373. L = tril(c)
  1374. L[r_c - n:, r_c - n:] = 0.
  1375. assert_equal(info, 1)
  1376. # assert_equal(r_c, n - r)
  1377. single_atol = 1000 * np.finfo(np.float32).eps
  1378. double_atol = 1000 * np.finfo(np.float64).eps
  1379. atol = single_atol if ind in [0, 2] else double_atol
  1380. assert_allclose(A[piv-1][:, piv-1], L @ L.conj().T, rtol=0., atol=atol)
  1381. def test_geequ():
  1382. desired_real = np.array([[0.6250, 1.0000, 0.0393, -0.4269],
  1383. [1.0000, -0.5619, -1.0000, -1.0000],
  1384. [0.5874, -1.0000, -0.0596, -0.5341],
  1385. [-1.0000, -0.5946, -0.0294, 0.9957]])
  1386. desired_cplx = np.array([[-0.2816+0.5359*1j,
  1387. 0.0812+0.9188*1j,
  1388. -0.7439-0.2561*1j],
  1389. [-0.3562-0.2954*1j,
  1390. 0.9566-0.0434*1j,
  1391. -0.0174+0.1555*1j],
  1392. [0.8607+0.1393*1j,
  1393. -0.2759+0.7241*1j,
  1394. -0.1642-0.1365*1j]])
  1395. for ind, dtype in enumerate(DTYPES):
  1396. if ind < 2:
  1397. # Use examples from the NAG documentation
  1398. A = np.array([[1.80e+10, 2.88e+10, 2.05e+00, -8.90e+09],
  1399. [5.25e+00, -2.95e+00, -9.50e-09, -3.80e+00],
  1400. [1.58e+00, -2.69e+00, -2.90e-10, -1.04e+00],
  1401. [-1.11e+00, -6.60e-01, -5.90e-11, 8.00e-01]])
  1402. A = A.astype(dtype)
  1403. else:
  1404. A = np.array([[-1.34e+00, 0.28e+10, -6.39e+00],
  1405. [-1.70e+00, 3.31e+10, -0.15e+00],
  1406. [2.41e-10, -0.56e+00, -0.83e-10]], dtype=dtype)
  1407. A += np.array([[2.55e+00, 3.17e+10, -2.20e+00],
  1408. [-1.41e+00, -0.15e+10, 1.34e+00],
  1409. [0.39e-10, 1.47e+00, -0.69e-10]])*1j
  1410. A = A.astype(dtype)
  1411. geequ = get_lapack_funcs('geequ', dtype=dtype)
  1412. r, c, rowcnd, colcnd, amax, info = geequ(A)
  1413. if ind < 2:
  1414. assert_allclose(desired_real.astype(dtype), r[:, None]*A*c,
  1415. rtol=0, atol=1e-4)
  1416. else:
  1417. assert_allclose(desired_cplx.astype(dtype), r[:, None]*A*c,
  1418. rtol=0, atol=1e-4)
  1419. def test_syequb():
  1420. desired_log2s = np.array([0, 0, 0, 0, 0, 0, -1, -1, -2, -3])
  1421. for ind, dtype in enumerate(DTYPES):
  1422. A = np.eye(10, dtype=dtype)
  1423. alpha = dtype(1. if ind < 2 else 1.j)
  1424. d = np.array([alpha * 2.**x for x in range(-5, 5)], dtype=dtype)
  1425. A += np.rot90(np.diag(d))
  1426. syequb = get_lapack_funcs('syequb', dtype=dtype)
  1427. s, scond, amax, info = syequb(A)
  1428. assert_equal(np.log2(s).astype(int), desired_log2s)
  1429. @pytest.mark.skipif(True,
  1430. reason="Failing on some OpenBLAS version, see gh-12276")
  1431. def test_heequb():
  1432. # zheequb has a bug for versions =< LAPACK 3.9.0
  1433. # See Reference-LAPACK gh-61 and gh-408
  1434. # Hence the zheequb test is customized accordingly to avoid
  1435. # work scaling.
  1436. A = np.diag([2]*5 + [1002]*5) + np.diag(np.ones(9), k=1)*1j
  1437. s, scond, amax, info = lapack.zheequb(A)
  1438. assert_equal(info, 0)
  1439. assert_allclose(np.log2(s), [0., -1.]*2 + [0.] + [-4]*5)
  1440. A = np.diag(2**np.abs(np.arange(-5, 6)) + 0j)
  1441. A[5, 5] = 1024
  1442. A[5, 0] = 16j
  1443. s, scond, amax, info = lapack.cheequb(A.astype(np.complex64), lower=1)
  1444. assert_equal(info, 0)
  1445. assert_allclose(np.log2(s), [-2, -1, -1, 0, 0, -5, 0, -1, -1, -2, -2])
  1446. def test_getc2_gesc2():
  1447. np.random.seed(42)
  1448. n = 10
  1449. desired_real = np.random.rand(n)
  1450. desired_cplx = np.random.rand(n) + np.random.rand(n)*1j
  1451. for ind, dtype in enumerate(DTYPES):
  1452. if ind < 2:
  1453. A = np.random.rand(n, n)
  1454. A = A.astype(dtype)
  1455. b = A @ desired_real
  1456. b = b.astype(dtype)
  1457. else:
  1458. A = np.random.rand(n, n) + np.random.rand(n, n)*1j
  1459. A = A.astype(dtype)
  1460. b = A @ desired_cplx
  1461. b = b.astype(dtype)
  1462. getc2 = get_lapack_funcs('getc2', dtype=dtype)
  1463. gesc2 = get_lapack_funcs('gesc2', dtype=dtype)
  1464. lu, ipiv, jpiv, info = getc2(A, overwrite_a=0)
  1465. x, scale = gesc2(lu, b, ipiv, jpiv, overwrite_rhs=0)
  1466. if ind < 2:
  1467. assert_array_almost_equal(desired_real.astype(dtype),
  1468. x/scale, decimal=4)
  1469. else:
  1470. assert_array_almost_equal(desired_cplx.astype(dtype),
  1471. x/scale, decimal=4)
  1472. @pytest.mark.parametrize('size', [(6, 5), (5, 5)])
  1473. @pytest.mark.parametrize('dtype', REAL_DTYPES)
  1474. @pytest.mark.parametrize('joba', range(6)) # 'C', 'E', 'F', 'G', 'A', 'R'
  1475. @pytest.mark.parametrize('jobu', range(4)) # 'U', 'F', 'W', 'N'
  1476. @pytest.mark.parametrize('jobv', range(4)) # 'V', 'J', 'W', 'N'
  1477. @pytest.mark.parametrize('jobr', [0, 1])
  1478. @pytest.mark.parametrize('jobp', [0, 1])
  1479. def test_gejsv_general(size, dtype, joba, jobu, jobv, jobr, jobp, jobt=0):
  1480. """Test the lapack routine ?gejsv.
  1481. This function tests that a singular value decomposition can be performed
  1482. on the random M-by-N matrix A. The test performs the SVD using ?gejsv
  1483. then performs the following checks:
  1484. * ?gejsv exist successfully (info == 0)
  1485. * The returned singular values are correct
  1486. * `A` can be reconstructed from `u`, `SIGMA`, `v`
  1487. * Ensure that u.T @ u is the identity matrix
  1488. * Ensure that v.T @ v is the identity matrix
  1489. * The reported matrix rank
  1490. * The reported number of singular values
  1491. * If denormalized floats are required
  1492. Notes
  1493. -----
  1494. joba specifies several choices effecting the calculation's accuracy
  1495. Although all arguments are tested, the tests only check that the correct
  1496. solution is returned - NOT that the prescribed actions are performed
  1497. internally.
  1498. jobt is, as of v3.9.0, still experimental and removed to cut down number of
  1499. test cases. However keyword itself is tested externally.
  1500. """
  1501. seed(42)
  1502. # Define some constants for later use:
  1503. m, n = size
  1504. atol = 100 * np.finfo(dtype).eps
  1505. A = generate_random_dtype_array(size, dtype)
  1506. gejsv = get_lapack_funcs('gejsv', dtype=dtype)
  1507. # Set up checks for invalid job? combinations
  1508. # if an invalid combination occurs we set the appropriate
  1509. # exit status.
  1510. lsvec = jobu < 2 # Calculate left singular vectors
  1511. rsvec = jobv < 2 # Calculate right singular vectors
  1512. l2tran = (jobt == 1) and (m == n)
  1513. is_complex = np.iscomplexobj(A)
  1514. invalid_real_jobv = (jobv == 1) and (not lsvec) and (not is_complex)
  1515. invalid_cplx_jobu = (jobu == 2) and not (rsvec and l2tran) and is_complex
  1516. invalid_cplx_jobv = (jobv == 2) and not (lsvec and l2tran) and is_complex
  1517. # Set the exit status to the expected value.
  1518. # Here we only check for invalid combinations, not individual
  1519. # parameters.
  1520. if invalid_cplx_jobu:
  1521. exit_status = -2
  1522. elif invalid_real_jobv or invalid_cplx_jobv:
  1523. exit_status = -3
  1524. else:
  1525. exit_status = 0
  1526. if (jobu > 1) and (jobv == 1):
  1527. assert_raises(Exception, gejsv, A, joba, jobu, jobv, jobr, jobt, jobp)
  1528. else:
  1529. sva, u, v, work, iwork, info = gejsv(A,
  1530. joba=joba,
  1531. jobu=jobu,
  1532. jobv=jobv,
  1533. jobr=jobr,
  1534. jobt=jobt,
  1535. jobp=jobp)
  1536. # Check that ?gejsv exited successfully/as expected
  1537. assert_equal(info, exit_status)
  1538. # If exit_status is non-zero the combination of jobs is invalid.
  1539. # We test this above but no calculations are performed.
  1540. if not exit_status:
  1541. # Check the returned singular values
  1542. sigma = (work[0] / work[1]) * sva[:n]
  1543. assert_allclose(sigma, svd(A, compute_uv=False), atol=atol)
  1544. if jobu == 1:
  1545. # If JOBU = 'F', then u contains the M-by-M matrix of
  1546. # the left singular vectors, including an ONB of the orthogonal
  1547. # complement of the Range(A)
  1548. # However, to recalculate A we are concerned about the
  1549. # first n singular values and so can ignore the latter.
  1550. # TODO: Add a test for ONB?
  1551. u = u[:, :n]
  1552. if lsvec and rsvec:
  1553. assert_allclose(u @ np.diag(sigma) @ v.conj().T, A, atol=atol)
  1554. if lsvec:
  1555. assert_allclose(u.conj().T @ u, np.identity(n), atol=atol)
  1556. if rsvec:
  1557. assert_allclose(v.conj().T @ v, np.identity(n), atol=atol)
  1558. assert_equal(iwork[0], np.linalg.matrix_rank(A))
  1559. assert_equal(iwork[1], np.count_nonzero(sigma))
  1560. # iwork[2] is non-zero if requested accuracy is not warranted for
  1561. # the data. This should never occur for these tests.
  1562. assert_equal(iwork[2], 0)
  1563. @pytest.mark.parametrize('dtype', REAL_DTYPES)
  1564. def test_gejsv_edge_arguments(dtype):
  1565. """Test edge arguments return expected status"""
  1566. gejsv = get_lapack_funcs('gejsv', dtype=dtype)
  1567. # scalar A
  1568. sva, u, v, work, iwork, info = gejsv(1.)
  1569. assert_equal(info, 0)
  1570. assert_equal(u.shape, (1, 1))
  1571. assert_equal(v.shape, (1, 1))
  1572. assert_equal(sva, np.array([1.], dtype=dtype))
  1573. # 1d A
  1574. A = np.ones((1,), dtype=dtype)
  1575. sva, u, v, work, iwork, info = gejsv(A)
  1576. assert_equal(info, 0)
  1577. assert_equal(u.shape, (1, 1))
  1578. assert_equal(v.shape, (1, 1))
  1579. assert_equal(sva, np.array([1.], dtype=dtype))
  1580. # 2d empty A
  1581. A = np.ones((1, 0), dtype=dtype)
  1582. sva, u, v, work, iwork, info = gejsv(A)
  1583. assert_equal(info, 0)
  1584. assert_equal(u.shape, (1, 0))
  1585. assert_equal(v.shape, (1, 0))
  1586. assert_equal(sva, np.array([], dtype=dtype))
  1587. # make sure "overwrite_a" is respected - user reported in gh-13191
  1588. A = np.sin(np.arange(100).reshape(10, 10)).astype(dtype)
  1589. A = np.asfortranarray(A + A.T) # make it symmetric and column major
  1590. Ac = A.copy('A')
  1591. _ = gejsv(A)
  1592. assert_allclose(A, Ac)
  1593. @pytest.mark.parametrize(('kwargs'),
  1594. ({'joba': 9},
  1595. {'jobu': 9},
  1596. {'jobv': 9},
  1597. {'jobr': 9},
  1598. {'jobt': 9},
  1599. {'jobp': 9})
  1600. )
  1601. def test_gejsv_invalid_job_arguments(kwargs):
  1602. """Test invalid job arguments raise an Exception"""
  1603. A = np.ones((2, 2), dtype=float)
  1604. gejsv = get_lapack_funcs('gejsv', dtype=float)
  1605. assert_raises(Exception, gejsv, A, **kwargs)
  1606. @pytest.mark.parametrize("A,sva_expect,u_expect,v_expect",
  1607. [(np.array([[2.27, -1.54, 1.15, -1.94],
  1608. [0.28, -1.67, 0.94, -0.78],
  1609. [-0.48, -3.09, 0.99, -0.21],
  1610. [1.07, 1.22, 0.79, 0.63],
  1611. [-2.35, 2.93, -1.45, 2.30],
  1612. [0.62, -7.39, 1.03, -2.57]]),
  1613. np.array([9.9966, 3.6831, 1.3569, 0.5000]),
  1614. np.array([[0.2774, -0.6003, -0.1277, 0.1323],
  1615. [0.2020, -0.0301, 0.2805, 0.7034],
  1616. [0.2918, 0.3348, 0.6453, 0.1906],
  1617. [-0.0938, -0.3699, 0.6781, -0.5399],
  1618. [-0.4213, 0.5266, 0.0413, -0.0575],
  1619. [0.7816, 0.3353, -0.1645, -0.3957]]),
  1620. np.array([[0.1921, -0.8030, 0.0041, -0.5642],
  1621. [-0.8794, -0.3926, -0.0752, 0.2587],
  1622. [0.2140, -0.2980, 0.7827, 0.5027],
  1623. [-0.3795, 0.3351, 0.6178, -0.6017]]))])
  1624. def test_gejsv_NAG(A, sva_expect, u_expect, v_expect):
  1625. """
  1626. This test implements the example found in the NAG manual, f08khf.
  1627. An example was not found for the complex case.
  1628. """
  1629. # NAG manual provides accuracy up to 4 decimals
  1630. atol = 1e-4
  1631. gejsv = get_lapack_funcs('gejsv', dtype=A.dtype)
  1632. sva, u, v, work, iwork, info = gejsv(A)
  1633. assert_allclose(sva_expect, sva, atol=atol)
  1634. assert_allclose(u_expect, u, atol=atol)
  1635. assert_allclose(v_expect, v, atol=atol)
  1636. @pytest.mark.parametrize("dtype", DTYPES)
  1637. def test_gttrf_gttrs(dtype):
  1638. # The test uses ?gttrf and ?gttrs to solve a random system for each dtype,
  1639. # tests that the output of ?gttrf define LU matricies, that input
  1640. # parameters are unmodified, transposal options function correctly, that
  1641. # incompatible matrix shapes raise an error, and singular matrices return
  1642. # non zero info.
  1643. seed(42)
  1644. n = 10
  1645. atol = 100 * np.finfo(dtype).eps
  1646. # create the matrix in accordance with the data type
  1647. du = generate_random_dtype_array((n-1,), dtype=dtype)
  1648. d = generate_random_dtype_array((n,), dtype=dtype)
  1649. dl = generate_random_dtype_array((n-1,), dtype=dtype)
  1650. diag_cpy = [dl.copy(), d.copy(), du.copy()]
  1651. A = np.diag(d) + np.diag(dl, -1) + np.diag(du, 1)
  1652. x = np.random.rand(n)
  1653. b = A @ x
  1654. gttrf, gttrs = get_lapack_funcs(('gttrf', 'gttrs'), dtype=dtype)
  1655. _dl, _d, _du, du2, ipiv, info = gttrf(dl, d, du)
  1656. # test to assure that the inputs of ?gttrf are unmodified
  1657. assert_array_equal(dl, diag_cpy[0])
  1658. assert_array_equal(d, diag_cpy[1])
  1659. assert_array_equal(du, diag_cpy[2])
  1660. # generate L and U factors from ?gttrf return values
  1661. # L/U are lower/upper triangular by construction (initially and at end)
  1662. U = np.diag(_d, 0) + np.diag(_du, 1) + np.diag(du2, 2)
  1663. L = np.eye(n, dtype=dtype)
  1664. for i, m in enumerate(_dl):
  1665. # L is given in a factored form.
  1666. # See
  1667. # www.hpcavf.uclan.ac.uk/softwaredoc/sgi_scsl_html/sgi_html/ch03.html
  1668. piv = ipiv[i] - 1
  1669. # right multiply by permutation matrix
  1670. L[:, [i, piv]] = L[:, [piv, i]]
  1671. # right multiply by Li, rank-one modification of identity
  1672. L[:, i] += L[:, i+1]*m
  1673. # one last permutation
  1674. i, piv = -1, ipiv[-1] - 1
  1675. # right multiply by final permutation matrix
  1676. L[:, [i, piv]] = L[:, [piv, i]]
  1677. # check that the outputs of ?gttrf define an LU decomposition of A
  1678. assert_allclose(A, L @ U, atol=atol)
  1679. b_cpy = b.copy()
  1680. x_gttrs, info = gttrs(_dl, _d, _du, du2, ipiv, b)
  1681. # test that the inputs of ?gttrs are unmodified
  1682. assert_array_equal(b, b_cpy)
  1683. # test that the result of ?gttrs matches the expected input
  1684. assert_allclose(x, x_gttrs, atol=atol)
  1685. # test that ?gttrf and ?gttrs work with transposal options
  1686. if dtype in REAL_DTYPES:
  1687. trans = "T"
  1688. b_trans = A.T @ x
  1689. else:
  1690. trans = "C"
  1691. b_trans = A.conj().T @ x
  1692. x_gttrs, info = gttrs(_dl, _d, _du, du2, ipiv, b_trans, trans=trans)
  1693. assert_allclose(x, x_gttrs, atol=atol)
  1694. # test that ValueError is raised with incompatible matrix shapes
  1695. with assert_raises(ValueError):
  1696. gttrf(dl[:-1], d, du)
  1697. with assert_raises(ValueError):
  1698. gttrf(dl, d[:-1], du)
  1699. with assert_raises(ValueError):
  1700. gttrf(dl, d, du[:-1])
  1701. # test that matrix of size n=2 raises exception
  1702. with assert_raises(Exception):
  1703. gttrf(dl[0], d[:1], du[0])
  1704. # test that singular (row of all zeroes) matrix fails via info
  1705. du[0] = 0
  1706. d[0] = 0
  1707. __dl, __d, __du, _du2, _ipiv, _info = gttrf(dl, d, du)
  1708. np.testing.assert_(__d[info - 1] == 0,
  1709. "?gttrf: _d[info-1] is {}, not the illegal value :0."
  1710. .format(__d[info - 1]))
  1711. @pytest.mark.parametrize("du, d, dl, du_exp, d_exp, du2_exp, ipiv_exp, b, x",
  1712. [(np.array([2.1, -1.0, 1.9, 8.0]),
  1713. np.array([3.0, 2.3, -5.0, -.9, 7.1]),
  1714. np.array([3.4, 3.6, 7.0, -6.0]),
  1715. np.array([2.3, -5, -.9, 7.1]),
  1716. np.array([3.4, 3.6, 7, -6, -1.015373]),
  1717. np.array([-1, 1.9, 8]),
  1718. np.array([2, 3, 4, 5, 5]),
  1719. np.array([[2.7, 6.6],
  1720. [-0.5, 10.8],
  1721. [2.6, -3.2],
  1722. [0.6, -11.2],
  1723. [2.7, 19.1]
  1724. ]),
  1725. np.array([[-4, 5],
  1726. [7, -4],
  1727. [3, -3],
  1728. [-4, -2],
  1729. [-3, 1]])),
  1730. (
  1731. np.array([2 - 1j, 2 + 1j, -1 + 1j, 1 - 1j]),
  1732. np.array([-1.3 + 1.3j, -1.3 + 1.3j,
  1733. -1.3 + 3.3j, - .3 + 4.3j,
  1734. -3.3 + 1.3j]),
  1735. np.array([1 - 2j, 1 + 1j, 2 - 3j, 1 + 1j]),
  1736. # du exp
  1737. np.array([-1.3 + 1.3j, -1.3 + 3.3j,
  1738. -0.3 + 4.3j, -3.3 + 1.3j]),
  1739. np.array([1 - 2j, 1 + 1j, 2 - 3j, 1 + 1j,
  1740. -1.3399 + 0.2875j]),
  1741. np.array([2 + 1j, -1 + 1j, 1 - 1j]),
  1742. np.array([2, 3, 4, 5, 5]),
  1743. np.array([[2.4 - 5j, 2.7 + 6.9j],
  1744. [3.4 + 18.2j, - 6.9 - 5.3j],
  1745. [-14.7 + 9.7j, - 6 - .6j],
  1746. [31.9 - 7.7j, -3.9 + 9.3j],
  1747. [-1 + 1.6j, -3 + 12.2j]]),
  1748. np.array([[1 + 1j, 2 - 1j],
  1749. [3 - 1j, 1 + 2j],
  1750. [4 + 5j, -1 + 1j],
  1751. [-1 - 2j, 2 + 1j],
  1752. [1 - 1j, 2 - 2j]])
  1753. )])
  1754. def test_gttrf_gttrs_NAG_f07cdf_f07cef_f07crf_f07csf(du, d, dl, du_exp, d_exp,
  1755. du2_exp, ipiv_exp, b, x):
  1756. # test to assure that wrapper is consistent with NAG Library Manual Mark 26
  1757. # example problems: f07cdf and f07cef (real)
  1758. # examples: f07crf and f07csf (complex)
  1759. # (Links may expire, so search for "NAG Library Manual Mark 26" online)
  1760. gttrf, gttrs = get_lapack_funcs(('gttrf', "gttrs"), (du[0], du[0]))
  1761. _dl, _d, _du, du2, ipiv, info = gttrf(dl, d, du)
  1762. assert_allclose(du2, du2_exp)
  1763. assert_allclose(_du, du_exp)
  1764. assert_allclose(_d, d_exp, atol=1e-4) # NAG examples provide 4 decimals.
  1765. assert_allclose(ipiv, ipiv_exp)
  1766. x_gttrs, info = gttrs(_dl, _d, _du, du2, ipiv, b)
  1767. assert_allclose(x_gttrs, x)
  1768. @pytest.mark.parametrize('dtype', DTYPES)
  1769. @pytest.mark.parametrize('shape', [(3, 7), (7, 3), (2**18, 2**18)])
  1770. def test_geqrfp_lwork(dtype, shape):
  1771. geqrfp_lwork = get_lapack_funcs(('geqrfp_lwork'), dtype=dtype)
  1772. m, n = shape
  1773. lwork, info = geqrfp_lwork(m=m, n=n)
  1774. assert_equal(info, 0)
  1775. @pytest.mark.parametrize("ddtype,dtype",
  1776. zip(REAL_DTYPES + REAL_DTYPES, DTYPES))
  1777. def test_pttrf_pttrs(ddtype, dtype):
  1778. seed(42)
  1779. # set test tolerance appropriate for dtype
  1780. atol = 100*np.finfo(dtype).eps
  1781. # n is the length diagonal of A
  1782. n = 10
  1783. # create diagonals according to size and dtype
  1784. # diagonal d should always be real.
  1785. # add 4 to d so it will be dominant for all dtypes
  1786. d = generate_random_dtype_array((n,), ddtype) + 4
  1787. # diagonal e may be real or complex.
  1788. e = generate_random_dtype_array((n-1,), dtype)
  1789. # assemble diagonals together into matrix
  1790. A = np.diag(d) + np.diag(e, -1) + np.diag(np.conj(e), 1)
  1791. # store a copy of diagonals to later verify
  1792. diag_cpy = [d.copy(), e.copy()]
  1793. pttrf = get_lapack_funcs('pttrf', dtype=dtype)
  1794. _d, _e, info = pttrf(d, e)
  1795. # test to assure that the inputs of ?pttrf are unmodified
  1796. assert_array_equal(d, diag_cpy[0])
  1797. assert_array_equal(e, diag_cpy[1])
  1798. assert_equal(info, 0, err_msg="pttrf: info = {}, should be 0".format(info))
  1799. # test that the factors from pttrf can be recombined to make A
  1800. L = np.diag(_e, -1) + np.diag(np.ones(n))
  1801. D = np.diag(_d)
  1802. assert_allclose(A, L@D@L.conjugate().T, atol=atol)
  1803. # generate random solution x
  1804. x = generate_random_dtype_array((n,), dtype)
  1805. # determine accompanying b to get soln x
  1806. b = A@x
  1807. # determine _x from pttrs
  1808. pttrs = get_lapack_funcs('pttrs', dtype=dtype)
  1809. _x, info = pttrs(_d, _e.conj(), b)
  1810. assert_equal(info, 0, err_msg="pttrs: info = {}, should be 0".format(info))
  1811. # test that _x from pttrs matches the expected x
  1812. assert_allclose(x, _x, atol=atol)
  1813. @pytest.mark.parametrize("ddtype,dtype",
  1814. zip(REAL_DTYPES + REAL_DTYPES, DTYPES))
  1815. def test_pttrf_pttrs_errors_incompatible_shape(ddtype, dtype):
  1816. n = 10
  1817. pttrf = get_lapack_funcs('pttrf', dtype=dtype)
  1818. d = generate_random_dtype_array((n,), ddtype) + 2
  1819. e = generate_random_dtype_array((n-1,), dtype)
  1820. # test that ValueError is raised with incompatible matrix shapes
  1821. assert_raises(ValueError, pttrf, d[:-1], e)
  1822. assert_raises(ValueError, pttrf, d, e[:-1])
  1823. @pytest.mark.parametrize("ddtype,dtype",
  1824. zip(REAL_DTYPES + REAL_DTYPES, DTYPES))
  1825. def test_pttrf_pttrs_errors_singular_nonSPD(ddtype, dtype):
  1826. n = 10
  1827. pttrf = get_lapack_funcs('pttrf', dtype=dtype)
  1828. d = generate_random_dtype_array((n,), ddtype) + 2
  1829. e = generate_random_dtype_array((n-1,), dtype)
  1830. # test that singular (row of all zeroes) matrix fails via info
  1831. d[0] = 0
  1832. e[0] = 0
  1833. _d, _e, info = pttrf(d, e)
  1834. assert_equal(_d[info - 1], 0,
  1835. "?pttrf: _d[info-1] is {}, not the illegal value :0."
  1836. .format(_d[info - 1]))
  1837. # test with non-spd matrix
  1838. d = generate_random_dtype_array((n,), ddtype)
  1839. _d, _e, info = pttrf(d, e)
  1840. assert_(info != 0, "?pttrf should fail with non-spd matrix, but didn't")
  1841. @pytest.mark.parametrize(("d, e, d_expect, e_expect, b, x_expect"), [
  1842. (np.array([4, 10, 29, 25, 5]),
  1843. np.array([-2, -6, 15, 8]),
  1844. np.array([4, 9, 25, 16, 1]),
  1845. np.array([-.5, -.6667, .6, .5]),
  1846. np.array([[6, 10], [9, 4], [2, 9], [14, 65],
  1847. [7, 23]]),
  1848. np.array([[2.5, 2], [2, -1], [1, -3], [-1, 6],
  1849. [3, -5]])
  1850. ), (
  1851. np.array([16, 41, 46, 21]),
  1852. np.array([16 + 16j, 18 - 9j, 1 - 4j]),
  1853. np.array([16, 9, 1, 4]),
  1854. np.array([1+1j, 2-1j, 1-4j]),
  1855. np.array([[64+16j, -16-32j], [93+62j, 61-66j],
  1856. [78-80j, 71-74j], [14-27j, 35+15j]]),
  1857. np.array([[2+1j, -3-2j], [1+1j, 1+1j], [1-2j, 1-2j],
  1858. [1-1j, 2+1j]])
  1859. )])
  1860. def test_pttrf_pttrs_NAG(d, e, d_expect, e_expect, b, x_expect):
  1861. # test to assure that wrapper is consistent with NAG Manual Mark 26
  1862. # example problems: f07jdf and f07jef (real)
  1863. # examples: f07jrf and f07csf (complex)
  1864. # NAG examples provide 4 decimals.
  1865. # (Links expire, so please search for "NAG Library Manual Mark 26" online)
  1866. atol = 1e-4
  1867. pttrf = get_lapack_funcs('pttrf', dtype=e[0])
  1868. _d, _e, info = pttrf(d, e)
  1869. assert_allclose(_d, d_expect, atol=atol)
  1870. assert_allclose(_e, e_expect, atol=atol)
  1871. pttrs = get_lapack_funcs('pttrs', dtype=e[0])
  1872. _x, info = pttrs(_d, _e.conj(), b)
  1873. assert_allclose(_x, x_expect, atol=atol)
  1874. # also test option `lower`
  1875. if e.dtype in COMPLEX_DTYPES:
  1876. _x, info = pttrs(_d, _e, b, lower=1)
  1877. assert_allclose(_x, x_expect, atol=atol)
  1878. def pteqr_get_d_e_A_z(dtype, realtype, n, compute_z):
  1879. # used by ?pteqr tests to build parameters
  1880. # returns tuple of (d, e, A, z)
  1881. if compute_z == 1:
  1882. # build Hermitian A from Q**T * tri * Q = A by creating Q and tri
  1883. A_eig = generate_random_dtype_array((n, n), dtype)
  1884. A_eig = A_eig + np.diag(np.zeros(n) + 4*n)
  1885. A_eig = (A_eig + A_eig.conj().T) / 2
  1886. # obtain right eigenvectors (orthogonal)
  1887. vr = eigh(A_eig)[1]
  1888. # create tridiagonal matrix
  1889. d = generate_random_dtype_array((n,), realtype) + 4
  1890. e = generate_random_dtype_array((n-1,), realtype)
  1891. tri = np.diag(d) + np.diag(e, 1) + np.diag(e, -1)
  1892. # Build A using these factors that sytrd would: (Q**T * tri * Q = A)
  1893. A = vr @ tri @ vr.conj().T
  1894. # vr is orthogonal
  1895. z = vr
  1896. else:
  1897. # d and e are always real per lapack docs.
  1898. d = generate_random_dtype_array((n,), realtype)
  1899. e = generate_random_dtype_array((n-1,), realtype)
  1900. # make SPD
  1901. d = d + 4
  1902. A = np.diag(d) + np.diag(e, 1) + np.diag(e, -1)
  1903. z = np.diag(d) + np.diag(e, -1) + np.diag(e, 1)
  1904. return (d, e, A, z)
  1905. @pytest.mark.parametrize("dtype,realtype",
  1906. zip(DTYPES, REAL_DTYPES + REAL_DTYPES))
  1907. @pytest.mark.parametrize("compute_z", range(3))
  1908. def test_pteqr(dtype, realtype, compute_z):
  1909. '''
  1910. Tests the ?pteqr lapack routine for all dtypes and compute_z parameters.
  1911. It generates random SPD matrix diagonals d and e, and then confirms
  1912. correct eigenvalues with scipy.linalg.eig. With applicable compute_z=2 it
  1913. tests that z can reform A.
  1914. '''
  1915. seed(42)
  1916. atol = 1000*np.finfo(dtype).eps
  1917. pteqr = get_lapack_funcs(('pteqr'), dtype=dtype)
  1918. n = 10
  1919. d, e, A, z = pteqr_get_d_e_A_z(dtype, realtype, n, compute_z)
  1920. d_pteqr, e_pteqr, z_pteqr, info = pteqr(d=d, e=e, z=z, compute_z=compute_z)
  1921. assert_equal(info, 0, "info = {}, should be 0.".format(info))
  1922. # compare the routine's eigenvalues with scipy.linalg.eig's.
  1923. assert_allclose(np.sort(eigh(A)[0]), np.sort(d_pteqr), atol=atol)
  1924. if compute_z:
  1925. # verify z_pteqr as orthogonal
  1926. assert_allclose(z_pteqr @ np.conj(z_pteqr).T, np.identity(n),
  1927. atol=atol)
  1928. # verify that z_pteqr recombines to A
  1929. assert_allclose(z_pteqr @ np.diag(d_pteqr) @ np.conj(z_pteqr).T,
  1930. A, atol=atol)
  1931. @pytest.mark.parametrize("dtype,realtype",
  1932. zip(DTYPES, REAL_DTYPES + REAL_DTYPES))
  1933. @pytest.mark.parametrize("compute_z", range(3))
  1934. def test_pteqr_error_non_spd(dtype, realtype, compute_z):
  1935. seed(42)
  1936. pteqr = get_lapack_funcs(('pteqr'), dtype=dtype)
  1937. n = 10
  1938. d, e, A, z = pteqr_get_d_e_A_z(dtype, realtype, n, compute_z)
  1939. # test with non-spd matrix
  1940. d_pteqr, e_pteqr, z_pteqr, info = pteqr(d - 4, e, z=z, compute_z=compute_z)
  1941. assert info > 0
  1942. @pytest.mark.parametrize("dtype,realtype",
  1943. zip(DTYPES, REAL_DTYPES + REAL_DTYPES))
  1944. @pytest.mark.parametrize("compute_z", range(3))
  1945. def test_pteqr_raise_error_wrong_shape(dtype, realtype, compute_z):
  1946. seed(42)
  1947. pteqr = get_lapack_funcs(('pteqr'), dtype=dtype)
  1948. n = 10
  1949. d, e, A, z = pteqr_get_d_e_A_z(dtype, realtype, n, compute_z)
  1950. # test with incorrect/incompatible array sizes
  1951. assert_raises(ValueError, pteqr, d[:-1], e, z=z, compute_z=compute_z)
  1952. assert_raises(ValueError, pteqr, d, e[:-1], z=z, compute_z=compute_z)
  1953. if compute_z:
  1954. assert_raises(ValueError, pteqr, d, e, z=z[:-1], compute_z=compute_z)
  1955. @pytest.mark.parametrize("dtype,realtype",
  1956. zip(DTYPES, REAL_DTYPES + REAL_DTYPES))
  1957. @pytest.mark.parametrize("compute_z", range(3))
  1958. def test_pteqr_error_singular(dtype, realtype, compute_z):
  1959. seed(42)
  1960. pteqr = get_lapack_funcs(('pteqr'), dtype=dtype)
  1961. n = 10
  1962. d, e, A, z = pteqr_get_d_e_A_z(dtype, realtype, n, compute_z)
  1963. # test with singular matrix
  1964. d[0] = 0
  1965. e[0] = 0
  1966. d_pteqr, e_pteqr, z_pteqr, info = pteqr(d, e, z=z, compute_z=compute_z)
  1967. assert info > 0
  1968. @pytest.mark.parametrize("compute_z,d,e,d_expect,z_expect",
  1969. [(2, # "I"
  1970. np.array([4.16, 5.25, 1.09, .62]),
  1971. np.array([3.17, -.97, .55]),
  1972. np.array([8.0023, 1.9926, 1.0014, 0.1237]),
  1973. np.array([[0.6326, 0.6245, -0.4191, 0.1847],
  1974. [0.7668, -0.4270, 0.4176, -0.2352],
  1975. [-0.1082, 0.6071, 0.4594, -0.6393],
  1976. [-0.0081, 0.2432, 0.6625, 0.7084]])),
  1977. ])
  1978. def test_pteqr_NAG_f08jgf(compute_z, d, e, d_expect, z_expect):
  1979. '''
  1980. Implements real (f08jgf) example from NAG Manual Mark 26.
  1981. Tests for correct outputs.
  1982. '''
  1983. # the NAG manual has 4 decimals accuracy
  1984. atol = 1e-4
  1985. pteqr = get_lapack_funcs(('pteqr'), dtype=d.dtype)
  1986. z = np.diag(d) + np.diag(e, 1) + np.diag(e, -1)
  1987. _d, _e, _z, info = pteqr(d=d, e=e, z=z, compute_z=compute_z)
  1988. assert_allclose(_d, d_expect, atol=atol)
  1989. assert_allclose(np.abs(_z), np.abs(z_expect), atol=atol)
  1990. @pytest.mark.parametrize('dtype', DTYPES)
  1991. @pytest.mark.parametrize('matrix_size', [(3, 4), (7, 6), (6, 6)])
  1992. def test_geqrfp(dtype, matrix_size):
  1993. # Tests for all dytpes, tall, wide, and square matrices.
  1994. # Using the routine with random matrix A, Q and R are obtained and then
  1995. # tested such that R is upper triangular and non-negative on the diagonal,
  1996. # and Q is an orthagonal matrix. Verifies that A=Q@R. It also
  1997. # tests against a matrix that for which the linalg.qr method returns
  1998. # negative diagonals, and for error messaging.
  1999. # set test tolerance appropriate for dtype
  2000. np.random.seed(42)
  2001. rtol = 250*np.finfo(dtype).eps
  2002. atol = 100*np.finfo(dtype).eps
  2003. # get appropriate ?geqrfp for dtype
  2004. geqrfp = get_lapack_funcs(('geqrfp'), dtype=dtype)
  2005. gqr = get_lapack_funcs(("orgqr"), dtype=dtype)
  2006. m, n = matrix_size
  2007. # create random matrix of dimentions m x n
  2008. A = generate_random_dtype_array((m, n), dtype=dtype)
  2009. # create qr matrix using geqrfp
  2010. qr_A, tau, info = geqrfp(A)
  2011. # obtain r from the upper triangular area
  2012. r = np.triu(qr_A)
  2013. # obtain q from the orgqr lapack routine
  2014. # based on linalg.qr's extraction strategy of q with orgqr
  2015. if m > n:
  2016. # this adds an extra column to the end of qr_A
  2017. # let qqr be an empty m x m matrix
  2018. qqr = np.zeros((m, m), dtype=dtype)
  2019. # set first n columns of qqr to qr_A
  2020. qqr[:, :n] = qr_A
  2021. # determine q from this qqr
  2022. # note that m is a sufficient for lwork based on LAPACK documentation
  2023. q = gqr(qqr, tau=tau, lwork=m)[0]
  2024. else:
  2025. q = gqr(qr_A[:, :m], tau=tau, lwork=m)[0]
  2026. # test that q and r still make A
  2027. assert_allclose(q@r, A, rtol=rtol)
  2028. # ensure that q is orthogonal (that q @ transposed q is the identity)
  2029. assert_allclose(np.eye(q.shape[0]), q@(q.conj().T), rtol=rtol,
  2030. atol=atol)
  2031. # ensure r is upper tri by comparing original r to r as upper triangular
  2032. assert_allclose(r, np.triu(r), rtol=rtol)
  2033. # make sure diagonals of r are positive for this random solution
  2034. assert_(np.all(np.diag(r) > np.zeros(len(np.diag(r)))))
  2035. # ensure that info is zero for this success
  2036. assert_(info == 0)
  2037. # test that this routine gives r diagonals that are positive for a
  2038. # matrix that returns negatives in the diagonal with scipy.linalg.rq
  2039. A_negative = generate_random_dtype_array((n, m), dtype=dtype) * -1
  2040. r_rq_neg, q_rq_neg = qr(A_negative)
  2041. rq_A_neg, tau_neg, info_neg = geqrfp(A_negative)
  2042. # assert that any of the entries on the diagonal from linalg.qr
  2043. # are negative and that all of geqrfp are positive.
  2044. assert_(np.any(np.diag(r_rq_neg) < 0) and
  2045. np.all(np.diag(r) > 0))
  2046. def test_geqrfp_errors_with_empty_array():
  2047. # check that empty array raises good error message
  2048. A_empty = np.array([])
  2049. geqrfp = get_lapack_funcs('geqrfp', dtype=A_empty.dtype)
  2050. assert_raises(Exception, geqrfp, A_empty)
  2051. @pytest.mark.parametrize("driver", ['ev', 'evd', 'evr', 'evx'])
  2052. @pytest.mark.parametrize("pfx", ['sy', 'he'])
  2053. def test_standard_eigh_lworks(pfx, driver):
  2054. n = 1200 # Some sufficiently big arbitrary number
  2055. dtype = REAL_DTYPES if pfx == 'sy' else COMPLEX_DTYPES
  2056. sc_dlw = get_lapack_funcs(pfx+driver+'_lwork', dtype=dtype[0])
  2057. dz_dlw = get_lapack_funcs(pfx+driver+'_lwork', dtype=dtype[1])
  2058. try:
  2059. _compute_lwork(sc_dlw, n, lower=1)
  2060. _compute_lwork(dz_dlw, n, lower=1)
  2061. except Exception as e:
  2062. pytest.fail("{}_lwork raised unexpected exception: {}"
  2063. "".format(pfx+driver, e))
  2064. @pytest.mark.parametrize("driver", ['gv', 'gvx'])
  2065. @pytest.mark.parametrize("pfx", ['sy', 'he'])
  2066. def test_generalized_eigh_lworks(pfx, driver):
  2067. n = 1200 # Some sufficiently big arbitrary number
  2068. dtype = REAL_DTYPES if pfx == 'sy' else COMPLEX_DTYPES
  2069. sc_dlw = get_lapack_funcs(pfx+driver+'_lwork', dtype=dtype[0])
  2070. dz_dlw = get_lapack_funcs(pfx+driver+'_lwork', dtype=dtype[1])
  2071. # Shouldn't raise any exceptions
  2072. try:
  2073. _compute_lwork(sc_dlw, n, uplo="L")
  2074. _compute_lwork(dz_dlw, n, uplo="L")
  2075. except Exception as e:
  2076. pytest.fail("{}_lwork raised unexpected exception: {}"
  2077. "".format(pfx+driver, e))
  2078. @pytest.mark.parametrize("dtype_", DTYPES)
  2079. @pytest.mark.parametrize("m", [1, 10, 100, 1000])
  2080. def test_orcsd_uncsd_lwork(dtype_, m):
  2081. seed(1234)
  2082. p = randint(0, m)
  2083. q = m - p
  2084. pfx = 'or' if dtype_ in REAL_DTYPES else 'un'
  2085. dlw = pfx + 'csd_lwork'
  2086. lw = get_lapack_funcs(dlw, dtype=dtype_)
  2087. lwval = _compute_lwork(lw, m, p, q)
  2088. lwval = lwval if pfx == 'un' else (lwval,)
  2089. assert all([x > 0 for x in lwval])
  2090. @pytest.mark.parametrize("dtype_", DTYPES)
  2091. def test_orcsd_uncsd(dtype_):
  2092. m, p, q = 250, 80, 170
  2093. pfx = 'or' if dtype_ in REAL_DTYPES else 'un'
  2094. X = ortho_group.rvs(m) if pfx == 'or' else unitary_group.rvs(m)
  2095. drv, dlw = get_lapack_funcs((pfx + 'csd', pfx + 'csd_lwork'), dtype=dtype_)
  2096. lwval = _compute_lwork(dlw, m, p, q)
  2097. lwvals = {'lwork': lwval} if pfx == 'or' else dict(zip(['lwork',
  2098. 'lrwork'], lwval))
  2099. cs11, cs12, cs21, cs22, theta, u1, u2, v1t, v2t, info =\
  2100. drv(X[:p, :q], X[:p, q:], X[p:, :q], X[p:, q:], **lwvals)
  2101. assert info == 0
  2102. U = block_diag(u1, u2)
  2103. VH = block_diag(v1t, v2t)
  2104. r = min(min(p, q), min(m-p, m-q))
  2105. n11 = min(p, q) - r
  2106. n12 = min(p, m-q) - r
  2107. n21 = min(m-p, q) - r
  2108. n22 = min(m-p, m-q) - r
  2109. S = np.zeros((m, m), dtype=dtype_)
  2110. one = dtype_(1.)
  2111. for i in range(n11):
  2112. S[i, i] = one
  2113. for i in range(n22):
  2114. S[p+i, q+i] = one
  2115. for i in range(n12):
  2116. S[i+n11+r, i+n11+r+n21+n22+r] = -one
  2117. for i in range(n21):
  2118. S[p+n22+r+i, n11+r+i] = one
  2119. for i in range(r):
  2120. S[i+n11, i+n11] = np.cos(theta[i])
  2121. S[p+n22+i, i+r+n21+n22] = np.cos(theta[i])
  2122. S[i+n11, i+n11+n21+n22+r] = -np.sin(theta[i])
  2123. S[p+n22+i, i+n11] = np.sin(theta[i])
  2124. Xc = U @ S @ VH
  2125. assert_allclose(X, Xc, rtol=0., atol=1e4*np.finfo(dtype_).eps)
  2126. @pytest.mark.parametrize("dtype", DTYPES)
  2127. @pytest.mark.parametrize("trans_bool", [False, True])
  2128. @pytest.mark.parametrize("fact", ["F", "N"])
  2129. def test_gtsvx(dtype, trans_bool, fact):
  2130. """
  2131. These tests uses ?gtsvx to solve a random Ax=b system for each dtype.
  2132. It tests that the outputs define an LU matrix, that inputs are unmodified,
  2133. transposal options, incompatible shapes, singular matrices, and
  2134. singular factorizations. It parametrizes DTYPES and the 'fact' value along
  2135. with the fact related inputs.
  2136. """
  2137. seed(42)
  2138. # set test tolerance appropriate for dtype
  2139. atol = 100 * np.finfo(dtype).eps
  2140. # obtain routine
  2141. gtsvx, gttrf = get_lapack_funcs(('gtsvx', 'gttrf'), dtype=dtype)
  2142. # Generate random tridiagonal matrix A
  2143. n = 10
  2144. dl = generate_random_dtype_array((n-1,), dtype=dtype)
  2145. d = generate_random_dtype_array((n,), dtype=dtype)
  2146. du = generate_random_dtype_array((n-1,), dtype=dtype)
  2147. A = np.diag(dl, -1) + np.diag(d) + np.diag(du, 1)
  2148. # generate random solution x
  2149. x = generate_random_dtype_array((n, 2), dtype=dtype)
  2150. # create b from x for equation Ax=b
  2151. trans = ("T" if dtype in REAL_DTYPES else "C") if trans_bool else "N"
  2152. b = (A.conj().T if trans_bool else A) @ x
  2153. # store a copy of the inputs to check they haven't been modified later
  2154. inputs_cpy = [dl.copy(), d.copy(), du.copy(), b.copy()]
  2155. # set these to None if fact = 'N', or the output of gttrf is fact = 'F'
  2156. dlf_, df_, duf_, du2f_, ipiv_, info_ = \
  2157. gttrf(dl, d, du) if fact == 'F' else [None]*6
  2158. gtsvx_out = gtsvx(dl, d, du, b, fact=fact, trans=trans, dlf=dlf_, df=df_,
  2159. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2160. dlf, df, duf, du2f, ipiv, x_soln, rcond, ferr, berr, info = gtsvx_out
  2161. assert_(info == 0, "?gtsvx info = {}, should be zero".format(info))
  2162. # assure that inputs are unmodified
  2163. assert_array_equal(dl, inputs_cpy[0])
  2164. assert_array_equal(d, inputs_cpy[1])
  2165. assert_array_equal(du, inputs_cpy[2])
  2166. assert_array_equal(b, inputs_cpy[3])
  2167. # test that x_soln matches the expected x
  2168. assert_allclose(x, x_soln, atol=atol)
  2169. # assert that the outputs are of correct type or shape
  2170. # rcond should be a scalar
  2171. assert_(hasattr(rcond, "__len__") is not True,
  2172. "rcond should be scalar but is {}".format(rcond))
  2173. # ferr should be length of # of cols in x
  2174. assert_(ferr.shape[0] == b.shape[1], "ferr.shape is {} but shoud be {},"
  2175. .format(ferr.shape[0], b.shape[1]))
  2176. # berr should be length of # of cols in x
  2177. assert_(berr.shape[0] == b.shape[1], "berr.shape is {} but shoud be {},"
  2178. .format(berr.shape[0], b.shape[1]))
  2179. @pytest.mark.parametrize("dtype", DTYPES)
  2180. @pytest.mark.parametrize("trans_bool", [0, 1])
  2181. @pytest.mark.parametrize("fact", ["F", "N"])
  2182. def test_gtsvx_error_singular(dtype, trans_bool, fact):
  2183. seed(42)
  2184. # obtain routine
  2185. gtsvx, gttrf = get_lapack_funcs(('gtsvx', 'gttrf'), dtype=dtype)
  2186. # Generate random tridiagonal matrix A
  2187. n = 10
  2188. dl = generate_random_dtype_array((n-1,), dtype=dtype)
  2189. d = generate_random_dtype_array((n,), dtype=dtype)
  2190. du = generate_random_dtype_array((n-1,), dtype=dtype)
  2191. A = np.diag(dl, -1) + np.diag(d) + np.diag(du, 1)
  2192. # generate random solution x
  2193. x = generate_random_dtype_array((n, 2), dtype=dtype)
  2194. # create b from x for equation Ax=b
  2195. trans = "T" if dtype in REAL_DTYPES else "C"
  2196. b = (A.conj().T if trans_bool else A) @ x
  2197. # set these to None if fact = 'N', or the output of gttrf is fact = 'F'
  2198. dlf_, df_, duf_, du2f_, ipiv_, info_ = \
  2199. gttrf(dl, d, du) if fact == 'F' else [None]*6
  2200. gtsvx_out = gtsvx(dl, d, du, b, fact=fact, trans=trans, dlf=dlf_, df=df_,
  2201. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2202. dlf, df, duf, du2f, ipiv, x_soln, rcond, ferr, berr, info = gtsvx_out
  2203. # test with singular matrix
  2204. # no need to test inputs with fact "F" since ?gttrf already does.
  2205. if fact == "N":
  2206. # Construct a singular example manually
  2207. d[-1] = 0
  2208. dl[-1] = 0
  2209. # solve using routine
  2210. gtsvx_out = gtsvx(dl, d, du, b)
  2211. dlf, df, duf, du2f, ipiv, x_soln, rcond, ferr, berr, info = gtsvx_out
  2212. # test for the singular matrix.
  2213. assert info > 0, "info should be > 0 for singular matrix"
  2214. elif fact == 'F':
  2215. # assuming that a singular factorization is input
  2216. df_[-1] = 0
  2217. duf_[-1] = 0
  2218. du2f_[-1] = 0
  2219. gtsvx_out = gtsvx(dl, d, du, b, fact=fact, dlf=dlf_, df=df_, duf=duf_,
  2220. du2=du2f_, ipiv=ipiv_)
  2221. dlf, df, duf, du2f, ipiv, x_soln, rcond, ferr, berr, info = gtsvx_out
  2222. # info should not be zero and should provide index of illegal value
  2223. assert info > 0, "info should be > 0 for singular matrix"
  2224. @pytest.mark.parametrize("dtype", DTYPES*2)
  2225. @pytest.mark.parametrize("trans_bool", [False, True])
  2226. @pytest.mark.parametrize("fact", ["F", "N"])
  2227. def test_gtsvx_error_incompatible_size(dtype, trans_bool, fact):
  2228. seed(42)
  2229. # obtain routine
  2230. gtsvx, gttrf = get_lapack_funcs(('gtsvx', 'gttrf'), dtype=dtype)
  2231. # Generate random tridiagonal matrix A
  2232. n = 10
  2233. dl = generate_random_dtype_array((n-1,), dtype=dtype)
  2234. d = generate_random_dtype_array((n,), dtype=dtype)
  2235. du = generate_random_dtype_array((n-1,), dtype=dtype)
  2236. A = np.diag(dl, -1) + np.diag(d) + np.diag(du, 1)
  2237. # generate random solution x
  2238. x = generate_random_dtype_array((n, 2), dtype=dtype)
  2239. # create b from x for equation Ax=b
  2240. trans = "T" if dtype in REAL_DTYPES else "C"
  2241. b = (A.conj().T if trans_bool else A) @ x
  2242. # set these to None if fact = 'N', or the output of gttrf is fact = 'F'
  2243. dlf_, df_, duf_, du2f_, ipiv_, info_ = \
  2244. gttrf(dl, d, du) if fact == 'F' else [None]*6
  2245. if fact == "N":
  2246. assert_raises(ValueError, gtsvx, dl[:-1], d, du, b,
  2247. fact=fact, trans=trans, dlf=dlf_, df=df_,
  2248. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2249. assert_raises(ValueError, gtsvx, dl, d[:-1], du, b,
  2250. fact=fact, trans=trans, dlf=dlf_, df=df_,
  2251. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2252. assert_raises(ValueError, gtsvx, dl, d, du[:-1], b,
  2253. fact=fact, trans=trans, dlf=dlf_, df=df_,
  2254. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2255. assert_raises(Exception, gtsvx, dl, d, du, b[:-1],
  2256. fact=fact, trans=trans, dlf=dlf_, df=df_,
  2257. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2258. else:
  2259. assert_raises(ValueError, gtsvx, dl, d, du, b,
  2260. fact=fact, trans=trans, dlf=dlf_[:-1], df=df_,
  2261. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2262. assert_raises(ValueError, gtsvx, dl, d, du, b,
  2263. fact=fact, trans=trans, dlf=dlf_, df=df_[:-1],
  2264. duf=duf_, du2=du2f_, ipiv=ipiv_)
  2265. assert_raises(ValueError, gtsvx, dl, d, du, b,
  2266. fact=fact, trans=trans, dlf=dlf_, df=df_,
  2267. duf=duf_[:-1], du2=du2f_, ipiv=ipiv_)
  2268. assert_raises(ValueError, gtsvx, dl, d, du, b,
  2269. fact=fact, trans=trans, dlf=dlf_, df=df_,
  2270. duf=duf_, du2=du2f_[:-1], ipiv=ipiv_)
  2271. @pytest.mark.parametrize("du,d,dl,b,x",
  2272. [(np.array([2.1, -1.0, 1.9, 8.0]),
  2273. np.array([3.0, 2.3, -5.0, -0.9, 7.1]),
  2274. np.array([3.4, 3.6, 7.0, -6.0]),
  2275. np.array([[2.7, 6.6], [-.5, 10.8], [2.6, -3.2],
  2276. [.6, -11.2], [2.7, 19.1]]),
  2277. np.array([[-4, 5], [7, -4], [3, -3], [-4, -2],
  2278. [-3, 1]])),
  2279. (np.array([2 - 1j, 2 + 1j, -1 + 1j, 1 - 1j]),
  2280. np.array([-1.3 + 1.3j, -1.3 + 1.3j, -1.3 + 3.3j,
  2281. -.3 + 4.3j, -3.3 + 1.3j]),
  2282. np.array([1 - 2j, 1 + 1j, 2 - 3j, 1 + 1j]),
  2283. np.array([[2.4 - 5j, 2.7 + 6.9j],
  2284. [3.4 + 18.2j, -6.9 - 5.3j],
  2285. [-14.7 + 9.7j, -6 - .6j],
  2286. [31.9 - 7.7j, -3.9 + 9.3j],
  2287. [-1 + 1.6j, -3 + 12.2j]]),
  2288. np.array([[1 + 1j, 2 - 1j], [3 - 1j, 1 + 2j],
  2289. [4 + 5j, -1 + 1j], [-1 - 2j, 2 + 1j],
  2290. [1 - 1j, 2 - 2j]]))])
  2291. def test_gtsvx_NAG(du, d, dl, b, x):
  2292. # Test to ensure wrapper is consistent with NAG Manual Mark 26
  2293. # example problems: real (f07cbf) and complex (f07cpf)
  2294. gtsvx = get_lapack_funcs('gtsvx', dtype=d.dtype)
  2295. gtsvx_out = gtsvx(dl, d, du, b)
  2296. dlf, df, duf, du2f, ipiv, x_soln, rcond, ferr, berr, info = gtsvx_out
  2297. assert_array_almost_equal(x, x_soln)
  2298. @pytest.mark.parametrize("dtype,realtype", zip(DTYPES, REAL_DTYPES
  2299. + REAL_DTYPES))
  2300. @pytest.mark.parametrize("fact,df_de_lambda",
  2301. [("F",
  2302. lambda d, e:get_lapack_funcs('pttrf',
  2303. dtype=e.dtype)(d, e)),
  2304. ("N", lambda d, e: (None, None, None))])
  2305. def test_ptsvx(dtype, realtype, fact, df_de_lambda):
  2306. '''
  2307. This tests the ?ptsvx lapack routine wrapper to solve a random system
  2308. Ax = b for all dtypes and input variations. Tests for: unmodified
  2309. input parameters, fact options, incompatible matrix shapes raise an error,
  2310. and singular matrices return info of illegal value.
  2311. '''
  2312. seed(42)
  2313. # set test tolerance appropriate for dtype
  2314. atol = 100 * np.finfo(dtype).eps
  2315. ptsvx = get_lapack_funcs('ptsvx', dtype=dtype)
  2316. n = 5
  2317. # create diagonals according to size and dtype
  2318. d = generate_random_dtype_array((n,), realtype) + 4
  2319. e = generate_random_dtype_array((n-1,), dtype)
  2320. A = np.diag(d) + np.diag(e, -1) + np.diag(np.conj(e), 1)
  2321. x_soln = generate_random_dtype_array((n, 2), dtype=dtype)
  2322. b = A @ x_soln
  2323. # use lambda to determine what df, ef are
  2324. df, ef, info = df_de_lambda(d, e)
  2325. # create copy to later test that they are unmodified
  2326. diag_cpy = [d.copy(), e.copy(), b.copy()]
  2327. # solve using routine
  2328. df, ef, x, rcond, ferr, berr, info = ptsvx(d, e, b, fact=fact,
  2329. df=df, ef=ef)
  2330. # d, e, and b should be unmodified
  2331. assert_array_equal(d, diag_cpy[0])
  2332. assert_array_equal(e, diag_cpy[1])
  2333. assert_array_equal(b, diag_cpy[2])
  2334. assert_(info == 0, "info should be 0 but is {}.".format(info))
  2335. assert_array_almost_equal(x_soln, x)
  2336. # test that the factors from ptsvx can be recombined to make A
  2337. L = np.diag(ef, -1) + np.diag(np.ones(n))
  2338. D = np.diag(df)
  2339. assert_allclose(A, L@D@(np.conj(L).T), atol=atol)
  2340. # assert that the outputs are of correct type or shape
  2341. # rcond should be a scalar
  2342. assert not hasattr(rcond, "__len__"), \
  2343. "rcond should be scalar but is {}".format(rcond)
  2344. # ferr should be length of # of cols in x
  2345. assert_(ferr.shape == (2,), "ferr.shape is {} but shoud be ({},)"
  2346. .format(ferr.shape, x_soln.shape[1]))
  2347. # berr should be length of # of cols in x
  2348. assert_(berr.shape == (2,), "berr.shape is {} but shoud be ({},)"
  2349. .format(berr.shape, x_soln.shape[1]))
  2350. @pytest.mark.parametrize("dtype,realtype", zip(DTYPES, REAL_DTYPES
  2351. + REAL_DTYPES))
  2352. @pytest.mark.parametrize("fact,df_de_lambda",
  2353. [("F",
  2354. lambda d, e:get_lapack_funcs('pttrf',
  2355. dtype=e.dtype)(d, e)),
  2356. ("N", lambda d, e: (None, None, None))])
  2357. def test_ptsvx_error_raise_errors(dtype, realtype, fact, df_de_lambda):
  2358. seed(42)
  2359. ptsvx = get_lapack_funcs('ptsvx', dtype=dtype)
  2360. n = 5
  2361. # create diagonals according to size and dtype
  2362. d = generate_random_dtype_array((n,), realtype) + 4
  2363. e = generate_random_dtype_array((n-1,), dtype)
  2364. A = np.diag(d) + np.diag(e, -1) + np.diag(np.conj(e), 1)
  2365. x_soln = generate_random_dtype_array((n, 2), dtype=dtype)
  2366. b = A @ x_soln
  2367. # use lambda to determine what df, ef are
  2368. df, ef, info = df_de_lambda(d, e)
  2369. # test with malformatted array sizes
  2370. assert_raises(ValueError, ptsvx, d[:-1], e, b, fact=fact, df=df, ef=ef)
  2371. assert_raises(ValueError, ptsvx, d, e[:-1], b, fact=fact, df=df, ef=ef)
  2372. assert_raises(Exception, ptsvx, d, e, b[:-1], fact=fact, df=df, ef=ef)
  2373. @pytest.mark.parametrize("dtype,realtype", zip(DTYPES, REAL_DTYPES
  2374. + REAL_DTYPES))
  2375. @pytest.mark.parametrize("fact,df_de_lambda",
  2376. [("F",
  2377. lambda d, e:get_lapack_funcs('pttrf',
  2378. dtype=e.dtype)(d, e)),
  2379. ("N", lambda d, e: (None, None, None))])
  2380. def test_ptsvx_non_SPD_singular(dtype, realtype, fact, df_de_lambda):
  2381. seed(42)
  2382. ptsvx = get_lapack_funcs('ptsvx', dtype=dtype)
  2383. n = 5
  2384. # create diagonals according to size and dtype
  2385. d = generate_random_dtype_array((n,), realtype) + 4
  2386. e = generate_random_dtype_array((n-1,), dtype)
  2387. A = np.diag(d) + np.diag(e, -1) + np.diag(np.conj(e), 1)
  2388. x_soln = generate_random_dtype_array((n, 2), dtype=dtype)
  2389. b = A @ x_soln
  2390. # use lambda to determine what df, ef are
  2391. df, ef, info = df_de_lambda(d, e)
  2392. if fact == "N":
  2393. d[3] = 0
  2394. # obtain new df, ef
  2395. df, ef, info = df_de_lambda(d, e)
  2396. # solve using routine
  2397. df, ef, x, rcond, ferr, berr, info = ptsvx(d, e, b)
  2398. # test for the singular matrix.
  2399. assert info > 0 and info <= n
  2400. # non SPD matrix
  2401. d = generate_random_dtype_array((n,), realtype)
  2402. df, ef, x, rcond, ferr, berr, info = ptsvx(d, e, b)
  2403. assert info > 0 and info <= n
  2404. else:
  2405. # assuming that someone is using a singular factorization
  2406. df, ef, info = df_de_lambda(d, e)
  2407. df[0] = 0
  2408. ef[0] = 0
  2409. df, ef, x, rcond, ferr, berr, info = ptsvx(d, e, b, fact=fact,
  2410. df=df, ef=ef)
  2411. assert info > 0
  2412. @pytest.mark.parametrize('d,e,b,x',
  2413. [(np.array([4, 10, 29, 25, 5]),
  2414. np.array([-2, -6, 15, 8]),
  2415. np.array([[6, 10], [9, 4], [2, 9], [14, 65],
  2416. [7, 23]]),
  2417. np.array([[2.5, 2], [2, -1], [1, -3],
  2418. [-1, 6], [3, -5]])),
  2419. (np.array([16, 41, 46, 21]),
  2420. np.array([16 + 16j, 18 - 9j, 1 - 4j]),
  2421. np.array([[64 + 16j, -16 - 32j],
  2422. [93 + 62j, 61 - 66j],
  2423. [78 - 80j, 71 - 74j],
  2424. [14 - 27j, 35 + 15j]]),
  2425. np.array([[2 + 1j, -3 - 2j],
  2426. [1 + 1j, 1 + 1j],
  2427. [1 - 2j, 1 - 2j],
  2428. [1 - 1j, 2 + 1j]]))])
  2429. def test_ptsvx_NAG(d, e, b, x):
  2430. # test to assure that wrapper is consistent with NAG Manual Mark 26
  2431. # example problemss: f07jbf, f07jpf
  2432. # (Links expire, so please search for "NAG Library Manual Mark 26" online)
  2433. # obtain routine with correct type based on e.dtype
  2434. ptsvx = get_lapack_funcs('ptsvx', dtype=e.dtype)
  2435. # solve using routine
  2436. df, ef, x_ptsvx, rcond, ferr, berr, info = ptsvx(d, e, b)
  2437. # determine ptsvx's solution and x are the same.
  2438. assert_array_almost_equal(x, x_ptsvx)
  2439. @pytest.mark.parametrize('lower', [False, True])
  2440. @pytest.mark.parametrize('dtype', DTYPES)
  2441. def test_pptrs_pptri_pptrf_ppsv_ppcon(dtype, lower):
  2442. seed(1234)
  2443. atol = np.finfo(dtype).eps*100
  2444. # Manual conversion to/from packed format is feasible here.
  2445. n, nrhs = 10, 4
  2446. a = generate_random_dtype_array([n, n], dtype=dtype)
  2447. b = generate_random_dtype_array([n, nrhs], dtype=dtype)
  2448. a = a.conj().T + a + np.eye(n, dtype=dtype) * dtype(5.)
  2449. if lower:
  2450. inds = ([x for y in range(n) for x in range(y, n)],
  2451. [y for y in range(n) for x in range(y, n)])
  2452. else:
  2453. inds = ([x for y in range(1, n+1) for x in range(y)],
  2454. [y-1 for y in range(1, n+1) for x in range(y)])
  2455. ap = a[inds]
  2456. ppsv, pptrf, pptrs, pptri, ppcon = get_lapack_funcs(
  2457. ('ppsv', 'pptrf', 'pptrs', 'pptri', 'ppcon'),
  2458. dtype=dtype,
  2459. ilp64="preferred")
  2460. ul, info = pptrf(n, ap, lower=lower)
  2461. assert_equal(info, 0)
  2462. aul = cholesky(a, lower=lower)[inds]
  2463. assert_allclose(ul, aul, rtol=0, atol=atol)
  2464. uli, info = pptri(n, ul, lower=lower)
  2465. assert_equal(info, 0)
  2466. auli = inv(a)[inds]
  2467. assert_allclose(uli, auli, rtol=0, atol=atol)
  2468. x, info = pptrs(n, ul, b, lower=lower)
  2469. assert_equal(info, 0)
  2470. bx = solve(a, b)
  2471. assert_allclose(x, bx, rtol=0, atol=atol)
  2472. xv, info = ppsv(n, ap, b, lower=lower)
  2473. assert_equal(info, 0)
  2474. assert_allclose(xv, bx, rtol=0, atol=atol)
  2475. anorm = np.linalg.norm(a, 1)
  2476. rcond, info = ppcon(n, ap, anorm=anorm, lower=lower)
  2477. assert_equal(info, 0)
  2478. assert_(abs(1/rcond - np.linalg.cond(a, p=1))*rcond < 1)
  2479. @pytest.mark.parametrize('dtype', DTYPES)
  2480. def test_gees_trexc(dtype):
  2481. seed(1234)
  2482. atol = np.finfo(dtype).eps*100
  2483. n = 10
  2484. a = generate_random_dtype_array([n, n], dtype=dtype)
  2485. gees, trexc = get_lapack_funcs(('gees', 'trexc'), dtype=dtype)
  2486. result = gees(lambda x: None, a, overwrite_a=False)
  2487. assert_equal(result[-1], 0)
  2488. t = result[0]
  2489. z = result[-3]
  2490. d2 = t[6, 6]
  2491. if dtype in COMPLEX_DTYPES:
  2492. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2493. assert_allclose(z @ t @ z.conj().T, a, rtol=0, atol=atol)
  2494. result = trexc(t, z, 7, 1)
  2495. assert_equal(result[-1], 0)
  2496. t = result[0]
  2497. z = result[-2]
  2498. if dtype in COMPLEX_DTYPES:
  2499. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2500. assert_allclose(z @ t @ z.conj().T, a, rtol=0, atol=atol)
  2501. assert_allclose(t[0, 0], d2, rtol=0, atol=atol)
  2502. @pytest.mark.parametrize(
  2503. "t, expect, ifst, ilst",
  2504. [(np.array([[0.80, -0.11, 0.01, 0.03],
  2505. [0.00, -0.10, 0.25, 0.35],
  2506. [0.00, -0.65, -0.10, 0.20],
  2507. [0.00, 0.00, 0.00, -0.10]]),
  2508. np.array([[-0.1000, -0.6463, 0.0874, 0.2010],
  2509. [0.2514, -0.1000, 0.0927, 0.3505],
  2510. [0.0000, 0.0000, 0.8000, -0.0117],
  2511. [0.0000, 0.0000, 0.0000, -0.1000]]),
  2512. 2, 1),
  2513. (np.array([[-6.00 - 7.00j, 0.36 - 0.36j, -0.19 + 0.48j, 0.88 - 0.25j],
  2514. [0.00 + 0.00j, -5.00 + 2.00j, -0.03 - 0.72j, -0.23 + 0.13j],
  2515. [0.00 + 0.00j, 0.00 + 0.00j, 8.00 - 1.00j, 0.94 + 0.53j],
  2516. [0.00 + 0.00j, 0.00 + 0.00j, 0.00 + 0.00j, 3.00 - 4.00j]]),
  2517. np.array([[-5.0000 + 2.0000j, -0.1574 + 0.7143j,
  2518. 0.1781 - 0.1913j, 0.3950 + 0.3861j],
  2519. [0.0000 + 0.0000j, 8.0000 - 1.0000j,
  2520. 1.0742 + 0.1447j, 0.2515 - 0.3397j],
  2521. [0.0000 + 0.0000j, 0.0000 + 0.0000j,
  2522. 3.0000 - 4.0000j, 0.2264 + 0.8962j],
  2523. [0.0000 + 0.0000j, 0.0000 + 0.0000j,
  2524. 0.0000 + 0.0000j, -6.0000 - 7.0000j]]),
  2525. 1, 4)])
  2526. def test_trexc_NAG(t, ifst, ilst, expect):
  2527. """
  2528. This test implements the example found in the NAG manual,
  2529. f08qfc, f08qtc, f08qgc, f08quc.
  2530. """
  2531. # NAG manual provides accuracy up to 4 decimals
  2532. atol = 1e-4
  2533. trexc = get_lapack_funcs('trexc', dtype=t.dtype)
  2534. result = trexc(t, t, ifst, ilst, wantq=0)
  2535. assert_equal(result[-1], 0)
  2536. t = result[0]
  2537. assert_allclose(expect, t, atol=atol)
  2538. @pytest.mark.parametrize('dtype', DTYPES)
  2539. def test_gges_tgexc(dtype):
  2540. if dtype == np.float32 and sys.platform == 'darwin':
  2541. pytest.xfail("gges[float32] broken for OpenBLAS on macOS, see gh-16949")
  2542. seed(1234)
  2543. atol = np.finfo(dtype).eps*100
  2544. n = 10
  2545. a = generate_random_dtype_array([n, n], dtype=dtype)
  2546. b = generate_random_dtype_array([n, n], dtype=dtype)
  2547. gges, tgexc = get_lapack_funcs(('gges', 'tgexc'), dtype=dtype)
  2548. result = gges(lambda x: None, a, b, overwrite_a=False, overwrite_b=False)
  2549. assert_equal(result[-1], 0)
  2550. s = result[0]
  2551. t = result[1]
  2552. q = result[-4]
  2553. z = result[-3]
  2554. d1 = s[0, 0] / t[0, 0]
  2555. d2 = s[6, 6] / t[6, 6]
  2556. if dtype in COMPLEX_DTYPES:
  2557. assert_allclose(s, np.triu(s), rtol=0, atol=atol)
  2558. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2559. assert_allclose(q @ s @ z.conj().T, a, rtol=0, atol=atol)
  2560. assert_allclose(q @ t @ z.conj().T, b, rtol=0, atol=atol)
  2561. result = tgexc(s, t, q, z, 7, 1)
  2562. assert_equal(result[-1], 0)
  2563. s = result[0]
  2564. t = result[1]
  2565. q = result[2]
  2566. z = result[3]
  2567. if dtype in COMPLEX_DTYPES:
  2568. assert_allclose(s, np.triu(s), rtol=0, atol=atol)
  2569. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2570. assert_allclose(q @ s @ z.conj().T, a, rtol=0, atol=atol)
  2571. assert_allclose(q @ t @ z.conj().T, b, rtol=0, atol=atol)
  2572. assert_allclose(s[0, 0] / t[0, 0], d2, rtol=0, atol=atol)
  2573. assert_allclose(s[1, 1] / t[1, 1], d1, rtol=0, atol=atol)
  2574. @pytest.mark.parametrize('dtype', DTYPES)
  2575. def test_gees_trsen(dtype):
  2576. seed(1234)
  2577. atol = np.finfo(dtype).eps*100
  2578. n = 10
  2579. a = generate_random_dtype_array([n, n], dtype=dtype)
  2580. gees, trsen, trsen_lwork = get_lapack_funcs(
  2581. ('gees', 'trsen', 'trsen_lwork'), dtype=dtype)
  2582. result = gees(lambda x: None, a, overwrite_a=False)
  2583. assert_equal(result[-1], 0)
  2584. t = result[0]
  2585. z = result[-3]
  2586. d2 = t[6, 6]
  2587. if dtype in COMPLEX_DTYPES:
  2588. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2589. assert_allclose(z @ t @ z.conj().T, a, rtol=0, atol=atol)
  2590. select = np.zeros(n)
  2591. select[6] = 1
  2592. lwork = _compute_lwork(trsen_lwork, select, t)
  2593. if dtype in COMPLEX_DTYPES:
  2594. result = trsen(select, t, z, lwork=lwork)
  2595. else:
  2596. result = trsen(select, t, z, lwork=lwork, liwork=lwork[1])
  2597. assert_equal(result[-1], 0)
  2598. t = result[0]
  2599. z = result[1]
  2600. if dtype in COMPLEX_DTYPES:
  2601. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2602. assert_allclose(z @ t @ z.conj().T, a, rtol=0, atol=atol)
  2603. assert_allclose(t[0, 0], d2, rtol=0, atol=atol)
  2604. @pytest.mark.parametrize(
  2605. "t, q, expect, select, expect_s, expect_sep",
  2606. [(np.array([[0.7995, -0.1144, 0.0060, 0.0336],
  2607. [0.0000, -0.0994, 0.2478, 0.3474],
  2608. [0.0000, -0.6483, -0.0994, 0.2026],
  2609. [0.0000, 0.0000, 0.0000, -0.1007]]),
  2610. np.array([[0.6551, 0.1037, 0.3450, 0.6641],
  2611. [0.5236, -0.5807, -0.6141, -0.1068],
  2612. [-0.5362, -0.3073, -0.2935, 0.7293],
  2613. [0.0956, 0.7467, -0.6463, 0.1249]]),
  2614. np.array([[0.3500, 0.4500, -0.1400, -0.1700],
  2615. [0.0900, 0.0700, -0.5399, 0.3500],
  2616. [-0.4400, -0.3300, -0.0300, 0.1700],
  2617. [0.2500, -0.3200, -0.1300, 0.1100]]),
  2618. np.array([1, 0, 0, 1]),
  2619. 1.75e+00, 3.22e+00),
  2620. (np.array([[-6.0004 - 6.9999j, 0.3637 - 0.3656j,
  2621. -0.1880 + 0.4787j, 0.8785 - 0.2539j],
  2622. [0.0000 + 0.0000j, -5.0000 + 2.0060j,
  2623. -0.0307 - 0.7217j, -0.2290 + 0.1313j],
  2624. [0.0000 + 0.0000j, 0.0000 + 0.0000j,
  2625. 7.9982 - 0.9964j, 0.9357 + 0.5359j],
  2626. [0.0000 + 0.0000j, 0.0000 + 0.0000j,
  2627. 0.0000 + 0.0000j, 3.0023 - 3.9998j]]),
  2628. np.array([[-0.8347 - 0.1364j, -0.0628 + 0.3806j,
  2629. 0.2765 - 0.0846j, 0.0633 - 0.2199j],
  2630. [0.0664 - 0.2968j, 0.2365 + 0.5240j,
  2631. -0.5877 - 0.4208j, 0.0835 + 0.2183j],
  2632. [-0.0362 - 0.3215j, 0.3143 - 0.5473j,
  2633. 0.0576 - 0.5736j, 0.0057 - 0.4058j],
  2634. [0.0086 + 0.2958j, -0.3416 - 0.0757j,
  2635. -0.1900 - 0.1600j, 0.8327 - 0.1868j]]),
  2636. np.array([[-3.9702 - 5.0406j, -4.1108 + 3.7002j,
  2637. -0.3403 + 1.0098j, 1.2899 - 0.8590j],
  2638. [0.3397 - 1.5006j, 1.5201 - 0.4301j,
  2639. 1.8797 - 5.3804j, 3.3606 + 0.6498j],
  2640. [3.3101 - 3.8506j, 2.4996 + 3.4504j,
  2641. 0.8802 - 1.0802j, 0.6401 - 1.4800j],
  2642. [-1.0999 + 0.8199j, 1.8103 - 1.5905j,
  2643. 3.2502 + 1.3297j, 1.5701 - 3.4397j]]),
  2644. np.array([1, 0, 0, 1]),
  2645. 1.02e+00, 1.82e-01)])
  2646. def test_trsen_NAG(t, q, select, expect, expect_s, expect_sep):
  2647. """
  2648. This test implements the example found in the NAG manual,
  2649. f08qgc, f08quc.
  2650. """
  2651. # NAG manual provides accuracy up to 4 and 2 decimals
  2652. atol = 1e-4
  2653. atol2 = 1e-2
  2654. trsen, trsen_lwork = get_lapack_funcs(
  2655. ('trsen', 'trsen_lwork'), dtype=t.dtype)
  2656. lwork = _compute_lwork(trsen_lwork, select, t)
  2657. if t.dtype in COMPLEX_DTYPES:
  2658. result = trsen(select, t, q, lwork=lwork)
  2659. else:
  2660. result = trsen(select, t, q, lwork=lwork, liwork=lwork[1])
  2661. assert_equal(result[-1], 0)
  2662. t = result[0]
  2663. q = result[1]
  2664. if t.dtype in COMPLEX_DTYPES:
  2665. s = result[4]
  2666. sep = result[5]
  2667. else:
  2668. s = result[5]
  2669. sep = result[6]
  2670. assert_allclose(expect, q @ t @ q.conj().T, atol=atol)
  2671. assert_allclose(expect_s, 1 / s, atol=atol2)
  2672. assert_allclose(expect_sep, 1 / sep, atol=atol2)
  2673. @pytest.mark.parametrize('dtype', DTYPES)
  2674. def test_gges_tgsen(dtype):
  2675. if dtype == np.float32 and sys.platform == 'darwin':
  2676. pytest.xfail("gges[float32] broken for OpenBLAS on macOS, see gh-16949")
  2677. seed(1234)
  2678. atol = np.finfo(dtype).eps*100
  2679. n = 10
  2680. a = generate_random_dtype_array([n, n], dtype=dtype)
  2681. b = generate_random_dtype_array([n, n], dtype=dtype)
  2682. gges, tgsen, tgsen_lwork = get_lapack_funcs(
  2683. ('gges', 'tgsen', 'tgsen_lwork'), dtype=dtype)
  2684. result = gges(lambda x: None, a, b, overwrite_a=False, overwrite_b=False)
  2685. assert_equal(result[-1], 0)
  2686. s = result[0]
  2687. t = result[1]
  2688. q = result[-4]
  2689. z = result[-3]
  2690. d1 = s[0, 0] / t[0, 0]
  2691. d2 = s[6, 6] / t[6, 6]
  2692. if dtype in COMPLEX_DTYPES:
  2693. assert_allclose(s, np.triu(s), rtol=0, atol=atol)
  2694. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2695. assert_allclose(q @ s @ z.conj().T, a, rtol=0, atol=atol)
  2696. assert_allclose(q @ t @ z.conj().T, b, rtol=0, atol=atol)
  2697. select = np.zeros(n)
  2698. select[6] = 1
  2699. lwork = _compute_lwork(tgsen_lwork, select, s, t)
  2700. # off-by-one error in LAPACK, see gh-issue #13397
  2701. lwork = (lwork[0]+1, lwork[1])
  2702. result = tgsen(select, s, t, q, z, lwork=lwork)
  2703. assert_equal(result[-1], 0)
  2704. s = result[0]
  2705. t = result[1]
  2706. q = result[-7]
  2707. z = result[-6]
  2708. if dtype in COMPLEX_DTYPES:
  2709. assert_allclose(s, np.triu(s), rtol=0, atol=atol)
  2710. assert_allclose(t, np.triu(t), rtol=0, atol=atol)
  2711. assert_allclose(q @ s @ z.conj().T, a, rtol=0, atol=atol)
  2712. assert_allclose(q @ t @ z.conj().T, b, rtol=0, atol=atol)
  2713. assert_allclose(s[0, 0] / t[0, 0], d2, rtol=0, atol=atol)
  2714. assert_allclose(s[1, 1] / t[1, 1], d1, rtol=0, atol=atol)