test_real_transforms.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import numpy as np
  2. from numpy.testing import assert_allclose, assert_array_equal
  3. import pytest
  4. from scipy.fft import dct, idct, dctn, idctn, dst, idst, dstn, idstn
  5. import scipy.fft as fft
  6. from scipy import fftpack
  7. import math
  8. SQRT_2 = math.sqrt(2)
  9. # scipy.fft wraps the fftpack versions but with normalized inverse transforms.
  10. # So, the forward transforms and definitions are already thoroughly tested in
  11. # fftpack/test_real_transforms.py
  12. @pytest.mark.parametrize("forward, backward", [(dct, idct), (dst, idst)])
  13. @pytest.mark.parametrize("type", [1, 2, 3, 4])
  14. @pytest.mark.parametrize("n", [2, 3, 4, 5, 10, 16])
  15. @pytest.mark.parametrize("axis", [0, 1])
  16. @pytest.mark.parametrize("norm", [None, 'backward', 'ortho', 'forward'])
  17. @pytest.mark.parametrize("orthogonalize", [False, True])
  18. def test_identity_1d(forward, backward, type, n, axis, norm, orthogonalize):
  19. # Test the identity f^-1(f(x)) == x
  20. x = np.random.rand(n, n)
  21. y = forward(x, type, axis=axis, norm=norm, orthogonalize=orthogonalize)
  22. z = backward(y, type, axis=axis, norm=norm, orthogonalize=orthogonalize)
  23. assert_allclose(z, x)
  24. pad = [(0, 0)] * 2
  25. pad[axis] = (0, 4)
  26. y2 = np.pad(y, pad, mode='edge')
  27. z2 = backward(y2, type, n, axis, norm, orthogonalize=orthogonalize)
  28. assert_allclose(z2, x)
  29. @pytest.mark.parametrize("forward, backward", [(dct, idct), (dst, idst)])
  30. @pytest.mark.parametrize("type", [1, 2, 3, 4])
  31. @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64,
  32. np.complex64, np.complex128])
  33. @pytest.mark.parametrize("axis", [0, 1])
  34. @pytest.mark.parametrize("norm", [None, 'backward', 'ortho', 'forward'])
  35. @pytest.mark.parametrize("overwrite_x", [True, False])
  36. def test_identity_1d_overwrite(forward, backward, type, dtype, axis, norm,
  37. overwrite_x):
  38. # Test the identity f^-1(f(x)) == x
  39. x = np.random.rand(7, 8).astype(dtype)
  40. x_orig = x.copy()
  41. y = forward(x, type, axis=axis, norm=norm, overwrite_x=overwrite_x)
  42. y_orig = y.copy()
  43. z = backward(y, type, axis=axis, norm=norm, overwrite_x=overwrite_x)
  44. if not overwrite_x:
  45. assert_allclose(z, x, rtol=1e-6, atol=1e-6)
  46. assert_array_equal(x, x_orig)
  47. assert_array_equal(y, y_orig)
  48. else:
  49. assert_allclose(z, x_orig, rtol=1e-6, atol=1e-6)
  50. @pytest.mark.parametrize("forward, backward", [(dctn, idctn), (dstn, idstn)])
  51. @pytest.mark.parametrize("type", [1, 2, 3, 4])
  52. @pytest.mark.parametrize("shape, axes",
  53. [
  54. ((4, 4), 0),
  55. ((4, 4), 1),
  56. ((4, 4), None),
  57. ((4, 4), (0, 1)),
  58. ((10, 12), None),
  59. ((10, 12), (0, 1)),
  60. ((4, 5, 6), None),
  61. ((4, 5, 6), 1),
  62. ((4, 5, 6), (0, 2)),
  63. ])
  64. @pytest.mark.parametrize("norm", [None, 'backward', 'ortho', 'forward'])
  65. @pytest.mark.parametrize("orthogonalize", [False, True])
  66. def test_identity_nd(forward, backward, type, shape, axes, norm,
  67. orthogonalize):
  68. # Test the identity f^-1(f(x)) == x
  69. x = np.random.random(shape)
  70. if axes is not None:
  71. shape = np.take(shape, axes)
  72. y = forward(x, type, axes=axes, norm=norm, orthogonalize=orthogonalize)
  73. z = backward(y, type, axes=axes, norm=norm, orthogonalize=orthogonalize)
  74. assert_allclose(z, x)
  75. if axes is None:
  76. pad = [(0, 4)] * x.ndim
  77. elif isinstance(axes, int):
  78. pad = [(0, 0)] * x.ndim
  79. pad[axes] = (0, 4)
  80. else:
  81. pad = [(0, 0)] * x.ndim
  82. for a in axes:
  83. pad[a] = (0, 4)
  84. y2 = np.pad(y, pad, mode='edge')
  85. z2 = backward(y2, type, shape, axes, norm, orthogonalize=orthogonalize)
  86. assert_allclose(z2, x)
  87. @pytest.mark.parametrize("forward, backward", [(dctn, idctn), (dstn, idstn)])
  88. @pytest.mark.parametrize("type", [1, 2, 3, 4])
  89. @pytest.mark.parametrize("shape, axes",
  90. [
  91. ((4, 5), 0),
  92. ((4, 5), 1),
  93. ((4, 5), None),
  94. ])
  95. @pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64,
  96. np.complex64, np.complex128])
  97. @pytest.mark.parametrize("norm", [None, 'backward', 'ortho', 'forward'])
  98. @pytest.mark.parametrize("overwrite_x", [False, True])
  99. def test_identity_nd_overwrite(forward, backward, type, shape, axes, dtype,
  100. norm, overwrite_x):
  101. # Test the identity f^-1(f(x)) == x
  102. x = np.random.random(shape).astype(dtype)
  103. x_orig = x.copy()
  104. if axes is not None:
  105. shape = np.take(shape, axes)
  106. y = forward(x, type, axes=axes, norm=norm)
  107. y_orig = y.copy()
  108. z = backward(y, type, axes=axes, norm=norm)
  109. if overwrite_x:
  110. assert_allclose(z, x_orig, rtol=1e-6, atol=1e-6)
  111. else:
  112. assert_allclose(z, x, rtol=1e-6, atol=1e-6)
  113. assert_array_equal(x, x_orig)
  114. assert_array_equal(y, y_orig)
  115. @pytest.mark.parametrize("func", ['dct', 'dst', 'dctn', 'dstn'])
  116. @pytest.mark.parametrize("type", [1, 2, 3, 4])
  117. @pytest.mark.parametrize("norm", [None, 'backward', 'ortho', 'forward'])
  118. def test_fftpack_equivalience(func, type, norm):
  119. x = np.random.rand(8, 16)
  120. fft_res = getattr(fft, func)(x, type, norm=norm)
  121. fftpack_res = getattr(fftpack, func)(x, type, norm=norm)
  122. assert_allclose(fft_res, fftpack_res)
  123. @pytest.mark.parametrize("func", [dct, dst, dctn, dstn])
  124. @pytest.mark.parametrize("type", [1, 2, 3, 4])
  125. def test_orthogonalize_default(func, type):
  126. # Test orthogonalize is the default when norm="ortho", but not otherwise
  127. x = np.random.rand(100)
  128. for norm, ortho in [
  129. ("forward", False),
  130. ("backward", False),
  131. ("ortho", True),
  132. ]:
  133. a = func(x, type=type, norm=norm, orthogonalize=ortho)
  134. b = func(x, type=type, norm=norm)
  135. assert_allclose(a, b)
  136. @pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
  137. @pytest.mark.parametrize("func, type", [
  138. (dct, 4), (dst, 1), (dst, 4)])
  139. def test_orthogonalize_noop(func, type, norm):
  140. # Transforms where orthogonalize is a no-op
  141. x = np.random.rand(100)
  142. y1 = func(x, type=type, norm=norm, orthogonalize=True)
  143. y2 = func(x, type=type, norm=norm, orthogonalize=False)
  144. assert_allclose(y1, y2)
  145. @pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
  146. def test_orthogonalize_dct1(norm):
  147. x = np.random.rand(100)
  148. x2 = x.copy()
  149. x2[0] *= SQRT_2
  150. x2[-1] *= SQRT_2
  151. y1 = dct(x, type=1, norm=norm, orthogonalize=True)
  152. y2 = dct(x2, type=1, norm=norm, orthogonalize=False)
  153. y2[0] /= SQRT_2
  154. y2[-1] /= SQRT_2
  155. assert_allclose(y1, y2)
  156. @pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
  157. @pytest.mark.parametrize("func", [dct, dst])
  158. def test_orthogonalize_dcst2(func, norm):
  159. x = np.random.rand(100)
  160. y1 = func(x, type=2, norm=norm, orthogonalize=True)
  161. y2 = func(x, type=2, norm=norm, orthogonalize=False)
  162. y2[0] /= SQRT_2
  163. assert_allclose(y1, y2)
  164. @pytest.mark.parametrize("norm", ["backward", "ortho", "forward"])
  165. @pytest.mark.parametrize("func", [dct, dst])
  166. def test_orthogonalize_dcst3(func, norm):
  167. x = np.random.rand(100)
  168. x2 = x.copy()
  169. x2[0] *= SQRT_2
  170. y1 = func(x, type=3, norm=norm, orthogonalize=True)
  171. y2 = func(x2, type=3, norm=norm, orthogonalize=False)
  172. assert_allclose(y1, y2)