basic.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. """
  2. Discrete Fourier Transforms - basic.py
  3. """
  4. import numpy as np
  5. import functools
  6. from . import pypocketfft as pfft
  7. from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
  8. _fix_shape, _fix_shape_1d, _normalization,
  9. _workers)
  10. def c2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
  11. workers=None, *, plan=None):
  12. """ Return discrete Fourier transform of real or complex sequence. """
  13. if plan is not None:
  14. raise NotImplementedError('Passing a precomputed plan is not yet '
  15. 'supported by scipy.fft functions')
  16. tmp = _asfarray(x)
  17. overwrite_x = overwrite_x or _datacopied(tmp, x)
  18. norm = _normalization(norm, forward)
  19. workers = _workers(workers)
  20. if n is not None:
  21. tmp, copied = _fix_shape_1d(tmp, n, axis)
  22. overwrite_x = overwrite_x or copied
  23. elif tmp.shape[axis] < 1:
  24. raise ValueError("invalid number of data points ({0}) specified"
  25. .format(tmp.shape[axis]))
  26. out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
  27. return pfft.c2c(tmp, (axis,), forward, norm, out, workers)
  28. fft = functools.partial(c2c, True)
  29. fft.__name__ = 'fft'
  30. ifft = functools.partial(c2c, False)
  31. ifft.__name__ = 'ifft'
  32. def r2c(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
  33. workers=None, *, plan=None):
  34. """
  35. Discrete Fourier transform of a real sequence.
  36. """
  37. if plan is not None:
  38. raise NotImplementedError('Passing a precomputed plan is not yet '
  39. 'supported by scipy.fft functions')
  40. tmp = _asfarray(x)
  41. norm = _normalization(norm, forward)
  42. workers = _workers(workers)
  43. if not np.isrealobj(tmp):
  44. raise TypeError("x must be a real sequence")
  45. if n is not None:
  46. tmp, _ = _fix_shape_1d(tmp, n, axis)
  47. elif tmp.shape[axis] < 1:
  48. raise ValueError("invalid number of data points ({0}) specified"
  49. .format(tmp.shape[axis]))
  50. # Note: overwrite_x is not utilised
  51. return pfft.r2c(tmp, (axis,), forward, norm, None, workers)
  52. rfft = functools.partial(r2c, True)
  53. rfft.__name__ = 'rfft'
  54. ihfft = functools.partial(r2c, False)
  55. ihfft.__name__ = 'ihfft'
  56. def c2r(forward, x, n=None, axis=-1, norm=None, overwrite_x=False,
  57. workers=None, *, plan=None):
  58. """
  59. Return inverse discrete Fourier transform of real sequence x.
  60. """
  61. if plan is not None:
  62. raise NotImplementedError('Passing a precomputed plan is not yet '
  63. 'supported by scipy.fft functions')
  64. tmp = _asfarray(x)
  65. norm = _normalization(norm, forward)
  66. workers = _workers(workers)
  67. # TODO: Optimize for hermitian and real?
  68. if np.isrealobj(tmp):
  69. tmp = tmp + 0.j
  70. # Last axis utilizes hermitian symmetry
  71. if n is None:
  72. n = (tmp.shape[axis] - 1) * 2
  73. if n < 1:
  74. raise ValueError("Invalid number of data points ({0}) specified"
  75. .format(n))
  76. else:
  77. tmp, _ = _fix_shape_1d(tmp, (n//2) + 1, axis)
  78. # Note: overwrite_x is not utilized
  79. return pfft.c2r(tmp, (axis,), n, forward, norm, None, workers)
  80. hfft = functools.partial(c2r, True)
  81. hfft.__name__ = 'hfft'
  82. irfft = functools.partial(c2r, False)
  83. irfft.__name__ = 'irfft'
  84. def fft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  85. *, plan=None):
  86. """
  87. 2-D discrete Fourier transform.
  88. """
  89. if plan is not None:
  90. raise NotImplementedError('Passing a precomputed plan is not yet '
  91. 'supported by scipy.fft functions')
  92. return fftn(x, s, axes, norm, overwrite_x, workers)
  93. def ifft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  94. *, plan=None):
  95. """
  96. 2-D discrete inverse Fourier transform of real or complex sequence.
  97. """
  98. if plan is not None:
  99. raise NotImplementedError('Passing a precomputed plan is not yet '
  100. 'supported by scipy.fft functions')
  101. return ifftn(x, s, axes, norm, overwrite_x, workers)
  102. def rfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  103. *, plan=None):
  104. """
  105. 2-D discrete Fourier transform of a real sequence
  106. """
  107. if plan is not None:
  108. raise NotImplementedError('Passing a precomputed plan is not yet '
  109. 'supported by scipy.fft functions')
  110. return rfftn(x, s, axes, norm, overwrite_x, workers)
  111. def irfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  112. *, plan=None):
  113. """
  114. 2-D discrete inverse Fourier transform of a real sequence
  115. """
  116. if plan is not None:
  117. raise NotImplementedError('Passing a precomputed plan is not yet '
  118. 'supported by scipy.fft functions')
  119. return irfftn(x, s, axes, norm, overwrite_x, workers)
  120. def hfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  121. *, plan=None):
  122. """
  123. 2-D discrete Fourier transform of a Hermitian sequence
  124. """
  125. if plan is not None:
  126. raise NotImplementedError('Passing a precomputed plan is not yet '
  127. 'supported by scipy.fft functions')
  128. return hfftn(x, s, axes, norm, overwrite_x, workers)
  129. def ihfft2(x, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None,
  130. *, plan=None):
  131. """
  132. 2-D discrete inverse Fourier transform of a Hermitian sequence
  133. """
  134. if plan is not None:
  135. raise NotImplementedError('Passing a precomputed plan is not yet '
  136. 'supported by scipy.fft functions')
  137. return ihfftn(x, s, axes, norm, overwrite_x, workers)
  138. def c2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
  139. workers=None, *, plan=None):
  140. """
  141. Return multidimensional discrete Fourier transform.
  142. """
  143. if plan is not None:
  144. raise NotImplementedError('Passing a precomputed plan is not yet '
  145. 'supported by scipy.fft functions')
  146. tmp = _asfarray(x)
  147. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  148. overwrite_x = overwrite_x or _datacopied(tmp, x)
  149. workers = _workers(workers)
  150. if len(axes) == 0:
  151. return x
  152. tmp, copied = _fix_shape(tmp, shape, axes)
  153. overwrite_x = overwrite_x or copied
  154. norm = _normalization(norm, forward)
  155. out = (tmp if overwrite_x and tmp.dtype.kind == 'c' else None)
  156. return pfft.c2c(tmp, axes, forward, norm, out, workers)
  157. fftn = functools.partial(c2cn, True)
  158. fftn.__name__ = 'fftn'
  159. ifftn = functools.partial(c2cn, False)
  160. ifftn.__name__ = 'ifftn'
  161. def r2cn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
  162. workers=None, *, plan=None):
  163. """Return multidimensional discrete Fourier transform of real input"""
  164. if plan is not None:
  165. raise NotImplementedError('Passing a precomputed plan is not yet '
  166. 'supported by scipy.fft functions')
  167. tmp = _asfarray(x)
  168. if not np.isrealobj(tmp):
  169. raise TypeError("x must be a real sequence")
  170. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  171. tmp, _ = _fix_shape(tmp, shape, axes)
  172. norm = _normalization(norm, forward)
  173. workers = _workers(workers)
  174. if len(axes) == 0:
  175. raise ValueError("at least 1 axis must be transformed")
  176. # Note: overwrite_x is not utilized
  177. return pfft.r2c(tmp, axes, forward, norm, None, workers)
  178. rfftn = functools.partial(r2cn, True)
  179. rfftn.__name__ = 'rfftn'
  180. ihfftn = functools.partial(r2cn, False)
  181. ihfftn.__name__ = 'ihfftn'
  182. def c2rn(forward, x, s=None, axes=None, norm=None, overwrite_x=False,
  183. workers=None, *, plan=None):
  184. """Multidimensional inverse discrete fourier transform with real output"""
  185. if plan is not None:
  186. raise NotImplementedError('Passing a precomputed plan is not yet '
  187. 'supported by scipy.fft functions')
  188. tmp = _asfarray(x)
  189. # TODO: Optimize for hermitian and real?
  190. if np.isrealobj(tmp):
  191. tmp = tmp + 0.j
  192. noshape = s is None
  193. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  194. if len(axes) == 0:
  195. raise ValueError("at least 1 axis must be transformed")
  196. if noshape:
  197. shape[-1] = (x.shape[axes[-1]] - 1) * 2
  198. norm = _normalization(norm, forward)
  199. workers = _workers(workers)
  200. # Last axis utilizes hermitian symmetry
  201. lastsize = shape[-1]
  202. shape[-1] = (shape[-1] // 2) + 1
  203. tmp, _ = _fix_shape(tmp, shape, axes)
  204. # Note: overwrite_x is not utilized
  205. return pfft.c2r(tmp, axes, lastsize, forward, norm, None, workers)
  206. hfftn = functools.partial(c2rn, True)
  207. hfftn.__name__ = 'hfftn'
  208. irfftn = functools.partial(c2rn, False)
  209. irfftn.__name__ = 'irfftn'
  210. def r2r_fftpack(forward, x, n=None, axis=-1, norm=None, overwrite_x=False):
  211. """FFT of a real sequence, returning fftpack half complex format"""
  212. tmp = _asfarray(x)
  213. overwrite_x = overwrite_x or _datacopied(tmp, x)
  214. norm = _normalization(norm, forward)
  215. workers = _workers(None)
  216. if tmp.dtype.kind == 'c':
  217. raise TypeError('x must be a real sequence')
  218. if n is not None:
  219. tmp, copied = _fix_shape_1d(tmp, n, axis)
  220. overwrite_x = overwrite_x or copied
  221. elif tmp.shape[axis] < 1:
  222. raise ValueError("invalid number of data points ({0}) specified"
  223. .format(tmp.shape[axis]))
  224. out = (tmp if overwrite_x else None)
  225. return pfft.r2r_fftpack(tmp, (axis,), forward, forward, norm, out, workers)
  226. rfft_fftpack = functools.partial(r2r_fftpack, True)
  227. rfft_fftpack.__name__ = 'rfft_fftpack'
  228. irfft_fftpack = functools.partial(r2r_fftpack, False)
  229. irfft_fftpack.__name__ = 'irfft_fftpack'