_wavelets.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. import numpy as np
  2. from scipy.linalg import eig
  3. from scipy.special import comb
  4. from scipy.signal import convolve
  5. __all__ = ['daub', 'qmf', 'cascade', 'morlet', 'ricker', 'morlet2', 'cwt']
  6. def daub(p):
  7. """
  8. The coefficients for the FIR low-pass filter producing Daubechies wavelets.
  9. p>=1 gives the order of the zero at f=1/2.
  10. There are 2p filter coefficients.
  11. Parameters
  12. ----------
  13. p : int
  14. Order of the zero at f=1/2, can have values from 1 to 34.
  15. Returns
  16. -------
  17. daub : ndarray
  18. Return
  19. """
  20. sqrt = np.sqrt
  21. if p < 1:
  22. raise ValueError("p must be at least 1.")
  23. if p == 1:
  24. c = 1 / sqrt(2)
  25. return np.array([c, c])
  26. elif p == 2:
  27. f = sqrt(2) / 8
  28. c = sqrt(3)
  29. return f * np.array([1 + c, 3 + c, 3 - c, 1 - c])
  30. elif p == 3:
  31. tmp = 12 * sqrt(10)
  32. z1 = 1.5 + sqrt(15 + tmp) / 6 - 1j * (sqrt(15) + sqrt(tmp - 15)) / 6
  33. z1c = np.conj(z1)
  34. f = sqrt(2) / 8
  35. d0 = np.real((1 - z1) * (1 - z1c))
  36. a0 = np.real(z1 * z1c)
  37. a1 = 2 * np.real(z1)
  38. return f / d0 * np.array([a0, 3 * a0 - a1, 3 * a0 - 3 * a1 + 1,
  39. a0 - 3 * a1 + 3, 3 - a1, 1])
  40. elif p < 35:
  41. # construct polynomial and factor it
  42. if p < 35:
  43. P = [comb(p - 1 + k, k, exact=1) for k in range(p)][::-1]
  44. yj = np.roots(P)
  45. else: # try different polynomial --- needs work
  46. P = [comb(p - 1 + k, k, exact=1) / 4.0**k
  47. for k in range(p)][::-1]
  48. yj = np.roots(P) / 4
  49. # for each root, compute two z roots, select the one with |z|>1
  50. # Build up final polynomial
  51. c = np.poly1d([1, 1])**p
  52. q = np.poly1d([1])
  53. for k in range(p - 1):
  54. yval = yj[k]
  55. part = 2 * sqrt(yval * (yval - 1))
  56. const = 1 - 2 * yval
  57. z1 = const + part
  58. if (abs(z1)) < 1:
  59. z1 = const - part
  60. q = q * [1, -z1]
  61. q = c * np.real(q)
  62. # Normalize result
  63. q = q / np.sum(q) * sqrt(2)
  64. return q.c[::-1]
  65. else:
  66. raise ValueError("Polynomial factorization does not work "
  67. "well for p too large.")
  68. def qmf(hk):
  69. """
  70. Return high-pass qmf filter from low-pass
  71. Parameters
  72. ----------
  73. hk : array_like
  74. Coefficients of high-pass filter.
  75. Returns
  76. -------
  77. array_like
  78. High-pass filter coefficients.
  79. """
  80. N = len(hk) - 1
  81. asgn = [{0: 1, 1: -1}[k % 2] for k in range(N + 1)]
  82. return hk[::-1] * np.array(asgn)
  83. def cascade(hk, J=7):
  84. """
  85. Return (x, phi, psi) at dyadic points ``K/2**J`` from filter coefficients.
  86. Parameters
  87. ----------
  88. hk : array_like
  89. Coefficients of low-pass filter.
  90. J : int, optional
  91. Values will be computed at grid points ``K/2**J``. Default is 7.
  92. Returns
  93. -------
  94. x : ndarray
  95. The dyadic points ``K/2**J`` for ``K=0...N * (2**J)-1`` where
  96. ``len(hk) = len(gk) = N+1``.
  97. phi : ndarray
  98. The scaling function ``phi(x)`` at `x`:
  99. ``phi(x) = sum(hk * phi(2x-k))``, where k is from 0 to N.
  100. psi : ndarray, optional
  101. The wavelet function ``psi(x)`` at `x`:
  102. ``phi(x) = sum(gk * phi(2x-k))``, where k is from 0 to N.
  103. `psi` is only returned if `gk` is not None.
  104. Notes
  105. -----
  106. The algorithm uses the vector cascade algorithm described by Strang and
  107. Nguyen in "Wavelets and Filter Banks". It builds a dictionary of values
  108. and slices for quick reuse. Then inserts vectors into final vector at the
  109. end.
  110. """
  111. N = len(hk) - 1
  112. if (J > 30 - np.log2(N + 1)):
  113. raise ValueError("Too many levels.")
  114. if (J < 1):
  115. raise ValueError("Too few levels.")
  116. # construct matrices needed
  117. nn, kk = np.ogrid[:N, :N]
  118. s2 = np.sqrt(2)
  119. # append a zero so that take works
  120. thk = np.r_[hk, 0]
  121. gk = qmf(hk)
  122. tgk = np.r_[gk, 0]
  123. indx1 = np.clip(2 * nn - kk, -1, N + 1)
  124. indx2 = np.clip(2 * nn - kk + 1, -1, N + 1)
  125. m = np.empty((2, 2, N, N), 'd')
  126. m[0, 0] = np.take(thk, indx1, 0)
  127. m[0, 1] = np.take(thk, indx2, 0)
  128. m[1, 0] = np.take(tgk, indx1, 0)
  129. m[1, 1] = np.take(tgk, indx2, 0)
  130. m *= s2
  131. # construct the grid of points
  132. x = np.arange(0, N * (1 << J), dtype=float) / (1 << J)
  133. phi = 0 * x
  134. psi = 0 * x
  135. # find phi0, and phi1
  136. lam, v = eig(m[0, 0])
  137. ind = np.argmin(np.absolute(lam - 1))
  138. # a dictionary with a binary representation of the
  139. # evaluation points x < 1 -- i.e. position is 0.xxxx
  140. v = np.real(v[:, ind])
  141. # need scaling function to integrate to 1 so find
  142. # eigenvector normalized to sum(v,axis=0)=1
  143. sm = np.sum(v)
  144. if sm < 0: # need scaling function to integrate to 1
  145. v = -v
  146. sm = -sm
  147. bitdic = {'0': v / sm}
  148. bitdic['1'] = np.dot(m[0, 1], bitdic['0'])
  149. step = 1 << J
  150. phi[::step] = bitdic['0']
  151. phi[(1 << (J - 1))::step] = bitdic['1']
  152. psi[::step] = np.dot(m[1, 0], bitdic['0'])
  153. psi[(1 << (J - 1))::step] = np.dot(m[1, 1], bitdic['0'])
  154. # descend down the levels inserting more and more values
  155. # into bitdic -- store the values in the correct location once we
  156. # have computed them -- stored in the dictionary
  157. # for quicker use later.
  158. prevkeys = ['1']
  159. for level in range(2, J + 1):
  160. newkeys = ['%d%s' % (xx, yy) for xx in [0, 1] for yy in prevkeys]
  161. fac = 1 << (J - level)
  162. for key in newkeys:
  163. # convert key to number
  164. num = 0
  165. for pos in range(level):
  166. if key[pos] == '1':
  167. num += (1 << (level - 1 - pos))
  168. pastphi = bitdic[key[1:]]
  169. ii = int(key[0])
  170. temp = np.dot(m[0, ii], pastphi)
  171. bitdic[key] = temp
  172. phi[num * fac::step] = temp
  173. psi[num * fac::step] = np.dot(m[1, ii], pastphi)
  174. prevkeys = newkeys
  175. return x, phi, psi
  176. def morlet(M, w=5.0, s=1.0, complete=True):
  177. """
  178. Complex Morlet wavelet.
  179. Parameters
  180. ----------
  181. M : int
  182. Length of the wavelet.
  183. w : float, optional
  184. Omega0. Default is 5
  185. s : float, optional
  186. Scaling factor, windowed from ``-s*2*pi`` to ``+s*2*pi``. Default is 1.
  187. complete : bool, optional
  188. Whether to use the complete or the standard version.
  189. Returns
  190. -------
  191. morlet : (M,) ndarray
  192. See Also
  193. --------
  194. morlet2 : Implementation of Morlet wavelet, compatible with `cwt`.
  195. scipy.signal.gausspulse
  196. Notes
  197. -----
  198. The standard version::
  199. pi**-0.25 * exp(1j*w*x) * exp(-0.5*(x**2))
  200. This commonly used wavelet is often referred to simply as the
  201. Morlet wavelet. Note that this simplified version can cause
  202. admissibility problems at low values of `w`.
  203. The complete version::
  204. pi**-0.25 * (exp(1j*w*x) - exp(-0.5*(w**2))) * exp(-0.5*(x**2))
  205. This version has a correction
  206. term to improve admissibility. For `w` greater than 5, the
  207. correction term is negligible.
  208. Note that the energy of the return wavelet is not normalised
  209. according to `s`.
  210. The fundamental frequency of this wavelet in Hz is given
  211. by ``f = 2*s*w*r / M`` where `r` is the sampling rate.
  212. Note: This function was created before `cwt` and is not compatible
  213. with it.
  214. Examples
  215. --------
  216. >>> from scipy import signal
  217. >>> import matplotlib.pyplot as plt
  218. >>> M = 100
  219. >>> s = 4.0
  220. >>> w = 2.0
  221. >>> wavelet = signal.morlet(M, s, w)
  222. >>> plt.plot(wavelet)
  223. >>> plt.show()
  224. """
  225. x = np.linspace(-s * 2 * np.pi, s * 2 * np.pi, M)
  226. output = np.exp(1j * w * x)
  227. if complete:
  228. output -= np.exp(-0.5 * (w**2))
  229. output *= np.exp(-0.5 * (x**2)) * np.pi**(-0.25)
  230. return output
  231. def ricker(points, a):
  232. """
  233. Return a Ricker wavelet, also known as the "Mexican hat wavelet".
  234. It models the function:
  235. ``A * (1 - (x/a)**2) * exp(-0.5*(x/a)**2)``,
  236. where ``A = 2/(sqrt(3*a)*(pi**0.25))``.
  237. Parameters
  238. ----------
  239. points : int
  240. Number of points in `vector`.
  241. Will be centered around 0.
  242. a : scalar
  243. Width parameter of the wavelet.
  244. Returns
  245. -------
  246. vector : (N,) ndarray
  247. Array of length `points` in shape of ricker curve.
  248. Examples
  249. --------
  250. >>> from scipy import signal
  251. >>> import matplotlib.pyplot as plt
  252. >>> points = 100
  253. >>> a = 4.0
  254. >>> vec2 = signal.ricker(points, a)
  255. >>> print(len(vec2))
  256. 100
  257. >>> plt.plot(vec2)
  258. >>> plt.show()
  259. """
  260. A = 2 / (np.sqrt(3 * a) * (np.pi**0.25))
  261. wsq = a**2
  262. vec = np.arange(0, points) - (points - 1.0) / 2
  263. xsq = vec**2
  264. mod = (1 - xsq / wsq)
  265. gauss = np.exp(-xsq / (2 * wsq))
  266. total = A * mod * gauss
  267. return total
  268. def morlet2(M, s, w=5):
  269. """
  270. Complex Morlet wavelet, designed to work with `cwt`.
  271. Returns the complete version of morlet wavelet, normalised
  272. according to `s`::
  273. exp(1j*w*x/s) * exp(-0.5*(x/s)**2) * pi**(-0.25) * sqrt(1/s)
  274. Parameters
  275. ----------
  276. M : int
  277. Length of the wavelet.
  278. s : float
  279. Width parameter of the wavelet.
  280. w : float, optional
  281. Omega0. Default is 5
  282. Returns
  283. -------
  284. morlet : (M,) ndarray
  285. See Also
  286. --------
  287. morlet : Implementation of Morlet wavelet, incompatible with `cwt`
  288. Notes
  289. -----
  290. .. versionadded:: 1.4.0
  291. This function was designed to work with `cwt`. Because `morlet2`
  292. returns an array of complex numbers, the `dtype` argument of `cwt`
  293. should be set to `complex128` for best results.
  294. Note the difference in implementation with `morlet`.
  295. The fundamental frequency of this wavelet in Hz is given by::
  296. f = w*fs / (2*s*np.pi)
  297. where ``fs`` is the sampling rate and `s` is the wavelet width parameter.
  298. Similarly we can get the wavelet width parameter at ``f``::
  299. s = w*fs / (2*f*np.pi)
  300. Examples
  301. --------
  302. >>> import numpy as np
  303. >>> from scipy import signal
  304. >>> import matplotlib.pyplot as plt
  305. >>> M = 100
  306. >>> s = 4.0
  307. >>> w = 2.0
  308. >>> wavelet = signal.morlet2(M, s, w)
  309. >>> plt.plot(abs(wavelet))
  310. >>> plt.show()
  311. This example shows basic use of `morlet2` with `cwt` in time-frequency
  312. analysis:
  313. >>> t, dt = np.linspace(0, 1, 200, retstep=True)
  314. >>> fs = 1/dt
  315. >>> w = 6.
  316. >>> sig = np.cos(2*np.pi*(50 + 10*t)*t) + np.sin(40*np.pi*t)
  317. >>> freq = np.linspace(1, fs/2, 100)
  318. >>> widths = w*fs / (2*freq*np.pi)
  319. >>> cwtm = signal.cwt(sig, signal.morlet2, widths, w=w)
  320. >>> plt.pcolormesh(t, freq, np.abs(cwtm), cmap='viridis', shading='gouraud')
  321. >>> plt.show()
  322. """
  323. x = np.arange(0, M) - (M - 1.0) / 2
  324. x = x / s
  325. wavelet = np.exp(1j * w * x) * np.exp(-0.5 * x**2) * np.pi**(-0.25)
  326. output = np.sqrt(1/s) * wavelet
  327. return output
  328. def cwt(data, wavelet, widths, dtype=None, **kwargs):
  329. """
  330. Continuous wavelet transform.
  331. Performs a continuous wavelet transform on `data`,
  332. using the `wavelet` function. A CWT performs a convolution
  333. with `data` using the `wavelet` function, which is characterized
  334. by a width parameter and length parameter. The `wavelet` function
  335. is allowed to be complex.
  336. Parameters
  337. ----------
  338. data : (N,) ndarray
  339. data on which to perform the transform.
  340. wavelet : function
  341. Wavelet function, which should take 2 arguments.
  342. The first argument is the number of points that the returned vector
  343. will have (len(wavelet(length,width)) == length).
  344. The second is a width parameter, defining the size of the wavelet
  345. (e.g. standard deviation of a gaussian). See `ricker`, which
  346. satisfies these requirements.
  347. widths : (M,) sequence
  348. Widths to use for transform.
  349. dtype : data-type, optional
  350. The desired data type of output. Defaults to ``float64`` if the
  351. output of `wavelet` is real and ``complex128`` if it is complex.
  352. .. versionadded:: 1.4.0
  353. kwargs
  354. Keyword arguments passed to wavelet function.
  355. .. versionadded:: 1.4.0
  356. Returns
  357. -------
  358. cwt: (M, N) ndarray
  359. Will have shape of (len(widths), len(data)).
  360. Notes
  361. -----
  362. .. versionadded:: 1.4.0
  363. For non-symmetric, complex-valued wavelets, the input signal is convolved
  364. with the time-reversed complex-conjugate of the wavelet data [1].
  365. ::
  366. length = min(10 * width[ii], len(data))
  367. cwt[ii,:] = signal.convolve(data, np.conj(wavelet(length, width[ii],
  368. **kwargs))[::-1], mode='same')
  369. References
  370. ----------
  371. .. [1] S. Mallat, "A Wavelet Tour of Signal Processing (3rd Edition)",
  372. Academic Press, 2009.
  373. Examples
  374. --------
  375. >>> import numpy as np
  376. >>> from scipy import signal
  377. >>> import matplotlib.pyplot as plt
  378. >>> t = np.linspace(-1, 1, 200, endpoint=False)
  379. >>> sig = np.cos(2 * np.pi * 7 * t) + signal.gausspulse(t - 0.4, fc=2)
  380. >>> widths = np.arange(1, 31)
  381. >>> cwtmatr = signal.cwt(sig, signal.ricker, widths)
  382. .. note:: For cwt matrix plotting it is advisable to flip the y-axis
  383. >>> cwtmatr_yflip = np.flipud(cwtmatr)
  384. >>> plt.imshow(cwtmatr_yflip, extent=[-1, 1, 1, 31], cmap='PRGn', aspect='auto',
  385. ... vmax=abs(cwtmatr).max(), vmin=-abs(cwtmatr).max())
  386. >>> plt.show()
  387. """
  388. # Determine output type
  389. if dtype is None:
  390. if np.asarray(wavelet(1, widths[0], **kwargs)).dtype.char in 'FDG':
  391. dtype = np.complex128
  392. else:
  393. dtype = np.float64
  394. output = np.empty((len(widths), len(data)), dtype=dtype)
  395. for ind, width in enumerate(widths):
  396. N = np.min([10 * width, len(data)])
  397. wavelet_data = np.conj(wavelet(N, width, **kwargs)[::-1])
  398. output[ind] = convolve(data, wavelet_data, mode='same')
  399. return output