test_basic.py 63 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714
  1. import platform
  2. import itertools
  3. import warnings
  4. import numpy as np
  5. from numpy import (arange, array, dot, zeros, identity, conjugate, transpose,
  6. float32)
  7. import numpy.linalg as linalg
  8. from numpy.random import random
  9. from numpy.testing import (assert_equal, assert_almost_equal, assert_,
  10. assert_array_almost_equal, assert_allclose,
  11. assert_array_equal, suppress_warnings)
  12. import pytest
  13. from pytest import raises as assert_raises
  14. from scipy._lib import _pep440
  15. from scipy.linalg import (solve, inv, det, lstsq, pinv, pinvh, norm,
  16. solve_banded, solveh_banded, solve_triangular,
  17. solve_circulant, circulant, LinAlgError, block_diag,
  18. matrix_balance, qr, LinAlgWarning)
  19. from scipy.linalg._testutils import assert_no_overwrite
  20. from scipy._lib._testutils import check_free_memory
  21. from scipy.linalg.blas import HAS_ILP64
  22. REAL_DTYPES = (np.float32, np.float64, np.longdouble)
  23. COMPLEX_DTYPES = (np.complex64, np.complex128, np.clongdouble)
  24. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  25. def _eps_cast(dtyp):
  26. """Get the epsilon for dtype, possibly downcast to BLAS types."""
  27. dt = dtyp
  28. if dt == np.longdouble:
  29. dt = np.float64
  30. elif dt == np.clongdouble:
  31. dt = np.complex128
  32. return np.finfo(dt).eps
  33. class TestSolveBanded:
  34. def test_real(self):
  35. a = array([[1.0, 20, 0, 0],
  36. [-30, 4, 6, 0],
  37. [2, 1, 20, 2],
  38. [0, -1, 7, 14]])
  39. ab = array([[0.0, 20, 6, 2],
  40. [1, 4, 20, 14],
  41. [-30, 1, 7, 0],
  42. [2, -1, 0, 0]])
  43. l, u = 2, 1
  44. b4 = array([10.0, 0.0, 2.0, 14.0])
  45. b4by1 = b4.reshape(-1, 1)
  46. b4by2 = array([[2, 1],
  47. [-30, 4],
  48. [2, 3],
  49. [1, 3]])
  50. b4by4 = array([[1, 0, 0, 0],
  51. [0, 0, 0, 1],
  52. [0, 1, 0, 0],
  53. [0, 1, 0, 0]])
  54. for b in [b4, b4by1, b4by2, b4by4]:
  55. x = solve_banded((l, u), ab, b)
  56. assert_array_almost_equal(dot(a, x), b)
  57. def test_complex(self):
  58. a = array([[1.0, 20, 0, 0],
  59. [-30, 4, 6, 0],
  60. [2j, 1, 20, 2j],
  61. [0, -1, 7, 14]])
  62. ab = array([[0.0, 20, 6, 2j],
  63. [1, 4, 20, 14],
  64. [-30, 1, 7, 0],
  65. [2j, -1, 0, 0]])
  66. l, u = 2, 1
  67. b4 = array([10.0, 0.0, 2.0, 14.0j])
  68. b4by1 = b4.reshape(-1, 1)
  69. b4by2 = array([[2, 1],
  70. [-30, 4],
  71. [2, 3],
  72. [1, 3]])
  73. b4by4 = array([[1, 0, 0, 0],
  74. [0, 0, 0, 1j],
  75. [0, 1, 0, 0],
  76. [0, 1, 0, 0]])
  77. for b in [b4, b4by1, b4by2, b4by4]:
  78. x = solve_banded((l, u), ab, b)
  79. assert_array_almost_equal(dot(a, x), b)
  80. def test_tridiag_real(self):
  81. ab = array([[0.0, 20, 6, 2],
  82. [1, 4, 20, 14],
  83. [-30, 1, 7, 0]])
  84. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  85. ab[2, :-1], -1)
  86. b4 = array([10.0, 0.0, 2.0, 14.0])
  87. b4by1 = b4.reshape(-1, 1)
  88. b4by2 = array([[2, 1],
  89. [-30, 4],
  90. [2, 3],
  91. [1, 3]])
  92. b4by4 = array([[1, 0, 0, 0],
  93. [0, 0, 0, 1],
  94. [0, 1, 0, 0],
  95. [0, 1, 0, 0]])
  96. for b in [b4, b4by1, b4by2, b4by4]:
  97. x = solve_banded((1, 1), ab, b)
  98. assert_array_almost_equal(dot(a, x), b)
  99. def test_tridiag_complex(self):
  100. ab = array([[0.0, 20, 6, 2j],
  101. [1, 4, 20, 14],
  102. [-30, 1, 7, 0]])
  103. a = np.diag(ab[0, 1:], 1) + np.diag(ab[1, :], 0) + np.diag(
  104. ab[2, :-1], -1)
  105. b4 = array([10.0, 0.0, 2.0, 14.0j])
  106. b4by1 = b4.reshape(-1, 1)
  107. b4by2 = array([[2, 1],
  108. [-30, 4],
  109. [2, 3],
  110. [1, 3]])
  111. b4by4 = array([[1, 0, 0, 0],
  112. [0, 0, 0, 1],
  113. [0, 1, 0, 0],
  114. [0, 1, 0, 0]])
  115. for b in [b4, b4by1, b4by2, b4by4]:
  116. x = solve_banded((1, 1), ab, b)
  117. assert_array_almost_equal(dot(a, x), b)
  118. def test_check_finite(self):
  119. a = array([[1.0, 20, 0, 0],
  120. [-30, 4, 6, 0],
  121. [2, 1, 20, 2],
  122. [0, -1, 7, 14]])
  123. ab = array([[0.0, 20, 6, 2],
  124. [1, 4, 20, 14],
  125. [-30, 1, 7, 0],
  126. [2, -1, 0, 0]])
  127. l, u = 2, 1
  128. b4 = array([10.0, 0.0, 2.0, 14.0])
  129. x = solve_banded((l, u), ab, b4, check_finite=False)
  130. assert_array_almost_equal(dot(a, x), b4)
  131. def test_bad_shape(self):
  132. ab = array([[0.0, 20, 6, 2],
  133. [1, 4, 20, 14],
  134. [-30, 1, 7, 0],
  135. [2, -1, 0, 0]])
  136. l, u = 2, 1
  137. bad = array([1.0, 2.0, 3.0, 4.0]).reshape(-1, 4)
  138. assert_raises(ValueError, solve_banded, (l, u), ab, bad)
  139. assert_raises(ValueError, solve_banded, (l, u), ab, [1.0, 2.0])
  140. # Values of (l,u) are not compatible with ab.
  141. assert_raises(ValueError, solve_banded, (1, 1), ab, [1.0, 2.0])
  142. def test_1x1(self):
  143. b = array([[1., 2., 3.]])
  144. x = solve_banded((1, 1), [[0], [2], [0]], b)
  145. assert_array_equal(x, [[0.5, 1.0, 1.5]])
  146. assert_equal(x.dtype, np.dtype('f8'))
  147. assert_array_equal(b, [[1.0, 2.0, 3.0]])
  148. def test_native_list_arguments(self):
  149. a = [[1.0, 20, 0, 0],
  150. [-30, 4, 6, 0],
  151. [2, 1, 20, 2],
  152. [0, -1, 7, 14]]
  153. ab = [[0.0, 20, 6, 2],
  154. [1, 4, 20, 14],
  155. [-30, 1, 7, 0],
  156. [2, -1, 0, 0]]
  157. l, u = 2, 1
  158. b = [10.0, 0.0, 2.0, 14.0]
  159. x = solve_banded((l, u), ab, b)
  160. assert_array_almost_equal(dot(a, x), b)
  161. class TestSolveHBanded:
  162. def test_01_upper(self):
  163. # Solve
  164. # [ 4 1 2 0] [1]
  165. # [ 1 4 1 2] X = [4]
  166. # [ 2 1 4 1] [1]
  167. # [ 0 2 1 4] [2]
  168. # with the RHS as a 1D array.
  169. ab = array([[0.0, 0.0, 2.0, 2.0],
  170. [-99, 1.0, 1.0, 1.0],
  171. [4.0, 4.0, 4.0, 4.0]])
  172. b = array([1.0, 4.0, 1.0, 2.0])
  173. x = solveh_banded(ab, b)
  174. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  175. def test_02_upper(self):
  176. # Solve
  177. # [ 4 1 2 0] [1 6]
  178. # [ 1 4 1 2] X = [4 2]
  179. # [ 2 1 4 1] [1 6]
  180. # [ 0 2 1 4] [2 1]
  181. #
  182. ab = array([[0.0, 0.0, 2.0, 2.0],
  183. [-99, 1.0, 1.0, 1.0],
  184. [4.0, 4.0, 4.0, 4.0]])
  185. b = array([[1.0, 6.0],
  186. [4.0, 2.0],
  187. [1.0, 6.0],
  188. [2.0, 1.0]])
  189. x = solveh_banded(ab, b)
  190. expected = array([[0.0, 1.0],
  191. [1.0, 0.0],
  192. [0.0, 1.0],
  193. [0.0, 0.0]])
  194. assert_array_almost_equal(x, expected)
  195. def test_03_upper(self):
  196. # Solve
  197. # [ 4 1 2 0] [1]
  198. # [ 1 4 1 2] X = [4]
  199. # [ 2 1 4 1] [1]
  200. # [ 0 2 1 4] [2]
  201. # with the RHS as a 2D array with shape (3,1).
  202. ab = array([[0.0, 0.0, 2.0, 2.0],
  203. [-99, 1.0, 1.0, 1.0],
  204. [4.0, 4.0, 4.0, 4.0]])
  205. b = array([1.0, 4.0, 1.0, 2.0]).reshape(-1, 1)
  206. x = solveh_banded(ab, b)
  207. assert_array_almost_equal(x, array([0., 1., 0., 0.]).reshape(-1, 1))
  208. def test_01_lower(self):
  209. # Solve
  210. # [ 4 1 2 0] [1]
  211. # [ 1 4 1 2] X = [4]
  212. # [ 2 1 4 1] [1]
  213. # [ 0 2 1 4] [2]
  214. #
  215. ab = array([[4.0, 4.0, 4.0, 4.0],
  216. [1.0, 1.0, 1.0, -99],
  217. [2.0, 2.0, 0.0, 0.0]])
  218. b = array([1.0, 4.0, 1.0, 2.0])
  219. x = solveh_banded(ab, b, lower=True)
  220. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  221. def test_02_lower(self):
  222. # Solve
  223. # [ 4 1 2 0] [1 6]
  224. # [ 1 4 1 2] X = [4 2]
  225. # [ 2 1 4 1] [1 6]
  226. # [ 0 2 1 4] [2 1]
  227. #
  228. ab = array([[4.0, 4.0, 4.0, 4.0],
  229. [1.0, 1.0, 1.0, -99],
  230. [2.0, 2.0, 0.0, 0.0]])
  231. b = array([[1.0, 6.0],
  232. [4.0, 2.0],
  233. [1.0, 6.0],
  234. [2.0, 1.0]])
  235. x = solveh_banded(ab, b, lower=True)
  236. expected = array([[0.0, 1.0],
  237. [1.0, 0.0],
  238. [0.0, 1.0],
  239. [0.0, 0.0]])
  240. assert_array_almost_equal(x, expected)
  241. def test_01_float32(self):
  242. # Solve
  243. # [ 4 1 2 0] [1]
  244. # [ 1 4 1 2] X = [4]
  245. # [ 2 1 4 1] [1]
  246. # [ 0 2 1 4] [2]
  247. #
  248. ab = array([[0.0, 0.0, 2.0, 2.0],
  249. [-99, 1.0, 1.0, 1.0],
  250. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  251. b = array([1.0, 4.0, 1.0, 2.0], dtype=float32)
  252. x = solveh_banded(ab, b)
  253. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  254. def test_02_float32(self):
  255. # Solve
  256. # [ 4 1 2 0] [1 6]
  257. # [ 1 4 1 2] X = [4 2]
  258. # [ 2 1 4 1] [1 6]
  259. # [ 0 2 1 4] [2 1]
  260. #
  261. ab = array([[0.0, 0.0, 2.0, 2.0],
  262. [-99, 1.0, 1.0, 1.0],
  263. [4.0, 4.0, 4.0, 4.0]], dtype=float32)
  264. b = array([[1.0, 6.0],
  265. [4.0, 2.0],
  266. [1.0, 6.0],
  267. [2.0, 1.0]], dtype=float32)
  268. x = solveh_banded(ab, b)
  269. expected = array([[0.0, 1.0],
  270. [1.0, 0.0],
  271. [0.0, 1.0],
  272. [0.0, 0.0]])
  273. assert_array_almost_equal(x, expected)
  274. def test_01_complex(self):
  275. # Solve
  276. # [ 4 -j 2 0] [2-j]
  277. # [ j 4 -j 2] X = [4-j]
  278. # [ 2 j 4 -j] [4+j]
  279. # [ 0 2 j 4] [2+j]
  280. #
  281. ab = array([[0.0, 0.0, 2.0, 2.0],
  282. [-99, -1.0j, -1.0j, -1.0j],
  283. [4.0, 4.0, 4.0, 4.0]])
  284. b = array([2-1.0j, 4.0-1j, 4+1j, 2+1j])
  285. x = solveh_banded(ab, b)
  286. assert_array_almost_equal(x, [0.0, 1.0, 1.0, 0.0])
  287. def test_02_complex(self):
  288. # Solve
  289. # [ 4 -j 2 0] [2-j 2+4j]
  290. # [ j 4 -j 2] X = [4-j -1-j]
  291. # [ 2 j 4 -j] [4+j 4+2j]
  292. # [ 0 2 j 4] [2+j j]
  293. #
  294. ab = array([[0.0, 0.0, 2.0, 2.0],
  295. [-99, -1.0j, -1.0j, -1.0j],
  296. [4.0, 4.0, 4.0, 4.0]])
  297. b = array([[2-1j, 2+4j],
  298. [4.0-1j, -1-1j],
  299. [4.0+1j, 4+2j],
  300. [2+1j, 1j]])
  301. x = solveh_banded(ab, b)
  302. expected = array([[0.0, 1.0j],
  303. [1.0, 0.0],
  304. [1.0, 1.0],
  305. [0.0, 0.0]])
  306. assert_array_almost_equal(x, expected)
  307. def test_tridiag_01_upper(self):
  308. # Solve
  309. # [ 4 1 0] [1]
  310. # [ 1 4 1] X = [4]
  311. # [ 0 1 4] [1]
  312. # with the RHS as a 1D array.
  313. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  314. b = array([1.0, 4.0, 1.0])
  315. x = solveh_banded(ab, b)
  316. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  317. def test_tridiag_02_upper(self):
  318. # Solve
  319. # [ 4 1 0] [1 4]
  320. # [ 1 4 1] X = [4 2]
  321. # [ 0 1 4] [1 4]
  322. #
  323. ab = array([[-99, 1.0, 1.0],
  324. [4.0, 4.0, 4.0]])
  325. b = array([[1.0, 4.0],
  326. [4.0, 2.0],
  327. [1.0, 4.0]])
  328. x = solveh_banded(ab, b)
  329. expected = array([[0.0, 1.0],
  330. [1.0, 0.0],
  331. [0.0, 1.0]])
  332. assert_array_almost_equal(x, expected)
  333. def test_tridiag_03_upper(self):
  334. # Solve
  335. # [ 4 1 0] [1]
  336. # [ 1 4 1] X = [4]
  337. # [ 0 1 4] [1]
  338. # with the RHS as a 2D array with shape (3,1).
  339. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  340. b = array([1.0, 4.0, 1.0]).reshape(-1, 1)
  341. x = solveh_banded(ab, b)
  342. assert_array_almost_equal(x, array([0.0, 1.0, 0.0]).reshape(-1, 1))
  343. def test_tridiag_01_lower(self):
  344. # Solve
  345. # [ 4 1 0] [1]
  346. # [ 1 4 1] X = [4]
  347. # [ 0 1 4] [1]
  348. #
  349. ab = array([[4.0, 4.0, 4.0],
  350. [1.0, 1.0, -99]])
  351. b = array([1.0, 4.0, 1.0])
  352. x = solveh_banded(ab, b, lower=True)
  353. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  354. def test_tridiag_02_lower(self):
  355. # Solve
  356. # [ 4 1 0] [1 4]
  357. # [ 1 4 1] X = [4 2]
  358. # [ 0 1 4] [1 4]
  359. #
  360. ab = array([[4.0, 4.0, 4.0],
  361. [1.0, 1.0, -99]])
  362. b = array([[1.0, 4.0],
  363. [4.0, 2.0],
  364. [1.0, 4.0]])
  365. x = solveh_banded(ab, b, lower=True)
  366. expected = array([[0.0, 1.0],
  367. [1.0, 0.0],
  368. [0.0, 1.0]])
  369. assert_array_almost_equal(x, expected)
  370. def test_tridiag_01_float32(self):
  371. # Solve
  372. # [ 4 1 0] [1]
  373. # [ 1 4 1] X = [4]
  374. # [ 0 1 4] [1]
  375. #
  376. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]], dtype=float32)
  377. b = array([1.0, 4.0, 1.0], dtype=float32)
  378. x = solveh_banded(ab, b)
  379. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  380. def test_tridiag_02_float32(self):
  381. # Solve
  382. # [ 4 1 0] [1 4]
  383. # [ 1 4 1] X = [4 2]
  384. # [ 0 1 4] [1 4]
  385. #
  386. ab = array([[-99, 1.0, 1.0],
  387. [4.0, 4.0, 4.0]], dtype=float32)
  388. b = array([[1.0, 4.0],
  389. [4.0, 2.0],
  390. [1.0, 4.0]], dtype=float32)
  391. x = solveh_banded(ab, b)
  392. expected = array([[0.0, 1.0],
  393. [1.0, 0.0],
  394. [0.0, 1.0]])
  395. assert_array_almost_equal(x, expected)
  396. def test_tridiag_01_complex(self):
  397. # Solve
  398. # [ 4 -j 0] [ -j]
  399. # [ j 4 -j] X = [4-j]
  400. # [ 0 j 4] [4+j]
  401. #
  402. ab = array([[-99, -1.0j, -1.0j], [4.0, 4.0, 4.0]])
  403. b = array([-1.0j, 4.0-1j, 4+1j])
  404. x = solveh_banded(ab, b)
  405. assert_array_almost_equal(x, [0.0, 1.0, 1.0])
  406. def test_tridiag_02_complex(self):
  407. # Solve
  408. # [ 4 -j 0] [ -j 4j]
  409. # [ j 4 -j] X = [4-j -1-j]
  410. # [ 0 j 4] [4+j 4 ]
  411. #
  412. ab = array([[-99, -1.0j, -1.0j],
  413. [4.0, 4.0, 4.0]])
  414. b = array([[-1j, 4.0j],
  415. [4.0-1j, -1.0-1j],
  416. [4.0+1j, 4.0]])
  417. x = solveh_banded(ab, b)
  418. expected = array([[0.0, 1.0j],
  419. [1.0, 0.0],
  420. [1.0, 1.0]])
  421. assert_array_almost_equal(x, expected)
  422. def test_check_finite(self):
  423. # Solve
  424. # [ 4 1 0] [1]
  425. # [ 1 4 1] X = [4]
  426. # [ 0 1 4] [1]
  427. # with the RHS as a 1D array.
  428. ab = array([[-99, 1.0, 1.0], [4.0, 4.0, 4.0]])
  429. b = array([1.0, 4.0, 1.0])
  430. x = solveh_banded(ab, b, check_finite=False)
  431. assert_array_almost_equal(x, [0.0, 1.0, 0.0])
  432. def test_bad_shapes(self):
  433. ab = array([[-99, 1.0, 1.0],
  434. [4.0, 4.0, 4.0]])
  435. b = array([[1.0, 4.0],
  436. [4.0, 2.0]])
  437. assert_raises(ValueError, solveh_banded, ab, b)
  438. assert_raises(ValueError, solveh_banded, ab, [1.0, 2.0])
  439. assert_raises(ValueError, solveh_banded, ab, [1.0])
  440. def test_1x1(self):
  441. x = solveh_banded([[1]], [[1, 2, 3]])
  442. assert_array_equal(x, [[1.0, 2.0, 3.0]])
  443. assert_equal(x.dtype, np.dtype('f8'))
  444. def test_native_list_arguments(self):
  445. # Same as test_01_upper, using python's native list.
  446. ab = [[0.0, 0.0, 2.0, 2.0],
  447. [-99, 1.0, 1.0, 1.0],
  448. [4.0, 4.0, 4.0, 4.0]]
  449. b = [1.0, 4.0, 1.0, 2.0]
  450. x = solveh_banded(ab, b)
  451. assert_array_almost_equal(x, [0.0, 1.0, 0.0, 0.0])
  452. class TestSolve:
  453. def setup_method(self):
  454. np.random.seed(1234)
  455. def test_20Feb04_bug(self):
  456. a = [[1, 1], [1.0, 0]] # ok
  457. x0 = solve(a, [1, 0j])
  458. assert_array_almost_equal(dot(a, x0), [1, 0])
  459. # gives failure with clapack.zgesv(..,rowmajor=0)
  460. a = [[1, 1], [1.2, 0]]
  461. b = [1, 0j]
  462. x0 = solve(a, b)
  463. assert_array_almost_equal(dot(a, x0), [1, 0])
  464. def test_simple(self):
  465. a = [[1, 20], [-30, 4]]
  466. for b in ([[1, 0], [0, 1]],
  467. [1, 0],
  468. [[2, 1], [-30, 4]]
  469. ):
  470. x = solve(a, b)
  471. assert_array_almost_equal(dot(a, x), b)
  472. def test_simple_complex(self):
  473. a = array([[5, 2], [2j, 4]], 'D')
  474. for b in ([1j, 0],
  475. [[1j, 1j], [0, 2]],
  476. [1, 0j],
  477. array([1, 0], 'D'),
  478. ):
  479. x = solve(a, b)
  480. assert_array_almost_equal(dot(a, x), b)
  481. def test_simple_pos(self):
  482. a = [[2, 3], [3, 5]]
  483. for lower in [0, 1]:
  484. for b in ([[1, 0], [0, 1]],
  485. [1, 0]
  486. ):
  487. x = solve(a, b, assume_a='pos', lower=lower)
  488. assert_array_almost_equal(dot(a, x), b)
  489. def test_simple_pos_complexb(self):
  490. a = [[5, 2], [2, 4]]
  491. for b in ([1j, 0],
  492. [[1j, 1j], [0, 2]],
  493. ):
  494. x = solve(a, b, assume_a='pos')
  495. assert_array_almost_equal(dot(a, x), b)
  496. def test_simple_sym(self):
  497. a = [[2, 3], [3, -5]]
  498. for lower in [0, 1]:
  499. for b in ([[1, 0], [0, 1]],
  500. [1, 0]
  501. ):
  502. x = solve(a, b, assume_a='sym', lower=lower)
  503. assert_array_almost_equal(dot(a, x), b)
  504. def test_simple_sym_complexb(self):
  505. a = [[5, 2], [2, -4]]
  506. for b in ([1j, 0],
  507. [[1j, 1j],[0, 2]]
  508. ):
  509. x = solve(a, b, assume_a='sym')
  510. assert_array_almost_equal(dot(a, x), b)
  511. def test_simple_sym_complex(self):
  512. a = [[5, 2+1j], [2+1j, -4]]
  513. for b in ([1j, 0],
  514. [1, 0],
  515. [[1j, 1j], [0, 2]]
  516. ):
  517. x = solve(a, b, assume_a='sym')
  518. assert_array_almost_equal(dot(a, x), b)
  519. def test_simple_her_actuallysym(self):
  520. a = [[2, 3], [3, -5]]
  521. for lower in [0, 1]:
  522. for b in ([[1, 0], [0, 1]],
  523. [1, 0],
  524. [1j, 0],
  525. ):
  526. x = solve(a, b, assume_a='her', lower=lower)
  527. assert_array_almost_equal(dot(a, x), b)
  528. def test_simple_her(self):
  529. a = [[5, 2+1j], [2-1j, -4]]
  530. for b in ([1j, 0],
  531. [1, 0],
  532. [[1j, 1j], [0, 2]]
  533. ):
  534. x = solve(a, b, assume_a='her')
  535. assert_array_almost_equal(dot(a, x), b)
  536. def test_nils_20Feb04(self):
  537. n = 2
  538. A = random([n, n])+random([n, n])*1j
  539. X = zeros((n, n), 'D')
  540. Ainv = inv(A)
  541. R = identity(n)+identity(n)*0j
  542. for i in arange(0, n):
  543. r = R[:, i]
  544. X[:, i] = solve(A, r)
  545. assert_array_almost_equal(X, Ainv)
  546. def test_random(self):
  547. n = 20
  548. a = random([n, n])
  549. for i in range(n):
  550. a[i, i] = 20*(.1+a[i, i])
  551. for i in range(4):
  552. b = random([n, 3])
  553. x = solve(a, b)
  554. assert_array_almost_equal(dot(a, x), b)
  555. def test_random_complex(self):
  556. n = 20
  557. a = random([n, n]) + 1j * random([n, n])
  558. for i in range(n):
  559. a[i, i] = 20*(.1+a[i, i])
  560. for i in range(2):
  561. b = random([n, 3])
  562. x = solve(a, b)
  563. assert_array_almost_equal(dot(a, x), b)
  564. def test_sym_pos_dep(self):
  565. with pytest.warns(
  566. DeprecationWarning,
  567. match="The 'sym_pos' keyword is deprecated",
  568. ):
  569. solve([[1.]], [1], sym_pos=True)
  570. def test_random_sym(self):
  571. n = 20
  572. a = random([n, n])
  573. for i in range(n):
  574. a[i, i] = abs(20*(.1+a[i, i]))
  575. for j in range(i):
  576. a[i, j] = a[j, i]
  577. for i in range(4):
  578. b = random([n])
  579. x = solve(a, b, assume_a="pos")
  580. assert_array_almost_equal(dot(a, x), b)
  581. def test_random_sym_complex(self):
  582. n = 20
  583. a = random([n, n])
  584. a = a + 1j*random([n, n])
  585. for i in range(n):
  586. a[i, i] = abs(20*(.1+a[i, i]))
  587. for j in range(i):
  588. a[i, j] = conjugate(a[j, i])
  589. b = random([n])+2j*random([n])
  590. for i in range(2):
  591. x = solve(a, b, assume_a="pos")
  592. assert_array_almost_equal(dot(a, x), b)
  593. def test_check_finite(self):
  594. a = [[1, 20], [-30, 4]]
  595. for b in ([[1, 0], [0, 1]], [1, 0],
  596. [[2, 1], [-30, 4]]):
  597. x = solve(a, b, check_finite=False)
  598. assert_array_almost_equal(dot(a, x), b)
  599. def test_scalar_a_and_1D_b(self):
  600. a = 1
  601. b = [1, 2, 3]
  602. x = solve(a, b)
  603. assert_array_almost_equal(x.ravel(), b)
  604. assert_(x.shape == (3,), 'Scalar_a_1D_b test returned wrong shape')
  605. def test_simple2(self):
  606. a = np.array([[1.80, 2.88, 2.05, -0.89],
  607. [525.00, -295.00, -95.00, -380.00],
  608. [1.58, -2.69, -2.90, -1.04],
  609. [-1.11, -0.66, -0.59, 0.80]])
  610. b = np.array([[9.52, 18.47],
  611. [2435.00, 225.00],
  612. [0.77, -13.28],
  613. [-6.22, -6.21]])
  614. x = solve(a, b)
  615. assert_array_almost_equal(x, np.array([[1., -1, 3, -5],
  616. [3, 2, 4, 1]]).T)
  617. def test_simple_complex2(self):
  618. a = np.array([[-1.34+2.55j, 0.28+3.17j, -6.39-2.20j, 0.72-0.92j],
  619. [-1.70-14.10j, 33.10-1.50j, -1.50+13.40j, 12.90+13.80j],
  620. [-3.29-2.39j, -1.91+4.42j, -0.14-1.35j, 1.72+1.35j],
  621. [2.41+0.39j, -0.56+1.47j, -0.83-0.69j, -1.96+0.67j]])
  622. b = np.array([[26.26+51.78j, 31.32-6.70j],
  623. [64.30-86.80j, 158.60-14.20j],
  624. [-5.75+25.31j, -2.15+30.19j],
  625. [1.16+2.57j, -2.56+7.55j]])
  626. x = solve(a, b)
  627. assert_array_almost_equal(x, np. array([[1+1.j, -1-2.j],
  628. [2-3.j, 5+1.j],
  629. [-4-5.j, -3+4.j],
  630. [6.j, 2-3.j]]))
  631. def test_hermitian(self):
  632. # An upper triangular matrix will be used for hermitian matrix a
  633. a = np.array([[-1.84, 0.11-0.11j, -1.78-1.18j, 3.91-1.50j],
  634. [0, -4.63, -1.84+0.03j, 2.21+0.21j],
  635. [0, 0, -8.87, 1.58-0.90j],
  636. [0, 0, 0, -1.36]])
  637. b = np.array([[2.98-10.18j, 28.68-39.89j],
  638. [-9.58+3.88j, -24.79-8.40j],
  639. [-0.77-16.05j, 4.23-70.02j],
  640. [7.79+5.48j, -35.39+18.01j]])
  641. res = np.array([[2.+1j, -8+6j],
  642. [3.-2j, 7-2j],
  643. [-1+2j, -1+5j],
  644. [1.-1j, 3-4j]])
  645. x = solve(a, b, assume_a='her')
  646. assert_array_almost_equal(x, res)
  647. # Also conjugate a and test for lower triangular data
  648. x = solve(a.conj().T, b, assume_a='her', lower=True)
  649. assert_array_almost_equal(x, res)
  650. def test_pos_and_sym(self):
  651. A = np.arange(1, 10).reshape(3, 3)
  652. x = solve(np.tril(A)/9, np.ones(3), assume_a='pos')
  653. assert_array_almost_equal(x, [9., 1.8, 1.])
  654. x = solve(np.tril(A)/9, np.ones(3), assume_a='sym')
  655. assert_array_almost_equal(x, [9., 1.8, 1.])
  656. def test_singularity(self):
  657. a = np.array([[1, 0, 0, 0, 0, 0, 1, 0, 1],
  658. [1, 1, 1, 0, 0, 0, 1, 0, 1],
  659. [0, 1, 1, 0, 0, 0, 1, 0, 1],
  660. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  661. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  662. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  663. [1, 0, 1, 1, 1, 1, 0, 0, 0],
  664. [1, 1, 1, 1, 1, 1, 1, 1, 1],
  665. [1, 1, 1, 1, 1, 1, 1, 1, 1]])
  666. b = np.arange(9)[:, None]
  667. assert_raises(LinAlgError, solve, a, b)
  668. def test_ill_condition_warning(self):
  669. a = np.array([[1, 1], [1+1e-16, 1-1e-16]])
  670. b = np.ones(2)
  671. with warnings.catch_warnings():
  672. warnings.simplefilter('error')
  673. assert_raises(LinAlgWarning, solve, a, b)
  674. def test_empty_rhs(self):
  675. a = np.eye(2)
  676. b = [[], []]
  677. x = solve(a, b)
  678. assert_(x.size == 0, 'Returned array is not empty')
  679. assert_(x.shape == (2, 0), 'Returned empty array shape is wrong')
  680. def test_multiple_rhs(self):
  681. a = np.eye(2)
  682. b = np.random.rand(2, 3, 4)
  683. x = solve(a, b)
  684. assert_array_almost_equal(x, b)
  685. def test_transposed_keyword(self):
  686. A = np.arange(9).reshape(3, 3) + 1
  687. x = solve(np.tril(A)/9, np.ones(3), transposed=True)
  688. assert_array_almost_equal(x, [1.2, 0.2, 1])
  689. x = solve(np.tril(A)/9, np.ones(3), transposed=False)
  690. assert_array_almost_equal(x, [9, -5.4, -1.2])
  691. def test_transposed_notimplemented(self):
  692. a = np.eye(3).astype(complex)
  693. with assert_raises(NotImplementedError):
  694. solve(a, a, transposed=True)
  695. def test_nonsquare_a(self):
  696. assert_raises(ValueError, solve, [1, 2], 1)
  697. def test_size_mismatch_with_1D_b(self):
  698. assert_array_almost_equal(solve(np.eye(3), np.ones(3)), np.ones(3))
  699. assert_raises(ValueError, solve, np.eye(3), np.ones(4))
  700. def test_assume_a_keyword(self):
  701. assert_raises(ValueError, solve, 1, 1, assume_a='zxcv')
  702. @pytest.mark.skip(reason="Failure on OS X (gh-7500), "
  703. "crash on Windows (gh-8064)")
  704. def test_all_type_size_routine_combinations(self):
  705. sizes = [10, 100]
  706. assume_as = ['gen', 'sym', 'pos', 'her']
  707. dtypes = [np.float32, np.float64, np.complex64, np.complex128]
  708. for size, assume_a, dtype in itertools.product(sizes, assume_as,
  709. dtypes):
  710. is_complex = dtype in (np.complex64, np.complex128)
  711. if assume_a == 'her' and not is_complex:
  712. continue
  713. err_msg = ("Failed for size: {}, assume_a: {},"
  714. "dtype: {}".format(size, assume_a, dtype))
  715. a = np.random.randn(size, size).astype(dtype)
  716. b = np.random.randn(size).astype(dtype)
  717. if is_complex:
  718. a = a + (1j*np.random.randn(size, size)).astype(dtype)
  719. if assume_a == 'sym': # Can still be complex but only symmetric
  720. a = a + a.T
  721. elif assume_a == 'her': # Handle hermitian matrices here instead
  722. a = a + a.T.conj()
  723. elif assume_a == 'pos':
  724. a = a.conj().T.dot(a) + 0.1*np.eye(size)
  725. tol = 1e-12 if dtype in (np.float64, np.complex128) else 1e-6
  726. if assume_a in ['gen', 'sym', 'her']:
  727. # We revert the tolerance from before
  728. # 4b4a6e7c34fa4060533db38f9a819b98fa81476c
  729. if dtype in (np.float32, np.complex64):
  730. tol *= 10
  731. x = solve(a, b, assume_a=assume_a)
  732. assert_allclose(a.dot(x), b,
  733. atol=tol * size,
  734. rtol=tol * size,
  735. err_msg=err_msg)
  736. if assume_a == 'sym' and dtype not in (np.complex64,
  737. np.complex128):
  738. x = solve(a, b, assume_a=assume_a, transposed=True)
  739. assert_allclose(a.dot(x), b,
  740. atol=tol * size,
  741. rtol=tol * size,
  742. err_msg=err_msg)
  743. class TestSolveTriangular:
  744. def test_simple(self):
  745. """
  746. solve_triangular on a simple 2x2 matrix.
  747. """
  748. A = array([[1, 0], [1, 2]])
  749. b = [1, 1]
  750. sol = solve_triangular(A, b, lower=True)
  751. assert_array_almost_equal(sol, [1, 0])
  752. # check that it works also for non-contiguous matrices
  753. sol = solve_triangular(A.T, b, lower=False)
  754. assert_array_almost_equal(sol, [.5, .5])
  755. # and that it gives the same result as trans=1
  756. sol = solve_triangular(A, b, lower=True, trans=1)
  757. assert_array_almost_equal(sol, [.5, .5])
  758. b = identity(2)
  759. sol = solve_triangular(A, b, lower=True, trans=1)
  760. assert_array_almost_equal(sol, [[1., -.5], [0, 0.5]])
  761. def test_simple_complex(self):
  762. """
  763. solve_triangular on a simple 2x2 complex matrix
  764. """
  765. A = array([[1+1j, 0], [1j, 2]])
  766. b = identity(2)
  767. sol = solve_triangular(A, b, lower=True, trans=1)
  768. assert_array_almost_equal(sol, [[.5-.5j, -.25-.25j], [0, 0.5]])
  769. # check other option combinations with complex rhs
  770. b = np.diag([1+1j, 1+2j])
  771. sol = solve_triangular(A, b, lower=True, trans=0)
  772. assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]])
  773. sol = solve_triangular(A, b, lower=True, trans=1)
  774. assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]])
  775. sol = solve_triangular(A, b, lower=True, trans=2)
  776. assert_array_almost_equal(sol, [[1j, -0.75-0.25j], [0, 0.5+1j]])
  777. sol = solve_triangular(A.T, b, lower=False, trans=0)
  778. assert_array_almost_equal(sol, [[1, 0.25-0.75j], [0, 0.5+1j]])
  779. sol = solve_triangular(A.T, b, lower=False, trans=1)
  780. assert_array_almost_equal(sol, [[1, 0], [-0.5j, 0.5+1j]])
  781. sol = solve_triangular(A.T, b, lower=False, trans=2)
  782. assert_array_almost_equal(sol, [[1j, 0], [-0.5, 0.5+1j]])
  783. def test_check_finite(self):
  784. """
  785. solve_triangular on a simple 2x2 matrix.
  786. """
  787. A = array([[1, 0], [1, 2]])
  788. b = [1, 1]
  789. sol = solve_triangular(A, b, lower=True, check_finite=False)
  790. assert_array_almost_equal(sol, [1, 0])
  791. class TestInv:
  792. def setup_method(self):
  793. np.random.seed(1234)
  794. def test_simple(self):
  795. a = [[1, 2], [3, 4]]
  796. a_inv = inv(a)
  797. assert_array_almost_equal(dot(a, a_inv), np.eye(2))
  798. a = [[1, 2, 3], [4, 5, 6], [7, 8, 10]]
  799. a_inv = inv(a)
  800. assert_array_almost_equal(dot(a, a_inv), np.eye(3))
  801. def test_random(self):
  802. n = 20
  803. for i in range(4):
  804. a = random([n, n])
  805. for i in range(n):
  806. a[i, i] = 20*(.1+a[i, i])
  807. a_inv = inv(a)
  808. assert_array_almost_equal(dot(a, a_inv),
  809. identity(n))
  810. def test_simple_complex(self):
  811. a = [[1, 2], [3, 4j]]
  812. a_inv = inv(a)
  813. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  814. def test_random_complex(self):
  815. n = 20
  816. for i in range(4):
  817. a = random([n, n])+2j*random([n, n])
  818. for i in range(n):
  819. a[i, i] = 20*(.1+a[i, i])
  820. a_inv = inv(a)
  821. assert_array_almost_equal(dot(a, a_inv),
  822. identity(n))
  823. def test_check_finite(self):
  824. a = [[1, 2], [3, 4]]
  825. a_inv = inv(a, check_finite=False)
  826. assert_array_almost_equal(dot(a, a_inv), [[1, 0], [0, 1]])
  827. class TestDet:
  828. def setup_method(self):
  829. np.random.seed(1234)
  830. def test_simple(self):
  831. a = [[1, 2], [3, 4]]
  832. a_det = det(a)
  833. assert_almost_equal(a_det, -2.0)
  834. def test_simple_complex(self):
  835. a = [[1, 2], [3, 4j]]
  836. a_det = det(a)
  837. assert_almost_equal(a_det, -6+4j)
  838. def test_random(self):
  839. basic_det = linalg.det
  840. n = 20
  841. for i in range(4):
  842. a = random([n, n])
  843. d1 = det(a)
  844. d2 = basic_det(a)
  845. assert_almost_equal(d1, d2)
  846. def test_random_complex(self):
  847. basic_det = linalg.det
  848. n = 20
  849. for i in range(4):
  850. a = random([n, n]) + 2j*random([n, n])
  851. d1 = det(a)
  852. d2 = basic_det(a)
  853. assert_allclose(d1, d2, rtol=1e-13)
  854. def test_check_finite(self):
  855. a = [[1, 2], [3, 4]]
  856. a_det = det(a, check_finite=False)
  857. assert_almost_equal(a_det, -2.0)
  858. def direct_lstsq(a, b, cmplx=0):
  859. at = transpose(a)
  860. if cmplx:
  861. at = conjugate(at)
  862. a1 = dot(at, a)
  863. b1 = dot(at, b)
  864. return solve(a1, b1)
  865. class TestLstsq:
  866. lapack_drivers = ('gelsd', 'gelss', 'gelsy', None)
  867. def setup_method(self):
  868. np.random.seed(1234)
  869. def test_simple_exact(self):
  870. for dtype in REAL_DTYPES:
  871. a = np.array([[1, 20], [-30, 4]], dtype=dtype)
  872. for lapack_driver in TestLstsq.lapack_drivers:
  873. for overwrite in (True, False):
  874. for bt in (((1, 0), (0, 1)), (1, 0),
  875. ((2, 1), (-30, 4))):
  876. # Store values in case they are overwritten
  877. # later
  878. a1 = a.copy()
  879. b = np.array(bt, dtype=dtype)
  880. b1 = b.copy()
  881. out = lstsq(a1, b1,
  882. lapack_driver=lapack_driver,
  883. overwrite_a=overwrite,
  884. overwrite_b=overwrite)
  885. x = out[0]
  886. r = out[2]
  887. assert_(r == 2,
  888. 'expected efficient rank 2, got %s' % r)
  889. assert_allclose(dot(a, x), b,
  890. atol=25 * _eps_cast(a1.dtype),
  891. rtol=25 * _eps_cast(a1.dtype),
  892. err_msg="driver: %s" % lapack_driver)
  893. def test_simple_overdet(self):
  894. for dtype in REAL_DTYPES:
  895. a = np.array([[1, 2], [4, 5], [3, 4]], dtype=dtype)
  896. b = np.array([1, 2, 3], dtype=dtype)
  897. for lapack_driver in TestLstsq.lapack_drivers:
  898. for overwrite in (True, False):
  899. # Store values in case they are overwritten later
  900. a1 = a.copy()
  901. b1 = b.copy()
  902. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  903. overwrite_a=overwrite,
  904. overwrite_b=overwrite)
  905. x = out[0]
  906. if lapack_driver == 'gelsy':
  907. residuals = np.sum((b - a.dot(x))**2)
  908. else:
  909. residuals = out[1]
  910. r = out[2]
  911. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  912. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  913. residuals,
  914. rtol=25 * _eps_cast(a1.dtype),
  915. atol=25 * _eps_cast(a1.dtype),
  916. err_msg="driver: %s" % lapack_driver)
  917. assert_allclose(x, (-0.428571428571429, 0.85714285714285),
  918. rtol=25 * _eps_cast(a1.dtype),
  919. atol=25 * _eps_cast(a1.dtype),
  920. err_msg="driver: %s" % lapack_driver)
  921. def test_simple_overdet_complex(self):
  922. for dtype in COMPLEX_DTYPES:
  923. a = np.array([[1+2j, 2], [4, 5], [3, 4]], dtype=dtype)
  924. b = np.array([1, 2+4j, 3], dtype=dtype)
  925. for lapack_driver in TestLstsq.lapack_drivers:
  926. for overwrite in (True, False):
  927. # Store values in case they are overwritten later
  928. a1 = a.copy()
  929. b1 = b.copy()
  930. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  931. overwrite_a=overwrite,
  932. overwrite_b=overwrite)
  933. x = out[0]
  934. if lapack_driver == 'gelsy':
  935. res = b - a.dot(x)
  936. residuals = np.sum(res * res.conj())
  937. else:
  938. residuals = out[1]
  939. r = out[2]
  940. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  941. assert_allclose(abs((dot(a, x) - b)**2).sum(axis=0),
  942. residuals,
  943. rtol=25 * _eps_cast(a1.dtype),
  944. atol=25 * _eps_cast(a1.dtype),
  945. err_msg="driver: %s" % lapack_driver)
  946. assert_allclose(
  947. x, (-0.4831460674157303 + 0.258426966292135j,
  948. 0.921348314606741 + 0.292134831460674j),
  949. rtol=25 * _eps_cast(a1.dtype),
  950. atol=25 * _eps_cast(a1.dtype),
  951. err_msg="driver: %s" % lapack_driver)
  952. def test_simple_underdet(self):
  953. for dtype in REAL_DTYPES:
  954. a = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
  955. b = np.array([1, 2], dtype=dtype)
  956. for lapack_driver in TestLstsq.lapack_drivers:
  957. for overwrite in (True, False):
  958. # Store values in case they are overwritten later
  959. a1 = a.copy()
  960. b1 = b.copy()
  961. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  962. overwrite_a=overwrite,
  963. overwrite_b=overwrite)
  964. x = out[0]
  965. r = out[2]
  966. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  967. assert_allclose(x, (-0.055555555555555, 0.111111111111111,
  968. 0.277777777777777),
  969. rtol=25 * _eps_cast(a1.dtype),
  970. atol=25 * _eps_cast(a1.dtype),
  971. err_msg="driver: %s" % lapack_driver)
  972. def test_random_exact(self):
  973. for dtype in REAL_DTYPES:
  974. for n in (20, 200):
  975. for lapack_driver in TestLstsq.lapack_drivers:
  976. for overwrite in (True, False):
  977. a = np.asarray(random([n, n]), dtype=dtype)
  978. for i in range(n):
  979. a[i, i] = 20 * (0.1 + a[i, i])
  980. for i in range(4):
  981. b = np.asarray(random([n, 3]), dtype=dtype)
  982. # Store values in case they are overwritten later
  983. a1 = a.copy()
  984. b1 = b.copy()
  985. out = lstsq(a1, b1,
  986. lapack_driver=lapack_driver,
  987. overwrite_a=overwrite,
  988. overwrite_b=overwrite)
  989. x = out[0]
  990. r = out[2]
  991. assert_(r == n, 'expected efficient rank %s, '
  992. 'got %s' % (n, r))
  993. if dtype is np.float32:
  994. assert_allclose(
  995. dot(a, x), b,
  996. rtol=500 * _eps_cast(a1.dtype),
  997. atol=500 * _eps_cast(a1.dtype),
  998. err_msg="driver: %s" % lapack_driver)
  999. else:
  1000. assert_allclose(
  1001. dot(a, x), b,
  1002. rtol=1000 * _eps_cast(a1.dtype),
  1003. atol=1000 * _eps_cast(a1.dtype),
  1004. err_msg="driver: %s" % lapack_driver)
  1005. def test_random_complex_exact(self):
  1006. if platform.system() != "Windows":
  1007. if _pep440.parse(np.__version__) >= _pep440.Version("1.24.0"):
  1008. libc_flavor = platform.libc_ver()[0]
  1009. if libc_flavor != "glibc":
  1010. pytest.skip("segfault observed on alpine per gh-17630")
  1011. for dtype in COMPLEX_DTYPES:
  1012. for n in (20, 200):
  1013. for lapack_driver in TestLstsq.lapack_drivers:
  1014. for overwrite in (True, False):
  1015. a = np.asarray(random([n, n]) + 1j*random([n, n]),
  1016. dtype=dtype)
  1017. for i in range(n):
  1018. a[i, i] = 20 * (0.1 + a[i, i])
  1019. for i in range(2):
  1020. b = np.asarray(random([n, 3]), dtype=dtype)
  1021. # Store values in case they are overwritten later
  1022. a1 = a.copy()
  1023. b1 = b.copy()
  1024. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1025. overwrite_a=overwrite,
  1026. overwrite_b=overwrite)
  1027. x = out[0]
  1028. r = out[2]
  1029. assert_(r == n, 'expected efficient rank %s, '
  1030. 'got %s' % (n, r))
  1031. if dtype is np.complex64:
  1032. assert_allclose(
  1033. dot(a, x), b,
  1034. rtol=400 * _eps_cast(a1.dtype),
  1035. atol=400 * _eps_cast(a1.dtype),
  1036. err_msg="driver: %s" % lapack_driver)
  1037. else:
  1038. assert_allclose(
  1039. dot(a, x), b,
  1040. rtol=1000 * _eps_cast(a1.dtype),
  1041. atol=1000 * _eps_cast(a1.dtype),
  1042. err_msg="driver: %s" % lapack_driver)
  1043. def test_random_overdet(self):
  1044. for dtype in REAL_DTYPES:
  1045. for (n, m) in ((20, 15), (200, 2)):
  1046. for lapack_driver in TestLstsq.lapack_drivers:
  1047. for overwrite in (True, False):
  1048. a = np.asarray(random([n, m]), dtype=dtype)
  1049. for i in range(m):
  1050. a[i, i] = 20 * (0.1 + a[i, i])
  1051. for i in range(4):
  1052. b = np.asarray(random([n, 3]), dtype=dtype)
  1053. # Store values in case they are overwritten later
  1054. a1 = a.copy()
  1055. b1 = b.copy()
  1056. out = lstsq(a1, b1,
  1057. lapack_driver=lapack_driver,
  1058. overwrite_a=overwrite,
  1059. overwrite_b=overwrite)
  1060. x = out[0]
  1061. r = out[2]
  1062. assert_(r == m, 'expected efficient rank %s, '
  1063. 'got %s' % (m, r))
  1064. assert_allclose(
  1065. x, direct_lstsq(a, b, cmplx=0),
  1066. rtol=25 * _eps_cast(a1.dtype),
  1067. atol=25 * _eps_cast(a1.dtype),
  1068. err_msg="driver: %s" % lapack_driver)
  1069. def test_random_complex_overdet(self):
  1070. for dtype in COMPLEX_DTYPES:
  1071. for (n, m) in ((20, 15), (200, 2)):
  1072. for lapack_driver in TestLstsq.lapack_drivers:
  1073. for overwrite in (True, False):
  1074. a = np.asarray(random([n, m]) + 1j*random([n, m]),
  1075. dtype=dtype)
  1076. for i in range(m):
  1077. a[i, i] = 20 * (0.1 + a[i, i])
  1078. for i in range(2):
  1079. b = np.asarray(random([n, 3]), dtype=dtype)
  1080. # Store values in case they are overwritten
  1081. # later
  1082. a1 = a.copy()
  1083. b1 = b.copy()
  1084. out = lstsq(a1, b1,
  1085. lapack_driver=lapack_driver,
  1086. overwrite_a=overwrite,
  1087. overwrite_b=overwrite)
  1088. x = out[0]
  1089. r = out[2]
  1090. assert_(r == m, 'expected efficient rank %s, '
  1091. 'got %s' % (m, r))
  1092. assert_allclose(
  1093. x, direct_lstsq(a, b, cmplx=1),
  1094. rtol=25 * _eps_cast(a1.dtype),
  1095. atol=25 * _eps_cast(a1.dtype),
  1096. err_msg="driver: %s" % lapack_driver)
  1097. def test_check_finite(self):
  1098. with suppress_warnings() as sup:
  1099. # On (some) OSX this tests triggers a warning (gh-7538)
  1100. sup.filter(RuntimeWarning,
  1101. "internal gelsd driver lwork query error,.*"
  1102. "Falling back to 'gelss' driver.")
  1103. at = np.array(((1, 20), (-30, 4)))
  1104. for dtype, bt, lapack_driver, overwrite, check_finite in \
  1105. itertools.product(REAL_DTYPES,
  1106. (((1, 0), (0, 1)), (1, 0), ((2, 1), (-30, 4))),
  1107. TestLstsq.lapack_drivers,
  1108. (True, False),
  1109. (True, False)):
  1110. a = at.astype(dtype)
  1111. b = np.array(bt, dtype=dtype)
  1112. # Store values in case they are overwritten
  1113. # later
  1114. a1 = a.copy()
  1115. b1 = b.copy()
  1116. out = lstsq(a1, b1, lapack_driver=lapack_driver,
  1117. check_finite=check_finite, overwrite_a=overwrite,
  1118. overwrite_b=overwrite)
  1119. x = out[0]
  1120. r = out[2]
  1121. assert_(r == 2, 'expected efficient rank 2, got %s' % r)
  1122. assert_allclose(dot(a, x), b,
  1123. rtol=25 * _eps_cast(a.dtype),
  1124. atol=25 * _eps_cast(a.dtype),
  1125. err_msg="driver: %s" % lapack_driver)
  1126. def test_zero_size(self):
  1127. for a_shape, b_shape in (((0, 2), (0,)),
  1128. ((0, 4), (0, 2)),
  1129. ((4, 0), (4,)),
  1130. ((4, 0), (4, 2))):
  1131. b = np.ones(b_shape)
  1132. x, residues, rank, s = lstsq(np.zeros(a_shape), b)
  1133. assert_equal(x, np.zeros((a_shape[1],) + b_shape[1:]))
  1134. residues_should_be = (np.empty((0,)) if a_shape[1]
  1135. else np.linalg.norm(b, axis=0)**2)
  1136. assert_equal(residues, residues_should_be)
  1137. assert_(rank == 0, 'expected rank 0')
  1138. assert_equal(s, np.empty((0,)))
  1139. class TestPinv:
  1140. def setup_method(self):
  1141. np.random.seed(1234)
  1142. def test_simple_real(self):
  1143. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1144. a_pinv = pinv(a)
  1145. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1146. def test_simple_complex(self):
  1147. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1148. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1149. dtype=float))
  1150. a_pinv = pinv(a)
  1151. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1152. def test_simple_singular(self):
  1153. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1154. a_pinv = pinv(a)
  1155. expected = array([[-6.38888889e-01, -1.66666667e-01, 3.05555556e-01],
  1156. [-5.55555556e-02, 1.30136518e-16, 5.55555556e-02],
  1157. [5.27777778e-01, 1.66666667e-01, -1.94444444e-01]])
  1158. assert_array_almost_equal(a_pinv, expected)
  1159. def test_simple_cols(self):
  1160. a = array([[1, 2, 3], [4, 5, 6]], dtype=float)
  1161. a_pinv = pinv(a)
  1162. expected = array([[-0.94444444, 0.44444444],
  1163. [-0.11111111, 0.11111111],
  1164. [0.72222222, -0.22222222]])
  1165. assert_array_almost_equal(a_pinv, expected)
  1166. def test_simple_rows(self):
  1167. a = array([[1, 2], [3, 4], [5, 6]], dtype=float)
  1168. a_pinv = pinv(a)
  1169. expected = array([[-1.33333333, -0.33333333, 0.66666667],
  1170. [1.08333333, 0.33333333, -0.41666667]])
  1171. assert_array_almost_equal(a_pinv, expected)
  1172. def test_check_finite(self):
  1173. a = array([[1, 2, 3], [4, 5, 6.], [7, 8, 10]])
  1174. a_pinv = pinv(a, check_finite=False)
  1175. assert_array_almost_equal(dot(a, a_pinv), np.eye(3))
  1176. def test_native_list_argument(self):
  1177. a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
  1178. a_pinv = pinv(a)
  1179. expected = array([[-6.38888889e-01, -1.66666667e-01, 3.05555556e-01],
  1180. [-5.55555556e-02, 1.30136518e-16, 5.55555556e-02],
  1181. [5.27777778e-01, 1.66666667e-01, -1.94444444e-01]])
  1182. assert_array_almost_equal(a_pinv, expected)
  1183. def test_atol_rtol(self):
  1184. n = 12
  1185. # get a random ortho matrix for shuffling
  1186. q, _ = qr(np.random.rand(n, n))
  1187. a_m = np.arange(35.0).reshape(7,5)
  1188. a = a_m.copy()
  1189. a[0,0] = 0.001
  1190. atol = 1e-5
  1191. rtol = 0.05
  1192. # svds of a_m is ~ [116.906, 4.234, tiny, tiny, tiny]
  1193. # svds of a is ~ [116.906, 4.234, 4.62959e-04, tiny, tiny]
  1194. # Just abs cutoff such that we arrive at a_modified
  1195. a_p = pinv(a_m, atol=atol, rtol=0.)
  1196. adiff1 = a @ a_p @ a - a
  1197. adiff2 = a_m @ a_p @ a_m - a_m
  1198. # Now adiff1 should be around atol value while adiff2 should be
  1199. # relatively tiny
  1200. assert_allclose(np.linalg.norm(adiff1), 5e-4, atol=5.e-4)
  1201. assert_allclose(np.linalg.norm(adiff2), 5e-14, atol=5.e-14)
  1202. # Now do the same but remove another sv ~4.234 via rtol
  1203. a_p = pinv(a_m, atol=atol, rtol=rtol)
  1204. adiff1 = a @ a_p @ a - a
  1205. adiff2 = a_m @ a_p @ a_m - a_m
  1206. assert_allclose(np.linalg.norm(adiff1), 4.233, rtol=0.01)
  1207. assert_allclose(np.linalg.norm(adiff2), 4.233, rtol=0.01)
  1208. class TestPinvSymmetric:
  1209. def setup_method(self):
  1210. np.random.seed(1234)
  1211. def test_simple_real(self):
  1212. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1213. a = np.dot(a, a.T)
  1214. a_pinv = pinvh(a)
  1215. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1216. def test_nonpositive(self):
  1217. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=float)
  1218. a = np.dot(a, a.T)
  1219. u, s, vt = np.linalg.svd(a)
  1220. s[0] *= -1
  1221. a = np.dot(u * s, vt) # a is now symmetric non-positive and singular
  1222. a_pinv = pinv(a)
  1223. a_pinvh = pinvh(a)
  1224. assert_array_almost_equal(a_pinv, a_pinvh)
  1225. def test_simple_complex(self):
  1226. a = (array([[1, 2, 3], [4, 5, 6], [7, 8, 10]],
  1227. dtype=float) + 1j * array([[10, 8, 7], [6, 5, 4], [3, 2, 1]],
  1228. dtype=float))
  1229. a = np.dot(a, a.conj().T)
  1230. a_pinv = pinvh(a)
  1231. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1232. def test_native_list_argument(self):
  1233. a = array([[1, 2, 3], [4, 5, 6], [7, 8, 10]], dtype=float)
  1234. a = np.dot(a, a.T)
  1235. a_pinv = pinvh(a.tolist())
  1236. assert_array_almost_equal(np.dot(a, a_pinv), np.eye(3))
  1237. def test_atol_rtol(self):
  1238. n = 12
  1239. # get a random ortho matrix for shuffling
  1240. q, _ = qr(np.random.rand(n, n))
  1241. a = np.diag([4, 3, 2, 1, 0.99e-4, 0.99e-5] + [0.99e-6]*(n-6))
  1242. a = q.T @ a @ q
  1243. a_m = np.diag([4, 3, 2, 1, 0.99e-4, 0.] + [0.]*(n-6))
  1244. a_m = q.T @ a_m @ q
  1245. atol = 1e-5
  1246. rtol = (4.01e-4 - 4e-5)/4
  1247. # Just abs cutoff such that we arrive at a_modified
  1248. a_p = pinvh(a, atol=atol, rtol=0.)
  1249. adiff1 = a @ a_p @ a - a
  1250. adiff2 = a_m @ a_p @ a_m - a_m
  1251. # Now adiff1 should dance around atol value since truncation
  1252. # while adiff2 should be relatively tiny
  1253. assert_allclose(norm(adiff1), atol, rtol=0.1)
  1254. assert_allclose(norm(adiff2), 1e-12, atol=1e-11)
  1255. # Now do the same but through rtol cancelling atol value
  1256. a_p = pinvh(a, atol=atol, rtol=rtol)
  1257. adiff1 = a @ a_p @ a - a
  1258. adiff2 = a_m @ a_p @ a_m - a_m
  1259. # adiff1 and adiff2 should be elevated to ~1e-4 due to mismatch
  1260. assert_allclose(norm(adiff1), 1e-4, rtol=0.1)
  1261. assert_allclose(norm(adiff2), 1e-4, rtol=0.1)
  1262. @pytest.mark.parametrize('scale', (1e-20, 1., 1e20))
  1263. @pytest.mark.parametrize('pinv_', (pinv, pinvh))
  1264. def test_auto_rcond(scale, pinv_):
  1265. x = np.array([[1, 0], [0, 1e-10]]) * scale
  1266. expected = np.diag(1. / np.diag(x))
  1267. x_inv = pinv_(x)
  1268. assert_allclose(x_inv, expected)
  1269. class TestVectorNorms:
  1270. def test_types(self):
  1271. for dtype in np.typecodes['AllFloat']:
  1272. x = np.array([1, 2, 3], dtype=dtype)
  1273. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  1274. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  1275. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  1276. for dtype in np.typecodes['Complex']:
  1277. x = np.array([1j, 2j, 3j], dtype=dtype)
  1278. tol = max(1e-15, np.finfo(dtype).eps.real * 20)
  1279. assert_allclose(norm(x), np.sqrt(14), rtol=tol)
  1280. assert_allclose(norm(x, 2), np.sqrt(14), rtol=tol)
  1281. def test_overflow(self):
  1282. # unlike numpy's norm, this one is
  1283. # safer on overflow
  1284. a = array([1e20], dtype=float32)
  1285. assert_almost_equal(norm(a), a)
  1286. def test_stable(self):
  1287. # more stable than numpy's norm
  1288. a = array([1e4] + [1]*10000, dtype=float32)
  1289. try:
  1290. # snrm in double precision; we obtain the same as for float64
  1291. # -- large atol needed due to varying blas implementations
  1292. assert_allclose(norm(a) - 1e4, 0.5, atol=1e-2)
  1293. except AssertionError:
  1294. # snrm implemented in single precision, == np.linalg.norm result
  1295. msg = ": Result should equal either 0.0 or 0.5 (depending on " \
  1296. "implementation of snrm2)."
  1297. assert_almost_equal(norm(a) - 1e4, 0.0, err_msg=msg)
  1298. def test_zero_norm(self):
  1299. assert_equal(norm([1, 0, 3], 0), 2)
  1300. assert_equal(norm([1, 2, 3], 0), 3)
  1301. def test_axis_kwd(self):
  1302. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1303. assert_allclose(norm(a, axis=1), [[3.60555128, 4.12310563]] * 2)
  1304. assert_allclose(norm(a, 1, axis=1), [[5.] * 2] * 2)
  1305. def test_keepdims_kwd(self):
  1306. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1307. b = norm(a, axis=1, keepdims=True)
  1308. assert_allclose(b, [[[3.60555128, 4.12310563]]] * 2)
  1309. assert_(b.shape == (2, 1, 2))
  1310. assert_allclose(norm(a, 1, axis=2, keepdims=True), [[[3.], [7.]]] * 2)
  1311. @pytest.mark.skipif(not HAS_ILP64, reason="64-bit BLAS required")
  1312. def test_large_vector(self):
  1313. check_free_memory(free_mb=17000)
  1314. x = np.zeros([2**31], dtype=np.float64)
  1315. x[-1] = 1
  1316. res = norm(x)
  1317. del x
  1318. assert_allclose(res, 1.0)
  1319. class TestMatrixNorms:
  1320. def test_matrix_norms(self):
  1321. # Not all of these are matrix norms in the most technical sense.
  1322. np.random.seed(1234)
  1323. for n, m in (1, 1), (1, 3), (3, 1), (4, 4), (4, 5), (5, 4):
  1324. for t in np.single, np.double, np.csingle, np.cdouble, np.int64:
  1325. A = 10 * np.random.randn(n, m).astype(t)
  1326. if np.issubdtype(A.dtype, np.complexfloating):
  1327. A = (A + 10j * np.random.randn(n, m)).astype(t)
  1328. t_high = np.cdouble
  1329. else:
  1330. t_high = np.double
  1331. for order in (None, 'fro', 1, -1, 2, -2, np.inf, -np.inf):
  1332. actual = norm(A, ord=order)
  1333. desired = np.linalg.norm(A, ord=order)
  1334. # SciPy may return higher precision matrix norms.
  1335. # This is a consequence of using LAPACK.
  1336. if not np.allclose(actual, desired):
  1337. desired = np.linalg.norm(A.astype(t_high), ord=order)
  1338. assert_allclose(actual, desired)
  1339. def test_axis_kwd(self):
  1340. a = np.array([[[2, 1], [3, 4]]] * 2, 'd')
  1341. b = norm(a, ord=np.inf, axis=(1, 0))
  1342. c = norm(np.swapaxes(a, 0, 1), ord=np.inf, axis=(0, 1))
  1343. d = norm(a, ord=1, axis=(0, 1))
  1344. assert_allclose(b, c)
  1345. assert_allclose(c, d)
  1346. assert_allclose(b, d)
  1347. assert_(b.shape == c.shape == d.shape)
  1348. b = norm(a, ord=1, axis=(1, 0))
  1349. c = norm(np.swapaxes(a, 0, 1), ord=1, axis=(0, 1))
  1350. d = norm(a, ord=np.inf, axis=(0, 1))
  1351. assert_allclose(b, c)
  1352. assert_allclose(c, d)
  1353. assert_allclose(b, d)
  1354. assert_(b.shape == c.shape == d.shape)
  1355. def test_keepdims_kwd(self):
  1356. a = np.arange(120, dtype='d').reshape(2, 3, 4, 5)
  1357. b = norm(a, ord=np.inf, axis=(1, 0), keepdims=True)
  1358. c = norm(a, ord=1, axis=(0, 1), keepdims=True)
  1359. assert_allclose(b, c)
  1360. assert_(b.shape == c.shape)
  1361. class TestOverwrite:
  1362. def test_solve(self):
  1363. assert_no_overwrite(solve, [(3, 3), (3,)])
  1364. def test_solve_triangular(self):
  1365. assert_no_overwrite(solve_triangular, [(3, 3), (3,)])
  1366. def test_solve_banded(self):
  1367. assert_no_overwrite(lambda ab, b: solve_banded((2, 1), ab, b),
  1368. [(4, 6), (6,)])
  1369. def test_solveh_banded(self):
  1370. assert_no_overwrite(solveh_banded, [(2, 6), (6,)])
  1371. def test_inv(self):
  1372. assert_no_overwrite(inv, [(3, 3)])
  1373. def test_det(self):
  1374. assert_no_overwrite(det, [(3, 3)])
  1375. def test_lstsq(self):
  1376. assert_no_overwrite(lstsq, [(3, 2), (3,)])
  1377. def test_pinv(self):
  1378. assert_no_overwrite(pinv, [(3, 3)])
  1379. def test_pinvh(self):
  1380. assert_no_overwrite(pinvh, [(3, 3)])
  1381. class TestSolveCirculant:
  1382. def test_basic1(self):
  1383. c = np.array([1, 2, 3, 5])
  1384. b = np.array([1, -1, 1, 0])
  1385. x = solve_circulant(c, b)
  1386. y = solve(circulant(c), b)
  1387. assert_allclose(x, y)
  1388. def test_basic2(self):
  1389. # b is a 2-d matrix.
  1390. c = np.array([1, 2, -3, -5])
  1391. b = np.arange(12).reshape(4, 3)
  1392. x = solve_circulant(c, b)
  1393. y = solve(circulant(c), b)
  1394. assert_allclose(x, y)
  1395. def test_basic3(self):
  1396. # b is a 3-d matrix.
  1397. c = np.array([1, 2, -3, -5])
  1398. b = np.arange(24).reshape(4, 3, 2)
  1399. x = solve_circulant(c, b)
  1400. y = solve(circulant(c), b)
  1401. assert_allclose(x, y)
  1402. def test_complex(self):
  1403. # Complex b and c
  1404. c = np.array([1+2j, -3, 4j, 5])
  1405. b = np.arange(8).reshape(4, 2) + 0.5j
  1406. x = solve_circulant(c, b)
  1407. y = solve(circulant(c), b)
  1408. assert_allclose(x, y)
  1409. def test_random_b_and_c(self):
  1410. # Random b and c
  1411. np.random.seed(54321)
  1412. c = np.random.randn(50)
  1413. b = np.random.randn(50)
  1414. x = solve_circulant(c, b)
  1415. y = solve(circulant(c), b)
  1416. assert_allclose(x, y)
  1417. def test_singular(self):
  1418. # c gives a singular circulant matrix.
  1419. c = np.array([1, 1, 0, 0])
  1420. b = np.array([1, 2, 3, 4])
  1421. x = solve_circulant(c, b, singular='lstsq')
  1422. y, res, rnk, s = lstsq(circulant(c), b)
  1423. assert_allclose(x, y)
  1424. assert_raises(LinAlgError, solve_circulant, x, y)
  1425. def test_axis_args(self):
  1426. # Test use of caxis, baxis and outaxis.
  1427. # c has shape (2, 1, 4)
  1428. c = np.array([[[-1, 2.5, 3, 3.5]], [[1, 6, 6, 6.5]]])
  1429. # b has shape (3, 4)
  1430. b = np.array([[0, 0, 1, 1], [1, 1, 0, 0], [1, -1, 0, 0]])
  1431. x = solve_circulant(c, b, baxis=1)
  1432. assert_equal(x.shape, (4, 2, 3))
  1433. expected = np.empty_like(x)
  1434. expected[:, 0, :] = solve(circulant(c[0]), b.T)
  1435. expected[:, 1, :] = solve(circulant(c[1]), b.T)
  1436. assert_allclose(x, expected)
  1437. x = solve_circulant(c, b, baxis=1, outaxis=-1)
  1438. assert_equal(x.shape, (2, 3, 4))
  1439. assert_allclose(np.moveaxis(x, -1, 0), expected)
  1440. # np.swapaxes(c, 1, 2) has shape (2, 4, 1); b.T has shape (4, 3).
  1441. x = solve_circulant(np.swapaxes(c, 1, 2), b.T, caxis=1)
  1442. assert_equal(x.shape, (4, 2, 3))
  1443. assert_allclose(x, expected)
  1444. def test_native_list_arguments(self):
  1445. # Same as test_basic1 using python's native list.
  1446. c = [1, 2, 3, 5]
  1447. b = [1, -1, 1, 0]
  1448. x = solve_circulant(c, b)
  1449. y = solve(circulant(c), b)
  1450. assert_allclose(x, y)
  1451. class TestMatrix_Balance:
  1452. def test_string_arg(self):
  1453. assert_raises(ValueError, matrix_balance, 'Some string for fail')
  1454. def test_infnan_arg(self):
  1455. assert_raises(ValueError, matrix_balance,
  1456. np.array([[1, 2], [3, np.inf]]))
  1457. assert_raises(ValueError, matrix_balance,
  1458. np.array([[1, 2], [3, np.nan]]))
  1459. def test_scaling(self):
  1460. _, y = matrix_balance(np.array([[1000, 1], [1000, 0]]))
  1461. # Pre/post LAPACK 3.5.0 gives the same result up to an offset
  1462. # since in each case col norm is x1000 greater and
  1463. # 1000 / 32 ~= 1 * 32 hence balanced with 2 ** 5.
  1464. assert_allclose(np.diff(np.log2(np.diag(y))), [5])
  1465. def test_scaling_order(self):
  1466. A = np.array([[1, 0, 1e-4], [1, 1, 1e-2], [1e4, 1e2, 1]])
  1467. x, y = matrix_balance(A)
  1468. assert_allclose(solve(y, A).dot(y), x)
  1469. def test_separate(self):
  1470. _, (y, z) = matrix_balance(np.array([[1000, 1], [1000, 0]]),
  1471. separate=1)
  1472. assert_equal(np.diff(np.log2(y)), [5])
  1473. assert_allclose(z, np.arange(2))
  1474. def test_permutation(self):
  1475. A = block_diag(np.ones((2, 2)), np.tril(np.ones((2, 2))),
  1476. np.ones((3, 3)))
  1477. x, (y, z) = matrix_balance(A, separate=1)
  1478. assert_allclose(y, np.ones_like(y))
  1479. assert_allclose(z, np.array([0, 1, 6, 5, 4, 3, 2]))
  1480. def test_perm_and_scaling(self):
  1481. # Matrix with its diagonal removed
  1482. cases = ( # Case 0
  1483. np.array([[0., 0., 0., 0., 0.000002],
  1484. [0., 0., 0., 0., 0.],
  1485. [2., 2., 0., 0., 0.],
  1486. [2., 2., 0., 0., 0.],
  1487. [0., 0., 0.000002, 0., 0.]]),
  1488. # Case 1 user reported GH-7258
  1489. np.array([[-0.5, 0., 0., 0.],
  1490. [0., -1., 0., 0.],
  1491. [1., 0., -0.5, 0.],
  1492. [0., 1., 0., -1.]]),
  1493. # Case 2 user reported GH-7258
  1494. np.array([[-3., 0., 1., 0.],
  1495. [-1., -1., -0., 1.],
  1496. [-3., -0., -0., 0.],
  1497. [-1., -0., 1., -1.]])
  1498. )
  1499. for A in cases:
  1500. x, y = matrix_balance(A)
  1501. x, (s, p) = matrix_balance(A, separate=1)
  1502. ip = np.empty_like(p)
  1503. ip[p] = np.arange(A.shape[0])
  1504. assert_allclose(y, np.diag(s)[ip, :])
  1505. assert_allclose(solve(y, A).dot(y), x)