test_wavelets.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import numpy as np
  2. from numpy.testing import assert_equal, \
  3. assert_array_equal, assert_array_almost_equal, assert_array_less, assert_
  4. import scipy.signal._wavelets as wavelets
  5. class TestWavelets:
  6. def test_qmf(self):
  7. assert_array_equal(wavelets.qmf([1, 1]), [1, -1])
  8. def test_daub(self):
  9. for i in range(1, 15):
  10. assert_equal(len(wavelets.daub(i)), i * 2)
  11. def test_cascade(self):
  12. for J in range(1, 7):
  13. for i in range(1, 5):
  14. lpcoef = wavelets.daub(i)
  15. k = len(lpcoef)
  16. x, phi, psi = wavelets.cascade(lpcoef, J)
  17. assert_(len(x) == len(phi) == len(psi))
  18. assert_equal(len(x), (k - 1) * 2 ** J)
  19. def test_morlet(self):
  20. x = wavelets.morlet(50, 4.1, complete=True)
  21. y = wavelets.morlet(50, 4.1, complete=False)
  22. # Test if complete and incomplete wavelet have same lengths:
  23. assert_equal(len(x), len(y))
  24. # Test if complete wavelet is less than incomplete wavelet:
  25. assert_array_less(x, y)
  26. x = wavelets.morlet(10, 50, complete=False)
  27. y = wavelets.morlet(10, 50, complete=True)
  28. # For large widths complete and incomplete wavelets should be
  29. # identical within numerical precision:
  30. assert_equal(x, y)
  31. # miscellaneous tests:
  32. x = np.array([1.73752399e-09 + 9.84327394e-25j,
  33. 6.49471756e-01 + 0.00000000e+00j,
  34. 1.73752399e-09 - 9.84327394e-25j])
  35. y = wavelets.morlet(3, w=2, complete=True)
  36. assert_array_almost_equal(x, y)
  37. x = np.array([2.00947715e-09 + 9.84327394e-25j,
  38. 7.51125544e-01 + 0.00000000e+00j,
  39. 2.00947715e-09 - 9.84327394e-25j])
  40. y = wavelets.morlet(3, w=2, complete=False)
  41. assert_array_almost_equal(x, y, decimal=2)
  42. x = wavelets.morlet(10000, s=4, complete=True)
  43. y = wavelets.morlet(20000, s=8, complete=True)[5000:15000]
  44. assert_array_almost_equal(x, y, decimal=2)
  45. x = wavelets.morlet(10000, s=4, complete=False)
  46. assert_array_almost_equal(y, x, decimal=2)
  47. y = wavelets.morlet(20000, s=8, complete=False)[5000:15000]
  48. assert_array_almost_equal(x, y, decimal=2)
  49. x = wavelets.morlet(10000, w=3, s=5, complete=True)
  50. y = wavelets.morlet(20000, w=3, s=10, complete=True)[5000:15000]
  51. assert_array_almost_equal(x, y, decimal=2)
  52. x = wavelets.morlet(10000, w=3, s=5, complete=False)
  53. assert_array_almost_equal(y, x, decimal=2)
  54. y = wavelets.morlet(20000, w=3, s=10, complete=False)[5000:15000]
  55. assert_array_almost_equal(x, y, decimal=2)
  56. x = wavelets.morlet(10000, w=7, s=10, complete=True)
  57. y = wavelets.morlet(20000, w=7, s=20, complete=True)[5000:15000]
  58. assert_array_almost_equal(x, y, decimal=2)
  59. x = wavelets.morlet(10000, w=7, s=10, complete=False)
  60. assert_array_almost_equal(x, y, decimal=2)
  61. y = wavelets.morlet(20000, w=7, s=20, complete=False)[5000:15000]
  62. assert_array_almost_equal(x, y, decimal=2)
  63. def test_morlet2(self):
  64. w = wavelets.morlet2(1.0, 0.5)
  65. expected = (np.pi**(-0.25) * np.sqrt(1/0.5)).astype(complex)
  66. assert_array_equal(w, expected)
  67. lengths = [5, 11, 15, 51, 101]
  68. for length in lengths:
  69. w = wavelets.morlet2(length, 1.0)
  70. assert_(len(w) == length)
  71. max_loc = np.argmax(w)
  72. assert_(max_loc == (length // 2))
  73. points = 100
  74. w = abs(wavelets.morlet2(points, 2.0))
  75. half_vec = np.arange(0, points // 2)
  76. assert_array_almost_equal(w[half_vec], w[-(half_vec + 1)])
  77. x = np.array([5.03701224e-09 + 2.46742437e-24j,
  78. 1.88279253e+00 + 0.00000000e+00j,
  79. 5.03701224e-09 - 2.46742437e-24j])
  80. y = wavelets.morlet2(3, s=1/(2*np.pi), w=2)
  81. assert_array_almost_equal(x, y)
  82. def test_ricker(self):
  83. w = wavelets.ricker(1.0, 1)
  84. expected = 2 / (np.sqrt(3 * 1.0) * (np.pi ** 0.25))
  85. assert_array_equal(w, expected)
  86. lengths = [5, 11, 15, 51, 101]
  87. for length in lengths:
  88. w = wavelets.ricker(length, 1.0)
  89. assert_(len(w) == length)
  90. max_loc = np.argmax(w)
  91. assert_(max_loc == (length // 2))
  92. points = 100
  93. w = wavelets.ricker(points, 2.0)
  94. half_vec = np.arange(0, points // 2)
  95. #Wavelet should be symmetric
  96. assert_array_almost_equal(w[half_vec], w[-(half_vec + 1)])
  97. #Check zeros
  98. aas = [5, 10, 15, 20, 30]
  99. points = 99
  100. for a in aas:
  101. w = wavelets.ricker(points, a)
  102. vec = np.arange(0, points) - (points - 1.0) / 2
  103. exp_zero1 = np.argmin(np.abs(vec - a))
  104. exp_zero2 = np.argmin(np.abs(vec + a))
  105. assert_array_almost_equal(w[exp_zero1], 0)
  106. assert_array_almost_equal(w[exp_zero2], 0)
  107. def test_cwt(self):
  108. widths = [1.0]
  109. delta_wavelet = lambda s, t: np.array([1])
  110. len_data = 100
  111. test_data = np.sin(np.pi * np.arange(0, len_data) / 10.0)
  112. #Test delta function input gives same data as output
  113. cwt_dat = wavelets.cwt(test_data, delta_wavelet, widths)
  114. assert_(cwt_dat.shape == (len(widths), len_data))
  115. assert_array_almost_equal(test_data, cwt_dat.flatten())
  116. #Check proper shape on output
  117. widths = [1, 3, 4, 5, 10]
  118. cwt_dat = wavelets.cwt(test_data, wavelets.ricker, widths)
  119. assert_(cwt_dat.shape == (len(widths), len_data))
  120. widths = [len_data * 10]
  121. #Note: this wavelet isn't defined quite right, but is fine for this test
  122. flat_wavelet = lambda l, w: np.full(w, 1 / w)
  123. cwt_dat = wavelets.cwt(test_data, flat_wavelet, widths)
  124. assert_array_almost_equal(cwt_dat, np.mean(test_data))