test_blas.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096
  1. #
  2. # Created by: Pearu Peterson, April 2002
  3. #
  4. __usage__ = """
  5. Build linalg:
  6. python setup.py build
  7. Run tests if scipy is installed:
  8. python -c 'import scipy;scipy.linalg.test()'
  9. """
  10. import math
  11. import pytest
  12. import numpy as np
  13. from numpy.testing import (assert_equal, assert_almost_equal, assert_,
  14. assert_array_almost_equal, assert_allclose)
  15. from pytest import raises as assert_raises
  16. from numpy import float32, float64, complex64, complex128, arange, triu, \
  17. tril, zeros, tril_indices, ones, mod, diag, append, eye, \
  18. nonzero
  19. from numpy.random import rand, seed
  20. from scipy.linalg import _fblas as fblas, get_blas_funcs, toeplitz, solve
  21. try:
  22. from scipy.linalg import _cblas as cblas
  23. except ImportError:
  24. cblas = None
  25. REAL_DTYPES = [float32, float64]
  26. COMPLEX_DTYPES = [complex64, complex128]
  27. DTYPES = REAL_DTYPES + COMPLEX_DTYPES
  28. def test_get_blas_funcs():
  29. # check that it returns Fortran code for arrays that are
  30. # fortran-ordered
  31. f1, f2, f3 = get_blas_funcs(
  32. ('axpy', 'axpy', 'axpy'),
  33. (np.empty((2, 2), dtype=np.complex64, order='F'),
  34. np.empty((2, 2), dtype=np.complex128, order='C'))
  35. )
  36. # get_blas_funcs will choose libraries depending on most generic
  37. # array
  38. assert_equal(f1.typecode, 'z')
  39. assert_equal(f2.typecode, 'z')
  40. if cblas is not None:
  41. assert_equal(f1.module_name, 'cblas')
  42. assert_equal(f2.module_name, 'cblas')
  43. # check defaults.
  44. f1 = get_blas_funcs('rotg')
  45. assert_equal(f1.typecode, 'd')
  46. # check also dtype interface
  47. f1 = get_blas_funcs('gemm', dtype=np.complex64)
  48. assert_equal(f1.typecode, 'c')
  49. f1 = get_blas_funcs('gemm', dtype='F')
  50. assert_equal(f1.typecode, 'c')
  51. # extended precision complex
  52. f1 = get_blas_funcs('gemm', dtype=np.longcomplex)
  53. assert_equal(f1.typecode, 'z')
  54. # check safe complex upcasting
  55. f1 = get_blas_funcs('axpy',
  56. (np.empty((2, 2), dtype=np.float64),
  57. np.empty((2, 2), dtype=np.complex64))
  58. )
  59. assert_equal(f1.typecode, 'z')
  60. def test_get_blas_funcs_alias():
  61. # check alias for get_blas_funcs
  62. f, g = get_blas_funcs(('nrm2', 'dot'), dtype=np.complex64)
  63. assert f.typecode == 'c'
  64. assert g.typecode == 'c'
  65. f, g, h = get_blas_funcs(('dot', 'dotc', 'dotu'), dtype=np.float64)
  66. assert f is g
  67. assert f is h
  68. class TestCBLAS1Simple:
  69. def test_axpy(self):
  70. for p in 'sd':
  71. f = getattr(cblas, p+'axpy', None)
  72. if f is None:
  73. continue
  74. assert_array_almost_equal(f([1, 2, 3], [2, -1, 3], a=5),
  75. [7, 9, 18])
  76. for p in 'cz':
  77. f = getattr(cblas, p+'axpy', None)
  78. if f is None:
  79. continue
  80. assert_array_almost_equal(f([1, 2j, 3], [2, -1, 3], a=5),
  81. [7, 10j-1, 18])
  82. class TestFBLAS1Simple:
  83. def test_axpy(self):
  84. for p in 'sd':
  85. f = getattr(fblas, p+'axpy', None)
  86. if f is None:
  87. continue
  88. assert_array_almost_equal(f([1, 2, 3], [2, -1, 3], a=5),
  89. [7, 9, 18])
  90. for p in 'cz':
  91. f = getattr(fblas, p+'axpy', None)
  92. if f is None:
  93. continue
  94. assert_array_almost_equal(f([1, 2j, 3], [2, -1, 3], a=5),
  95. [7, 10j-1, 18])
  96. def test_copy(self):
  97. for p in 'sd':
  98. f = getattr(fblas, p+'copy', None)
  99. if f is None:
  100. continue
  101. assert_array_almost_equal(f([3, 4, 5], [8]*3), [3, 4, 5])
  102. for p in 'cz':
  103. f = getattr(fblas, p+'copy', None)
  104. if f is None:
  105. continue
  106. assert_array_almost_equal(f([3, 4j, 5+3j], [8]*3), [3, 4j, 5+3j])
  107. def test_asum(self):
  108. for p in 'sd':
  109. f = getattr(fblas, p+'asum', None)
  110. if f is None:
  111. continue
  112. assert_almost_equal(f([3, -4, 5]), 12)
  113. for p in ['sc', 'dz']:
  114. f = getattr(fblas, p+'asum', None)
  115. if f is None:
  116. continue
  117. assert_almost_equal(f([3j, -4, 3-4j]), 14)
  118. def test_dot(self):
  119. for p in 'sd':
  120. f = getattr(fblas, p+'dot', None)
  121. if f is None:
  122. continue
  123. assert_almost_equal(f([3, -4, 5], [2, 5, 1]), -9)
  124. def test_complex_dotu(self):
  125. for p in 'cz':
  126. f = getattr(fblas, p+'dotu', None)
  127. if f is None:
  128. continue
  129. assert_almost_equal(f([3j, -4, 3-4j], [2, 3, 1]), -9+2j)
  130. def test_complex_dotc(self):
  131. for p in 'cz':
  132. f = getattr(fblas, p+'dotc', None)
  133. if f is None:
  134. continue
  135. assert_almost_equal(f([3j, -4, 3-4j], [2, 3j, 1]), 3-14j)
  136. def test_nrm2(self):
  137. for p in 'sd':
  138. f = getattr(fblas, p+'nrm2', None)
  139. if f is None:
  140. continue
  141. assert_almost_equal(f([3, -4, 5]), math.sqrt(50))
  142. for p in ['c', 'z', 'sc', 'dz']:
  143. f = getattr(fblas, p+'nrm2', None)
  144. if f is None:
  145. continue
  146. assert_almost_equal(f([3j, -4, 3-4j]), math.sqrt(50))
  147. def test_scal(self):
  148. for p in 'sd':
  149. f = getattr(fblas, p+'scal', None)
  150. if f is None:
  151. continue
  152. assert_array_almost_equal(f(2, [3, -4, 5]), [6, -8, 10])
  153. for p in 'cz':
  154. f = getattr(fblas, p+'scal', None)
  155. if f is None:
  156. continue
  157. assert_array_almost_equal(f(3j, [3j, -4, 3-4j]), [-9, -12j, 12+9j])
  158. for p in ['cs', 'zd']:
  159. f = getattr(fblas, p+'scal', None)
  160. if f is None:
  161. continue
  162. assert_array_almost_equal(f(3, [3j, -4, 3-4j]), [9j, -12, 9-12j])
  163. def test_swap(self):
  164. for p in 'sd':
  165. f = getattr(fblas, p+'swap', None)
  166. if f is None:
  167. continue
  168. x, y = [2, 3, 1], [-2, 3, 7]
  169. x1, y1 = f(x, y)
  170. assert_array_almost_equal(x1, y)
  171. assert_array_almost_equal(y1, x)
  172. for p in 'cz':
  173. f = getattr(fblas, p+'swap', None)
  174. if f is None:
  175. continue
  176. x, y = [2, 3j, 1], [-2, 3, 7-3j]
  177. x1, y1 = f(x, y)
  178. assert_array_almost_equal(x1, y)
  179. assert_array_almost_equal(y1, x)
  180. def test_amax(self):
  181. for p in 'sd':
  182. f = getattr(fblas, 'i'+p+'amax')
  183. assert_equal(f([-2, 4, 3]), 1)
  184. for p in 'cz':
  185. f = getattr(fblas, 'i'+p+'amax')
  186. assert_equal(f([-5, 4+3j, 6]), 1)
  187. # XXX: need tests for rot,rotm,rotg,rotmg
  188. class TestFBLAS2Simple:
  189. def test_gemv(self):
  190. for p in 'sd':
  191. f = getattr(fblas, p+'gemv', None)
  192. if f is None:
  193. continue
  194. assert_array_almost_equal(f(3, [[3]], [-4]), [-36])
  195. assert_array_almost_equal(f(3, [[3]], [-4], 3, [5]), [-21])
  196. for p in 'cz':
  197. f = getattr(fblas, p+'gemv', None)
  198. if f is None:
  199. continue
  200. assert_array_almost_equal(f(3j, [[3-4j]], [-4]), [-48-36j])
  201. assert_array_almost_equal(f(3j, [[3-4j]], [-4], 3, [5j]),
  202. [-48-21j])
  203. def test_ger(self):
  204. for p in 'sd':
  205. f = getattr(fblas, p+'ger', None)
  206. if f is None:
  207. continue
  208. assert_array_almost_equal(f(1, [1, 2], [3, 4]), [[3, 4], [6, 8]])
  209. assert_array_almost_equal(f(2, [1, 2, 3], [3, 4]),
  210. [[6, 8], [12, 16], [18, 24]])
  211. assert_array_almost_equal(f(1, [1, 2], [3, 4],
  212. a=[[1, 2], [3, 4]]), [[4, 6], [9, 12]])
  213. for p in 'cz':
  214. f = getattr(fblas, p+'geru', None)
  215. if f is None:
  216. continue
  217. assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
  218. [[3j, 4j], [6, 8]])
  219. assert_array_almost_equal(f(-2, [1j, 2j, 3j], [3j, 4j]),
  220. [[6, 8], [12, 16], [18, 24]])
  221. for p in 'cz':
  222. for name in ('ger', 'gerc'):
  223. f = getattr(fblas, p+name, None)
  224. if f is None:
  225. continue
  226. assert_array_almost_equal(f(1, [1j, 2], [3, 4]),
  227. [[3j, 4j], [6, 8]])
  228. assert_array_almost_equal(f(2, [1j, 2j, 3j], [3j, 4j]),
  229. [[6, 8], [12, 16], [18, 24]])
  230. def test_syr_her(self):
  231. x = np.arange(1, 5, dtype='d')
  232. resx = np.triu(x[:, np.newaxis] * x)
  233. resx_reverse = np.triu(x[::-1, np.newaxis] * x[::-1])
  234. y = np.linspace(0, 8.5, 17, endpoint=False)
  235. z = np.arange(1, 9, dtype='d').view('D')
  236. resz = np.triu(z[:, np.newaxis] * z)
  237. resz_reverse = np.triu(z[::-1, np.newaxis] * z[::-1])
  238. rehz = np.triu(z[:, np.newaxis] * z.conj())
  239. rehz_reverse = np.triu(z[::-1, np.newaxis] * z[::-1].conj())
  240. w = np.c_[np.zeros(4), z, np.zeros(4)].ravel()
  241. for p, rtol in zip('sd', [1e-7, 1e-14]):
  242. f = getattr(fblas, p+'syr', None)
  243. if f is None:
  244. continue
  245. assert_allclose(f(1.0, x), resx, rtol=rtol)
  246. assert_allclose(f(1.0, x, lower=True), resx.T, rtol=rtol)
  247. assert_allclose(f(1.0, y, incx=2, offx=2, n=4), resx, rtol=rtol)
  248. # negative increments imply reversed vectors in blas
  249. assert_allclose(f(1.0, y, incx=-2, offx=2, n=4),
  250. resx_reverse, rtol=rtol)
  251. a = np.zeros((4, 4), 'f' if p == 's' else 'd', 'F')
  252. b = f(1.0, x, a=a, overwrite_a=True)
  253. assert_allclose(a, resx, rtol=rtol)
  254. b = f(2.0, x, a=a)
  255. assert_(a is not b)
  256. assert_allclose(b, 3*resx, rtol=rtol)
  257. assert_raises(Exception, f, 1.0, x, incx=0)
  258. assert_raises(Exception, f, 1.0, x, offx=5)
  259. assert_raises(Exception, f, 1.0, x, offx=-2)
  260. assert_raises(Exception, f, 1.0, x, n=-2)
  261. assert_raises(Exception, f, 1.0, x, n=5)
  262. assert_raises(Exception, f, 1.0, x, lower=2)
  263. assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))
  264. for p, rtol in zip('cz', [1e-7, 1e-14]):
  265. f = getattr(fblas, p+'syr', None)
  266. if f is None:
  267. continue
  268. assert_allclose(f(1.0, z), resz, rtol=rtol)
  269. assert_allclose(f(1.0, z, lower=True), resz.T, rtol=rtol)
  270. assert_allclose(f(1.0, w, incx=3, offx=1, n=4), resz, rtol=rtol)
  271. # negative increments imply reversed vectors in blas
  272. assert_allclose(f(1.0, w, incx=-3, offx=1, n=4),
  273. resz_reverse, rtol=rtol)
  274. a = np.zeros((4, 4), 'F' if p == 'c' else 'D', 'F')
  275. b = f(1.0, z, a=a, overwrite_a=True)
  276. assert_allclose(a, resz, rtol=rtol)
  277. b = f(2.0, z, a=a)
  278. assert_(a is not b)
  279. assert_allclose(b, 3*resz, rtol=rtol)
  280. assert_raises(Exception, f, 1.0, x, incx=0)
  281. assert_raises(Exception, f, 1.0, x, offx=5)
  282. assert_raises(Exception, f, 1.0, x, offx=-2)
  283. assert_raises(Exception, f, 1.0, x, n=-2)
  284. assert_raises(Exception, f, 1.0, x, n=5)
  285. assert_raises(Exception, f, 1.0, x, lower=2)
  286. assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))
  287. for p, rtol in zip('cz', [1e-7, 1e-14]):
  288. f = getattr(fblas, p+'her', None)
  289. if f is None:
  290. continue
  291. assert_allclose(f(1.0, z), rehz, rtol=rtol)
  292. assert_allclose(f(1.0, z, lower=True), rehz.T.conj(), rtol=rtol)
  293. assert_allclose(f(1.0, w, incx=3, offx=1, n=4), rehz, rtol=rtol)
  294. # negative increments imply reversed vectors in blas
  295. assert_allclose(f(1.0, w, incx=-3, offx=1, n=4),
  296. rehz_reverse, rtol=rtol)
  297. a = np.zeros((4, 4), 'F' if p == 'c' else 'D', 'F')
  298. b = f(1.0, z, a=a, overwrite_a=True)
  299. assert_allclose(a, rehz, rtol=rtol)
  300. b = f(2.0, z, a=a)
  301. assert_(a is not b)
  302. assert_allclose(b, 3*rehz, rtol=rtol)
  303. assert_raises(Exception, f, 1.0, x, incx=0)
  304. assert_raises(Exception, f, 1.0, x, offx=5)
  305. assert_raises(Exception, f, 1.0, x, offx=-2)
  306. assert_raises(Exception, f, 1.0, x, n=-2)
  307. assert_raises(Exception, f, 1.0, x, n=5)
  308. assert_raises(Exception, f, 1.0, x, lower=2)
  309. assert_raises(Exception, f, 1.0, x, a=np.zeros((2, 2), 'd', 'F'))
  310. def test_syr2(self):
  311. x = np.arange(1, 5, dtype='d')
  312. y = np.arange(5, 9, dtype='d')
  313. resxy = np.triu(x[:, np.newaxis] * y + y[:, np.newaxis] * x)
  314. resxy_reverse = np.triu(x[::-1, np.newaxis] * y[::-1]
  315. + y[::-1, np.newaxis] * x[::-1])
  316. q = np.linspace(0, 8.5, 17, endpoint=False)
  317. for p, rtol in zip('sd', [1e-7, 1e-14]):
  318. f = getattr(fblas, p+'syr2', None)
  319. if f is None:
  320. continue
  321. assert_allclose(f(1.0, x, y), resxy, rtol=rtol)
  322. assert_allclose(f(1.0, x, y, n=3), resxy[:3, :3], rtol=rtol)
  323. assert_allclose(f(1.0, x, y, lower=True), resxy.T, rtol=rtol)
  324. assert_allclose(f(1.0, q, q, incx=2, offx=2, incy=2, offy=10),
  325. resxy, rtol=rtol)
  326. assert_allclose(f(1.0, q, q, incx=2, offx=2, incy=2, offy=10, n=3),
  327. resxy[:3, :3], rtol=rtol)
  328. # negative increments imply reversed vectors in blas
  329. assert_allclose(f(1.0, q, q, incx=-2, offx=2, incy=-2, offy=10),
  330. resxy_reverse, rtol=rtol)
  331. a = np.zeros((4, 4), 'f' if p == 's' else 'd', 'F')
  332. b = f(1.0, x, y, a=a, overwrite_a=True)
  333. assert_allclose(a, resxy, rtol=rtol)
  334. b = f(2.0, x, y, a=a)
  335. assert_(a is not b)
  336. assert_allclose(b, 3*resxy, rtol=rtol)
  337. assert_raises(Exception, f, 1.0, x, y, incx=0)
  338. assert_raises(Exception, f, 1.0, x, y, offx=5)
  339. assert_raises(Exception, f, 1.0, x, y, offx=-2)
  340. assert_raises(Exception, f, 1.0, x, y, incy=0)
  341. assert_raises(Exception, f, 1.0, x, y, offy=5)
  342. assert_raises(Exception, f, 1.0, x, y, offy=-2)
  343. assert_raises(Exception, f, 1.0, x, y, n=-2)
  344. assert_raises(Exception, f, 1.0, x, y, n=5)
  345. assert_raises(Exception, f, 1.0, x, y, lower=2)
  346. assert_raises(Exception, f, 1.0, x, y,
  347. a=np.zeros((2, 2), 'd', 'F'))
  348. def test_her2(self):
  349. x = np.arange(1, 9, dtype='d').view('D')
  350. y = np.arange(9, 17, dtype='d').view('D')
  351. resxy = x[:, np.newaxis] * y.conj() + y[:, np.newaxis] * x.conj()
  352. resxy = np.triu(resxy)
  353. resxy_reverse = x[::-1, np.newaxis] * y[::-1].conj()
  354. resxy_reverse += y[::-1, np.newaxis] * x[::-1].conj()
  355. resxy_reverse = np.triu(resxy_reverse)
  356. u = np.c_[np.zeros(4), x, np.zeros(4)].ravel()
  357. v = np.c_[np.zeros(4), y, np.zeros(4)].ravel()
  358. for p, rtol in zip('cz', [1e-7, 1e-14]):
  359. f = getattr(fblas, p+'her2', None)
  360. if f is None:
  361. continue
  362. assert_allclose(f(1.0, x, y), resxy, rtol=rtol)
  363. assert_allclose(f(1.0, x, y, n=3), resxy[:3, :3], rtol=rtol)
  364. assert_allclose(f(1.0, x, y, lower=True), resxy.T.conj(),
  365. rtol=rtol)
  366. assert_allclose(f(1.0, u, v, incx=3, offx=1, incy=3, offy=1),
  367. resxy, rtol=rtol)
  368. assert_allclose(f(1.0, u, v, incx=3, offx=1, incy=3, offy=1, n=3),
  369. resxy[:3, :3], rtol=rtol)
  370. # negative increments imply reversed vectors in blas
  371. assert_allclose(f(1.0, u, v, incx=-3, offx=1, incy=-3, offy=1),
  372. resxy_reverse, rtol=rtol)
  373. a = np.zeros((4, 4), 'F' if p == 'c' else 'D', 'F')
  374. b = f(1.0, x, y, a=a, overwrite_a=True)
  375. assert_allclose(a, resxy, rtol=rtol)
  376. b = f(2.0, x, y, a=a)
  377. assert_(a is not b)
  378. assert_allclose(b, 3*resxy, rtol=rtol)
  379. assert_raises(Exception, f, 1.0, x, y, incx=0)
  380. assert_raises(Exception, f, 1.0, x, y, offx=5)
  381. assert_raises(Exception, f, 1.0, x, y, offx=-2)
  382. assert_raises(Exception, f, 1.0, x, y, incy=0)
  383. assert_raises(Exception, f, 1.0, x, y, offy=5)
  384. assert_raises(Exception, f, 1.0, x, y, offy=-2)
  385. assert_raises(Exception, f, 1.0, x, y, n=-2)
  386. assert_raises(Exception, f, 1.0, x, y, n=5)
  387. assert_raises(Exception, f, 1.0, x, y, lower=2)
  388. assert_raises(Exception, f, 1.0, x, y,
  389. a=np.zeros((2, 2), 'd', 'F'))
  390. def test_gbmv(self):
  391. seed(1234)
  392. for ind, dtype in enumerate(DTYPES):
  393. n = 7
  394. m = 5
  395. kl = 1
  396. ku = 2
  397. # fake a banded matrix via toeplitz
  398. A = toeplitz(append(rand(kl+1), zeros(m-kl-1)),
  399. append(rand(ku+1), zeros(n-ku-1)))
  400. A = A.astype(dtype)
  401. Ab = zeros((kl+ku+1, n), dtype=dtype)
  402. # Form the banded storage
  403. Ab[2, :5] = A[0, 0] # diag
  404. Ab[1, 1:6] = A[0, 1] # sup1
  405. Ab[0, 2:7] = A[0, 2] # sup2
  406. Ab[3, :4] = A[1, 0] # sub1
  407. x = rand(n).astype(dtype)
  408. y = rand(m).astype(dtype)
  409. alpha, beta = dtype(3), dtype(-5)
  410. func, = get_blas_funcs(('gbmv',), dtype=dtype)
  411. y1 = func(m=m, n=n, ku=ku, kl=kl, alpha=alpha, a=Ab,
  412. x=x, y=y, beta=beta)
  413. y2 = alpha * A.dot(x) + beta * y
  414. assert_array_almost_equal(y1, y2)
  415. def test_sbmv_hbmv(self):
  416. seed(1234)
  417. for ind, dtype in enumerate(DTYPES):
  418. n = 6
  419. k = 2
  420. A = zeros((n, n), dtype=dtype)
  421. Ab = zeros((k+1, n), dtype=dtype)
  422. # Form the array and its packed banded storage
  423. A[arange(n), arange(n)] = rand(n)
  424. for ind2 in range(1, k+1):
  425. temp = rand(n-ind2)
  426. A[arange(n-ind2), arange(ind2, n)] = temp
  427. Ab[-1-ind2, ind2:] = temp
  428. A = A.astype(dtype)
  429. A = A + A.T if ind < 2 else A + A.conj().T
  430. Ab[-1, :] = diag(A)
  431. x = rand(n).astype(dtype)
  432. y = rand(n).astype(dtype)
  433. alpha, beta = dtype(1.25), dtype(3)
  434. if ind > 1:
  435. func, = get_blas_funcs(('hbmv',), dtype=dtype)
  436. else:
  437. func, = get_blas_funcs(('sbmv',), dtype=dtype)
  438. y1 = func(k=k, alpha=alpha, a=Ab, x=x, y=y, beta=beta)
  439. y2 = alpha * A.dot(x) + beta * y
  440. assert_array_almost_equal(y1, y2)
  441. def test_spmv_hpmv(self):
  442. seed(1234)
  443. for ind, dtype in enumerate(DTYPES+COMPLEX_DTYPES):
  444. n = 3
  445. A = rand(n, n).astype(dtype)
  446. if ind > 1:
  447. A += rand(n, n)*1j
  448. A = A.astype(dtype)
  449. A = A + A.T if ind < 4 else A + A.conj().T
  450. c, r = tril_indices(n)
  451. Ap = A[r, c]
  452. x = rand(n).astype(dtype)
  453. y = rand(n).astype(dtype)
  454. xlong = arange(2*n).astype(dtype)
  455. ylong = ones(2*n).astype(dtype)
  456. alpha, beta = dtype(1.25), dtype(2)
  457. if ind > 3:
  458. func, = get_blas_funcs(('hpmv',), dtype=dtype)
  459. else:
  460. func, = get_blas_funcs(('spmv',), dtype=dtype)
  461. y1 = func(n=n, alpha=alpha, ap=Ap, x=x, y=y, beta=beta)
  462. y2 = alpha * A.dot(x) + beta * y
  463. assert_array_almost_equal(y1, y2)
  464. # Test inc and offsets
  465. y1 = func(n=n-1, alpha=alpha, beta=beta, x=xlong, y=ylong, ap=Ap,
  466. incx=2, incy=2, offx=n, offy=n)
  467. y2 = (alpha * A[:-1, :-1]).dot(xlong[3::2]) + beta * ylong[3::2]
  468. assert_array_almost_equal(y1[3::2], y2)
  469. assert_almost_equal(y1[4], ylong[4])
  470. def test_spr_hpr(self):
  471. seed(1234)
  472. for ind, dtype in enumerate(DTYPES+COMPLEX_DTYPES):
  473. n = 3
  474. A = rand(n, n).astype(dtype)
  475. if ind > 1:
  476. A += rand(n, n)*1j
  477. A = A.astype(dtype)
  478. A = A + A.T if ind < 4 else A + A.conj().T
  479. c, r = tril_indices(n)
  480. Ap = A[r, c]
  481. x = rand(n).astype(dtype)
  482. alpha = (DTYPES+COMPLEX_DTYPES)[mod(ind, 4)](2.5)
  483. if ind > 3:
  484. func, = get_blas_funcs(('hpr',), dtype=dtype)
  485. y2 = alpha * x[:, None].dot(x[None, :].conj()) + A
  486. else:
  487. func, = get_blas_funcs(('spr',), dtype=dtype)
  488. y2 = alpha * x[:, None].dot(x[None, :]) + A
  489. y1 = func(n=n, alpha=alpha, ap=Ap, x=x)
  490. y1f = zeros((3, 3), dtype=dtype)
  491. y1f[r, c] = y1
  492. y1f[c, r] = y1.conj() if ind > 3 else y1
  493. assert_array_almost_equal(y1f, y2)
  494. def test_spr2_hpr2(self):
  495. seed(1234)
  496. for ind, dtype in enumerate(DTYPES):
  497. n = 3
  498. A = rand(n, n).astype(dtype)
  499. if ind > 1:
  500. A += rand(n, n)*1j
  501. A = A.astype(dtype)
  502. A = A + A.T if ind < 2 else A + A.conj().T
  503. c, r = tril_indices(n)
  504. Ap = A[r, c]
  505. x = rand(n).astype(dtype)
  506. y = rand(n).astype(dtype)
  507. alpha = dtype(2)
  508. if ind > 1:
  509. func, = get_blas_funcs(('hpr2',), dtype=dtype)
  510. else:
  511. func, = get_blas_funcs(('spr2',), dtype=dtype)
  512. u = alpha.conj() * x[:, None].dot(y[None, :].conj())
  513. y2 = A + u + u.conj().T
  514. y1 = func(n=n, alpha=alpha, x=x, y=y, ap=Ap)
  515. y1f = zeros((3, 3), dtype=dtype)
  516. y1f[r, c] = y1
  517. y1f[[1, 2, 2], [0, 0, 1]] = y1[[1, 3, 4]].conj()
  518. assert_array_almost_equal(y1f, y2)
  519. def test_tbmv(self):
  520. seed(1234)
  521. for ind, dtype in enumerate(DTYPES):
  522. n = 10
  523. k = 3
  524. x = rand(n).astype(dtype)
  525. A = zeros((n, n), dtype=dtype)
  526. # Banded upper triangular array
  527. for sup in range(k+1):
  528. A[arange(n-sup), arange(sup, n)] = rand(n-sup)
  529. # Add complex parts for c,z
  530. if ind > 1:
  531. A[nonzero(A)] += 1j * rand((k+1)*n-(k*(k+1)//2)).astype(dtype)
  532. # Form the banded storage
  533. Ab = zeros((k+1, n), dtype=dtype)
  534. for row in range(k+1):
  535. Ab[-row-1, row:] = diag(A, k=row)
  536. func, = get_blas_funcs(('tbmv',), dtype=dtype)
  537. y1 = func(k=k, a=Ab, x=x)
  538. y2 = A.dot(x)
  539. assert_array_almost_equal(y1, y2)
  540. y1 = func(k=k, a=Ab, x=x, diag=1)
  541. A[arange(n), arange(n)] = dtype(1)
  542. y2 = A.dot(x)
  543. assert_array_almost_equal(y1, y2)
  544. y1 = func(k=k, a=Ab, x=x, diag=1, trans=1)
  545. y2 = A.T.dot(x)
  546. assert_array_almost_equal(y1, y2)
  547. y1 = func(k=k, a=Ab, x=x, diag=1, trans=2)
  548. y2 = A.conj().T.dot(x)
  549. assert_array_almost_equal(y1, y2)
  550. def test_tbsv(self):
  551. seed(1234)
  552. for ind, dtype in enumerate(DTYPES):
  553. n = 6
  554. k = 3
  555. x = rand(n).astype(dtype)
  556. A = zeros((n, n), dtype=dtype)
  557. # Banded upper triangular array
  558. for sup in range(k+1):
  559. A[arange(n-sup), arange(sup, n)] = rand(n-sup)
  560. # Add complex parts for c,z
  561. if ind > 1:
  562. A[nonzero(A)] += 1j * rand((k+1)*n-(k*(k+1)//2)).astype(dtype)
  563. # Form the banded storage
  564. Ab = zeros((k+1, n), dtype=dtype)
  565. for row in range(k+1):
  566. Ab[-row-1, row:] = diag(A, k=row)
  567. func, = get_blas_funcs(('tbsv',), dtype=dtype)
  568. y1 = func(k=k, a=Ab, x=x)
  569. y2 = solve(A, x)
  570. assert_array_almost_equal(y1, y2)
  571. y1 = func(k=k, a=Ab, x=x, diag=1)
  572. A[arange(n), arange(n)] = dtype(1)
  573. y2 = solve(A, x)
  574. assert_array_almost_equal(y1, y2)
  575. y1 = func(k=k, a=Ab, x=x, diag=1, trans=1)
  576. y2 = solve(A.T, x)
  577. assert_array_almost_equal(y1, y2)
  578. y1 = func(k=k, a=Ab, x=x, diag=1, trans=2)
  579. y2 = solve(A.conj().T, x)
  580. assert_array_almost_equal(y1, y2)
  581. def test_tpmv(self):
  582. seed(1234)
  583. for ind, dtype in enumerate(DTYPES):
  584. n = 10
  585. x = rand(n).astype(dtype)
  586. # Upper triangular array
  587. A = triu(rand(n, n)) if ind < 2 else triu(rand(n, n)+rand(n, n)*1j)
  588. # Form the packed storage
  589. c, r = tril_indices(n)
  590. Ap = A[r, c]
  591. func, = get_blas_funcs(('tpmv',), dtype=dtype)
  592. y1 = func(n=n, ap=Ap, x=x)
  593. y2 = A.dot(x)
  594. assert_array_almost_equal(y1, y2)
  595. y1 = func(n=n, ap=Ap, x=x, diag=1)
  596. A[arange(n), arange(n)] = dtype(1)
  597. y2 = A.dot(x)
  598. assert_array_almost_equal(y1, y2)
  599. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=1)
  600. y2 = A.T.dot(x)
  601. assert_array_almost_equal(y1, y2)
  602. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=2)
  603. y2 = A.conj().T.dot(x)
  604. assert_array_almost_equal(y1, y2)
  605. def test_tpsv(self):
  606. seed(1234)
  607. for ind, dtype in enumerate(DTYPES):
  608. n = 10
  609. x = rand(n).astype(dtype)
  610. # Upper triangular array
  611. A = triu(rand(n, n)) if ind < 2 else triu(rand(n, n)+rand(n, n)*1j)
  612. A += eye(n)
  613. # Form the packed storage
  614. c, r = tril_indices(n)
  615. Ap = A[r, c]
  616. func, = get_blas_funcs(('tpsv',), dtype=dtype)
  617. y1 = func(n=n, ap=Ap, x=x)
  618. y2 = solve(A, x)
  619. assert_array_almost_equal(y1, y2)
  620. y1 = func(n=n, ap=Ap, x=x, diag=1)
  621. A[arange(n), arange(n)] = dtype(1)
  622. y2 = solve(A, x)
  623. assert_array_almost_equal(y1, y2)
  624. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=1)
  625. y2 = solve(A.T, x)
  626. assert_array_almost_equal(y1, y2)
  627. y1 = func(n=n, ap=Ap, x=x, diag=1, trans=2)
  628. y2 = solve(A.conj().T, x)
  629. assert_array_almost_equal(y1, y2)
  630. def test_trmv(self):
  631. seed(1234)
  632. for ind, dtype in enumerate(DTYPES):
  633. n = 3
  634. A = (rand(n, n)+eye(n)).astype(dtype)
  635. x = rand(3).astype(dtype)
  636. func, = get_blas_funcs(('trmv',), dtype=dtype)
  637. y1 = func(a=A, x=x)
  638. y2 = triu(A).dot(x)
  639. assert_array_almost_equal(y1, y2)
  640. y1 = func(a=A, x=x, diag=1)
  641. A[arange(n), arange(n)] = dtype(1)
  642. y2 = triu(A).dot(x)
  643. assert_array_almost_equal(y1, y2)
  644. y1 = func(a=A, x=x, diag=1, trans=1)
  645. y2 = triu(A).T.dot(x)
  646. assert_array_almost_equal(y1, y2)
  647. y1 = func(a=A, x=x, diag=1, trans=2)
  648. y2 = triu(A).conj().T.dot(x)
  649. assert_array_almost_equal(y1, y2)
  650. def test_trsv(self):
  651. seed(1234)
  652. for ind, dtype in enumerate(DTYPES):
  653. n = 15
  654. A = (rand(n, n)+eye(n)).astype(dtype)
  655. x = rand(n).astype(dtype)
  656. func, = get_blas_funcs(('trsv',), dtype=dtype)
  657. y1 = func(a=A, x=x)
  658. y2 = solve(triu(A), x)
  659. assert_array_almost_equal(y1, y2)
  660. y1 = func(a=A, x=x, lower=1)
  661. y2 = solve(tril(A), x)
  662. assert_array_almost_equal(y1, y2)
  663. y1 = func(a=A, x=x, diag=1)
  664. A[arange(n), arange(n)] = dtype(1)
  665. y2 = solve(triu(A), x)
  666. assert_array_almost_equal(y1, y2)
  667. y1 = func(a=A, x=x, diag=1, trans=1)
  668. y2 = solve(triu(A).T, x)
  669. assert_array_almost_equal(y1, y2)
  670. y1 = func(a=A, x=x, diag=1, trans=2)
  671. y2 = solve(triu(A).conj().T, x)
  672. assert_array_almost_equal(y1, y2)
  673. class TestFBLAS3Simple:
  674. def test_gemm(self):
  675. for p in 'sd':
  676. f = getattr(fblas, p+'gemm', None)
  677. if f is None:
  678. continue
  679. assert_array_almost_equal(f(3, [3], [-4]), [[-36]])
  680. assert_array_almost_equal(f(3, [3], [-4], 3, [5]), [-21])
  681. for p in 'cz':
  682. f = getattr(fblas, p+'gemm', None)
  683. if f is None:
  684. continue
  685. assert_array_almost_equal(f(3j, [3-4j], [-4]), [[-48-36j]])
  686. assert_array_almost_equal(f(3j, [3-4j], [-4], 3, [5j]), [-48-21j])
  687. def _get_func(func, ps='sdzc'):
  688. """Just a helper: return a specified BLAS function w/typecode."""
  689. for p in ps:
  690. f = getattr(fblas, p+func, None)
  691. if f is None:
  692. continue
  693. yield f
  694. class TestBLAS3Symm:
  695. def setup_method(self):
  696. self.a = np.array([[1., 2.],
  697. [0., 1.]])
  698. self.b = np.array([[1., 0., 3.],
  699. [0., -1., 2.]])
  700. self.c = np.ones((2, 3))
  701. self.t = np.array([[2., -1., 8.],
  702. [3., 0., 9.]])
  703. def test_symm(self):
  704. for f in _get_func('symm'):
  705. res = f(a=self.a, b=self.b, c=self.c, alpha=1., beta=1.)
  706. assert_array_almost_equal(res, self.t)
  707. res = f(a=self.a.T, b=self.b, lower=1, c=self.c, alpha=1., beta=1.)
  708. assert_array_almost_equal(res, self.t)
  709. res = f(a=self.a, b=self.b.T, side=1, c=self.c.T,
  710. alpha=1., beta=1.)
  711. assert_array_almost_equal(res, self.t.T)
  712. def test_summ_wrong_side(self):
  713. f = getattr(fblas, 'dsymm', None)
  714. if f is not None:
  715. assert_raises(Exception, f, **{'a': self.a, 'b': self.b,
  716. 'alpha': 1, 'side': 1})
  717. # `side=1` means C <- B*A, hence shapes of A and B are to be
  718. # compatible. Otherwise, f2py exception is raised
  719. def test_symm_wrong_uplo(self):
  720. """SYMM only considers the upper/lower part of A. Hence setting
  721. wrong value for `lower` (default is lower=0, meaning upper triangle)
  722. gives a wrong result.
  723. """
  724. f = getattr(fblas, 'dsymm', None)
  725. if f is not None:
  726. res = f(a=self.a, b=self.b, c=self.c, alpha=1., beta=1.)
  727. assert np.allclose(res, self.t)
  728. res = f(a=self.a, b=self.b, lower=1, c=self.c, alpha=1., beta=1.)
  729. assert not np.allclose(res, self.t)
  730. class TestBLAS3Syrk:
  731. def setup_method(self):
  732. self.a = np.array([[1., 0.],
  733. [0., -2.],
  734. [2., 3.]])
  735. self.t = np.array([[1., 0., 2.],
  736. [0., 4., -6.],
  737. [2., -6., 13.]])
  738. self.tt = np.array([[5., 6.],
  739. [6., 13.]])
  740. def test_syrk(self):
  741. for f in _get_func('syrk'):
  742. c = f(a=self.a, alpha=1.)
  743. assert_array_almost_equal(np.triu(c), np.triu(self.t))
  744. c = f(a=self.a, alpha=1., lower=1)
  745. assert_array_almost_equal(np.tril(c), np.tril(self.t))
  746. c0 = np.ones(self.t.shape)
  747. c = f(a=self.a, alpha=1., beta=1., c=c0)
  748. assert_array_almost_equal(np.triu(c), np.triu(self.t+c0))
  749. c = f(a=self.a, alpha=1., trans=1)
  750. assert_array_almost_equal(np.triu(c), np.triu(self.tt))
  751. # prints '0-th dimension must be fixed to 3 but got 5',
  752. # FIXME: suppress?
  753. # FIXME: how to catch the _fblas.error?
  754. def test_syrk_wrong_c(self):
  755. f = getattr(fblas, 'dsyrk', None)
  756. if f is not None:
  757. assert_raises(Exception, f, **{'a': self.a, 'alpha': 1.,
  758. 'c': np.ones((5, 8))})
  759. # if C is supplied, it must have compatible dimensions
  760. class TestBLAS3Syr2k:
  761. def setup_method(self):
  762. self.a = np.array([[1., 0.],
  763. [0., -2.],
  764. [2., 3.]])
  765. self.b = np.array([[0., 1.],
  766. [1., 0.],
  767. [0, 1.]])
  768. self.t = np.array([[0., -1., 3.],
  769. [-1., 0., 0.],
  770. [3., 0., 6.]])
  771. self.tt = np.array([[0., 1.],
  772. [1., 6]])
  773. def test_syr2k(self):
  774. for f in _get_func('syr2k'):
  775. c = f(a=self.a, b=self.b, alpha=1.)
  776. assert_array_almost_equal(np.triu(c), np.triu(self.t))
  777. c = f(a=self.a, b=self.b, alpha=1., lower=1)
  778. assert_array_almost_equal(np.tril(c), np.tril(self.t))
  779. c0 = np.ones(self.t.shape)
  780. c = f(a=self.a, b=self.b, alpha=1., beta=1., c=c0)
  781. assert_array_almost_equal(np.triu(c), np.triu(self.t+c0))
  782. c = f(a=self.a, b=self.b, alpha=1., trans=1)
  783. assert_array_almost_equal(np.triu(c), np.triu(self.tt))
  784. # prints '0-th dimension must be fixed to 3 but got 5', FIXME: suppress?
  785. def test_syr2k_wrong_c(self):
  786. f = getattr(fblas, 'dsyr2k', None)
  787. if f is not None:
  788. assert_raises(Exception, f, **{'a': self.a,
  789. 'b': self.b,
  790. 'alpha': 1.,
  791. 'c': np.zeros((15, 8))})
  792. # if C is supplied, it must have compatible dimensions
  793. class TestSyHe:
  794. """Quick and simple tests for (zc)-symm, syrk, syr2k."""
  795. def setup_method(self):
  796. self.sigma_y = np.array([[0., -1.j],
  797. [1.j, 0.]])
  798. def test_symm_zc(self):
  799. for f in _get_func('symm', 'zc'):
  800. # NB: a is symmetric w/upper diag of ONLY
  801. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  802. assert_array_almost_equal(np.triu(res), np.diag([1, -1]))
  803. def test_hemm_zc(self):
  804. for f in _get_func('hemm', 'zc'):
  805. # NB: a is hermitian w/upper diag of ONLY
  806. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  807. assert_array_almost_equal(np.triu(res), np.diag([1, 1]))
  808. def test_syrk_zr(self):
  809. for f in _get_func('syrk', 'zc'):
  810. res = f(a=self.sigma_y, alpha=1.)
  811. assert_array_almost_equal(np.triu(res), np.diag([-1, -1]))
  812. def test_herk_zr(self):
  813. for f in _get_func('herk', 'zc'):
  814. res = f(a=self.sigma_y, alpha=1.)
  815. assert_array_almost_equal(np.triu(res), np.diag([1, 1]))
  816. def test_syr2k_zr(self):
  817. for f in _get_func('syr2k', 'zc'):
  818. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  819. assert_array_almost_equal(np.triu(res), 2.*np.diag([-1, -1]))
  820. def test_her2k_zr(self):
  821. for f in _get_func('her2k', 'zc'):
  822. res = f(a=self.sigma_y, b=self.sigma_y, alpha=1.)
  823. assert_array_almost_equal(np.triu(res), 2.*np.diag([1, 1]))
  824. class TestTRMM:
  825. """Quick and simple tests for dtrmm."""
  826. def setup_method(self):
  827. self.a = np.array([[1., 2., ],
  828. [-2., 1.]])
  829. self.b = np.array([[3., 4., -1.],
  830. [5., 6., -2.]])
  831. self.a2 = np.array([[1, 1, 2, 3],
  832. [0, 1, 4, 5],
  833. [0, 0, 1, 6],
  834. [0, 0, 0, 1]], order="f")
  835. self.b2 = np.array([[1, 4], [2, 5], [3, 6], [7, 8], [9, 10]],
  836. order="f")
  837. @pytest.mark.parametrize("dtype_", DTYPES)
  838. def test_side(self, dtype_):
  839. trmm = get_blas_funcs("trmm", dtype=dtype_)
  840. # Provide large A array that works for side=1 but not 0 (see gh-10841)
  841. assert_raises(Exception, trmm, 1.0, self.a2, self.b2)
  842. res = trmm(1.0, self.a2.astype(dtype_), self.b2.astype(dtype_),
  843. side=1)
  844. k = self.b2.shape[1]
  845. assert_allclose(res, self.b2 @ self.a2[:k, :k], rtol=0.,
  846. atol=100*np.finfo(dtype_).eps)
  847. def test_ab(self):
  848. f = getattr(fblas, 'dtrmm', None)
  849. if f is not None:
  850. result = f(1., self.a, self.b)
  851. # default a is upper triangular
  852. expected = np.array([[13., 16., -5.],
  853. [5., 6., -2.]])
  854. assert_array_almost_equal(result, expected)
  855. def test_ab_lower(self):
  856. f = getattr(fblas, 'dtrmm', None)
  857. if f is not None:
  858. result = f(1., self.a, self.b, lower=True)
  859. expected = np.array([[3., 4., -1.],
  860. [-1., -2., 0.]]) # now a is lower triangular
  861. assert_array_almost_equal(result, expected)
  862. def test_b_overwrites(self):
  863. # BLAS dtrmm modifies B argument in-place.
  864. # Here the default is to copy, but this can be overridden
  865. f = getattr(fblas, 'dtrmm', None)
  866. if f is not None:
  867. for overwr in [True, False]:
  868. bcopy = self.b.copy()
  869. result = f(1., self.a, bcopy, overwrite_b=overwr)
  870. # C-contiguous arrays are copied
  871. assert_(bcopy.flags.f_contiguous is False and
  872. np.may_share_memory(bcopy, result) is False)
  873. assert_equal(bcopy, self.b)
  874. bcopy = np.asfortranarray(self.b.copy()) # or just transpose it
  875. result = f(1., self.a, bcopy, overwrite_b=True)
  876. assert_(bcopy.flags.f_contiguous is True and
  877. np.may_share_memory(bcopy, result) is True)
  878. assert_array_almost_equal(bcopy, result)
  879. def test_trsm():
  880. seed(1234)
  881. for ind, dtype in enumerate(DTYPES):
  882. tol = np.finfo(dtype).eps*1000
  883. func, = get_blas_funcs(('trsm',), dtype=dtype)
  884. # Test protection against size mismatches
  885. A = rand(4, 5).astype(dtype)
  886. B = rand(4, 4).astype(dtype)
  887. alpha = dtype(1)
  888. assert_raises(Exception, func, alpha, A, B)
  889. assert_raises(Exception, func, alpha, A.T, B)
  890. n = 8
  891. m = 7
  892. alpha = dtype(-2.5)
  893. A = (rand(m, m) if ind < 2 else rand(m, m) + rand(m, m)*1j) + eye(m)
  894. A = A.astype(dtype)
  895. Au = triu(A)
  896. Al = tril(A)
  897. B1 = rand(m, n).astype(dtype)
  898. B2 = rand(n, m).astype(dtype)
  899. x1 = func(alpha=alpha, a=A, b=B1)
  900. assert_equal(B1.shape, x1.shape)
  901. x2 = solve(Au, alpha*B1)
  902. assert_allclose(x1, x2, atol=tol)
  903. x1 = func(alpha=alpha, a=A, b=B1, trans_a=1)
  904. x2 = solve(Au.T, alpha*B1)
  905. assert_allclose(x1, x2, atol=tol)
  906. x1 = func(alpha=alpha, a=A, b=B1, trans_a=2)
  907. x2 = solve(Au.conj().T, alpha*B1)
  908. assert_allclose(x1, x2, atol=tol)
  909. x1 = func(alpha=alpha, a=A, b=B1, diag=1)
  910. Au[arange(m), arange(m)] = dtype(1)
  911. x2 = solve(Au, alpha*B1)
  912. assert_allclose(x1, x2, atol=tol)
  913. x1 = func(alpha=alpha, a=A, b=B2, diag=1, side=1)
  914. x2 = solve(Au.conj().T, alpha*B2.conj().T)
  915. assert_allclose(x1, x2.conj().T, atol=tol)
  916. x1 = func(alpha=alpha, a=A, b=B2, diag=1, side=1, lower=1)
  917. Al[arange(m), arange(m)] = dtype(1)
  918. x2 = solve(Al.conj().T, alpha*B2.conj().T)
  919. assert_allclose(x1, x2.conj().T, atol=tol)