test_slerp.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import numpy as np
  2. from numpy.testing import assert_allclose
  3. import pytest
  4. from scipy.spatial import geometric_slerp
  5. def _generate_spherical_points(ndim=3, n_pts=2):
  6. # generate uniform points on sphere
  7. # see: https://stackoverflow.com/a/23785326
  8. # tentatively extended to arbitrary dims
  9. # for 0-sphere it will always produce antipodes
  10. np.random.seed(123)
  11. points = np.random.normal(size=(n_pts, ndim))
  12. points /= np.linalg.norm(points, axis=1)[:, np.newaxis]
  13. return points[0], points[1]
  14. class TestGeometricSlerp:
  15. # Test various properties of the geometric slerp code
  16. @pytest.mark.parametrize("n_dims", [2, 3, 5, 7, 9])
  17. @pytest.mark.parametrize("n_pts", [0, 3, 17])
  18. def test_shape_property(self, n_dims, n_pts):
  19. # geometric_slerp output shape should match
  20. # input dimensionality & requested number
  21. # of interpolation points
  22. start, end = _generate_spherical_points(n_dims, 2)
  23. actual = geometric_slerp(start=start,
  24. end=end,
  25. t=np.linspace(0, 1, n_pts))
  26. assert actual.shape == (n_pts, n_dims)
  27. @pytest.mark.parametrize("n_dims", [2, 3, 5, 7, 9])
  28. @pytest.mark.parametrize("n_pts", [3, 17])
  29. def test_include_ends(self, n_dims, n_pts):
  30. # geometric_slerp should return a data structure
  31. # that includes the start and end coordinates
  32. # when t includes 0 and 1 ends
  33. # this is convenient for plotting surfaces represented
  34. # by interpolations for example
  35. # the generator doesn't work so well for the unit
  36. # sphere (it always produces antipodes), so use
  37. # custom values there
  38. start, end = _generate_spherical_points(n_dims, 2)
  39. actual = geometric_slerp(start=start,
  40. end=end,
  41. t=np.linspace(0, 1, n_pts))
  42. assert_allclose(actual[0], start)
  43. assert_allclose(actual[-1], end)
  44. @pytest.mark.parametrize("start, end", [
  45. # both arrays are not flat
  46. (np.zeros((1, 3)), np.ones((1, 3))),
  47. # only start array is not flat
  48. (np.zeros((1, 3)), np.ones(3)),
  49. # only end array is not flat
  50. (np.zeros(1), np.ones((3, 1))),
  51. ])
  52. def test_input_shape_flat(self, start, end):
  53. # geometric_slerp should handle input arrays that are
  54. # not flat appropriately
  55. with pytest.raises(ValueError, match='one-dimensional'):
  56. geometric_slerp(start=start,
  57. end=end,
  58. t=np.linspace(0, 1, 10))
  59. @pytest.mark.parametrize("start, end", [
  60. # 7-D and 3-D ends
  61. (np.zeros(7), np.ones(3)),
  62. # 2-D and 1-D ends
  63. (np.zeros(2), np.ones(1)),
  64. # empty, "3D" will also get caught this way
  65. (np.array([]), np.ones(3)),
  66. ])
  67. def test_input_dim_mismatch(self, start, end):
  68. # geometric_slerp must appropriately handle cases where
  69. # an interpolation is attempted across two different
  70. # dimensionalities
  71. with pytest.raises(ValueError, match='dimensions'):
  72. geometric_slerp(start=start,
  73. end=end,
  74. t=np.linspace(0, 1, 10))
  75. @pytest.mark.parametrize("start, end", [
  76. # both empty
  77. (np.array([]), np.array([])),
  78. ])
  79. def test_input_at_least1d(self, start, end):
  80. # empty inputs to geometric_slerp must
  81. # be handled appropriately when not detected
  82. # by mismatch
  83. with pytest.raises(ValueError, match='at least two-dim'):
  84. geometric_slerp(start=start,
  85. end=end,
  86. t=np.linspace(0, 1, 10))
  87. @pytest.mark.parametrize("start, end, expected", [
  88. # North and South Poles are definitely antipodes
  89. # but should be handled gracefully now
  90. (np.array([0, 0, 1.0]), np.array([0, 0, -1.0]), "warning"),
  91. # this case will issue a warning & be handled
  92. # gracefully as well;
  93. # North Pole was rotated very slightly
  94. # using r = R.from_euler('x', 0.035, degrees=True)
  95. # to achieve Euclidean distance offset from diameter by
  96. # 9.328908379124812e-08, within the default tol
  97. (np.array([0.00000000e+00,
  98. -6.10865200e-04,
  99. 9.99999813e-01]), np.array([0, 0, -1.0]), "warning"),
  100. # this case should succeed without warning because a
  101. # sufficiently large
  102. # rotation was applied to North Pole point to shift it
  103. # to a Euclidean distance of 2.3036691931821451e-07
  104. # from South Pole, which is larger than tol
  105. (np.array([0.00000000e+00,
  106. -9.59930941e-04,
  107. 9.99999539e-01]), np.array([0, 0, -1.0]), "success"),
  108. ])
  109. def test_handle_antipodes(self, start, end, expected):
  110. # antipodal points must be handled appropriately;
  111. # there are an infinite number of possible geodesic
  112. # interpolations between them in higher dims
  113. if expected == "warning":
  114. with pytest.warns(UserWarning, match='antipodes'):
  115. res = geometric_slerp(start=start,
  116. end=end,
  117. t=np.linspace(0, 1, 10))
  118. else:
  119. res = geometric_slerp(start=start,
  120. end=end,
  121. t=np.linspace(0, 1, 10))
  122. # antipodes or near-antipodes should still produce
  123. # slerp paths on the surface of the sphere (but they
  124. # may be ambiguous):
  125. assert_allclose(np.linalg.norm(res, axis=1), 1.0)
  126. @pytest.mark.parametrize("start, end, expected", [
  127. # 2-D with n_pts=4 (two new interpolation points)
  128. # this is an actual circle
  129. (np.array([1, 0]),
  130. np.array([0, 1]),
  131. np.array([[1, 0],
  132. [np.sqrt(3) / 2, 0.5], # 30 deg on unit circle
  133. [0.5, np.sqrt(3) / 2], # 60 deg on unit circle
  134. [0, 1]])),
  135. # likewise for 3-D (add z = 0 plane)
  136. # this is an ordinary sphere
  137. (np.array([1, 0, 0]),
  138. np.array([0, 1, 0]),
  139. np.array([[1, 0, 0],
  140. [np.sqrt(3) / 2, 0.5, 0],
  141. [0.5, np.sqrt(3) / 2, 0],
  142. [0, 1, 0]])),
  143. # for 5-D, pad more columns with constants
  144. # zeros are easiest--non-zero values on unit
  145. # circle are more difficult to reason about
  146. # at higher dims
  147. (np.array([1, 0, 0, 0, 0]),
  148. np.array([0, 1, 0, 0, 0]),
  149. np.array([[1, 0, 0, 0, 0],
  150. [np.sqrt(3) / 2, 0.5, 0, 0, 0],
  151. [0.5, np.sqrt(3) / 2, 0, 0, 0],
  152. [0, 1, 0, 0, 0]])),
  153. ])
  154. def test_straightforward_examples(self, start, end, expected):
  155. # some straightforward interpolation tests, sufficiently
  156. # simple to use the unit circle to deduce expected values;
  157. # for larger dimensions, pad with constants so that the
  158. # data is N-D but simpler to reason about
  159. actual = geometric_slerp(start=start,
  160. end=end,
  161. t=np.linspace(0, 1, 4))
  162. assert_allclose(actual, expected, atol=1e-16)
  163. @pytest.mark.parametrize("t", [
  164. # both interval ends clearly violate limits
  165. np.linspace(-20, 20, 300),
  166. # only one interval end violating limit slightly
  167. np.linspace(-0.0001, 0.0001, 17),
  168. ])
  169. def test_t_values_limits(self, t):
  170. # geometric_slerp() should appropriately handle
  171. # interpolation parameters < 0 and > 1
  172. with pytest.raises(ValueError, match='interpolation parameter'):
  173. _ = geometric_slerp(start=np.array([1, 0]),
  174. end=np.array([0, 1]),
  175. t=t)
  176. @pytest.mark.parametrize("start, end", [
  177. (np.array([1]),
  178. np.array([0])),
  179. (np.array([0]),
  180. np.array([1])),
  181. (np.array([-17.7]),
  182. np.array([165.9])),
  183. ])
  184. def test_0_sphere_handling(self, start, end):
  185. # it does not make sense to interpolate the set of
  186. # two points that is the 0-sphere
  187. with pytest.raises(ValueError, match='at least two-dim'):
  188. _ = geometric_slerp(start=start,
  189. end=end,
  190. t=np.linspace(0, 1, 4))
  191. @pytest.mark.parametrize("tol", [
  192. # an integer currently raises
  193. 5,
  194. # string raises
  195. "7",
  196. # list and arrays also raise
  197. [5, 6, 7], np.array(9.0),
  198. ])
  199. def test_tol_type(self, tol):
  200. # geometric_slerp() should raise if tol is not
  201. # a suitable float type
  202. with pytest.raises(ValueError, match='must be a float'):
  203. _ = geometric_slerp(start=np.array([1, 0]),
  204. end=np.array([0, 1]),
  205. t=np.linspace(0, 1, 5),
  206. tol=tol)
  207. @pytest.mark.parametrize("tol", [
  208. -5e-6,
  209. -7e-10,
  210. ])
  211. def test_tol_sign(self, tol):
  212. # geometric_slerp() currently handles negative
  213. # tol values, as long as they are floats
  214. _ = geometric_slerp(start=np.array([1, 0]),
  215. end=np.array([0, 1]),
  216. t=np.linspace(0, 1, 5),
  217. tol=tol)
  218. @pytest.mark.parametrize("start, end", [
  219. # 1-sphere (circle) with one point at origin
  220. # and the other on the circle
  221. (np.array([1, 0]), np.array([0, 0])),
  222. # 2-sphere (normal sphere) with both points
  223. # just slightly off sphere by the same amount
  224. # in different directions
  225. (np.array([1 + 1e-6, 0, 0]),
  226. np.array([0, 1 - 1e-6, 0])),
  227. # same thing in 4-D
  228. (np.array([1 + 1e-6, 0, 0, 0]),
  229. np.array([0, 1 - 1e-6, 0, 0])),
  230. ])
  231. def test_unit_sphere_enforcement(self, start, end):
  232. # geometric_slerp() should raise on input that clearly
  233. # cannot be on an n-sphere of radius 1
  234. with pytest.raises(ValueError, match='unit n-sphere'):
  235. geometric_slerp(start=start,
  236. end=end,
  237. t=np.linspace(0, 1, 5))
  238. @pytest.mark.parametrize("start, end", [
  239. # 1-sphere 45 degree case
  240. (np.array([1, 0]),
  241. np.array([np.sqrt(2) / 2.,
  242. np.sqrt(2) / 2.])),
  243. # 2-sphere 135 degree case
  244. (np.array([1, 0]),
  245. np.array([-np.sqrt(2) / 2.,
  246. np.sqrt(2) / 2.])),
  247. ])
  248. @pytest.mark.parametrize("t_func", [
  249. np.linspace, np.logspace])
  250. def test_order_handling(self, start, end, t_func):
  251. # geometric_slerp() should handle scenarios with
  252. # ascending and descending t value arrays gracefully;
  253. # results should simply be reversed
  254. # for scrambled / unsorted parameters, the same values
  255. # should be returned, just in scrambled order
  256. num_t_vals = 20
  257. np.random.seed(789)
  258. forward_t_vals = t_func(0, 10, num_t_vals)
  259. # normalize to max of 1
  260. forward_t_vals /= forward_t_vals.max()
  261. reverse_t_vals = np.flipud(forward_t_vals)
  262. shuffled_indices = np.arange(num_t_vals)
  263. np.random.shuffle(shuffled_indices)
  264. scramble_t_vals = forward_t_vals.copy()[shuffled_indices]
  265. forward_results = geometric_slerp(start=start,
  266. end=end,
  267. t=forward_t_vals)
  268. reverse_results = geometric_slerp(start=start,
  269. end=end,
  270. t=reverse_t_vals)
  271. scrambled_results = geometric_slerp(start=start,
  272. end=end,
  273. t=scramble_t_vals)
  274. # check fidelity to input order
  275. assert_allclose(forward_results, np.flipud(reverse_results))
  276. assert_allclose(forward_results[shuffled_indices],
  277. scrambled_results)
  278. @pytest.mark.parametrize("t", [
  279. # string:
  280. "15, 5, 7",
  281. # complex numbers currently produce a warning
  282. # but not sure we need to worry about it too much:
  283. # [3 + 1j, 5 + 2j],
  284. ])
  285. def test_t_values_conversion(self, t):
  286. with pytest.raises(ValueError):
  287. _ = geometric_slerp(start=np.array([1]),
  288. end=np.array([0]),
  289. t=t)
  290. def test_accept_arraylike(self):
  291. # array-like support requested by reviewer
  292. # in gh-10380
  293. actual = geometric_slerp([1, 0], [0, 1], [0, 1/3, 0.5, 2/3, 1])
  294. # expected values are based on visual inspection
  295. # of the unit circle for the progressions along
  296. # the circumference provided in t
  297. expected = np.array([[1, 0],
  298. [np.sqrt(3) / 2, 0.5],
  299. [np.sqrt(2) / 2,
  300. np.sqrt(2) / 2],
  301. [0.5, np.sqrt(3) / 2],
  302. [0, 1]], dtype=np.float64)
  303. # Tyler's original Cython implementation of geometric_slerp
  304. # can pass at atol=0 here, but on balance we will accept
  305. # 1e-16 for an implementation that avoids Cython and
  306. # makes up accuracy ground elsewhere
  307. assert_allclose(actual, expected, atol=1e-16)
  308. def test_scalar_t(self):
  309. # when t is a scalar, return value is a single
  310. # interpolated point of the appropriate dimensionality
  311. # requested by reviewer in gh-10380
  312. actual = geometric_slerp([1, 0], [0, 1], 0.5)
  313. expected = np.array([np.sqrt(2) / 2,
  314. np.sqrt(2) / 2], dtype=np.float64)
  315. assert actual.shape == (2,)
  316. assert_allclose(actual, expected)
  317. @pytest.mark.parametrize('start', [
  318. np.array([1, 0, 0]),
  319. np.array([0, 1]),
  320. ])
  321. @pytest.mark.parametrize('t', [
  322. np.array(1),
  323. np.array([1]),
  324. np.array([[1]]),
  325. np.array([[[1]]]),
  326. np.array([]),
  327. np.linspace(0, 1, 5),
  328. ])
  329. def test_degenerate_input(self, start, t):
  330. if np.asarray(t).ndim > 1:
  331. with pytest.raises(ValueError):
  332. geometric_slerp(start=start, end=start, t=t)
  333. else:
  334. shape = (t.size,) + start.shape
  335. expected = np.full(shape, start)
  336. actual = geometric_slerp(start=start, end=start, t=t)
  337. assert_allclose(actual, expected)
  338. # Check that degenerate and non-degenerate
  339. # inputs yield the same size
  340. non_degenerate = geometric_slerp(start=start, end=start[::-1], t=t)
  341. assert actual.size == non_degenerate.size
  342. @pytest.mark.parametrize('k', np.logspace(-10, -1, 10))
  343. def test_numerical_stability_pi(self, k):
  344. # geometric_slerp should have excellent numerical
  345. # stability for angles approaching pi between
  346. # the start and end points
  347. angle = np.pi - k
  348. ts = np.linspace(0, 1, 100)
  349. P = np.array([1, 0, 0, 0])
  350. Q = np.array([np.cos(angle), np.sin(angle), 0, 0])
  351. # the test should only be enforced for cases where
  352. # geometric_slerp determines that the input is actually
  353. # on the unit sphere
  354. with np.testing.suppress_warnings() as sup:
  355. sup.filter(UserWarning)
  356. result = geometric_slerp(P, Q, ts, 1e-18)
  357. norms = np.linalg.norm(result, axis=1)
  358. error = np.max(np.abs(norms - 1))
  359. assert error < 4e-15
  360. @pytest.mark.parametrize('t', [
  361. [[0, 0.5]],
  362. [[[[[[[[[0, 0.5]]]]]]]]],
  363. ])
  364. def test_interpolation_param_ndim(self, t):
  365. # regression test for gh-14465
  366. arr1 = np.array([0, 1])
  367. arr2 = np.array([1, 0])
  368. with pytest.raises(ValueError):
  369. geometric_slerp(start=arr1,
  370. end=arr2,
  371. t=t)
  372. with pytest.raises(ValueError):
  373. geometric_slerp(start=arr1,
  374. end=arr1,
  375. t=t)