test_rbf.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # Created by John Travers, Robert Hetland, 2007
  2. """ Test functions for rbf module """
  3. import numpy as np
  4. from numpy.testing import (assert_, assert_array_almost_equal,
  5. assert_almost_equal)
  6. from numpy import linspace, sin, cos, random, exp, allclose
  7. from scipy.interpolate._rbf import Rbf
  8. FUNCTIONS = ('multiquadric', 'inverse multiquadric', 'gaussian',
  9. 'cubic', 'quintic', 'thin-plate', 'linear')
  10. def check_rbf1d_interpolation(function):
  11. # Check that the Rbf function interpolates through the nodes (1D)
  12. x = linspace(0,10,9)
  13. y = sin(x)
  14. rbf = Rbf(x, y, function=function)
  15. yi = rbf(x)
  16. assert_array_almost_equal(y, yi)
  17. assert_almost_equal(rbf(float(x[0])), y[0])
  18. def check_rbf2d_interpolation(function):
  19. # Check that the Rbf function interpolates through the nodes (2D).
  20. x = random.rand(50,1)*4-2
  21. y = random.rand(50,1)*4-2
  22. z = x*exp(-x**2-1j*y**2)
  23. rbf = Rbf(x, y, z, epsilon=2, function=function)
  24. zi = rbf(x, y)
  25. zi.shape = x.shape
  26. assert_array_almost_equal(z, zi)
  27. def check_rbf3d_interpolation(function):
  28. # Check that the Rbf function interpolates through the nodes (3D).
  29. x = random.rand(50, 1)*4 - 2
  30. y = random.rand(50, 1)*4 - 2
  31. z = random.rand(50, 1)*4 - 2
  32. d = x*exp(-x**2 - y**2)
  33. rbf = Rbf(x, y, z, d, epsilon=2, function=function)
  34. di = rbf(x, y, z)
  35. di.shape = x.shape
  36. assert_array_almost_equal(di, d)
  37. def test_rbf_interpolation():
  38. for function in FUNCTIONS:
  39. check_rbf1d_interpolation(function)
  40. check_rbf2d_interpolation(function)
  41. check_rbf3d_interpolation(function)
  42. def check_2drbf1d_interpolation(function):
  43. # Check that the 2-D Rbf function interpolates through the nodes (1D)
  44. x = linspace(0, 10, 9)
  45. y0 = sin(x)
  46. y1 = cos(x)
  47. y = np.vstack([y0, y1]).T
  48. rbf = Rbf(x, y, function=function, mode='N-D')
  49. yi = rbf(x)
  50. assert_array_almost_equal(y, yi)
  51. assert_almost_equal(rbf(float(x[0])), y[0])
  52. def check_2drbf2d_interpolation(function):
  53. # Check that the 2-D Rbf function interpolates through the nodes (2D).
  54. x = random.rand(50, ) * 4 - 2
  55. y = random.rand(50, ) * 4 - 2
  56. z0 = x * exp(-x ** 2 - 1j * y ** 2)
  57. z1 = y * exp(-y ** 2 - 1j * x ** 2)
  58. z = np.vstack([z0, z1]).T
  59. rbf = Rbf(x, y, z, epsilon=2, function=function, mode='N-D')
  60. zi = rbf(x, y)
  61. zi.shape = z.shape
  62. assert_array_almost_equal(z, zi)
  63. def check_2drbf3d_interpolation(function):
  64. # Check that the 2-D Rbf function interpolates through the nodes (3D).
  65. x = random.rand(50, ) * 4 - 2
  66. y = random.rand(50, ) * 4 - 2
  67. z = random.rand(50, ) * 4 - 2
  68. d0 = x * exp(-x ** 2 - y ** 2)
  69. d1 = y * exp(-y ** 2 - x ** 2)
  70. d = np.vstack([d0, d1]).T
  71. rbf = Rbf(x, y, z, d, epsilon=2, function=function, mode='N-D')
  72. di = rbf(x, y, z)
  73. di.shape = d.shape
  74. assert_array_almost_equal(di, d)
  75. def test_2drbf_interpolation():
  76. for function in FUNCTIONS:
  77. check_2drbf1d_interpolation(function)
  78. check_2drbf2d_interpolation(function)
  79. check_2drbf3d_interpolation(function)
  80. def check_rbf1d_regularity(function, atol):
  81. # Check that the Rbf function approximates a smooth function well away
  82. # from the nodes.
  83. x = linspace(0, 10, 9)
  84. y = sin(x)
  85. rbf = Rbf(x, y, function=function)
  86. xi = linspace(0, 10, 100)
  87. yi = rbf(xi)
  88. msg = "abs-diff: %f" % abs(yi - sin(xi)).max()
  89. assert_(allclose(yi, sin(xi), atol=atol), msg)
  90. def test_rbf_regularity():
  91. tolerances = {
  92. 'multiquadric': 0.1,
  93. 'inverse multiquadric': 0.15,
  94. 'gaussian': 0.15,
  95. 'cubic': 0.15,
  96. 'quintic': 0.1,
  97. 'thin-plate': 0.1,
  98. 'linear': 0.2
  99. }
  100. for function in FUNCTIONS:
  101. check_rbf1d_regularity(function, tolerances.get(function, 1e-2))
  102. def check_2drbf1d_regularity(function, atol):
  103. # Check that the 2-D Rbf function approximates a smooth function well away
  104. # from the nodes.
  105. x = linspace(0, 10, 9)
  106. y0 = sin(x)
  107. y1 = cos(x)
  108. y = np.vstack([y0, y1]).T
  109. rbf = Rbf(x, y, function=function, mode='N-D')
  110. xi = linspace(0, 10, 100)
  111. yi = rbf(xi)
  112. msg = "abs-diff: %f" % abs(yi - np.vstack([sin(xi), cos(xi)]).T).max()
  113. assert_(allclose(yi, np.vstack([sin(xi), cos(xi)]).T, atol=atol), msg)
  114. def test_2drbf_regularity():
  115. tolerances = {
  116. 'multiquadric': 0.1,
  117. 'inverse multiquadric': 0.15,
  118. 'gaussian': 0.15,
  119. 'cubic': 0.15,
  120. 'quintic': 0.1,
  121. 'thin-plate': 0.15,
  122. 'linear': 0.2
  123. }
  124. for function in FUNCTIONS:
  125. check_2drbf1d_regularity(function, tolerances.get(function, 1e-2))
  126. def check_rbf1d_stability(function):
  127. # Check that the Rbf function with default epsilon is not subject
  128. # to overshoot. Regression for issue #4523.
  129. #
  130. # Generate some data (fixed random seed hence deterministic)
  131. np.random.seed(1234)
  132. x = np.linspace(0, 10, 50)
  133. z = x + 4.0 * np.random.randn(len(x))
  134. rbf = Rbf(x, z, function=function)
  135. xi = np.linspace(0, 10, 1000)
  136. yi = rbf(xi)
  137. # subtract the linear trend and make sure there no spikes
  138. assert_(np.abs(yi-xi).max() / np.abs(z-x).max() < 1.1)
  139. def test_rbf_stability():
  140. for function in FUNCTIONS:
  141. check_rbf1d_stability(function)
  142. def test_default_construction():
  143. # Check that the Rbf class can be constructed with the default
  144. # multiquadric basis function. Regression test for ticket #1228.
  145. x = linspace(0,10,9)
  146. y = sin(x)
  147. rbf = Rbf(x, y)
  148. yi = rbf(x)
  149. assert_array_almost_equal(y, yi)
  150. def test_function_is_callable():
  151. # Check that the Rbf class can be constructed with function=callable.
  152. x = linspace(0,10,9)
  153. y = sin(x)
  154. linfunc = lambda x:x
  155. rbf = Rbf(x, y, function=linfunc)
  156. yi = rbf(x)
  157. assert_array_almost_equal(y, yi)
  158. def test_two_arg_function_is_callable():
  159. # Check that the Rbf class can be constructed with a two argument
  160. # function=callable.
  161. def _func(self, r):
  162. return self.epsilon + r
  163. x = linspace(0,10,9)
  164. y = sin(x)
  165. rbf = Rbf(x, y, function=_func)
  166. yi = rbf(x)
  167. assert_array_almost_equal(y, yi)
  168. def test_rbf_epsilon_none():
  169. x = linspace(0, 10, 9)
  170. y = sin(x)
  171. Rbf(x, y, epsilon=None)
  172. def test_rbf_epsilon_none_collinear():
  173. # Check that collinear points in one dimension doesn't cause an error
  174. # due to epsilon = 0
  175. x = [1, 2, 3]
  176. y = [4, 4, 4]
  177. z = [5, 6, 7]
  178. rbf = Rbf(x, y, z, epsilon=None)
  179. assert_(rbf.epsilon > 0)