test_polyint.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. import warnings
  2. import io
  3. import numpy as np
  4. from numpy.testing import (
  5. assert_almost_equal, assert_array_equal, assert_array_almost_equal,
  6. assert_allclose, assert_equal, assert_)
  7. from pytest import raises as assert_raises
  8. import pytest
  9. from scipy.interpolate import (
  10. KroghInterpolator, krogh_interpolate,
  11. BarycentricInterpolator, barycentric_interpolate,
  12. approximate_taylor_polynomial, CubicHermiteSpline, pchip,
  13. PchipInterpolator, pchip_interpolate, Akima1DInterpolator, CubicSpline,
  14. make_interp_spline)
  15. def check_shape(interpolator_cls, x_shape, y_shape, deriv_shape=None, axis=0,
  16. extra_args={}):
  17. np.random.seed(1234)
  18. x = [-1, 0, 1, 2, 3, 4]
  19. s = list(range(1, len(y_shape)+1))
  20. s.insert(axis % (len(y_shape)+1), 0)
  21. y = np.random.rand(*((6,) + y_shape)).transpose(s)
  22. xi = np.zeros(x_shape)
  23. if interpolator_cls is CubicHermiteSpline:
  24. dydx = np.random.rand(*((6,) + y_shape)).transpose(s)
  25. yi = interpolator_cls(x, y, dydx, axis=axis, **extra_args)(xi)
  26. else:
  27. yi = interpolator_cls(x, y, axis=axis, **extra_args)(xi)
  28. target_shape = ((deriv_shape or ()) + y.shape[:axis]
  29. + x_shape + y.shape[axis:][1:])
  30. assert_equal(yi.shape, target_shape)
  31. # check it works also with lists
  32. if x_shape and y.size > 0:
  33. if interpolator_cls is CubicHermiteSpline:
  34. interpolator_cls(list(x), list(y), list(dydx), axis=axis,
  35. **extra_args)(list(xi))
  36. else:
  37. interpolator_cls(list(x), list(y), axis=axis,
  38. **extra_args)(list(xi))
  39. # check also values
  40. if xi.size > 0 and deriv_shape is None:
  41. bs_shape = y.shape[:axis] + (1,)*len(x_shape) + y.shape[axis:][1:]
  42. yv = y[((slice(None,),)*(axis % y.ndim)) + (1,)]
  43. yv = yv.reshape(bs_shape)
  44. yi, y = np.broadcast_arrays(yi, yv)
  45. assert_allclose(yi, y)
  46. SHAPES = [(), (0,), (1,), (6, 2, 5)]
  47. def test_shapes():
  48. def spl_interp(x, y, axis):
  49. return make_interp_spline(x, y, axis=axis)
  50. for ip in [KroghInterpolator, BarycentricInterpolator, CubicHermiteSpline,
  51. pchip, Akima1DInterpolator, CubicSpline, spl_interp]:
  52. for s1 in SHAPES:
  53. for s2 in SHAPES:
  54. for axis in range(-len(s2), len(s2)):
  55. if ip != CubicSpline:
  56. check_shape(ip, s1, s2, None, axis)
  57. else:
  58. for bc in ['natural', 'clamped']:
  59. extra = {'bc_type': bc}
  60. check_shape(ip, s1, s2, None, axis, extra)
  61. def test_derivs_shapes():
  62. def krogh_derivs(x, y, axis=0):
  63. return KroghInterpolator(x, y, axis).derivatives
  64. for s1 in SHAPES:
  65. for s2 in SHAPES:
  66. for axis in range(-len(s2), len(s2)):
  67. check_shape(krogh_derivs, s1, s2, (6,), axis)
  68. def test_deriv_shapes():
  69. def krogh_deriv(x, y, axis=0):
  70. return KroghInterpolator(x, y, axis).derivative
  71. def pchip_deriv(x, y, axis=0):
  72. return pchip(x, y, axis).derivative()
  73. def pchip_deriv2(x, y, axis=0):
  74. return pchip(x, y, axis).derivative(2)
  75. def pchip_antideriv(x, y, axis=0):
  76. return pchip(x, y, axis).antiderivative()
  77. def pchip_antideriv2(x, y, axis=0):
  78. return pchip(x, y, axis).antiderivative(2)
  79. def pchip_deriv_inplace(x, y, axis=0):
  80. class P(PchipInterpolator):
  81. def __call__(self, x):
  82. return PchipInterpolator.__call__(self, x, 1)
  83. pass
  84. return P(x, y, axis)
  85. def akima_deriv(x, y, axis=0):
  86. return Akima1DInterpolator(x, y, axis).derivative()
  87. def akima_antideriv(x, y, axis=0):
  88. return Akima1DInterpolator(x, y, axis).antiderivative()
  89. def cspline_deriv(x, y, axis=0):
  90. return CubicSpline(x, y, axis).derivative()
  91. def cspline_antideriv(x, y, axis=0):
  92. return CubicSpline(x, y, axis).antiderivative()
  93. def bspl_deriv(x, y, axis=0):
  94. return make_interp_spline(x, y, axis=axis).derivative()
  95. def bspl_antideriv(x, y, axis=0):
  96. return make_interp_spline(x, y, axis=axis).antiderivative()
  97. for ip in [krogh_deriv, pchip_deriv, pchip_deriv2, pchip_deriv_inplace,
  98. pchip_antideriv, pchip_antideriv2, akima_deriv, akima_antideriv,
  99. cspline_deriv, cspline_antideriv, bspl_deriv, bspl_antideriv]:
  100. for s1 in SHAPES:
  101. for s2 in SHAPES:
  102. for axis in range(-len(s2), len(s2)):
  103. check_shape(ip, s1, s2, (), axis)
  104. def test_complex():
  105. x = [1, 2, 3, 4]
  106. y = [1, 2, 1j, 3]
  107. for ip in [KroghInterpolator, BarycentricInterpolator, pchip, CubicSpline]:
  108. p = ip(x, y)
  109. assert_allclose(y, p(x))
  110. dydx = [0, -1j, 2, 3j]
  111. p = CubicHermiteSpline(x, y, dydx)
  112. assert_allclose(y, p(x))
  113. assert_allclose(dydx, p(x, 1))
  114. class TestKrogh:
  115. def setup_method(self):
  116. self.true_poly = np.poly1d([-2,3,1,5,-4])
  117. self.test_xs = np.linspace(-1,1,100)
  118. self.xs = np.linspace(-1,1,5)
  119. self.ys = self.true_poly(self.xs)
  120. def test_lagrange(self):
  121. P = KroghInterpolator(self.xs,self.ys)
  122. assert_almost_equal(self.true_poly(self.test_xs),P(self.test_xs))
  123. def test_scalar(self):
  124. P = KroghInterpolator(self.xs,self.ys)
  125. assert_almost_equal(self.true_poly(7),P(7))
  126. assert_almost_equal(self.true_poly(np.array(7)), P(np.array(7)))
  127. def test_derivatives(self):
  128. P = KroghInterpolator(self.xs,self.ys)
  129. D = P.derivatives(self.test_xs)
  130. for i in range(D.shape[0]):
  131. assert_almost_equal(self.true_poly.deriv(i)(self.test_xs),
  132. D[i])
  133. def test_low_derivatives(self):
  134. P = KroghInterpolator(self.xs,self.ys)
  135. D = P.derivatives(self.test_xs,len(self.xs)+2)
  136. for i in range(D.shape[0]):
  137. assert_almost_equal(self.true_poly.deriv(i)(self.test_xs),
  138. D[i])
  139. def test_derivative(self):
  140. P = KroghInterpolator(self.xs,self.ys)
  141. m = 10
  142. r = P.derivatives(self.test_xs,m)
  143. for i in range(m):
  144. assert_almost_equal(P.derivative(self.test_xs,i),r[i])
  145. def test_high_derivative(self):
  146. P = KroghInterpolator(self.xs,self.ys)
  147. for i in range(len(self.xs), 2*len(self.xs)):
  148. assert_almost_equal(P.derivative(self.test_xs,i),
  149. np.zeros(len(self.test_xs)))
  150. def test_hermite(self):
  151. P = KroghInterpolator(self.xs,self.ys)
  152. assert_almost_equal(self.true_poly(self.test_xs),P(self.test_xs))
  153. def test_vector(self):
  154. xs = [0, 1, 2]
  155. ys = np.array([[0,1],[1,0],[2,1]])
  156. P = KroghInterpolator(xs,ys)
  157. Pi = [KroghInterpolator(xs,ys[:,i]) for i in range(ys.shape[1])]
  158. test_xs = np.linspace(-1,3,100)
  159. assert_almost_equal(P(test_xs),
  160. np.asarray([p(test_xs) for p in Pi]).T)
  161. assert_almost_equal(P.derivatives(test_xs),
  162. np.transpose(np.asarray([p.derivatives(test_xs) for p in Pi]),
  163. (1,2,0)))
  164. def test_empty(self):
  165. P = KroghInterpolator(self.xs,self.ys)
  166. assert_array_equal(P([]), [])
  167. def test_shapes_scalarvalue(self):
  168. P = KroghInterpolator(self.xs,self.ys)
  169. assert_array_equal(np.shape(P(0)), ())
  170. assert_array_equal(np.shape(P(np.array(0))), ())
  171. assert_array_equal(np.shape(P([0])), (1,))
  172. assert_array_equal(np.shape(P([0,1])), (2,))
  173. def test_shapes_scalarvalue_derivative(self):
  174. P = KroghInterpolator(self.xs,self.ys)
  175. n = P.n
  176. assert_array_equal(np.shape(P.derivatives(0)), (n,))
  177. assert_array_equal(np.shape(P.derivatives(np.array(0))), (n,))
  178. assert_array_equal(np.shape(P.derivatives([0])), (n,1))
  179. assert_array_equal(np.shape(P.derivatives([0,1])), (n,2))
  180. def test_shapes_vectorvalue(self):
  181. P = KroghInterpolator(self.xs,np.outer(self.ys,np.arange(3)))
  182. assert_array_equal(np.shape(P(0)), (3,))
  183. assert_array_equal(np.shape(P([0])), (1,3))
  184. assert_array_equal(np.shape(P([0,1])), (2,3))
  185. def test_shapes_1d_vectorvalue(self):
  186. P = KroghInterpolator(self.xs,np.outer(self.ys,[1]))
  187. assert_array_equal(np.shape(P(0)), (1,))
  188. assert_array_equal(np.shape(P([0])), (1,1))
  189. assert_array_equal(np.shape(P([0,1])), (2,1))
  190. def test_shapes_vectorvalue_derivative(self):
  191. P = KroghInterpolator(self.xs,np.outer(self.ys,np.arange(3)))
  192. n = P.n
  193. assert_array_equal(np.shape(P.derivatives(0)), (n,3))
  194. assert_array_equal(np.shape(P.derivatives([0])), (n,1,3))
  195. assert_array_equal(np.shape(P.derivatives([0,1])), (n,2,3))
  196. def test_wrapper(self):
  197. P = KroghInterpolator(self.xs, self.ys)
  198. ki = krogh_interpolate
  199. assert_almost_equal(P(self.test_xs), ki(self.xs, self.ys, self.test_xs))
  200. assert_almost_equal(P.derivative(self.test_xs, 2),
  201. ki(self.xs, self.ys, self.test_xs, der=2))
  202. assert_almost_equal(P.derivatives(self.test_xs, 2),
  203. ki(self.xs, self.ys, self.test_xs, der=[0, 1]))
  204. def test_int_inputs(self):
  205. # Check input args are cast correctly to floats, gh-3669
  206. x = [0, 234, 468, 702, 936, 1170, 1404, 2340, 3744, 6084, 8424,
  207. 13104, 60000]
  208. offset_cdf = np.array([-0.95, -0.86114777, -0.8147762, -0.64072425,
  209. -0.48002351, -0.34925329, -0.26503107,
  210. -0.13148093, -0.12988833, -0.12979296,
  211. -0.12973574, -0.08582937, 0.05])
  212. f = KroghInterpolator(x, offset_cdf)
  213. assert_allclose(abs((f(x) - offset_cdf) / f.derivative(x, 1)),
  214. 0, atol=1e-10)
  215. def test_derivatives_complex(self):
  216. # regression test for gh-7381: krogh.derivatives(0) fails complex y
  217. x, y = np.array([-1, -1, 0, 1, 1]), np.array([1, 1.0j, 0, -1, 1.0j])
  218. func = KroghInterpolator(x, y)
  219. cmplx = func.derivatives(0)
  220. cmplx2 = (KroghInterpolator(x, y.real).derivatives(0) +
  221. 1j*KroghInterpolator(x, y.imag).derivatives(0))
  222. assert_allclose(cmplx, cmplx2, atol=1e-15)
  223. def test_high_degree_warning(self):
  224. with pytest.warns(UserWarning, match="40 degrees provided,"):
  225. KroghInterpolator(np.arange(40), np.ones(40))
  226. class TestTaylor:
  227. def test_exponential(self):
  228. degree = 5
  229. p = approximate_taylor_polynomial(np.exp, 0, degree, 1, 15)
  230. for i in range(degree+1):
  231. assert_almost_equal(p(0),1)
  232. p = p.deriv()
  233. assert_almost_equal(p(0),0)
  234. class TestBarycentric:
  235. def setup_method(self):
  236. self.true_poly = np.poly1d([-2, 3, 1, 5, -4])
  237. self.test_xs = np.linspace(-1, 1, 100)
  238. self.xs = np.linspace(-1, 1, 5)
  239. self.ys = self.true_poly(self.xs)
  240. def test_lagrange(self):
  241. P = BarycentricInterpolator(self.xs, self.ys)
  242. assert_almost_equal(self.true_poly(self.test_xs), P(self.test_xs))
  243. def test_scalar(self):
  244. P = BarycentricInterpolator(self.xs, self.ys)
  245. assert_almost_equal(self.true_poly(7), P(7))
  246. assert_almost_equal(self.true_poly(np.array(7)), P(np.array(7)))
  247. def test_delayed(self):
  248. P = BarycentricInterpolator(self.xs)
  249. P.set_yi(self.ys)
  250. assert_almost_equal(self.true_poly(self.test_xs), P(self.test_xs))
  251. def test_append(self):
  252. P = BarycentricInterpolator(self.xs[:3], self.ys[:3])
  253. P.add_xi(self.xs[3:], self.ys[3:])
  254. assert_almost_equal(self.true_poly(self.test_xs), P(self.test_xs))
  255. def test_vector(self):
  256. xs = [0, 1, 2]
  257. ys = np.array([[0, 1], [1, 0], [2, 1]])
  258. BI = BarycentricInterpolator
  259. P = BI(xs, ys)
  260. Pi = [BI(xs, ys[:, i]) for i in range(ys.shape[1])]
  261. test_xs = np.linspace(-1, 3, 100)
  262. assert_almost_equal(P(test_xs),
  263. np.asarray([p(test_xs) for p in Pi]).T)
  264. def test_shapes_scalarvalue(self):
  265. P = BarycentricInterpolator(self.xs, self.ys)
  266. assert_array_equal(np.shape(P(0)), ())
  267. assert_array_equal(np.shape(P(np.array(0))), ())
  268. assert_array_equal(np.shape(P([0])), (1,))
  269. assert_array_equal(np.shape(P([0, 1])), (2,))
  270. def test_shapes_vectorvalue(self):
  271. P = BarycentricInterpolator(self.xs, np.outer(self.ys, np.arange(3)))
  272. assert_array_equal(np.shape(P(0)), (3,))
  273. assert_array_equal(np.shape(P([0])), (1, 3))
  274. assert_array_equal(np.shape(P([0, 1])), (2, 3))
  275. def test_shapes_1d_vectorvalue(self):
  276. P = BarycentricInterpolator(self.xs, np.outer(self.ys, [1]))
  277. assert_array_equal(np.shape(P(0)), (1,))
  278. assert_array_equal(np.shape(P([0])), (1, 1))
  279. assert_array_equal(np.shape(P([0,1])), (2, 1))
  280. def test_wrapper(self):
  281. P = BarycentricInterpolator(self.xs, self.ys)
  282. values = barycentric_interpolate(self.xs, self.ys, self.test_xs)
  283. assert_almost_equal(P(self.test_xs), values)
  284. def test_int_input(self):
  285. x = 1000 * np.arange(1, 11) # np.prod(x[-1] - x[:-1]) overflows
  286. y = np.arange(1, 11)
  287. value = barycentric_interpolate(x, y, 1000 * 9.5)
  288. assert_almost_equal(value, 9.5)
  289. def test_large_chebyshev(self):
  290. # The weights for Chebyshev points of the second kind have analytically
  291. # solvable weights. Naive calculation of barycentric weights will fail
  292. # for large N because of numerical underflow and overflow. We test
  293. # correctness for large N against analytical Chebyshev weights.
  294. # Without capacity scaling or permutation, n=800 fails,
  295. # With just capacity scaling, n=1097 fails
  296. # With both capacity scaling and random permutation, n=30000 succeeds
  297. n = 800
  298. j = np.arange(n + 1).astype(np.float64)
  299. x = np.cos(j * np.pi / n)
  300. # See page 506 of Berrut and Trefethen 2004 for this formula
  301. w = (-1) ** j
  302. w[0] *= 0.5
  303. w[-1] *= 0.5
  304. P = BarycentricInterpolator(x)
  305. # It's okay to have a constant scaling factor in the weights because it
  306. # cancels out in the evaluation of the polynomial.
  307. factor = P.wi[0]
  308. assert_almost_equal(P.wi / (2 * factor), w)
  309. def test_warning(self):
  310. # Test if the divide-by-zero warning is properly ignored when computing
  311. # interpolated values equals to interpolation points
  312. P = BarycentricInterpolator([0, 1], [1, 2])
  313. with np.errstate(divide='raise'):
  314. yi = P(P.xi)
  315. # Additionaly check if the interpolated values are the nodes values
  316. assert_almost_equal(yi, P.yi.ravel())
  317. class TestPCHIP:
  318. def _make_random(self, npts=20):
  319. np.random.seed(1234)
  320. xi = np.sort(np.random.random(npts))
  321. yi = np.random.random(npts)
  322. return pchip(xi, yi), xi, yi
  323. def test_overshoot(self):
  324. # PCHIP should not overshoot
  325. p, xi, yi = self._make_random()
  326. for i in range(len(xi)-1):
  327. x1, x2 = xi[i], xi[i+1]
  328. y1, y2 = yi[i], yi[i+1]
  329. if y1 > y2:
  330. y1, y2 = y2, y1
  331. xp = np.linspace(x1, x2, 10)
  332. yp = p(xp)
  333. assert_(((y1 <= yp + 1e-15) & (yp <= y2 + 1e-15)).all())
  334. def test_monotone(self):
  335. # PCHIP should preserve monotonicty
  336. p, xi, yi = self._make_random()
  337. for i in range(len(xi)-1):
  338. x1, x2 = xi[i], xi[i+1]
  339. y1, y2 = yi[i], yi[i+1]
  340. xp = np.linspace(x1, x2, 10)
  341. yp = p(xp)
  342. assert_(((y2-y1) * (yp[1:] - yp[:1]) > 0).all())
  343. def test_cast(self):
  344. # regression test for integer input data, see gh-3453
  345. data = np.array([[0, 4, 12, 27, 47, 60, 79, 87, 99, 100],
  346. [-33, -33, -19, -2, 12, 26, 38, 45, 53, 55]])
  347. xx = np.arange(100)
  348. curve = pchip(data[0], data[1])(xx)
  349. data1 = data * 1.0
  350. curve1 = pchip(data1[0], data1[1])(xx)
  351. assert_allclose(curve, curve1, atol=1e-14, rtol=1e-14)
  352. def test_nag(self):
  353. # Example from NAG C implementation,
  354. # http://nag.com/numeric/cl/nagdoc_cl25/html/e01/e01bec.html
  355. # suggested in gh-5326 as a smoke test for the way the derivatives
  356. # are computed (see also gh-3453)
  357. dataStr = '''
  358. 7.99 0.00000E+0
  359. 8.09 0.27643E-4
  360. 8.19 0.43750E-1
  361. 8.70 0.16918E+0
  362. 9.20 0.46943E+0
  363. 10.00 0.94374E+0
  364. 12.00 0.99864E+0
  365. 15.00 0.99992E+0
  366. 20.00 0.99999E+0
  367. '''
  368. data = np.loadtxt(io.StringIO(dataStr))
  369. pch = pchip(data[:,0], data[:,1])
  370. resultStr = '''
  371. 7.9900 0.0000
  372. 9.1910 0.4640
  373. 10.3920 0.9645
  374. 11.5930 0.9965
  375. 12.7940 0.9992
  376. 13.9950 0.9998
  377. 15.1960 0.9999
  378. 16.3970 1.0000
  379. 17.5980 1.0000
  380. 18.7990 1.0000
  381. 20.0000 1.0000
  382. '''
  383. result = np.loadtxt(io.StringIO(resultStr))
  384. assert_allclose(result[:,1], pch(result[:,0]), rtol=0., atol=5e-5)
  385. def test_endslopes(self):
  386. # this is a smoke test for gh-3453: PCHIP interpolator should not
  387. # set edge slopes to zero if the data do not suggest zero edge derivatives
  388. x = np.array([0.0, 0.1, 0.25, 0.35])
  389. y1 = np.array([279.35, 0.5e3, 1.0e3, 2.5e3])
  390. y2 = np.array([279.35, 2.5e3, 1.50e3, 1.0e3])
  391. for pp in (pchip(x, y1), pchip(x, y2)):
  392. for t in (x[0], x[-1]):
  393. assert_(pp(t, 1) != 0)
  394. def test_all_zeros(self):
  395. x = np.arange(10)
  396. y = np.zeros_like(x)
  397. # this should work and not generate any warnings
  398. with warnings.catch_warnings():
  399. warnings.filterwarnings('error')
  400. pch = pchip(x, y)
  401. xx = np.linspace(0, 9, 101)
  402. assert_equal(pch(xx), 0.)
  403. def test_two_points(self):
  404. # regression test for gh-6222: pchip([0, 1], [0, 1]) fails because
  405. # it tries to use a three-point scheme to estimate edge derivatives,
  406. # while there are only two points available.
  407. # Instead, it should construct a linear interpolator.
  408. x = np.linspace(0, 1, 11)
  409. p = pchip([0, 1], [0, 2])
  410. assert_allclose(p(x), 2*x, atol=1e-15)
  411. def test_pchip_interpolate(self):
  412. assert_array_almost_equal(
  413. pchip_interpolate([1,2,3], [4,5,6], [0.5], der=1),
  414. [1.])
  415. assert_array_almost_equal(
  416. pchip_interpolate([1,2,3], [4,5,6], [0.5], der=0),
  417. [3.5])
  418. assert_array_almost_equal(
  419. pchip_interpolate([1,2,3], [4,5,6], [0.5], der=[0, 1]),
  420. [[3.5], [1]])
  421. def test_roots(self):
  422. # regression test for gh-6357: .roots method should work
  423. p = pchip([0, 1], [-1, 1])
  424. r = p.roots()
  425. assert_allclose(r, 0.5)
  426. class TestCubicSpline:
  427. @staticmethod
  428. def check_correctness(S, bc_start='not-a-knot', bc_end='not-a-knot',
  429. tol=1e-14):
  430. """Check that spline coefficients satisfy the continuity and boundary
  431. conditions."""
  432. x = S.x
  433. c = S.c
  434. dx = np.diff(x)
  435. dx = dx.reshape([dx.shape[0]] + [1] * (c.ndim - 2))
  436. dxi = dx[:-1]
  437. # Check C2 continuity.
  438. assert_allclose(c[3, 1:], c[0, :-1] * dxi**3 + c[1, :-1] * dxi**2 +
  439. c[2, :-1] * dxi + c[3, :-1], rtol=tol, atol=tol)
  440. assert_allclose(c[2, 1:], 3 * c[0, :-1] * dxi**2 +
  441. 2 * c[1, :-1] * dxi + c[2, :-1], rtol=tol, atol=tol)
  442. assert_allclose(c[1, 1:], 3 * c[0, :-1] * dxi + c[1, :-1],
  443. rtol=tol, atol=tol)
  444. # Check that we found a parabola, the third derivative is 0.
  445. if x.size == 3 and bc_start == 'not-a-knot' and bc_end == 'not-a-knot':
  446. assert_allclose(c[0], 0, rtol=tol, atol=tol)
  447. return
  448. # Check periodic boundary conditions.
  449. if bc_start == 'periodic':
  450. assert_allclose(S(x[0], 0), S(x[-1], 0), rtol=tol, atol=tol)
  451. assert_allclose(S(x[0], 1), S(x[-1], 1), rtol=tol, atol=tol)
  452. assert_allclose(S(x[0], 2), S(x[-1], 2), rtol=tol, atol=tol)
  453. return
  454. # Check other boundary conditions.
  455. if bc_start == 'not-a-knot':
  456. if x.size == 2:
  457. slope = (S(x[1]) - S(x[0])) / dx[0]
  458. assert_allclose(S(x[0], 1), slope, rtol=tol, atol=tol)
  459. else:
  460. assert_allclose(c[0, 0], c[0, 1], rtol=tol, atol=tol)
  461. elif bc_start == 'clamped':
  462. assert_allclose(S(x[0], 1), 0, rtol=tol, atol=tol)
  463. elif bc_start == 'natural':
  464. assert_allclose(S(x[0], 2), 0, rtol=tol, atol=tol)
  465. else:
  466. order, value = bc_start
  467. assert_allclose(S(x[0], order), value, rtol=tol, atol=tol)
  468. if bc_end == 'not-a-knot':
  469. if x.size == 2:
  470. slope = (S(x[1]) - S(x[0])) / dx[0]
  471. assert_allclose(S(x[1], 1), slope, rtol=tol, atol=tol)
  472. else:
  473. assert_allclose(c[0, -1], c[0, -2], rtol=tol, atol=tol)
  474. elif bc_end == 'clamped':
  475. assert_allclose(S(x[-1], 1), 0, rtol=tol, atol=tol)
  476. elif bc_end == 'natural':
  477. assert_allclose(S(x[-1], 2), 0, rtol=2*tol, atol=2*tol)
  478. else:
  479. order, value = bc_end
  480. assert_allclose(S(x[-1], order), value, rtol=tol, atol=tol)
  481. def check_all_bc(self, x, y, axis):
  482. deriv_shape = list(y.shape)
  483. del deriv_shape[axis]
  484. first_deriv = np.empty(deriv_shape)
  485. first_deriv.fill(2)
  486. second_deriv = np.empty(deriv_shape)
  487. second_deriv.fill(-1)
  488. bc_all = [
  489. 'not-a-knot',
  490. 'natural',
  491. 'clamped',
  492. (1, first_deriv),
  493. (2, second_deriv)
  494. ]
  495. for bc in bc_all[:3]:
  496. S = CubicSpline(x, y, axis=axis, bc_type=bc)
  497. self.check_correctness(S, bc, bc)
  498. for bc_start in bc_all:
  499. for bc_end in bc_all:
  500. S = CubicSpline(x, y, axis=axis, bc_type=(bc_start, bc_end))
  501. self.check_correctness(S, bc_start, bc_end, tol=2e-14)
  502. def test_general(self):
  503. x = np.array([-1, 0, 0.5, 2, 4, 4.5, 5.5, 9])
  504. y = np.array([0, -0.5, 2, 3, 2.5, 1, 1, 0.5])
  505. for n in [2, 3, x.size]:
  506. self.check_all_bc(x[:n], y[:n], 0)
  507. Y = np.empty((2, n, 2))
  508. Y[0, :, 0] = y[:n]
  509. Y[0, :, 1] = y[:n] - 1
  510. Y[1, :, 0] = y[:n] + 2
  511. Y[1, :, 1] = y[:n] + 3
  512. self.check_all_bc(x[:n], Y, 1)
  513. def test_periodic(self):
  514. for n in [2, 3, 5]:
  515. x = np.linspace(0, 2 * np.pi, n)
  516. y = np.cos(x)
  517. S = CubicSpline(x, y, bc_type='periodic')
  518. self.check_correctness(S, 'periodic', 'periodic')
  519. Y = np.empty((2, n, 2))
  520. Y[0, :, 0] = y
  521. Y[0, :, 1] = y + 2
  522. Y[1, :, 0] = y - 1
  523. Y[1, :, 1] = y + 5
  524. S = CubicSpline(x, Y, axis=1, bc_type='periodic')
  525. self.check_correctness(S, 'periodic', 'periodic')
  526. def test_periodic_eval(self):
  527. x = np.linspace(0, 2 * np.pi, 10)
  528. y = np.cos(x)
  529. S = CubicSpline(x, y, bc_type='periodic')
  530. assert_almost_equal(S(1), S(1 + 2 * np.pi), decimal=15)
  531. def test_second_derivative_continuity_gh_11758(self):
  532. # gh-11758: C2 continuity fail
  533. x = np.array([0.9, 1.3, 1.9, 2.1, 2.6, 3.0, 3.9, 4.4, 4.7, 5.0, 6.0,
  534. 7.0, 8.0, 9.2, 10.5, 11.3, 11.6, 12.0, 12.6, 13.0, 13.3])
  535. y = np.array([1.3, 1.5, 1.85, 2.1, 2.6, 2.7, 2.4, 2.15, 2.05, 2.1,
  536. 2.25, 2.3, 2.25, 1.95, 1.4, 0.9, 0.7, 0.6, 0.5, 0.4, 1.3])
  537. S = CubicSpline(x, y, bc_type='periodic', extrapolate='periodic')
  538. self.check_correctness(S, 'periodic', 'periodic')
  539. def test_three_points(self):
  540. # gh-11758: Fails computing a_m2_m1
  541. # In this case, s (first derivatives) could be found manually by solving
  542. # system of 2 linear equations. Due to solution of this system,
  543. # s[i] = (h1m2 + h2m1) / (h1 + h2), where h1 = x[1] - x[0], h2 = x[2] - x[1],
  544. # m1 = (y[1] - y[0]) / h1, m2 = (y[2] - y[1]) / h2
  545. x = np.array([1.0, 2.75, 3.0])
  546. y = np.array([1.0, 15.0, 1.0])
  547. S = CubicSpline(x, y, bc_type='periodic')
  548. self.check_correctness(S, 'periodic', 'periodic')
  549. assert_allclose(S.derivative(1)(x), np.array([-48.0, -48.0, -48.0]))
  550. def test_dtypes(self):
  551. x = np.array([0, 1, 2, 3], dtype=int)
  552. y = np.array([-5, 2, 3, 1], dtype=int)
  553. S = CubicSpline(x, y)
  554. self.check_correctness(S)
  555. y = np.array([-1+1j, 0.0, 1-1j, 0.5-1.5j])
  556. S = CubicSpline(x, y)
  557. self.check_correctness(S)
  558. S = CubicSpline(x, x ** 3, bc_type=("natural", (1, 2j)))
  559. self.check_correctness(S, "natural", (1, 2j))
  560. y = np.array([-5, 2, 3, 1])
  561. S = CubicSpline(x, y, bc_type=[(1, 2 + 0.5j), (2, 0.5 - 1j)])
  562. self.check_correctness(S, (1, 2 + 0.5j), (2, 0.5 - 1j))
  563. def test_small_dx(self):
  564. rng = np.random.RandomState(0)
  565. x = np.sort(rng.uniform(size=100))
  566. y = 1e4 + rng.uniform(size=100)
  567. S = CubicSpline(x, y)
  568. self.check_correctness(S, tol=1e-13)
  569. def test_incorrect_inputs(self):
  570. x = np.array([1, 2, 3, 4])
  571. y = np.array([1, 2, 3, 4])
  572. xc = np.array([1 + 1j, 2, 3, 4])
  573. xn = np.array([np.nan, 2, 3, 4])
  574. xo = np.array([2, 1, 3, 4])
  575. yn = np.array([np.nan, 2, 3, 4])
  576. y3 = [1, 2, 3]
  577. x1 = [1]
  578. y1 = [1]
  579. assert_raises(ValueError, CubicSpline, xc, y)
  580. assert_raises(ValueError, CubicSpline, xn, y)
  581. assert_raises(ValueError, CubicSpline, x, yn)
  582. assert_raises(ValueError, CubicSpline, xo, y)
  583. assert_raises(ValueError, CubicSpline, x, y3)
  584. assert_raises(ValueError, CubicSpline, x[:, np.newaxis], y)
  585. assert_raises(ValueError, CubicSpline, x1, y1)
  586. wrong_bc = [('periodic', 'clamped'),
  587. ((2, 0), (3, 10)),
  588. ((1, 0), ),
  589. (0., 0.),
  590. 'not-a-typo']
  591. for bc_type in wrong_bc:
  592. assert_raises(ValueError, CubicSpline, x, y, 0, bc_type, True)
  593. # Shapes mismatch when giving arbitrary derivative values:
  594. Y = np.c_[y, y]
  595. bc1 = ('clamped', (1, 0))
  596. bc2 = ('clamped', (1, [0, 0, 0]))
  597. bc3 = ('clamped', (1, [[0, 0]]))
  598. assert_raises(ValueError, CubicSpline, x, Y, 0, bc1, True)
  599. assert_raises(ValueError, CubicSpline, x, Y, 0, bc2, True)
  600. assert_raises(ValueError, CubicSpline, x, Y, 0, bc3, True)
  601. # periodic condition, y[-1] must be equal to y[0]:
  602. assert_raises(ValueError, CubicSpline, x, y, 0, 'periodic', True)
  603. def test_CubicHermiteSpline_correctness():
  604. x = [0, 2, 7]
  605. y = [-1, 2, 3]
  606. dydx = [0, 3, 7]
  607. s = CubicHermiteSpline(x, y, dydx)
  608. assert_allclose(s(x), y, rtol=1e-15)
  609. assert_allclose(s(x, 1), dydx, rtol=1e-15)
  610. def test_CubicHermiteSpline_error_handling():
  611. x = [1, 2, 3]
  612. y = [0, 3, 5]
  613. dydx = [1, -1, 2, 3]
  614. assert_raises(ValueError, CubicHermiteSpline, x, y, dydx)
  615. dydx_with_nan = [1, 0, np.nan]
  616. assert_raises(ValueError, CubicHermiteSpline, x, y, dydx_with_nan)
  617. def test_roots_extrapolate_gh_11185():
  618. x = np.array([0.001, 0.002])
  619. y = np.array([1.66066935e-06, 1.10410807e-06])
  620. dy = np.array([-1.60061854, -1.600619])
  621. p = CubicHermiteSpline(x, y, dy)
  622. # roots(extrapolate=True) for a polynomial with a single interval
  623. # should return all three real roots
  624. r = p.roots(extrapolate=True)
  625. assert_equal(p.c.shape[1], 1)
  626. assert_equal(r.size, 3)
  627. class TestZeroSizeArrays:
  628. # regression tests for gh-17241 : CubicSpline et al must not segfault
  629. # when y.size == 0
  630. # The two methods below are _almost_ the same, but not quite:
  631. # one is for objects which have the `bc_type` argument (CubicSpline)
  632. # and the other one is for those which do not (Pchip, Akima1D)
  633. @pytest.mark.parametrize('y', [np.zeros((10, 0, 5)),
  634. np.zeros((10, 5, 0))])
  635. @pytest.mark.parametrize('bc_type',
  636. ['not-a-knot', 'periodic', 'natural', 'clamped'])
  637. @pytest.mark.parametrize('axis', [0, 1, 2])
  638. @pytest.mark.parametrize('cls', [make_interp_spline, CubicSpline])
  639. def test_zero_size(self, cls, y, bc_type, axis):
  640. x = np.arange(10)
  641. xval = np.arange(3)
  642. obj = cls(x, y, bc_type=bc_type)
  643. assert obj(xval).size == 0
  644. assert obj(xval).shape == xval.shape + y.shape[1:]
  645. # Also check with an explicit non-default axis
  646. yt = np.moveaxis(y, 0, axis) # (10, 0, 5) --> (0, 10, 5) if axis=1 etc
  647. obj = cls(x, yt, bc_type=bc_type, axis=axis)
  648. sh = yt.shape[:axis] + (xval.size, ) + yt.shape[axis+1:]
  649. assert obj(xval).size == 0
  650. assert obj(xval).shape == sh
  651. @pytest.mark.parametrize('y', [np.zeros((10, 0, 5)),
  652. np.zeros((10, 5, 0))])
  653. @pytest.mark.parametrize('axis', [0, 1, 2])
  654. @pytest.mark.parametrize('cls', [PchipInterpolator, Akima1DInterpolator])
  655. def test_zero_size_2(self, cls, y, axis):
  656. x = np.arange(10)
  657. xval = np.arange(3)
  658. obj = cls(x, y)
  659. assert obj(xval).size == 0
  660. assert obj(xval).shape == xval.shape + y.shape[1:]
  661. # Also check with an explicit non-default axis
  662. yt = np.moveaxis(y, 0, axis) # (10, 0, 5) --> (0, 10, 5) if axis=1 etc
  663. obj = cls(x, yt, axis=axis)
  664. sh = yt.shape[:axis] + (xval.size, ) + yt.shape[axis+1:]
  665. assert obj(xval).size == 0
  666. assert obj(xval).shape == sh