test_c_api.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import numpy as np
  2. from numpy.testing import assert_allclose
  3. from scipy import ndimage
  4. from scipy.ndimage import _ctest
  5. from scipy.ndimage import _cytest
  6. from scipy._lib._ccallback import LowLevelCallable
  7. FILTER1D_FUNCTIONS = [
  8. lambda filter_size: _ctest.filter1d(filter_size),
  9. lambda filter_size: _cytest.filter1d(filter_size, with_signature=False),
  10. lambda filter_size: LowLevelCallable(_cytest.filter1d(filter_size, with_signature=True)),
  11. lambda filter_size: LowLevelCallable.from_cython(_cytest, "_filter1d",
  12. _cytest.filter1d_capsule(filter_size)),
  13. ]
  14. FILTER2D_FUNCTIONS = [
  15. lambda weights: _ctest.filter2d(weights),
  16. lambda weights: _cytest.filter2d(weights, with_signature=False),
  17. lambda weights: LowLevelCallable(_cytest.filter2d(weights, with_signature=True)),
  18. lambda weights: LowLevelCallable.from_cython(_cytest, "_filter2d", _cytest.filter2d_capsule(weights)),
  19. ]
  20. TRANSFORM_FUNCTIONS = [
  21. lambda shift: _ctest.transform(shift),
  22. lambda shift: _cytest.transform(shift, with_signature=False),
  23. lambda shift: LowLevelCallable(_cytest.transform(shift, with_signature=True)),
  24. lambda shift: LowLevelCallable.from_cython(_cytest, "_transform", _cytest.transform_capsule(shift)),
  25. ]
  26. def test_generic_filter():
  27. def filter2d(footprint_elements, weights):
  28. return (weights*footprint_elements).sum()
  29. def check(j):
  30. func = FILTER2D_FUNCTIONS[j]
  31. im = np.ones((20, 20))
  32. im[:10,:10] = 0
  33. footprint = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]])
  34. footprint_size = np.count_nonzero(footprint)
  35. weights = np.ones(footprint_size)/footprint_size
  36. res = ndimage.generic_filter(im, func(weights),
  37. footprint=footprint)
  38. std = ndimage.generic_filter(im, filter2d, footprint=footprint,
  39. extra_arguments=(weights,))
  40. assert_allclose(res, std, err_msg="#{} failed".format(j))
  41. for j, func in enumerate(FILTER2D_FUNCTIONS):
  42. check(j)
  43. def test_generic_filter1d():
  44. def filter1d(input_line, output_line, filter_size):
  45. for i in range(output_line.size):
  46. output_line[i] = 0
  47. for j in range(filter_size):
  48. output_line[i] += input_line[i+j]
  49. output_line /= filter_size
  50. def check(j):
  51. func = FILTER1D_FUNCTIONS[j]
  52. im = np.tile(np.hstack((np.zeros(10), np.ones(10))), (10, 1))
  53. filter_size = 3
  54. res = ndimage.generic_filter1d(im, func(filter_size),
  55. filter_size)
  56. std = ndimage.generic_filter1d(im, filter1d, filter_size,
  57. extra_arguments=(filter_size,))
  58. assert_allclose(res, std, err_msg="#{} failed".format(j))
  59. for j, func in enumerate(FILTER1D_FUNCTIONS):
  60. check(j)
  61. def test_geometric_transform():
  62. def transform(output_coordinates, shift):
  63. return output_coordinates[0] - shift, output_coordinates[1] - shift
  64. def check(j):
  65. func = TRANSFORM_FUNCTIONS[j]
  66. im = np.arange(12).reshape(4, 3).astype(np.float64)
  67. shift = 0.5
  68. res = ndimage.geometric_transform(im, func(shift))
  69. std = ndimage.geometric_transform(im, transform, extra_arguments=(shift,))
  70. assert_allclose(res, std, err_msg="#{} failed".format(j))
  71. for j, func in enumerate(TRANSFORM_FUNCTIONS):
  72. check(j)