realtransforms.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import numpy as np
  2. from . import pypocketfft as pfft
  3. from .helper import (_asfarray, _init_nd_shape_and_axes, _datacopied,
  4. _fix_shape, _fix_shape_1d, _normalization, _workers)
  5. import functools
  6. def _r2r(forward, transform, x, type=2, n=None, axis=-1, norm=None,
  7. overwrite_x=False, workers=None, orthogonalize=None):
  8. """Forward or backward 1-D DCT/DST
  9. Parameters
  10. ----------
  11. forward : bool
  12. Transform direction (determines type and normalisation)
  13. transform : {pypocketfft.dct, pypocketfft.dst}
  14. The transform to perform
  15. """
  16. tmp = _asfarray(x)
  17. overwrite_x = overwrite_x or _datacopied(tmp, x)
  18. norm = _normalization(norm, forward)
  19. workers = _workers(workers)
  20. if not forward:
  21. if type == 2:
  22. type = 3
  23. elif type == 3:
  24. type = 2
  25. if n is not None:
  26. tmp, copied = _fix_shape_1d(tmp, n, axis)
  27. overwrite_x = overwrite_x or copied
  28. elif tmp.shape[axis] < 1:
  29. raise ValueError("invalid number of data points ({0}) specified"
  30. .format(tmp.shape[axis]))
  31. out = (tmp if overwrite_x else None)
  32. # For complex input, transform real and imaginary components separably
  33. if np.iscomplexobj(x):
  34. out = np.empty_like(tmp) if out is None else out
  35. transform(tmp.real, type, (axis,), norm, out.real, workers)
  36. transform(tmp.imag, type, (axis,), norm, out.imag, workers)
  37. return out
  38. return transform(tmp, type, (axis,), norm, out, workers, orthogonalize)
  39. dct = functools.partial(_r2r, True, pfft.dct)
  40. dct.__name__ = 'dct'
  41. idct = functools.partial(_r2r, False, pfft.dct)
  42. idct.__name__ = 'idct'
  43. dst = functools.partial(_r2r, True, pfft.dst)
  44. dst.__name__ = 'dst'
  45. idst = functools.partial(_r2r, False, pfft.dst)
  46. idst.__name__ = 'idst'
  47. def _r2rn(forward, transform, x, type=2, s=None, axes=None, norm=None,
  48. overwrite_x=False, workers=None, orthogonalize=None):
  49. """Forward or backward nd DCT/DST
  50. Parameters
  51. ----------
  52. forward : bool
  53. Transform direction (determines type and normalisation)
  54. transform : {pypocketfft.dct, pypocketfft.dst}
  55. The transform to perform
  56. """
  57. tmp = _asfarray(x)
  58. shape, axes = _init_nd_shape_and_axes(tmp, s, axes)
  59. overwrite_x = overwrite_x or _datacopied(tmp, x)
  60. if len(axes) == 0:
  61. return x
  62. tmp, copied = _fix_shape(tmp, shape, axes)
  63. overwrite_x = overwrite_x or copied
  64. if not forward:
  65. if type == 2:
  66. type = 3
  67. elif type == 3:
  68. type = 2
  69. norm = _normalization(norm, forward)
  70. workers = _workers(workers)
  71. out = (tmp if overwrite_x else None)
  72. # For complex input, transform real and imaginary components separably
  73. if np.iscomplexobj(x):
  74. out = np.empty_like(tmp) if out is None else out
  75. transform(tmp.real, type, axes, norm, out.real, workers)
  76. transform(tmp.imag, type, axes, norm, out.imag, workers)
  77. return out
  78. return transform(tmp, type, axes, norm, out, workers, orthogonalize)
  79. dctn = functools.partial(_r2rn, True, pfft.dct)
  80. dctn.__name__ = 'dctn'
  81. idctn = functools.partial(_r2rn, False, pfft.dct)
  82. idctn.__name__ = 'idctn'
  83. dstn = functools.partial(_r2rn, True, pfft.dst)
  84. dstn.__name__ = 'dstn'
  85. idstn = functools.partial(_r2rn, False, pfft.dst)
  86. idstn.__name__ = 'idstn'