test_czt.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # This program is public domain
  2. # Authors: Paul Kienzle, Nadav Horesh
  3. '''
  4. A unit test module for czt.py
  5. '''
  6. import pytest
  7. from numpy.testing import assert_allclose
  8. from scipy.fft import fft
  9. from scipy.signal import (czt, zoom_fft, czt_points, CZT, ZoomFFT)
  10. import numpy as np
  11. def check_czt(x):
  12. # Check that czt is the equivalent of normal fft
  13. y = fft(x)
  14. y1 = czt(x)
  15. assert_allclose(y1, y, rtol=1e-13)
  16. # Check that interpolated czt is the equivalent of normal fft
  17. y = fft(x, 100*len(x))
  18. y1 = czt(x, 100*len(x))
  19. assert_allclose(y1, y, rtol=1e-12)
  20. def check_zoom_fft(x):
  21. # Check that zoom_fft is the equivalent of normal fft
  22. y = fft(x)
  23. y1 = zoom_fft(x, [0, 2-2./len(y)], endpoint=True)
  24. assert_allclose(y1, y, rtol=1e-11, atol=1e-14)
  25. y1 = zoom_fft(x, [0, 2])
  26. assert_allclose(y1, y, rtol=1e-11, atol=1e-14)
  27. # Test fn scalar
  28. y1 = zoom_fft(x, 2-2./len(y), endpoint=True)
  29. assert_allclose(y1, y, rtol=1e-11, atol=1e-14)
  30. y1 = zoom_fft(x, 2)
  31. assert_allclose(y1, y, rtol=1e-11, atol=1e-14)
  32. # Check that zoom_fft with oversampling is equivalent to zero padding
  33. over = 10
  34. yover = fft(x, over*len(x))
  35. y2 = zoom_fft(x, [0, 2-2./len(yover)], m=len(yover), endpoint=True)
  36. assert_allclose(y2, yover, rtol=1e-12, atol=1e-10)
  37. y2 = zoom_fft(x, [0, 2], m=len(yover))
  38. assert_allclose(y2, yover, rtol=1e-12, atol=1e-10)
  39. # Check that zoom_fft works on a subrange
  40. w = np.linspace(0, 2-2./len(x), len(x))
  41. f1, f2 = w[3], w[6]
  42. y3 = zoom_fft(x, [f1, f2], m=3*over+1, endpoint=True)
  43. idx3 = slice(3*over, 6*over+1)
  44. assert_allclose(y3, yover[idx3], rtol=1e-13)
  45. def test_1D():
  46. # Test of 1D version of the transforms
  47. np.random.seed(0) # Deterministic randomness
  48. # Random signals
  49. lengths = np.random.randint(8, 200, 20)
  50. np.append(lengths, 1)
  51. for length in lengths:
  52. x = np.random.random(length)
  53. check_zoom_fft(x)
  54. check_czt(x)
  55. # Gauss
  56. t = np.linspace(-2, 2, 128)
  57. x = np.exp(-t**2/0.01)
  58. check_zoom_fft(x)
  59. # Linear
  60. x = [1, 2, 3, 4, 5, 6, 7]
  61. check_zoom_fft(x)
  62. # Check near powers of two
  63. check_zoom_fft(range(126-31))
  64. check_zoom_fft(range(127-31))
  65. check_zoom_fft(range(128-31))
  66. check_zoom_fft(range(129-31))
  67. check_zoom_fft(range(130-31))
  68. # Check transform on n-D array input
  69. x = np.reshape(np.arange(3*2*28), (3, 2, 28))
  70. y1 = zoom_fft(x, [0, 2-2./28])
  71. y2 = zoom_fft(x[2, 0, :], [0, 2-2./28])
  72. assert_allclose(y1[2, 0], y2, rtol=1e-13, atol=1e-12)
  73. y1 = zoom_fft(x, [0, 2], endpoint=False)
  74. y2 = zoom_fft(x[2, 0, :], [0, 2], endpoint=False)
  75. assert_allclose(y1[2, 0], y2, rtol=1e-13, atol=1e-12)
  76. # Random (not a test condition)
  77. x = np.random.rand(101)
  78. check_zoom_fft(x)
  79. # Spikes
  80. t = np.linspace(0, 1, 128)
  81. x = np.sin(2*np.pi*t*5)+np.sin(2*np.pi*t*13)
  82. check_zoom_fft(x)
  83. # Sines
  84. x = np.zeros(100, dtype=complex)
  85. x[[1, 5, 21]] = 1
  86. check_zoom_fft(x)
  87. # Sines plus complex component
  88. x += 1j*np.linspace(0, 0.5, x.shape[0])
  89. check_zoom_fft(x)
  90. def test_large_prime_lengths():
  91. np.random.seed(0) # Deterministic randomness
  92. for N in (101, 1009, 10007):
  93. x = np.random.rand(N)
  94. y = fft(x)
  95. y1 = czt(x)
  96. assert_allclose(y, y1, rtol=1e-12)
  97. @pytest.mark.slow
  98. def test_czt_vs_fft():
  99. np.random.seed(123)
  100. random_lengths = np.random.exponential(100000, size=10).astype('int')
  101. for n in random_lengths:
  102. a = np.random.randn(n)
  103. assert_allclose(czt(a), fft(a), rtol=1e-11)
  104. def test_empty_input():
  105. with pytest.raises(ValueError, match='Invalid number of CZT'):
  106. czt([])
  107. with pytest.raises(ValueError, match='Invalid number of CZT'):
  108. zoom_fft([], 0.5)
  109. def test_0_rank_input():
  110. with pytest.raises(IndexError, match='tuple index out of range'):
  111. czt(5)
  112. with pytest.raises(IndexError, match='tuple index out of range'):
  113. zoom_fft(5, 0.5)
  114. @pytest.mark.parametrize('impulse', ([0, 0, 1], [0, 0, 1, 0, 0],
  115. np.concatenate((np.array([0, 0, 1]),
  116. np.zeros(100)))))
  117. @pytest.mark.parametrize('m', (1, 3, 5, 8, 101, 1021))
  118. @pytest.mark.parametrize('a', (1, 2, 0.5, 1.1))
  119. # Step that tests away from the unit circle, but not so far it explodes from
  120. # numerical error
  121. @pytest.mark.parametrize('w', (None, 0.98534 + 0.17055j))
  122. def test_czt_math(impulse, m, w, a):
  123. # z-transform of an impulse is 1 everywhere
  124. assert_allclose(czt(impulse[2:], m=m, w=w, a=a),
  125. np.ones(m), rtol=1e-10)
  126. # z-transform of a delayed impulse is z**-1
  127. assert_allclose(czt(impulse[1:], m=m, w=w, a=a),
  128. czt_points(m=m, w=w, a=a)**-1, rtol=1e-10)
  129. # z-transform of a 2-delayed impulse is z**-2
  130. assert_allclose(czt(impulse, m=m, w=w, a=a),
  131. czt_points(m=m, w=w, a=a)**-2, rtol=1e-10)
  132. def test_int_args():
  133. # Integer argument `a` was producing all 0s
  134. assert_allclose(abs(czt([0, 1], m=10, a=2)), 0.5*np.ones(10), rtol=1e-15)
  135. assert_allclose(czt_points(11, w=2), 1/(2**np.arange(11)), rtol=1e-30)
  136. def test_czt_points():
  137. for N in (1, 2, 3, 8, 11, 100, 101, 10007):
  138. assert_allclose(czt_points(N), np.exp(2j*np.pi*np.arange(N)/N),
  139. rtol=1e-30)
  140. assert_allclose(czt_points(7, w=1), np.ones(7), rtol=1e-30)
  141. assert_allclose(czt_points(11, w=2.), 1/(2**np.arange(11)), rtol=1e-30)
  142. func = CZT(12, m=11, w=2., a=1)
  143. assert_allclose(func.points(), 1/(2**np.arange(11)), rtol=1e-30)
  144. @pytest.mark.parametrize('cls, args', [(CZT, (100,)), (ZoomFFT, (100, 0.2))])
  145. def test_CZT_size_mismatch(cls, args):
  146. # Data size doesn't match function's expected size
  147. myfunc = cls(*args)
  148. with pytest.raises(ValueError, match='CZT defined for'):
  149. myfunc(np.arange(5))
  150. def test_invalid_range():
  151. with pytest.raises(ValueError, match='2-length sequence'):
  152. ZoomFFT(100, [1, 2, 3])
  153. @pytest.mark.parametrize('m', [0, -11, 5.5, 4.0])
  154. def test_czt_points_errors(m):
  155. # Invalid number of points
  156. with pytest.raises(ValueError, match='Invalid number of CZT'):
  157. czt_points(m)
  158. @pytest.mark.parametrize('size', [0, -5, 3.5, 4.0])
  159. def test_nonsense_size(size):
  160. # Numpy and Scipy fft() give ValueError for 0 output size, so we do, too
  161. with pytest.raises(ValueError, match='Invalid number of CZT'):
  162. CZT(size, 3)
  163. with pytest.raises(ValueError, match='Invalid number of CZT'):
  164. ZoomFFT(size, 0.2, 3)
  165. with pytest.raises(ValueError, match='Invalid number of CZT'):
  166. CZT(3, size)
  167. with pytest.raises(ValueError, match='Invalid number of CZT'):
  168. ZoomFFT(3, 0.2, size)
  169. with pytest.raises(ValueError, match='Invalid number of CZT'):
  170. czt([1, 2, 3], size)
  171. with pytest.raises(ValueError, match='Invalid number of CZT'):
  172. zoom_fft([1, 2, 3], 0.2, size)