test_fblas.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. # Test interfaces to fortran blas.
  2. #
  3. # The tests are more of interface than they are of the underlying blas.
  4. # Only very small matrices checked -- N=3 or so.
  5. #
  6. # !! Complex calculations really aren't checked that carefully.
  7. # !! Only real valued complex numbers are used in tests.
  8. from numpy import float32, float64, complex64, complex128, arange, array, \
  9. zeros, shape, transpose, newaxis, common_type, conjugate
  10. from scipy.linalg import _fblas as fblas
  11. from numpy.testing import assert_array_equal, \
  12. assert_allclose, assert_array_almost_equal, assert_
  13. import pytest
  14. # decimal accuracy to require between Python and LAPACK/BLAS calculations
  15. accuracy = 5
  16. # Since numpy.dot likely uses the same blas, use this routine
  17. # to check.
  18. def matrixmultiply(a, b):
  19. if len(b.shape) == 1:
  20. b_is_vector = True
  21. b = b[:, newaxis]
  22. else:
  23. b_is_vector = False
  24. assert_(a.shape[1] == b.shape[0])
  25. c = zeros((a.shape[0], b.shape[1]), common_type(a, b))
  26. for i in range(a.shape[0]):
  27. for j in range(b.shape[1]):
  28. s = 0
  29. for k in range(a.shape[1]):
  30. s += a[i, k] * b[k, j]
  31. c[i, j] = s
  32. if b_is_vector:
  33. c = c.reshape((a.shape[0],))
  34. return c
  35. ##################################################
  36. # Test blas ?axpy
  37. class BaseAxpy:
  38. ''' Mixin class for axpy tests '''
  39. def test_default_a(self):
  40. x = arange(3., dtype=self.dtype)
  41. y = arange(3., dtype=x.dtype)
  42. real_y = x*1.+y
  43. y = self.blas_func(x, y)
  44. assert_array_equal(real_y, y)
  45. def test_simple(self):
  46. x = arange(3., dtype=self.dtype)
  47. y = arange(3., dtype=x.dtype)
  48. real_y = x*3.+y
  49. y = self.blas_func(x, y, a=3.)
  50. assert_array_equal(real_y, y)
  51. def test_x_stride(self):
  52. x = arange(6., dtype=self.dtype)
  53. y = zeros(3, x.dtype)
  54. y = arange(3., dtype=x.dtype)
  55. real_y = x[::2]*3.+y
  56. y = self.blas_func(x, y, a=3., n=3, incx=2)
  57. assert_array_equal(real_y, y)
  58. def test_y_stride(self):
  59. x = arange(3., dtype=self.dtype)
  60. y = zeros(6, x.dtype)
  61. real_y = x*3.+y[::2]
  62. y = self.blas_func(x, y, a=3., n=3, incy=2)
  63. assert_array_equal(real_y, y[::2])
  64. def test_x_and_y_stride(self):
  65. x = arange(12., dtype=self.dtype)
  66. y = zeros(6, x.dtype)
  67. real_y = x[::4]*3.+y[::2]
  68. y = self.blas_func(x, y, a=3., n=3, incx=4, incy=2)
  69. assert_array_equal(real_y, y[::2])
  70. def test_x_bad_size(self):
  71. x = arange(12., dtype=self.dtype)
  72. y = zeros(6, x.dtype)
  73. with pytest.raises(Exception, match='failed for 1st keyword'):
  74. self.blas_func(x, y, n=4, incx=5)
  75. def test_y_bad_size(self):
  76. x = arange(12., dtype=self.dtype)
  77. y = zeros(6, x.dtype)
  78. with pytest.raises(Exception, match='failed for 1st keyword'):
  79. self.blas_func(x, y, n=3, incy=5)
  80. try:
  81. class TestSaxpy(BaseAxpy):
  82. blas_func = fblas.saxpy
  83. dtype = float32
  84. except AttributeError:
  85. class TestSaxpy:
  86. pass
  87. class TestDaxpy(BaseAxpy):
  88. blas_func = fblas.daxpy
  89. dtype = float64
  90. try:
  91. class TestCaxpy(BaseAxpy):
  92. blas_func = fblas.caxpy
  93. dtype = complex64
  94. except AttributeError:
  95. class TestCaxpy:
  96. pass
  97. class TestZaxpy(BaseAxpy):
  98. blas_func = fblas.zaxpy
  99. dtype = complex128
  100. ##################################################
  101. # Test blas ?scal
  102. class BaseScal:
  103. ''' Mixin class for scal testing '''
  104. def test_simple(self):
  105. x = arange(3., dtype=self.dtype)
  106. real_x = x*3.
  107. x = self.blas_func(3., x)
  108. assert_array_equal(real_x, x)
  109. def test_x_stride(self):
  110. x = arange(6., dtype=self.dtype)
  111. real_x = x.copy()
  112. real_x[::2] = x[::2]*array(3., self.dtype)
  113. x = self.blas_func(3., x, n=3, incx=2)
  114. assert_array_equal(real_x, x)
  115. def test_x_bad_size(self):
  116. x = arange(12., dtype=self.dtype)
  117. with pytest.raises(Exception, match='failed for 1st keyword'):
  118. self.blas_func(2., x, n=4, incx=5)
  119. try:
  120. class TestSscal(BaseScal):
  121. blas_func = fblas.sscal
  122. dtype = float32
  123. except AttributeError:
  124. class TestSscal:
  125. pass
  126. class TestDscal(BaseScal):
  127. blas_func = fblas.dscal
  128. dtype = float64
  129. try:
  130. class TestCscal(BaseScal):
  131. blas_func = fblas.cscal
  132. dtype = complex64
  133. except AttributeError:
  134. class TestCscal:
  135. pass
  136. class TestZscal(BaseScal):
  137. blas_func = fblas.zscal
  138. dtype = complex128
  139. ##################################################
  140. # Test blas ?copy
  141. class BaseCopy:
  142. ''' Mixin class for copy testing '''
  143. def test_simple(self):
  144. x = arange(3., dtype=self.dtype)
  145. y = zeros(shape(x), x.dtype)
  146. y = self.blas_func(x, y)
  147. assert_array_equal(x, y)
  148. def test_x_stride(self):
  149. x = arange(6., dtype=self.dtype)
  150. y = zeros(3, x.dtype)
  151. y = self.blas_func(x, y, n=3, incx=2)
  152. assert_array_equal(x[::2], y)
  153. def test_y_stride(self):
  154. x = arange(3., dtype=self.dtype)
  155. y = zeros(6, x.dtype)
  156. y = self.blas_func(x, y, n=3, incy=2)
  157. assert_array_equal(x, y[::2])
  158. def test_x_and_y_stride(self):
  159. x = arange(12., dtype=self.dtype)
  160. y = zeros(6, x.dtype)
  161. y = self.blas_func(x, y, n=3, incx=4, incy=2)
  162. assert_array_equal(x[::4], y[::2])
  163. def test_x_bad_size(self):
  164. x = arange(12., dtype=self.dtype)
  165. y = zeros(6, x.dtype)
  166. with pytest.raises(Exception, match='failed for 1st keyword'):
  167. self.blas_func(x, y, n=4, incx=5)
  168. def test_y_bad_size(self):
  169. x = arange(12., dtype=self.dtype)
  170. y = zeros(6, x.dtype)
  171. with pytest.raises(Exception, match='failed for 1st keyword'):
  172. self.blas_func(x, y, n=3, incy=5)
  173. # def test_y_bad_type(self):
  174. ## Hmmm. Should this work? What should be the output.
  175. # x = arange(3.,dtype=self.dtype)
  176. # y = zeros(shape(x))
  177. # self.blas_func(x,y)
  178. # assert_array_equal(x,y)
  179. try:
  180. class TestScopy(BaseCopy):
  181. blas_func = fblas.scopy
  182. dtype = float32
  183. except AttributeError:
  184. class TestScopy:
  185. pass
  186. class TestDcopy(BaseCopy):
  187. blas_func = fblas.dcopy
  188. dtype = float64
  189. try:
  190. class TestCcopy(BaseCopy):
  191. blas_func = fblas.ccopy
  192. dtype = complex64
  193. except AttributeError:
  194. class TestCcopy:
  195. pass
  196. class TestZcopy(BaseCopy):
  197. blas_func = fblas.zcopy
  198. dtype = complex128
  199. ##################################################
  200. # Test blas ?swap
  201. class BaseSwap:
  202. ''' Mixin class for swap tests '''
  203. def test_simple(self):
  204. x = arange(3., dtype=self.dtype)
  205. y = zeros(shape(x), x.dtype)
  206. desired_x = y.copy()
  207. desired_y = x.copy()
  208. x, y = self.blas_func(x, y)
  209. assert_array_equal(desired_x, x)
  210. assert_array_equal(desired_y, y)
  211. def test_x_stride(self):
  212. x = arange(6., dtype=self.dtype)
  213. y = zeros(3, x.dtype)
  214. desired_x = y.copy()
  215. desired_y = x.copy()[::2]
  216. x, y = self.blas_func(x, y, n=3, incx=2)
  217. assert_array_equal(desired_x, x[::2])
  218. assert_array_equal(desired_y, y)
  219. def test_y_stride(self):
  220. x = arange(3., dtype=self.dtype)
  221. y = zeros(6, x.dtype)
  222. desired_x = y.copy()[::2]
  223. desired_y = x.copy()
  224. x, y = self.blas_func(x, y, n=3, incy=2)
  225. assert_array_equal(desired_x, x)
  226. assert_array_equal(desired_y, y[::2])
  227. def test_x_and_y_stride(self):
  228. x = arange(12., dtype=self.dtype)
  229. y = zeros(6, x.dtype)
  230. desired_x = y.copy()[::2]
  231. desired_y = x.copy()[::4]
  232. x, y = self.blas_func(x, y, n=3, incx=4, incy=2)
  233. assert_array_equal(desired_x, x[::4])
  234. assert_array_equal(desired_y, y[::2])
  235. def test_x_bad_size(self):
  236. x = arange(12., dtype=self.dtype)
  237. y = zeros(6, x.dtype)
  238. with pytest.raises(Exception, match='failed for 1st keyword'):
  239. self.blas_func(x, y, n=4, incx=5)
  240. def test_y_bad_size(self):
  241. x = arange(12., dtype=self.dtype)
  242. y = zeros(6, x.dtype)
  243. with pytest.raises(Exception, match='failed for 1st keyword'):
  244. self.blas_func(x, y, n=3, incy=5)
  245. try:
  246. class TestSswap(BaseSwap):
  247. blas_func = fblas.sswap
  248. dtype = float32
  249. except AttributeError:
  250. class TestSswap:
  251. pass
  252. class TestDswap(BaseSwap):
  253. blas_func = fblas.dswap
  254. dtype = float64
  255. try:
  256. class TestCswap(BaseSwap):
  257. blas_func = fblas.cswap
  258. dtype = complex64
  259. except AttributeError:
  260. class TestCswap:
  261. pass
  262. class TestZswap(BaseSwap):
  263. blas_func = fblas.zswap
  264. dtype = complex128
  265. ##################################################
  266. # Test blas ?gemv
  267. # This will be a mess to test all cases.
  268. class BaseGemv:
  269. ''' Mixin class for gemv tests '''
  270. def get_data(self, x_stride=1, y_stride=1):
  271. mult = array(1, dtype=self.dtype)
  272. if self.dtype in [complex64, complex128]:
  273. mult = array(1+1j, dtype=self.dtype)
  274. from numpy.random import normal, seed
  275. seed(1234)
  276. alpha = array(1., dtype=self.dtype) * mult
  277. beta = array(1., dtype=self.dtype) * mult
  278. a = normal(0., 1., (3, 3)).astype(self.dtype) * mult
  279. x = arange(shape(a)[0]*x_stride, dtype=self.dtype) * mult
  280. y = arange(shape(a)[1]*y_stride, dtype=self.dtype) * mult
  281. return alpha, beta, a, x, y
  282. def test_simple(self):
  283. alpha, beta, a, x, y = self.get_data()
  284. desired_y = alpha*matrixmultiply(a, x)+beta*y
  285. y = self.blas_func(alpha, a, x, beta, y)
  286. assert_array_almost_equal(desired_y, y)
  287. def test_default_beta_y(self):
  288. alpha, beta, a, x, y = self.get_data()
  289. desired_y = matrixmultiply(a, x)
  290. y = self.blas_func(1, a, x)
  291. assert_array_almost_equal(desired_y, y)
  292. def test_simple_transpose(self):
  293. alpha, beta, a, x, y = self.get_data()
  294. desired_y = alpha*matrixmultiply(transpose(a), x)+beta*y
  295. y = self.blas_func(alpha, a, x, beta, y, trans=1)
  296. assert_array_almost_equal(desired_y, y)
  297. def test_simple_transpose_conj(self):
  298. alpha, beta, a, x, y = self.get_data()
  299. desired_y = alpha*matrixmultiply(transpose(conjugate(a)), x)+beta*y
  300. y = self.blas_func(alpha, a, x, beta, y, trans=2)
  301. assert_array_almost_equal(desired_y, y)
  302. def test_x_stride(self):
  303. alpha, beta, a, x, y = self.get_data(x_stride=2)
  304. desired_y = alpha*matrixmultiply(a, x[::2])+beta*y
  305. y = self.blas_func(alpha, a, x, beta, y, incx=2)
  306. assert_array_almost_equal(desired_y, y)
  307. def test_x_stride_transpose(self):
  308. alpha, beta, a, x, y = self.get_data(x_stride=2)
  309. desired_y = alpha*matrixmultiply(transpose(a), x[::2])+beta*y
  310. y = self.blas_func(alpha, a, x, beta, y, trans=1, incx=2)
  311. assert_array_almost_equal(desired_y, y)
  312. def test_x_stride_assert(self):
  313. # What is the use of this test?
  314. alpha, beta, a, x, y = self.get_data(x_stride=2)
  315. with pytest.raises(Exception, match='failed for 3rd argument'):
  316. y = self.blas_func(1, a, x, 1, y, trans=0, incx=3)
  317. with pytest.raises(Exception, match='failed for 3rd argument'):
  318. y = self.blas_func(1, a, x, 1, y, trans=1, incx=3)
  319. def test_y_stride(self):
  320. alpha, beta, a, x, y = self.get_data(y_stride=2)
  321. desired_y = y.copy()
  322. desired_y[::2] = alpha*matrixmultiply(a, x)+beta*y[::2]
  323. y = self.blas_func(alpha, a, x, beta, y, incy=2)
  324. assert_array_almost_equal(desired_y, y)
  325. def test_y_stride_transpose(self):
  326. alpha, beta, a, x, y = self.get_data(y_stride=2)
  327. desired_y = y.copy()
  328. desired_y[::2] = alpha*matrixmultiply(transpose(a), x)+beta*y[::2]
  329. y = self.blas_func(alpha, a, x, beta, y, trans=1, incy=2)
  330. assert_array_almost_equal(desired_y, y)
  331. def test_y_stride_assert(self):
  332. # What is the use of this test?
  333. alpha, beta, a, x, y = self.get_data(y_stride=2)
  334. with pytest.raises(Exception, match='failed for 2nd keyword'):
  335. y = self.blas_func(1, a, x, 1, y, trans=0, incy=3)
  336. with pytest.raises(Exception, match='failed for 2nd keyword'):
  337. y = self.blas_func(1, a, x, 1, y, trans=1, incy=3)
  338. try:
  339. class TestSgemv(BaseGemv):
  340. blas_func = fblas.sgemv
  341. dtype = float32
  342. def test_sgemv_on_osx(self):
  343. from itertools import product
  344. import sys
  345. import numpy as np
  346. if sys.platform != 'darwin':
  347. return
  348. def aligned_array(shape, align, dtype, order='C'):
  349. # Make array shape `shape` with aligned at `align` bytes
  350. d = dtype()
  351. # Make array of correct size with `align` extra bytes
  352. N = np.prod(shape)
  353. tmp = np.zeros(N * d.nbytes + align, dtype=np.uint8)
  354. address = tmp.__array_interface__["data"][0]
  355. # Find offset into array giving desired alignment
  356. for offset in range(align):
  357. if (address + offset) % align == 0:
  358. break
  359. tmp = tmp[offset:offset+N*d.nbytes].view(dtype=dtype)
  360. return tmp.reshape(shape, order=order)
  361. def as_aligned(arr, align, dtype, order='C'):
  362. # Copy `arr` into an aligned array with same shape
  363. aligned = aligned_array(arr.shape, align, dtype, order)
  364. aligned[:] = arr[:]
  365. return aligned
  366. def assert_dot_close(A, X, desired):
  367. assert_allclose(self.blas_func(1.0, A, X), desired,
  368. rtol=1e-5, atol=1e-7)
  369. testdata = product((15, 32), (10000,), (200, 89), ('C', 'F'))
  370. for align, m, n, a_order in testdata:
  371. A_d = np.random.rand(m, n)
  372. X_d = np.random.rand(n)
  373. desired = np.dot(A_d, X_d)
  374. # Calculation with aligned single precision
  375. A_f = as_aligned(A_d, align, np.float32, order=a_order)
  376. X_f = as_aligned(X_d, align, np.float32, order=a_order)
  377. assert_dot_close(A_f, X_f, desired)
  378. except AttributeError:
  379. class TestSgemv:
  380. pass
  381. class TestDgemv(BaseGemv):
  382. blas_func = fblas.dgemv
  383. dtype = float64
  384. try:
  385. class TestCgemv(BaseGemv):
  386. blas_func = fblas.cgemv
  387. dtype = complex64
  388. except AttributeError:
  389. class TestCgemv:
  390. pass
  391. class TestZgemv(BaseGemv):
  392. blas_func = fblas.zgemv
  393. dtype = complex128
  394. """
  395. ##################################################
  396. ### Test blas ?ger
  397. ### This will be a mess to test all cases.
  398. class BaseGer:
  399. def get_data(self,x_stride=1,y_stride=1):
  400. from numpy.random import normal, seed
  401. seed(1234)
  402. alpha = array(1., dtype = self.dtype)
  403. a = normal(0.,1.,(3,3)).astype(self.dtype)
  404. x = arange(shape(a)[0]*x_stride,dtype=self.dtype)
  405. y = arange(shape(a)[1]*y_stride,dtype=self.dtype)
  406. return alpha,a,x,y
  407. def test_simple(self):
  408. alpha,a,x,y = self.get_data()
  409. # tranpose takes care of Fortran vs. C(and Python) memory layout
  410. desired_a = alpha*transpose(x[:,newaxis]*y) + a
  411. self.blas_func(x,y,a)
  412. assert_array_almost_equal(desired_a,a)
  413. def test_x_stride(self):
  414. alpha,a,x,y = self.get_data(x_stride=2)
  415. desired_a = alpha*transpose(x[::2,newaxis]*y) + a
  416. self.blas_func(x,y,a,incx=2)
  417. assert_array_almost_equal(desired_a,a)
  418. def test_x_stride_assert(self):
  419. alpha,a,x,y = self.get_data(x_stride=2)
  420. with pytest.raises(ValueError, match='foo'):
  421. self.blas_func(x,y,a,incx=3)
  422. def test_y_stride(self):
  423. alpha,a,x,y = self.get_data(y_stride=2)
  424. desired_a = alpha*transpose(x[:,newaxis]*y[::2]) + a
  425. self.blas_func(x,y,a,incy=2)
  426. assert_array_almost_equal(desired_a,a)
  427. def test_y_stride_assert(self):
  428. alpha,a,x,y = self.get_data(y_stride=2)
  429. with pytest.raises(ValueError, match='foo'):
  430. self.blas_func(a,x,y,incy=3)
  431. class TestSger(BaseGer):
  432. blas_func = fblas.sger
  433. dtype = float32
  434. class TestDger(BaseGer):
  435. blas_func = fblas.dger
  436. dtype = float64
  437. """
  438. ##################################################
  439. # Test blas ?gerc
  440. # This will be a mess to test all cases.
  441. """
  442. class BaseGerComplex(BaseGer):
  443. def get_data(self,x_stride=1,y_stride=1):
  444. from numpy.random import normal, seed
  445. seed(1234)
  446. alpha = array(1+1j, dtype = self.dtype)
  447. a = normal(0.,1.,(3,3)).astype(self.dtype)
  448. a = a + normal(0.,1.,(3,3)) * array(1j, dtype = self.dtype)
  449. x = normal(0.,1.,shape(a)[0]*x_stride).astype(self.dtype)
  450. x = x + x * array(1j, dtype = self.dtype)
  451. y = normal(0.,1.,shape(a)[1]*y_stride).astype(self.dtype)
  452. y = y + y * array(1j, dtype = self.dtype)
  453. return alpha,a,x,y
  454. def test_simple(self):
  455. alpha,a,x,y = self.get_data()
  456. # tranpose takes care of Fortran vs. C(and Python) memory layout
  457. a = a * array(0.,dtype = self.dtype)
  458. #desired_a = alpha*transpose(x[:,newaxis]*self.transform(y)) + a
  459. desired_a = alpha*transpose(x[:,newaxis]*y) + a
  460. #self.blas_func(x,y,a,alpha = alpha)
  461. fblas.cgeru(x,y,a,alpha = alpha)
  462. assert_array_almost_equal(desired_a,a)
  463. #def test_x_stride(self):
  464. # alpha,a,x,y = self.get_data(x_stride=2)
  465. # desired_a = alpha*transpose(x[::2,newaxis]*self.transform(y)) + a
  466. # self.blas_func(x,y,a,incx=2)
  467. # assert_array_almost_equal(desired_a,a)
  468. #def test_y_stride(self):
  469. # alpha,a,x,y = self.get_data(y_stride=2)
  470. # desired_a = alpha*transpose(x[:,newaxis]*self.transform(y[::2])) + a
  471. # self.blas_func(x,y,a,incy=2)
  472. # assert_array_almost_equal(desired_a,a)
  473. class TestCgeru(BaseGerComplex):
  474. blas_func = fblas.cgeru
  475. dtype = complex64
  476. def transform(self,x):
  477. return x
  478. class TestZgeru(BaseGerComplex):
  479. blas_func = fblas.zgeru
  480. dtype = complex128
  481. def transform(self,x):
  482. return x
  483. class TestCgerc(BaseGerComplex):
  484. blas_func = fblas.cgerc
  485. dtype = complex64
  486. def transform(self,x):
  487. return conjugate(x)
  488. class TestZgerc(BaseGerComplex):
  489. blas_func = fblas.zgerc
  490. dtype = complex128
  491. def transform(self,x):
  492. return conjugate(x)
  493. """