test_backend.py 4.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from functools import partial
  2. import numpy as np
  3. import scipy.fft
  4. from scipy.fft import _fftlog, _pocketfft, set_backend
  5. from scipy.fft.tests import mock_backend
  6. from numpy.testing import assert_allclose, assert_equal
  7. import pytest
  8. fnames = ('fft', 'fft2', 'fftn',
  9. 'ifft', 'ifft2', 'ifftn',
  10. 'rfft', 'rfft2', 'rfftn',
  11. 'irfft', 'irfft2', 'irfftn',
  12. 'dct', 'idct', 'dctn', 'idctn',
  13. 'dst', 'idst', 'dstn', 'idstn',
  14. 'fht', 'ifht')
  15. np_funcs = (np.fft.fft, np.fft.fft2, np.fft.fftn,
  16. np.fft.ifft, np.fft.ifft2, np.fft.ifftn,
  17. np.fft.rfft, np.fft.rfft2, np.fft.rfftn,
  18. np.fft.irfft, np.fft.irfft2, np.fft.irfftn,
  19. np.fft.hfft, _pocketfft.hfft2, _pocketfft.hfftn, # np has no hfftn
  20. np.fft.ihfft, _pocketfft.ihfft2, _pocketfft.ihfftn,
  21. _pocketfft.dct, _pocketfft.idct, _pocketfft.dctn, _pocketfft.idctn,
  22. _pocketfft.dst, _pocketfft.idst, _pocketfft.dstn, _pocketfft.idstn,
  23. # must provide required kwargs for fht, ifht
  24. partial(_fftlog.fht, dln=2, mu=0.5),
  25. partial(_fftlog.ifht, dln=2, mu=0.5))
  26. funcs = (scipy.fft.fft, scipy.fft.fft2, scipy.fft.fftn,
  27. scipy.fft.ifft, scipy.fft.ifft2, scipy.fft.ifftn,
  28. scipy.fft.rfft, scipy.fft.rfft2, scipy.fft.rfftn,
  29. scipy.fft.irfft, scipy.fft.irfft2, scipy.fft.irfftn,
  30. scipy.fft.hfft, scipy.fft.hfft2, scipy.fft.hfftn,
  31. scipy.fft.ihfft, scipy.fft.ihfft2, scipy.fft.ihfftn,
  32. scipy.fft.dct, scipy.fft.idct, scipy.fft.dctn, scipy.fft.idctn,
  33. scipy.fft.dst, scipy.fft.idst, scipy.fft.dstn, scipy.fft.idstn,
  34. # must provide required kwargs for fht, ifht
  35. partial(scipy.fft.fht, dln=2, mu=0.5),
  36. partial(scipy.fft.ifht, dln=2, mu=0.5))
  37. mocks = (mock_backend.fft, mock_backend.fft2, mock_backend.fftn,
  38. mock_backend.ifft, mock_backend.ifft2, mock_backend.ifftn,
  39. mock_backend.rfft, mock_backend.rfft2, mock_backend.rfftn,
  40. mock_backend.irfft, mock_backend.irfft2, mock_backend.irfftn,
  41. mock_backend.hfft, mock_backend.hfft2, mock_backend.hfftn,
  42. mock_backend.ihfft, mock_backend.ihfft2, mock_backend.ihfftn,
  43. mock_backend.dct, mock_backend.idct,
  44. mock_backend.dctn, mock_backend.idctn,
  45. mock_backend.dst, mock_backend.idst,
  46. mock_backend.dstn, mock_backend.idstn,
  47. mock_backend.fht, mock_backend.ifht)
  48. @pytest.mark.parametrize("func, np_func, mock", zip(funcs, np_funcs, mocks))
  49. def test_backend_call(func, np_func, mock):
  50. x = np.arange(20).reshape((10,2))
  51. answer = np_func(x)
  52. assert_allclose(func(x), answer, atol=1e-10)
  53. with set_backend(mock_backend, only=True):
  54. mock.number_calls = 0
  55. y = func(x)
  56. assert_equal(y, mock.return_value)
  57. assert_equal(mock.number_calls, 1)
  58. assert_allclose(func(x), answer, atol=1e-10)
  59. plan_funcs = (scipy.fft.fft, scipy.fft.fft2, scipy.fft.fftn,
  60. scipy.fft.ifft, scipy.fft.ifft2, scipy.fft.ifftn,
  61. scipy.fft.rfft, scipy.fft.rfft2, scipy.fft.rfftn,
  62. scipy.fft.irfft, scipy.fft.irfft2, scipy.fft.irfftn,
  63. scipy.fft.hfft, scipy.fft.hfft2, scipy.fft.hfftn,
  64. scipy.fft.ihfft, scipy.fft.ihfft2, scipy.fft.ihfftn)
  65. plan_mocks = (mock_backend.fft, mock_backend.fft2, mock_backend.fftn,
  66. mock_backend.ifft, mock_backend.ifft2, mock_backend.ifftn,
  67. mock_backend.rfft, mock_backend.rfft2, mock_backend.rfftn,
  68. mock_backend.irfft, mock_backend.irfft2, mock_backend.irfftn,
  69. mock_backend.hfft, mock_backend.hfft2, mock_backend.hfftn,
  70. mock_backend.ihfft, mock_backend.ihfft2, mock_backend.ihfftn)
  71. @pytest.mark.parametrize("func, mock", zip(plan_funcs, plan_mocks))
  72. def test_backend_plan(func, mock):
  73. x = np.arange(20).reshape((10, 2))
  74. with pytest.raises(NotImplementedError, match='precomputed plan'):
  75. func(x, plan='foo')
  76. with set_backend(mock_backend, only=True):
  77. mock.number_calls = 0
  78. y = func(x, plan='foo')
  79. assert_equal(y, mock.return_value)
  80. assert_equal(mock.number_calls, 1)
  81. assert_equal(mock.last_args[1]['plan'], 'foo')