test_waveforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import numpy as np
  2. from numpy.testing import (assert_almost_equal, assert_equal,
  3. assert_, assert_allclose, assert_array_equal)
  4. from pytest import raises as assert_raises
  5. import scipy.signal._waveforms as waveforms
  6. # These chirp_* functions are the instantaneous frequencies of the signals
  7. # returned by chirp().
  8. def chirp_linear(t, f0, f1, t1):
  9. f = f0 + (f1 - f0) * t / t1
  10. return f
  11. def chirp_quadratic(t, f0, f1, t1, vertex_zero=True):
  12. if vertex_zero:
  13. f = f0 + (f1 - f0) * t**2 / t1**2
  14. else:
  15. f = f1 - (f1 - f0) * (t1 - t)**2 / t1**2
  16. return f
  17. def chirp_geometric(t, f0, f1, t1):
  18. f = f0 * (f1/f0)**(t/t1)
  19. return f
  20. def chirp_hyperbolic(t, f0, f1, t1):
  21. f = f0*f1*t1 / ((f0 - f1)*t + f1*t1)
  22. return f
  23. def compute_frequency(t, theta):
  24. """
  25. Compute theta'(t)/(2*pi), where theta'(t) is the derivative of theta(t).
  26. """
  27. # Assume theta and t are 1-D NumPy arrays.
  28. # Assume that t is uniformly spaced.
  29. dt = t[1] - t[0]
  30. f = np.diff(theta)/(2*np.pi) / dt
  31. tf = 0.5*(t[1:] + t[:-1])
  32. return tf, f
  33. class TestChirp:
  34. def test_linear_at_zero(self):
  35. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='linear')
  36. assert_almost_equal(w, 1.0)
  37. def test_linear_freq_01(self):
  38. method = 'linear'
  39. f0 = 1.0
  40. f1 = 2.0
  41. t1 = 1.0
  42. t = np.linspace(0, t1, 100)
  43. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  44. tf, f = compute_frequency(t, phase)
  45. abserr = np.max(np.abs(f - chirp_linear(tf, f0, f1, t1)))
  46. assert_(abserr < 1e-6)
  47. def test_linear_freq_02(self):
  48. method = 'linear'
  49. f0 = 200.0
  50. f1 = 100.0
  51. t1 = 10.0
  52. t = np.linspace(0, t1, 100)
  53. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  54. tf, f = compute_frequency(t, phase)
  55. abserr = np.max(np.abs(f - chirp_linear(tf, f0, f1, t1)))
  56. assert_(abserr < 1e-6)
  57. def test_quadratic_at_zero(self):
  58. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='quadratic')
  59. assert_almost_equal(w, 1.0)
  60. def test_quadratic_at_zero2(self):
  61. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='quadratic',
  62. vertex_zero=False)
  63. assert_almost_equal(w, 1.0)
  64. def test_quadratic_freq_01(self):
  65. method = 'quadratic'
  66. f0 = 1.0
  67. f1 = 2.0
  68. t1 = 1.0
  69. t = np.linspace(0, t1, 2000)
  70. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  71. tf, f = compute_frequency(t, phase)
  72. abserr = np.max(np.abs(f - chirp_quadratic(tf, f0, f1, t1)))
  73. assert_(abserr < 1e-6)
  74. def test_quadratic_freq_02(self):
  75. method = 'quadratic'
  76. f0 = 20.0
  77. f1 = 10.0
  78. t1 = 10.0
  79. t = np.linspace(0, t1, 2000)
  80. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  81. tf, f = compute_frequency(t, phase)
  82. abserr = np.max(np.abs(f - chirp_quadratic(tf, f0, f1, t1)))
  83. assert_(abserr < 1e-6)
  84. def test_logarithmic_at_zero(self):
  85. w = waveforms.chirp(t=0, f0=1.0, f1=2.0, t1=1.0, method='logarithmic')
  86. assert_almost_equal(w, 1.0)
  87. def test_logarithmic_freq_01(self):
  88. method = 'logarithmic'
  89. f0 = 1.0
  90. f1 = 2.0
  91. t1 = 1.0
  92. t = np.linspace(0, t1, 10000)
  93. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  94. tf, f = compute_frequency(t, phase)
  95. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  96. assert_(abserr < 1e-6)
  97. def test_logarithmic_freq_02(self):
  98. method = 'logarithmic'
  99. f0 = 200.0
  100. f1 = 100.0
  101. t1 = 10.0
  102. t = np.linspace(0, t1, 10000)
  103. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  104. tf, f = compute_frequency(t, phase)
  105. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  106. assert_(abserr < 1e-6)
  107. def test_logarithmic_freq_03(self):
  108. method = 'logarithmic'
  109. f0 = 100.0
  110. f1 = 100.0
  111. t1 = 10.0
  112. t = np.linspace(0, t1, 10000)
  113. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  114. tf, f = compute_frequency(t, phase)
  115. abserr = np.max(np.abs(f - chirp_geometric(tf, f0, f1, t1)))
  116. assert_(abserr < 1e-6)
  117. def test_hyperbolic_at_zero(self):
  118. w = waveforms.chirp(t=0, f0=10.0, f1=1.0, t1=1.0, method='hyperbolic')
  119. assert_almost_equal(w, 1.0)
  120. def test_hyperbolic_freq_01(self):
  121. method = 'hyperbolic'
  122. t1 = 1.0
  123. t = np.linspace(0, t1, 10000)
  124. # f0 f1
  125. cases = [[10.0, 1.0],
  126. [1.0, 10.0],
  127. [-10.0, -1.0],
  128. [-1.0, -10.0]]
  129. for f0, f1 in cases:
  130. phase = waveforms._chirp_phase(t, f0, t1, f1, method)
  131. tf, f = compute_frequency(t, phase)
  132. expected = chirp_hyperbolic(tf, f0, f1, t1)
  133. assert_allclose(f, expected)
  134. def test_hyperbolic_zero_freq(self):
  135. # f0=0 or f1=0 must raise a ValueError.
  136. method = 'hyperbolic'
  137. t1 = 1.0
  138. t = np.linspace(0, t1, 5)
  139. assert_raises(ValueError, waveforms.chirp, t, 0, t1, 1, method)
  140. assert_raises(ValueError, waveforms.chirp, t, 1, t1, 0, method)
  141. def test_unknown_method(self):
  142. method = "foo"
  143. f0 = 10.0
  144. f1 = 20.0
  145. t1 = 1.0
  146. t = np.linspace(0, t1, 10)
  147. assert_raises(ValueError, waveforms.chirp, t, f0, t1, f1, method)
  148. def test_integer_t1(self):
  149. f0 = 10.0
  150. f1 = 20.0
  151. t = np.linspace(-1, 1, 11)
  152. t1 = 3.0
  153. float_result = waveforms.chirp(t, f0, t1, f1)
  154. t1 = 3
  155. int_result = waveforms.chirp(t, f0, t1, f1)
  156. err_msg = "Integer input 't1=3' gives wrong result"
  157. assert_equal(int_result, float_result, err_msg=err_msg)
  158. def test_integer_f0(self):
  159. f1 = 20.0
  160. t1 = 3.0
  161. t = np.linspace(-1, 1, 11)
  162. f0 = 10.0
  163. float_result = waveforms.chirp(t, f0, t1, f1)
  164. f0 = 10
  165. int_result = waveforms.chirp(t, f0, t1, f1)
  166. err_msg = "Integer input 'f0=10' gives wrong result"
  167. assert_equal(int_result, float_result, err_msg=err_msg)
  168. def test_integer_f1(self):
  169. f0 = 10.0
  170. t1 = 3.0
  171. t = np.linspace(-1, 1, 11)
  172. f1 = 20.0
  173. float_result = waveforms.chirp(t, f0, t1, f1)
  174. f1 = 20
  175. int_result = waveforms.chirp(t, f0, t1, f1)
  176. err_msg = "Integer input 'f1=20' gives wrong result"
  177. assert_equal(int_result, float_result, err_msg=err_msg)
  178. def test_integer_all(self):
  179. f0 = 10
  180. t1 = 3
  181. f1 = 20
  182. t = np.linspace(-1, 1, 11)
  183. float_result = waveforms.chirp(t, float(f0), float(t1), float(f1))
  184. int_result = waveforms.chirp(t, f0, t1, f1)
  185. err_msg = "Integer input 'f0=10, t1=3, f1=20' gives wrong result"
  186. assert_equal(int_result, float_result, err_msg=err_msg)
  187. class TestSweepPoly:
  188. def test_sweep_poly_quad1(self):
  189. p = np.poly1d([1.0, 0.0, 1.0])
  190. t = np.linspace(0, 3.0, 10000)
  191. phase = waveforms._sweep_poly_phase(t, p)
  192. tf, f = compute_frequency(t, phase)
  193. expected = p(tf)
  194. abserr = np.max(np.abs(f - expected))
  195. assert_(abserr < 1e-6)
  196. def test_sweep_poly_const(self):
  197. p = np.poly1d(2.0)
  198. t = np.linspace(0, 3.0, 10000)
  199. phase = waveforms._sweep_poly_phase(t, p)
  200. tf, f = compute_frequency(t, phase)
  201. expected = p(tf)
  202. abserr = np.max(np.abs(f - expected))
  203. assert_(abserr < 1e-6)
  204. def test_sweep_poly_linear(self):
  205. p = np.poly1d([-1.0, 10.0])
  206. t = np.linspace(0, 3.0, 10000)
  207. phase = waveforms._sweep_poly_phase(t, p)
  208. tf, f = compute_frequency(t, phase)
  209. expected = p(tf)
  210. abserr = np.max(np.abs(f - expected))
  211. assert_(abserr < 1e-6)
  212. def test_sweep_poly_quad2(self):
  213. p = np.poly1d([1.0, 0.0, -2.0])
  214. t = np.linspace(0, 3.0, 10000)
  215. phase = waveforms._sweep_poly_phase(t, p)
  216. tf, f = compute_frequency(t, phase)
  217. expected = p(tf)
  218. abserr = np.max(np.abs(f - expected))
  219. assert_(abserr < 1e-6)
  220. def test_sweep_poly_cubic(self):
  221. p = np.poly1d([2.0, 1.0, 0.0, -2.0])
  222. t = np.linspace(0, 2.0, 10000)
  223. phase = waveforms._sweep_poly_phase(t, p)
  224. tf, f = compute_frequency(t, phase)
  225. expected = p(tf)
  226. abserr = np.max(np.abs(f - expected))
  227. assert_(abserr < 1e-6)
  228. def test_sweep_poly_cubic2(self):
  229. """Use an array of coefficients instead of a poly1d."""
  230. p = np.array([2.0, 1.0, 0.0, -2.0])
  231. t = np.linspace(0, 2.0, 10000)
  232. phase = waveforms._sweep_poly_phase(t, p)
  233. tf, f = compute_frequency(t, phase)
  234. expected = np.poly1d(p)(tf)
  235. abserr = np.max(np.abs(f - expected))
  236. assert_(abserr < 1e-6)
  237. def test_sweep_poly_cubic3(self):
  238. """Use a list of coefficients instead of a poly1d."""
  239. p = [2.0, 1.0, 0.0, -2.0]
  240. t = np.linspace(0, 2.0, 10000)
  241. phase = waveforms._sweep_poly_phase(t, p)
  242. tf, f = compute_frequency(t, phase)
  243. expected = np.poly1d(p)(tf)
  244. abserr = np.max(np.abs(f - expected))
  245. assert_(abserr < 1e-6)
  246. class TestGaussPulse:
  247. def test_integer_fc(self):
  248. float_result = waveforms.gausspulse('cutoff', fc=1000.0)
  249. int_result = waveforms.gausspulse('cutoff', fc=1000)
  250. err_msg = "Integer input 'fc=1000' gives wrong result"
  251. assert_equal(int_result, float_result, err_msg=err_msg)
  252. def test_integer_bw(self):
  253. float_result = waveforms.gausspulse('cutoff', bw=1.0)
  254. int_result = waveforms.gausspulse('cutoff', bw=1)
  255. err_msg = "Integer input 'bw=1' gives wrong result"
  256. assert_equal(int_result, float_result, err_msg=err_msg)
  257. def test_integer_bwr(self):
  258. float_result = waveforms.gausspulse('cutoff', bwr=-6.0)
  259. int_result = waveforms.gausspulse('cutoff', bwr=-6)
  260. err_msg = "Integer input 'bwr=-6' gives wrong result"
  261. assert_equal(int_result, float_result, err_msg=err_msg)
  262. def test_integer_tpr(self):
  263. float_result = waveforms.gausspulse('cutoff', tpr=-60.0)
  264. int_result = waveforms.gausspulse('cutoff', tpr=-60)
  265. err_msg = "Integer input 'tpr=-60' gives wrong result"
  266. assert_equal(int_result, float_result, err_msg=err_msg)
  267. class TestUnitImpulse:
  268. def test_no_index(self):
  269. assert_array_equal(waveforms.unit_impulse(7), [1, 0, 0, 0, 0, 0, 0])
  270. assert_array_equal(waveforms.unit_impulse((3, 3)),
  271. [[1, 0, 0], [0, 0, 0], [0, 0, 0]])
  272. def test_index(self):
  273. assert_array_equal(waveforms.unit_impulse(10, 3),
  274. [0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
  275. assert_array_equal(waveforms.unit_impulse((3, 3), (1, 1)),
  276. [[0, 0, 0], [0, 1, 0], [0, 0, 0]])
  277. # Broadcasting
  278. imp = waveforms.unit_impulse((4, 4), 2)
  279. assert_array_equal(imp, np.array([[0, 0, 0, 0],
  280. [0, 0, 0, 0],
  281. [0, 0, 1, 0],
  282. [0, 0, 0, 0]]))
  283. def test_mid(self):
  284. assert_array_equal(waveforms.unit_impulse((3, 3), 'mid'),
  285. [[0, 0, 0], [0, 1, 0], [0, 0, 0]])
  286. assert_array_equal(waveforms.unit_impulse(9, 'mid'),
  287. [0, 0, 0, 0, 1, 0, 0, 0, 0])
  288. def test_dtype(self):
  289. imp = waveforms.unit_impulse(7)
  290. assert_(np.issubdtype(imp.dtype, np.floating))
  291. imp = waveforms.unit_impulse(5, 3, dtype=int)
  292. assert_(np.issubdtype(imp.dtype, np.integer))
  293. imp = waveforms.unit_impulse((5, 2), (3, 1), dtype=complex)
  294. assert_(np.issubdtype(imp.dtype, np.complexfloating))