test_real_transforms.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815
  1. from os.path import join, dirname
  2. import numpy as np
  3. from numpy.testing import assert_array_almost_equal, assert_equal
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from scipy.fftpack._realtransforms import (
  7. dct, idct, dst, idst, dctn, idctn, dstn, idstn)
  8. # Matlab reference data
  9. MDATA = np.load(join(dirname(__file__), 'test.npz'))
  10. X = [MDATA['x%d' % i] for i in range(8)]
  11. Y = [MDATA['y%d' % i] for i in range(8)]
  12. # FFTW reference data: the data are organized as follows:
  13. # * SIZES is an array containing all available sizes
  14. # * for every type (1, 2, 3, 4) and every size, the array dct_type_size
  15. # contains the output of the DCT applied to the input np.linspace(0, size-1,
  16. # size)
  17. FFTWDATA_DOUBLE = np.load(join(dirname(__file__), 'fftw_double_ref.npz'))
  18. FFTWDATA_SINGLE = np.load(join(dirname(__file__), 'fftw_single_ref.npz'))
  19. FFTWDATA_SIZES = FFTWDATA_DOUBLE['sizes']
  20. def fftw_dct_ref(type, size, dt):
  21. x = np.linspace(0, size-1, size).astype(dt)
  22. dt = np.result_type(np.float32, dt)
  23. if dt == np.double:
  24. data = FFTWDATA_DOUBLE
  25. elif dt == np.float32:
  26. data = FFTWDATA_SINGLE
  27. else:
  28. raise ValueError()
  29. y = (data['dct_%d_%d' % (type, size)]).astype(dt)
  30. return x, y, dt
  31. def fftw_dst_ref(type, size, dt):
  32. x = np.linspace(0, size-1, size).astype(dt)
  33. dt = np.result_type(np.float32, dt)
  34. if dt == np.double:
  35. data = FFTWDATA_DOUBLE
  36. elif dt == np.float32:
  37. data = FFTWDATA_SINGLE
  38. else:
  39. raise ValueError()
  40. y = (data['dst_%d_%d' % (type, size)]).astype(dt)
  41. return x, y, dt
  42. def dct_2d_ref(x, **kwargs):
  43. """Calculate reference values for testing dct2."""
  44. x = np.array(x, copy=True)
  45. for row in range(x.shape[0]):
  46. x[row, :] = dct(x[row, :], **kwargs)
  47. for col in range(x.shape[1]):
  48. x[:, col] = dct(x[:, col], **kwargs)
  49. return x
  50. def idct_2d_ref(x, **kwargs):
  51. """Calculate reference values for testing idct2."""
  52. x = np.array(x, copy=True)
  53. for row in range(x.shape[0]):
  54. x[row, :] = idct(x[row, :], **kwargs)
  55. for col in range(x.shape[1]):
  56. x[:, col] = idct(x[:, col], **kwargs)
  57. return x
  58. def dst_2d_ref(x, **kwargs):
  59. """Calculate reference values for testing dst2."""
  60. x = np.array(x, copy=True)
  61. for row in range(x.shape[0]):
  62. x[row, :] = dst(x[row, :], **kwargs)
  63. for col in range(x.shape[1]):
  64. x[:, col] = dst(x[:, col], **kwargs)
  65. return x
  66. def idst_2d_ref(x, **kwargs):
  67. """Calculate reference values for testing idst2."""
  68. x = np.array(x, copy=True)
  69. for row in range(x.shape[0]):
  70. x[row, :] = idst(x[row, :], **kwargs)
  71. for col in range(x.shape[1]):
  72. x[:, col] = idst(x[:, col], **kwargs)
  73. return x
  74. def naive_dct1(x, norm=None):
  75. """Calculate textbook definition version of DCT-I."""
  76. x = np.array(x, copy=True)
  77. N = len(x)
  78. M = N-1
  79. y = np.zeros(N)
  80. m0, m = 1, 2
  81. if norm == 'ortho':
  82. m0 = np.sqrt(1.0/M)
  83. m = np.sqrt(2.0/M)
  84. for k in range(N):
  85. for n in range(1, N-1):
  86. y[k] += m*x[n]*np.cos(np.pi*n*k/M)
  87. y[k] += m0 * x[0]
  88. y[k] += m0 * x[N-1] * (1 if k % 2 == 0 else -1)
  89. if norm == 'ortho':
  90. y[0] *= 1/np.sqrt(2)
  91. y[N-1] *= 1/np.sqrt(2)
  92. return y
  93. def naive_dst1(x, norm=None):
  94. """Calculate textbook definition version of DST-I."""
  95. x = np.array(x, copy=True)
  96. N = len(x)
  97. M = N+1
  98. y = np.zeros(N)
  99. for k in range(N):
  100. for n in range(N):
  101. y[k] += 2*x[n]*np.sin(np.pi*(n+1.0)*(k+1.0)/M)
  102. if norm == 'ortho':
  103. y *= np.sqrt(0.5/M)
  104. return y
  105. def naive_dct4(x, norm=None):
  106. """Calculate textbook definition version of DCT-IV."""
  107. x = np.array(x, copy=True)
  108. N = len(x)
  109. y = np.zeros(N)
  110. for k in range(N):
  111. for n in range(N):
  112. y[k] += x[n]*np.cos(np.pi*(n+0.5)*(k+0.5)/(N))
  113. if norm == 'ortho':
  114. y *= np.sqrt(2.0/N)
  115. else:
  116. y *= 2
  117. return y
  118. def naive_dst4(x, norm=None):
  119. """Calculate textbook definition version of DST-IV."""
  120. x = np.array(x, copy=True)
  121. N = len(x)
  122. y = np.zeros(N)
  123. for k in range(N):
  124. for n in range(N):
  125. y[k] += x[n]*np.sin(np.pi*(n+0.5)*(k+0.5)/(N))
  126. if norm == 'ortho':
  127. y *= np.sqrt(2.0/N)
  128. else:
  129. y *= 2
  130. return y
  131. class TestComplex:
  132. def test_dct_complex64(self):
  133. y = dct(1j*np.arange(5, dtype=np.complex64))
  134. x = 1j*dct(np.arange(5))
  135. assert_array_almost_equal(x, y)
  136. def test_dct_complex(self):
  137. y = dct(np.arange(5)*1j)
  138. x = 1j*dct(np.arange(5))
  139. assert_array_almost_equal(x, y)
  140. def test_idct_complex(self):
  141. y = idct(np.arange(5)*1j)
  142. x = 1j*idct(np.arange(5))
  143. assert_array_almost_equal(x, y)
  144. def test_dst_complex64(self):
  145. y = dst(np.arange(5, dtype=np.complex64)*1j)
  146. x = 1j*dst(np.arange(5))
  147. assert_array_almost_equal(x, y)
  148. def test_dst_complex(self):
  149. y = dst(np.arange(5)*1j)
  150. x = 1j*dst(np.arange(5))
  151. assert_array_almost_equal(x, y)
  152. def test_idst_complex(self):
  153. y = idst(np.arange(5)*1j)
  154. x = 1j*idst(np.arange(5))
  155. assert_array_almost_equal(x, y)
  156. class _TestDCTBase:
  157. def setup_method(self):
  158. self.rdt = None
  159. self.dec = 14
  160. self.type = None
  161. def test_definition(self):
  162. for i in FFTWDATA_SIZES:
  163. x, yr, dt = fftw_dct_ref(self.type, i, self.rdt)
  164. y = dct(x, type=self.type)
  165. assert_equal(y.dtype, dt)
  166. # XXX: we divide by np.max(y) because the tests fail otherwise. We
  167. # should really use something like assert_array_approx_equal. The
  168. # difference is due to fftw using a better algorithm w.r.t error
  169. # propagation compared to the ones from fftpack.
  170. assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
  171. err_msg="Size %d failed" % i)
  172. def test_axis(self):
  173. nt = 2
  174. for i in [7, 8, 9, 16, 32, 64]:
  175. x = np.random.randn(nt, i)
  176. y = dct(x, type=self.type)
  177. for j in range(nt):
  178. assert_array_almost_equal(y[j], dct(x[j], type=self.type),
  179. decimal=self.dec)
  180. x = x.T
  181. y = dct(x, axis=0, type=self.type)
  182. for j in range(nt):
  183. assert_array_almost_equal(y[:,j], dct(x[:,j], type=self.type),
  184. decimal=self.dec)
  185. class _TestDCTIBase(_TestDCTBase):
  186. def test_definition_ortho(self):
  187. # Test orthornomal mode.
  188. dt = np.result_type(np.float32, self.rdt)
  189. for xr in X:
  190. x = np.array(xr, dtype=self.rdt)
  191. y = dct(x, norm='ortho', type=1)
  192. y2 = naive_dct1(x, norm='ortho')
  193. assert_equal(y.dtype, dt)
  194. assert_array_almost_equal(y / np.max(y), y2 / np.max(y), decimal=self.dec)
  195. class _TestDCTIIBase(_TestDCTBase):
  196. def test_definition_matlab(self):
  197. # Test correspondence with MATLAB (orthornomal mode).
  198. dt = np.result_type(np.float32, self.rdt)
  199. for xr, yr in zip(X, Y):
  200. x = np.array(xr, dtype=dt)
  201. y = dct(x, norm="ortho", type=2)
  202. assert_equal(y.dtype, dt)
  203. assert_array_almost_equal(y, yr, decimal=self.dec)
  204. class _TestDCTIIIBase(_TestDCTBase):
  205. def test_definition_ortho(self):
  206. # Test orthornomal mode.
  207. dt = np.result_type(np.float32, self.rdt)
  208. for xr in X:
  209. x = np.array(xr, dtype=self.rdt)
  210. y = dct(x, norm='ortho', type=2)
  211. xi = dct(y, norm="ortho", type=3)
  212. assert_equal(xi.dtype, dt)
  213. assert_array_almost_equal(xi, x, decimal=self.dec)
  214. class _TestDCTIVBase(_TestDCTBase):
  215. def test_definition_ortho(self):
  216. # Test orthornomal mode.
  217. dt = np.result_type(np.float32, self.rdt)
  218. for xr in X:
  219. x = np.array(xr, dtype=self.rdt)
  220. y = dct(x, norm='ortho', type=4)
  221. y2 = naive_dct4(x, norm='ortho')
  222. assert_equal(y.dtype, dt)
  223. assert_array_almost_equal(y / np.max(y), y2 / np.max(y), decimal=self.dec)
  224. class TestDCTIDouble(_TestDCTIBase):
  225. def setup_method(self):
  226. self.rdt = np.double
  227. self.dec = 10
  228. self.type = 1
  229. class TestDCTIFloat(_TestDCTIBase):
  230. def setup_method(self):
  231. self.rdt = np.float32
  232. self.dec = 4
  233. self.type = 1
  234. class TestDCTIInt(_TestDCTIBase):
  235. def setup_method(self):
  236. self.rdt = int
  237. self.dec = 5
  238. self.type = 1
  239. class TestDCTIIDouble(_TestDCTIIBase):
  240. def setup_method(self):
  241. self.rdt = np.double
  242. self.dec = 10
  243. self.type = 2
  244. class TestDCTIIFloat(_TestDCTIIBase):
  245. def setup_method(self):
  246. self.rdt = np.float32
  247. self.dec = 5
  248. self.type = 2
  249. class TestDCTIIInt(_TestDCTIIBase):
  250. def setup_method(self):
  251. self.rdt = int
  252. self.dec = 5
  253. self.type = 2
  254. class TestDCTIIIDouble(_TestDCTIIIBase):
  255. def setup_method(self):
  256. self.rdt = np.double
  257. self.dec = 14
  258. self.type = 3
  259. class TestDCTIIIFloat(_TestDCTIIIBase):
  260. def setup_method(self):
  261. self.rdt = np.float32
  262. self.dec = 5
  263. self.type = 3
  264. class TestDCTIIIInt(_TestDCTIIIBase):
  265. def setup_method(self):
  266. self.rdt = int
  267. self.dec = 5
  268. self.type = 3
  269. class TestDCTIVDouble(_TestDCTIVBase):
  270. def setup_method(self):
  271. self.rdt = np.double
  272. self.dec = 12
  273. self.type = 3
  274. class TestDCTIVFloat(_TestDCTIVBase):
  275. def setup_method(self):
  276. self.rdt = np.float32
  277. self.dec = 5
  278. self.type = 3
  279. class TestDCTIVInt(_TestDCTIVBase):
  280. def setup_method(self):
  281. self.rdt = int
  282. self.dec = 5
  283. self.type = 3
  284. class _TestIDCTBase:
  285. def setup_method(self):
  286. self.rdt = None
  287. self.dec = 14
  288. self.type = None
  289. def test_definition(self):
  290. for i in FFTWDATA_SIZES:
  291. xr, yr, dt = fftw_dct_ref(self.type, i, self.rdt)
  292. x = idct(yr, type=self.type)
  293. if self.type == 1:
  294. x /= 2 * (i-1)
  295. else:
  296. x /= 2 * i
  297. assert_equal(x.dtype, dt)
  298. # XXX: we divide by np.max(y) because the tests fail otherwise. We
  299. # should really use something like assert_array_approx_equal. The
  300. # difference is due to fftw using a better algorithm w.r.t error
  301. # propagation compared to the ones from fftpack.
  302. assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
  303. err_msg="Size %d failed" % i)
  304. class TestIDCTIDouble(_TestIDCTBase):
  305. def setup_method(self):
  306. self.rdt = np.double
  307. self.dec = 10
  308. self.type = 1
  309. class TestIDCTIFloat(_TestIDCTBase):
  310. def setup_method(self):
  311. self.rdt = np.float32
  312. self.dec = 4
  313. self.type = 1
  314. class TestIDCTIInt(_TestIDCTBase):
  315. def setup_method(self):
  316. self.rdt = int
  317. self.dec = 4
  318. self.type = 1
  319. class TestIDCTIIDouble(_TestIDCTBase):
  320. def setup_method(self):
  321. self.rdt = np.double
  322. self.dec = 10
  323. self.type = 2
  324. class TestIDCTIIFloat(_TestIDCTBase):
  325. def setup_method(self):
  326. self.rdt = np.float32
  327. self.dec = 5
  328. self.type = 2
  329. class TestIDCTIIInt(_TestIDCTBase):
  330. def setup_method(self):
  331. self.rdt = int
  332. self.dec = 5
  333. self.type = 2
  334. class TestIDCTIIIDouble(_TestIDCTBase):
  335. def setup_method(self):
  336. self.rdt = np.double
  337. self.dec = 14
  338. self.type = 3
  339. class TestIDCTIIIFloat(_TestIDCTBase):
  340. def setup_method(self):
  341. self.rdt = np.float32
  342. self.dec = 5
  343. self.type = 3
  344. class TestIDCTIIIInt(_TestIDCTBase):
  345. def setup_method(self):
  346. self.rdt = int
  347. self.dec = 5
  348. self.type = 3
  349. class TestIDCTIVDouble(_TestIDCTBase):
  350. def setup_method(self):
  351. self.rdt = np.double
  352. self.dec = 12
  353. self.type = 4
  354. class TestIDCTIVFloat(_TestIDCTBase):
  355. def setup_method(self):
  356. self.rdt = np.float32
  357. self.dec = 5
  358. self.type = 4
  359. class TestIDCTIVInt(_TestIDCTBase):
  360. def setup_method(self):
  361. self.rdt = int
  362. self.dec = 5
  363. self.type = 4
  364. class _TestDSTBase:
  365. def setup_method(self):
  366. self.rdt = None # dtype
  367. self.dec = None # number of decimals to match
  368. self.type = None # dst type
  369. def test_definition(self):
  370. for i in FFTWDATA_SIZES:
  371. xr, yr, dt = fftw_dst_ref(self.type, i, self.rdt)
  372. y = dst(xr, type=self.type)
  373. assert_equal(y.dtype, dt)
  374. # XXX: we divide by np.max(y) because the tests fail otherwise. We
  375. # should really use something like assert_array_approx_equal. The
  376. # difference is due to fftw using a better algorithm w.r.t error
  377. # propagation compared to the ones from fftpack.
  378. assert_array_almost_equal(y / np.max(y), yr / np.max(y), decimal=self.dec,
  379. err_msg="Size %d failed" % i)
  380. class _TestDSTIBase(_TestDSTBase):
  381. def test_definition_ortho(self):
  382. # Test orthornomal mode.
  383. dt = np.result_type(np.float32, self.rdt)
  384. for xr in X:
  385. x = np.array(xr, dtype=self.rdt)
  386. y = dst(x, norm='ortho', type=1)
  387. y2 = naive_dst1(x, norm='ortho')
  388. assert_equal(y.dtype, dt)
  389. assert_array_almost_equal(y / np.max(y), y2 / np.max(y), decimal=self.dec)
  390. class _TestDSTIVBase(_TestDSTBase):
  391. def test_definition_ortho(self):
  392. # Test orthornomal mode.
  393. dt = np.result_type(np.float32, self.rdt)
  394. for xr in X:
  395. x = np.array(xr, dtype=self.rdt)
  396. y = dst(x, norm='ortho', type=4)
  397. y2 = naive_dst4(x, norm='ortho')
  398. assert_equal(y.dtype, dt)
  399. assert_array_almost_equal(y, y2, decimal=self.dec)
  400. class TestDSTIDouble(_TestDSTIBase):
  401. def setup_method(self):
  402. self.rdt = np.double
  403. self.dec = 12
  404. self.type = 1
  405. class TestDSTIFloat(_TestDSTIBase):
  406. def setup_method(self):
  407. self.rdt = np.float32
  408. self.dec = 4
  409. self.type = 1
  410. class TestDSTIInt(_TestDSTIBase):
  411. def setup_method(self):
  412. self.rdt = int
  413. self.dec = 5
  414. self.type = 1
  415. class TestDSTIIDouble(_TestDSTBase):
  416. def setup_method(self):
  417. self.rdt = np.double
  418. self.dec = 14
  419. self.type = 2
  420. class TestDSTIIFloat(_TestDSTBase):
  421. def setup_method(self):
  422. self.rdt = np.float32
  423. self.dec = 6
  424. self.type = 2
  425. class TestDSTIIInt(_TestDSTBase):
  426. def setup_method(self):
  427. self.rdt = int
  428. self.dec = 6
  429. self.type = 2
  430. class TestDSTIIIDouble(_TestDSTBase):
  431. def setup_method(self):
  432. self.rdt = np.double
  433. self.dec = 14
  434. self.type = 3
  435. class TestDSTIIIFloat(_TestDSTBase):
  436. def setup_method(self):
  437. self.rdt = np.float32
  438. self.dec = 7
  439. self.type = 3
  440. class TestDSTIIIInt(_TestDSTBase):
  441. def setup_method(self):
  442. self.rdt = int
  443. self.dec = 7
  444. self.type = 3
  445. class TestDSTIVDouble(_TestDSTIVBase):
  446. def setup_method(self):
  447. self.rdt = np.double
  448. self.dec = 12
  449. self.type = 4
  450. class TestDSTIVFloat(_TestDSTIVBase):
  451. def setup_method(self):
  452. self.rdt = np.float32
  453. self.dec = 4
  454. self.type = 4
  455. class TestDSTIVInt(_TestDSTIVBase):
  456. def setup_method(self):
  457. self.rdt = int
  458. self.dec = 5
  459. self.type = 4
  460. class _TestIDSTBase:
  461. def setup_method(self):
  462. self.rdt = None
  463. self.dec = None
  464. self.type = None
  465. def test_definition(self):
  466. for i in FFTWDATA_SIZES:
  467. xr, yr, dt = fftw_dst_ref(self.type, i, self.rdt)
  468. x = idst(yr, type=self.type)
  469. if self.type == 1:
  470. x /= 2 * (i+1)
  471. else:
  472. x /= 2 * i
  473. assert_equal(x.dtype, dt)
  474. # XXX: we divide by np.max(x) because the tests fail otherwise. We
  475. # should really use something like assert_array_approx_equal. The
  476. # difference is due to fftw using a better algorithm w.r.t error
  477. # propagation compared to the ones from fftpack.
  478. assert_array_almost_equal(x / np.max(x), xr / np.max(x), decimal=self.dec,
  479. err_msg="Size %d failed" % i)
  480. class TestIDSTIDouble(_TestIDSTBase):
  481. def setup_method(self):
  482. self.rdt = np.double
  483. self.dec = 12
  484. self.type = 1
  485. class TestIDSTIFloat(_TestIDSTBase):
  486. def setup_method(self):
  487. self.rdt = np.float32
  488. self.dec = 4
  489. self.type = 1
  490. class TestIDSTIInt(_TestIDSTBase):
  491. def setup_method(self):
  492. self.rdt = int
  493. self.dec = 4
  494. self.type = 1
  495. class TestIDSTIIDouble(_TestIDSTBase):
  496. def setup_method(self):
  497. self.rdt = np.double
  498. self.dec = 14
  499. self.type = 2
  500. class TestIDSTIIFloat(_TestIDSTBase):
  501. def setup_method(self):
  502. self.rdt = np.float32
  503. self.dec = 6
  504. self.type = 2
  505. class TestIDSTIIInt(_TestIDSTBase):
  506. def setup_method(self):
  507. self.rdt = int
  508. self.dec = 6
  509. self.type = 2
  510. class TestIDSTIIIDouble(_TestIDSTBase):
  511. def setup_method(self):
  512. self.rdt = np.double
  513. self.dec = 14
  514. self.type = 3
  515. class TestIDSTIIIFloat(_TestIDSTBase):
  516. def setup_method(self):
  517. self.rdt = np.float32
  518. self.dec = 6
  519. self.type = 3
  520. class TestIDSTIIIInt(_TestIDSTBase):
  521. def setup_method(self):
  522. self.rdt = int
  523. self.dec = 6
  524. self.type = 3
  525. class TestIDSTIVDouble(_TestIDSTBase):
  526. def setup_method(self):
  527. self.rdt = np.double
  528. self.dec = 12
  529. self.type = 4
  530. class TestIDSTIVFloat(_TestIDSTBase):
  531. def setup_method(self):
  532. self.rdt = np.float32
  533. self.dec = 6
  534. self.type = 4
  535. class TestIDSTIVnt(_TestIDSTBase):
  536. def setup_method(self):
  537. self.rdt = int
  538. self.dec = 6
  539. self.type = 4
  540. class TestOverwrite:
  541. """Check input overwrite behavior."""
  542. real_dtypes = [np.float32, np.float64]
  543. def _check(self, x, routine, type, fftsize, axis, norm, overwrite_x, **kw):
  544. x2 = x.copy()
  545. routine(x2, type, fftsize, axis, norm, overwrite_x=overwrite_x)
  546. sig = "%s(%s%r, %r, axis=%r, overwrite_x=%r)" % (
  547. routine.__name__, x.dtype, x.shape, fftsize, axis, overwrite_x)
  548. if not overwrite_x:
  549. assert_equal(x2, x, err_msg="spurious overwrite in %s" % sig)
  550. def _check_1d(self, routine, dtype, shape, axis):
  551. np.random.seed(1234)
  552. if np.issubdtype(dtype, np.complexfloating):
  553. data = np.random.randn(*shape) + 1j*np.random.randn(*shape)
  554. else:
  555. data = np.random.randn(*shape)
  556. data = data.astype(dtype)
  557. for type in [1, 2, 3, 4]:
  558. for overwrite_x in [True, False]:
  559. for norm in [None, 'ortho']:
  560. self._check(data, routine, type, None, axis, norm,
  561. overwrite_x)
  562. def test_dct(self):
  563. for dtype in self.real_dtypes:
  564. self._check_1d(dct, dtype, (16,), -1)
  565. self._check_1d(dct, dtype, (16, 2), 0)
  566. self._check_1d(dct, dtype, (2, 16), 1)
  567. def test_idct(self):
  568. for dtype in self.real_dtypes:
  569. self._check_1d(idct, dtype, (16,), -1)
  570. self._check_1d(idct, dtype, (16, 2), 0)
  571. self._check_1d(idct, dtype, (2, 16), 1)
  572. def test_dst(self):
  573. for dtype in self.real_dtypes:
  574. self._check_1d(dst, dtype, (16,), -1)
  575. self._check_1d(dst, dtype, (16, 2), 0)
  576. self._check_1d(dst, dtype, (2, 16), 1)
  577. def test_idst(self):
  578. for dtype in self.real_dtypes:
  579. self._check_1d(idst, dtype, (16,), -1)
  580. self._check_1d(idst, dtype, (16, 2), 0)
  581. self._check_1d(idst, dtype, (2, 16), 1)
  582. class Test_DCTN_IDCTN:
  583. dec = 14
  584. dct_type = [1, 2, 3, 4]
  585. norms = [None, 'ortho']
  586. rstate = np.random.RandomState(1234)
  587. shape = (32, 16)
  588. data = rstate.randn(*shape)
  589. @pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
  590. (dstn, idstn)])
  591. @pytest.mark.parametrize('axes', [None,
  592. 1, (1,), [1],
  593. 0, (0,), [0],
  594. (0, 1), [0, 1],
  595. (-2, -1), [-2, -1]])
  596. @pytest.mark.parametrize('dct_type', dct_type)
  597. @pytest.mark.parametrize('norm', ['ortho'])
  598. def test_axes_round_trip(self, fforward, finverse, axes, dct_type, norm):
  599. tmp = fforward(self.data, type=dct_type, axes=axes, norm=norm)
  600. tmp = finverse(tmp, type=dct_type, axes=axes, norm=norm)
  601. assert_array_almost_equal(self.data, tmp, decimal=12)
  602. @pytest.mark.parametrize('fforward,fforward_ref', [(dctn, dct_2d_ref),
  603. (dstn, dst_2d_ref)])
  604. @pytest.mark.parametrize('dct_type', dct_type)
  605. @pytest.mark.parametrize('norm', norms)
  606. def test_dctn_vs_2d_reference(self, fforward, fforward_ref,
  607. dct_type, norm):
  608. y1 = fforward(self.data, type=dct_type, axes=None, norm=norm)
  609. y2 = fforward_ref(self.data, type=dct_type, norm=norm)
  610. assert_array_almost_equal(y1, y2, decimal=11)
  611. @pytest.mark.parametrize('finverse,finverse_ref', [(idctn, idct_2d_ref),
  612. (idstn, idst_2d_ref)])
  613. @pytest.mark.parametrize('dct_type', dct_type)
  614. @pytest.mark.parametrize('norm', [None, 'ortho'])
  615. def test_idctn_vs_2d_reference(self, finverse, finverse_ref,
  616. dct_type, norm):
  617. fdata = dctn(self.data, type=dct_type, norm=norm)
  618. y1 = finverse(fdata, type=dct_type, norm=norm)
  619. y2 = finverse_ref(fdata, type=dct_type, norm=norm)
  620. assert_array_almost_equal(y1, y2, decimal=11)
  621. @pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
  622. (dstn, idstn)])
  623. def test_axes_and_shape(self, fforward, finverse):
  624. with assert_raises(ValueError,
  625. match="when given, axes and shape arguments"
  626. " have to be of the same length"):
  627. fforward(self.data, shape=self.data.shape[0], axes=(0, 1))
  628. with assert_raises(ValueError,
  629. match="when given, axes and shape arguments"
  630. " have to be of the same length"):
  631. fforward(self.data, shape=self.data.shape[0], axes=None)
  632. with assert_raises(ValueError,
  633. match="when given, axes and shape arguments"
  634. " have to be of the same length"):
  635. fforward(self.data, shape=self.data.shape, axes=0)
  636. @pytest.mark.parametrize('fforward', [dctn, dstn])
  637. def test_shape(self, fforward):
  638. tmp = fforward(self.data, shape=(128, 128), axes=None)
  639. assert_equal(tmp.shape, (128, 128))
  640. @pytest.mark.parametrize('fforward,finverse', [(dctn, idctn),
  641. (dstn, idstn)])
  642. @pytest.mark.parametrize('axes', [1, (1,), [1],
  643. 0, (0,), [0]])
  644. def test_shape_is_none_with_axes(self, fforward, finverse, axes):
  645. tmp = fforward(self.data, shape=None, axes=axes, norm='ortho')
  646. tmp = finverse(tmp, shape=None, axes=axes, norm='ortho')
  647. assert_array_almost_equal(self.data, tmp, decimal=self.dec)