test_kdtree.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470
  1. # Copyright Anne M. Archibald 2008
  2. # Released under the scipy license
  3. import os
  4. from numpy.testing import (assert_equal, assert_array_equal, assert_,
  5. assert_almost_equal, assert_array_almost_equal,
  6. assert_allclose)
  7. from pytest import raises as assert_raises
  8. import pytest
  9. from platform import python_implementation
  10. import numpy as np
  11. from scipy.spatial import KDTree, Rectangle, distance_matrix, cKDTree
  12. from scipy.spatial._ckdtree import cKDTreeNode
  13. from scipy.spatial import minkowski_distance
  14. import itertools
  15. @pytest.fixture(params=[KDTree, cKDTree])
  16. def kdtree_type(request):
  17. return request.param
  18. def KDTreeTest(kls):
  19. """Class decorator to create test cases for KDTree and cKDTree
  20. Tests use the class variable ``kdtree_type`` as the tree constructor.
  21. """
  22. if not kls.__name__.startswith('_Test'):
  23. raise RuntimeError("Expected a class name starting with _Test")
  24. for tree in (KDTree, cKDTree):
  25. test_name = kls.__name__[1:] + '_' + tree.__name__
  26. if test_name in globals():
  27. raise RuntimeError("Duplicated test name: " + test_name)
  28. # Create a new sub-class with kdtree_type defined
  29. test_case = type(test_name, (kls,), {'kdtree_type': tree})
  30. globals()[test_name] = test_case
  31. return kls
  32. def distance_box(a, b, p, boxsize):
  33. diff = a - b
  34. diff[diff > 0.5 * boxsize] -= boxsize
  35. diff[diff < -0.5 * boxsize] += boxsize
  36. d = minkowski_distance(diff, 0, p)
  37. return d
  38. class ConsistencyTests:
  39. def distance(self, a, b, p):
  40. return minkowski_distance(a, b, p)
  41. def test_nearest(self):
  42. x = self.x
  43. d, i = self.kdtree.query(x, 1)
  44. assert_almost_equal(d**2, np.sum((x-self.data[i])**2))
  45. eps = 1e-8
  46. assert_(np.all(np.sum((self.data-x[np.newaxis, :])**2, axis=1) > d**2-eps))
  47. def test_m_nearest(self):
  48. x = self.x
  49. m = self.m
  50. dd, ii = self.kdtree.query(x, m)
  51. d = np.amax(dd)
  52. i = ii[np.argmax(dd)]
  53. assert_almost_equal(d**2, np.sum((x-self.data[i])**2))
  54. eps = 1e-8
  55. assert_equal(np.sum(np.sum((self.data-x[np.newaxis, :])**2, axis=1) < d**2+eps), m)
  56. def test_points_near(self):
  57. x = self.x
  58. d = self.d
  59. dd, ii = self.kdtree.query(x, k=self.kdtree.n, distance_upper_bound=d)
  60. eps = 1e-8
  61. hits = 0
  62. for near_d, near_i in zip(dd, ii):
  63. if near_d == np.inf:
  64. continue
  65. hits += 1
  66. assert_almost_equal(near_d**2, np.sum((x-self.data[near_i])**2))
  67. assert_(near_d < d+eps, "near_d=%g should be less than %g" % (near_d, d))
  68. assert_equal(np.sum(self.distance(self.data, x, 2) < d**2+eps), hits)
  69. def test_points_near_l1(self):
  70. x = self.x
  71. d = self.d
  72. dd, ii = self.kdtree.query(x, k=self.kdtree.n, p=1, distance_upper_bound=d)
  73. eps = 1e-8
  74. hits = 0
  75. for near_d, near_i in zip(dd, ii):
  76. if near_d == np.inf:
  77. continue
  78. hits += 1
  79. assert_almost_equal(near_d, self.distance(x, self.data[near_i], 1))
  80. assert_(near_d < d+eps, "near_d=%g should be less than %g" % (near_d, d))
  81. assert_equal(np.sum(self.distance(self.data, x, 1) < d+eps), hits)
  82. def test_points_near_linf(self):
  83. x = self.x
  84. d = self.d
  85. dd, ii = self.kdtree.query(x, k=self.kdtree.n, p=np.inf, distance_upper_bound=d)
  86. eps = 1e-8
  87. hits = 0
  88. for near_d, near_i in zip(dd, ii):
  89. if near_d == np.inf:
  90. continue
  91. hits += 1
  92. assert_almost_equal(near_d, self.distance(x, self.data[near_i], np.inf))
  93. assert_(near_d < d+eps, "near_d=%g should be less than %g" % (near_d, d))
  94. assert_equal(np.sum(self.distance(self.data, x, np.inf) < d+eps), hits)
  95. def test_approx(self):
  96. x = self.x
  97. k = self.k
  98. eps = 0.1
  99. d_real, i_real = self.kdtree.query(x, k)
  100. d, i = self.kdtree.query(x, k, eps=eps)
  101. assert_(np.all(d <= d_real*(1+eps)))
  102. @KDTreeTest
  103. class _Test_random(ConsistencyTests):
  104. def setup_method(self):
  105. self.n = 100
  106. self.m = 4
  107. np.random.seed(1234)
  108. self.data = np.random.randn(self.n, self.m)
  109. self.kdtree = self.kdtree_type(self.data, leafsize=2)
  110. self.x = np.random.randn(self.m)
  111. self.d = 0.2
  112. self.k = 10
  113. @KDTreeTest
  114. class _Test_random_far(_Test_random):
  115. def setup_method(self):
  116. super().setup_method()
  117. self.x = np.random.randn(self.m)+10
  118. @KDTreeTest
  119. class _Test_small(ConsistencyTests):
  120. def setup_method(self):
  121. self.data = np.array([[0, 0, 0],
  122. [0, 0, 1],
  123. [0, 1, 0],
  124. [0, 1, 1],
  125. [1, 0, 0],
  126. [1, 0, 1],
  127. [1, 1, 0],
  128. [1, 1, 1]])
  129. self.kdtree = self.kdtree_type(self.data)
  130. self.n = self.kdtree.n
  131. self.m = self.kdtree.m
  132. np.random.seed(1234)
  133. self.x = np.random.randn(3)
  134. self.d = 0.5
  135. self.k = 4
  136. def test_nearest(self):
  137. assert_array_equal(
  138. self.kdtree.query((0, 0, 0.1), 1),
  139. (0.1, 0))
  140. def test_nearest_two(self):
  141. assert_array_equal(
  142. self.kdtree.query((0, 0, 0.1), 2),
  143. ([0.1, 0.9], [0, 1]))
  144. @KDTreeTest
  145. class _Test_small_nonleaf(_Test_small):
  146. def setup_method(self):
  147. super().setup_method()
  148. self.kdtree = self.kdtree_type(self.data, leafsize=1)
  149. class Test_vectorization_KDTree:
  150. def setup_method(self):
  151. self.data = np.array([[0, 0, 0],
  152. [0, 0, 1],
  153. [0, 1, 0],
  154. [0, 1, 1],
  155. [1, 0, 0],
  156. [1, 0, 1],
  157. [1, 1, 0],
  158. [1, 1, 1]])
  159. self.kdtree = KDTree(self.data)
  160. def test_single_query(self):
  161. d, i = self.kdtree.query(np.array([0, 0, 0]))
  162. assert_(isinstance(d, float))
  163. assert_(np.issubdtype(i, np.signedinteger))
  164. def test_vectorized_query(self):
  165. d, i = self.kdtree.query(np.zeros((2, 4, 3)))
  166. assert_equal(np.shape(d), (2, 4))
  167. assert_equal(np.shape(i), (2, 4))
  168. def test_single_query_multiple_neighbors(self):
  169. s = 23
  170. kk = self.kdtree.n+s
  171. d, i = self.kdtree.query(np.array([0, 0, 0]), k=kk)
  172. assert_equal(np.shape(d), (kk,))
  173. assert_equal(np.shape(i), (kk,))
  174. assert_(np.all(~np.isfinite(d[-s:])))
  175. assert_(np.all(i[-s:] == self.kdtree.n))
  176. def test_vectorized_query_multiple_neighbors(self):
  177. s = 23
  178. kk = self.kdtree.n+s
  179. d, i = self.kdtree.query(np.zeros((2, 4, 3)), k=kk)
  180. assert_equal(np.shape(d), (2, 4, kk))
  181. assert_equal(np.shape(i), (2, 4, kk))
  182. assert_(np.all(~np.isfinite(d[:, :, -s:])))
  183. assert_(np.all(i[:, :, -s:] == self.kdtree.n))
  184. def test_query_raises_for_k_none(self):
  185. x = 1.0
  186. with pytest.raises(ValueError, match="k must be an integer or*"):
  187. self.kdtree.query(x, k=None)
  188. class Test_vectorization_cKDTree:
  189. def setup_method(self):
  190. self.data = np.array([[0, 0, 0],
  191. [0, 0, 1],
  192. [0, 1, 0],
  193. [0, 1, 1],
  194. [1, 0, 0],
  195. [1, 0, 1],
  196. [1, 1, 0],
  197. [1, 1, 1]])
  198. self.kdtree = cKDTree(self.data)
  199. def test_single_query(self):
  200. d, i = self.kdtree.query([0, 0, 0])
  201. assert_(isinstance(d, float))
  202. assert_(isinstance(i, int))
  203. def test_vectorized_query(self):
  204. d, i = self.kdtree.query(np.zeros((2, 4, 3)))
  205. assert_equal(np.shape(d), (2, 4))
  206. assert_equal(np.shape(i), (2, 4))
  207. def test_vectorized_query_noncontiguous_values(self):
  208. np.random.seed(1234)
  209. qs = np.random.randn(3, 1000).T
  210. ds, i_s = self.kdtree.query(qs)
  211. for q, d, i in zip(qs, ds, i_s):
  212. assert_equal(self.kdtree.query(q), (d, i))
  213. def test_single_query_multiple_neighbors(self):
  214. s = 23
  215. kk = self.kdtree.n+s
  216. d, i = self.kdtree.query([0, 0, 0], k=kk)
  217. assert_equal(np.shape(d), (kk,))
  218. assert_equal(np.shape(i), (kk,))
  219. assert_(np.all(~np.isfinite(d[-s:])))
  220. assert_(np.all(i[-s:] == self.kdtree.n))
  221. def test_vectorized_query_multiple_neighbors(self):
  222. s = 23
  223. kk = self.kdtree.n+s
  224. d, i = self.kdtree.query(np.zeros((2, 4, 3)), k=kk)
  225. assert_equal(np.shape(d), (2, 4, kk))
  226. assert_equal(np.shape(i), (2, 4, kk))
  227. assert_(np.all(~np.isfinite(d[:, :, -s:])))
  228. assert_(np.all(i[:, :, -s:] == self.kdtree.n))
  229. class ball_consistency:
  230. tol = 0.0
  231. def distance(self, a, b, p):
  232. return minkowski_distance(a * 1.0, b * 1.0, p)
  233. def test_in_ball(self):
  234. x = np.atleast_2d(self.x)
  235. d = np.broadcast_to(self.d, x.shape[:-1])
  236. l = self.T.query_ball_point(x, self.d, p=self.p, eps=self.eps)
  237. for i, ind in enumerate(l):
  238. dist = self.distance(self.data[ind], x[i], self.p) - d[i]*(1.+self.eps)
  239. norm = self.distance(self.data[ind], x[i], self.p) + d[i]*(1.+self.eps)
  240. assert_array_equal(dist < self.tol * norm, True)
  241. def test_found_all(self):
  242. x = np.atleast_2d(self.x)
  243. d = np.broadcast_to(self.d, x.shape[:-1])
  244. l = self.T.query_ball_point(x, self.d, p=self.p, eps=self.eps)
  245. for i, ind in enumerate(l):
  246. c = np.ones(self.T.n, dtype=bool)
  247. c[ind] = False
  248. dist = self.distance(self.data[c], x[i], self.p) - d[i]/(1.+self.eps)
  249. norm = self.distance(self.data[c], x[i], self.p) + d[i]/(1.+self.eps)
  250. assert_array_equal(dist > -self.tol * norm, True)
  251. @KDTreeTest
  252. class _Test_random_ball(ball_consistency):
  253. def setup_method(self):
  254. n = 100
  255. m = 4
  256. np.random.seed(1234)
  257. self.data = np.random.randn(n, m)
  258. self.T = self.kdtree_type(self.data, leafsize=2)
  259. self.x = np.random.randn(m)
  260. self.p = 2.
  261. self.eps = 0
  262. self.d = 0.2
  263. @KDTreeTest
  264. class _Test_random_ball_periodic(ball_consistency):
  265. def distance(self, a, b, p):
  266. return distance_box(a, b, p, 1.0)
  267. def setup_method(self):
  268. n = 10000
  269. m = 4
  270. np.random.seed(1234)
  271. self.data = np.random.uniform(size=(n, m))
  272. self.T = self.kdtree_type(self.data, leafsize=2, boxsize=1)
  273. self.x = np.full(m, 0.1)
  274. self.p = 2.
  275. self.eps = 0
  276. self.d = 0.2
  277. def test_in_ball_outside(self):
  278. l = self.T.query_ball_point(self.x + 1.0, self.d, p=self.p, eps=self.eps)
  279. for i in l:
  280. assert_(self.distance(self.data[i], self.x, self.p) <= self.d*(1.+self.eps))
  281. l = self.T.query_ball_point(self.x - 1.0, self.d, p=self.p, eps=self.eps)
  282. for i in l:
  283. assert_(self.distance(self.data[i], self.x, self.p) <= self.d*(1.+self.eps))
  284. def test_found_all_outside(self):
  285. c = np.ones(self.T.n, dtype=bool)
  286. l = self.T.query_ball_point(self.x + 1.0, self.d, p=self.p, eps=self.eps)
  287. c[l] = False
  288. assert_(np.all(self.distance(self.data[c], self.x, self.p) >= self.d/(1.+self.eps)))
  289. l = self.T.query_ball_point(self.x - 1.0, self.d, p=self.p, eps=self.eps)
  290. c[l] = False
  291. assert_(np.all(self.distance(self.data[c], self.x, self.p) >= self.d/(1.+self.eps)))
  292. @KDTreeTest
  293. class _Test_random_ball_largep_issue9890(ball_consistency):
  294. # allow some roundoff errors due to numerical issues
  295. tol = 1e-13
  296. def setup_method(self):
  297. n = 1000
  298. m = 2
  299. np.random.seed(123)
  300. self.data = np.random.randint(100, 1000, size=(n, m))
  301. self.T = self.kdtree_type(self.data)
  302. self.x = self.data
  303. self.p = 100
  304. self.eps = 0
  305. self.d = 10
  306. @KDTreeTest
  307. class _Test_random_ball_approx(_Test_random_ball):
  308. def setup_method(self):
  309. super().setup_method()
  310. self.eps = 0.1
  311. @KDTreeTest
  312. class _Test_random_ball_approx_periodic(_Test_random_ball):
  313. def setup_method(self):
  314. super().setup_method()
  315. self.eps = 0.1
  316. @KDTreeTest
  317. class _Test_random_ball_far(_Test_random_ball):
  318. def setup_method(self):
  319. super().setup_method()
  320. self.d = 2.
  321. @KDTreeTest
  322. class _Test_random_ball_far_periodic(_Test_random_ball_periodic):
  323. def setup_method(self):
  324. super().setup_method()
  325. self.d = 2.
  326. @KDTreeTest
  327. class _Test_random_ball_l1(_Test_random_ball):
  328. def setup_method(self):
  329. super().setup_method()
  330. self.p = 1
  331. @KDTreeTest
  332. class _Test_random_ball_linf(_Test_random_ball):
  333. def setup_method(self):
  334. super().setup_method()
  335. self.p = np.inf
  336. def test_random_ball_vectorized(kdtree_type):
  337. n = 20
  338. m = 5
  339. np.random.seed(1234)
  340. T = kdtree_type(np.random.randn(n, m))
  341. r = T.query_ball_point(np.random.randn(2, 3, m), 1)
  342. assert_equal(r.shape, (2, 3))
  343. assert_(isinstance(r[0, 0], list))
  344. def test_query_ball_point_multithreading(kdtree_type):
  345. np.random.seed(0)
  346. n = 5000
  347. k = 2
  348. points = np.random.randn(n, k)
  349. T = kdtree_type(points)
  350. l1 = T.query_ball_point(points, 0.003, workers=1)
  351. l2 = T.query_ball_point(points, 0.003, workers=64)
  352. l3 = T.query_ball_point(points, 0.003, workers=-1)
  353. for i in range(n):
  354. if l1[i] or l2[i]:
  355. assert_array_equal(l1[i], l2[i])
  356. for i in range(n):
  357. if l1[i] or l3[i]:
  358. assert_array_equal(l1[i], l3[i])
  359. class two_trees_consistency:
  360. def distance(self, a, b, p):
  361. return minkowski_distance(a, b, p)
  362. def test_all_in_ball(self):
  363. r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
  364. for i, l in enumerate(r):
  365. for j in l:
  366. assert_(self.distance(self.data1[i], self.data2[j], self.p) <= self.d*(1.+self.eps))
  367. def test_found_all(self):
  368. r = self.T1.query_ball_tree(self.T2, self.d, p=self.p, eps=self.eps)
  369. for i, l in enumerate(r):
  370. c = np.ones(self.T2.n, dtype=bool)
  371. c[l] = False
  372. assert_(np.all(self.distance(self.data2[c], self.data1[i], self.p) >= self.d/(1.+self.eps)))
  373. @KDTreeTest
  374. class _Test_two_random_trees(two_trees_consistency):
  375. def setup_method(self):
  376. n = 50
  377. m = 4
  378. np.random.seed(1234)
  379. self.data1 = np.random.randn(n, m)
  380. self.T1 = self.kdtree_type(self.data1, leafsize=2)
  381. self.data2 = np.random.randn(n, m)
  382. self.T2 = self.kdtree_type(self.data2, leafsize=2)
  383. self.p = 2.
  384. self.eps = 0
  385. self.d = 0.2
  386. @KDTreeTest
  387. class _Test_two_random_trees_periodic(two_trees_consistency):
  388. def distance(self, a, b, p):
  389. return distance_box(a, b, p, 1.0)
  390. def setup_method(self):
  391. n = 50
  392. m = 4
  393. np.random.seed(1234)
  394. self.data1 = np.random.uniform(size=(n, m))
  395. self.T1 = self.kdtree_type(self.data1, leafsize=2, boxsize=1.0)
  396. self.data2 = np.random.uniform(size=(n, m))
  397. self.T2 = self.kdtree_type(self.data2, leafsize=2, boxsize=1.0)
  398. self.p = 2.
  399. self.eps = 0
  400. self.d = 0.2
  401. @KDTreeTest
  402. class _Test_two_random_trees_far(_Test_two_random_trees):
  403. def setup_method(self):
  404. super().setup_method()
  405. self.d = 2
  406. @KDTreeTest
  407. class _Test_two_random_trees_far_periodic(_Test_two_random_trees_periodic):
  408. def setup_method(self):
  409. super().setup_method()
  410. self.d = 2
  411. @KDTreeTest
  412. class _Test_two_random_trees_linf(_Test_two_random_trees):
  413. def setup_method(self):
  414. super().setup_method()
  415. self.p = np.inf
  416. @KDTreeTest
  417. class _Test_two_random_trees_linf_periodic(_Test_two_random_trees_periodic):
  418. def setup_method(self):
  419. super().setup_method()
  420. self.p = np.inf
  421. class Test_rectangle:
  422. def setup_method(self):
  423. self.rect = Rectangle([0, 0], [1, 1])
  424. def test_min_inside(self):
  425. assert_almost_equal(self.rect.min_distance_point([0.5, 0.5]), 0)
  426. def test_min_one_side(self):
  427. assert_almost_equal(self.rect.min_distance_point([0.5, 1.5]), 0.5)
  428. def test_min_two_sides(self):
  429. assert_almost_equal(self.rect.min_distance_point([2, 2]), np.sqrt(2))
  430. def test_max_inside(self):
  431. assert_almost_equal(self.rect.max_distance_point([0.5, 0.5]), 1/np.sqrt(2))
  432. def test_max_one_side(self):
  433. assert_almost_equal(self.rect.max_distance_point([0.5, 1.5]), np.hypot(0.5, 1.5))
  434. def test_max_two_sides(self):
  435. assert_almost_equal(self.rect.max_distance_point([2, 2]), 2*np.sqrt(2))
  436. def test_split(self):
  437. less, greater = self.rect.split(0, 0.1)
  438. assert_array_equal(less.maxes, [0.1, 1])
  439. assert_array_equal(less.mins, [0, 0])
  440. assert_array_equal(greater.maxes, [1, 1])
  441. assert_array_equal(greater.mins, [0.1, 0])
  442. def test_distance_l2():
  443. assert_almost_equal(minkowski_distance([0, 0], [1, 1], 2), np.sqrt(2))
  444. def test_distance_l1():
  445. assert_almost_equal(minkowski_distance([0, 0], [1, 1], 1), 2)
  446. def test_distance_linf():
  447. assert_almost_equal(minkowski_distance([0, 0], [1, 1], np.inf), 1)
  448. def test_distance_vectorization():
  449. np.random.seed(1234)
  450. x = np.random.randn(10, 1, 3)
  451. y = np.random.randn(1, 7, 3)
  452. assert_equal(minkowski_distance(x, y).shape, (10, 7))
  453. class count_neighbors_consistency:
  454. def test_one_radius(self):
  455. r = 0.2
  456. assert_equal(self.T1.count_neighbors(self.T2, r),
  457. np.sum([len(l) for l in self.T1.query_ball_tree(self.T2, r)]))
  458. def test_large_radius(self):
  459. r = 1000
  460. assert_equal(self.T1.count_neighbors(self.T2, r),
  461. np.sum([len(l) for l in self.T1.query_ball_tree(self.T2, r)]))
  462. def test_multiple_radius(self):
  463. rs = np.exp(np.linspace(np.log(0.01), np.log(10), 3))
  464. results = self.T1.count_neighbors(self.T2, rs)
  465. assert_(np.all(np.diff(results) >= 0))
  466. for r, result in zip(rs, results):
  467. assert_equal(self.T1.count_neighbors(self.T2, r), result)
  468. @KDTreeTest
  469. class _Test_count_neighbors(count_neighbors_consistency):
  470. def setup_method(self):
  471. n = 50
  472. m = 2
  473. np.random.seed(1234)
  474. self.T1 = self.kdtree_type(np.random.randn(n, m), leafsize=2)
  475. self.T2 = self.kdtree_type(np.random.randn(n, m), leafsize=2)
  476. class sparse_distance_matrix_consistency:
  477. def distance(self, a, b, p):
  478. return minkowski_distance(a, b, p)
  479. def test_consistency_with_neighbors(self):
  480. M = self.T1.sparse_distance_matrix(self.T2, self.r)
  481. r = self.T1.query_ball_tree(self.T2, self.r)
  482. for i, l in enumerate(r):
  483. for j in l:
  484. assert_almost_equal(M[i, j],
  485. self.distance(self.T1.data[i], self.T2.data[j], self.p),
  486. decimal=14)
  487. for ((i, j), d) in M.items():
  488. assert_(j in r[i])
  489. def test_zero_distance(self):
  490. # raises an exception for bug 870 (FIXME: Does it?)
  491. self.T1.sparse_distance_matrix(self.T1, self.r)
  492. def test_consistency(self):
  493. # Test consistency with a distance_matrix
  494. M1 = self.T1.sparse_distance_matrix(self.T2, self.r)
  495. expected = distance_matrix(self.T1.data, self.T2.data)
  496. expected[expected > self.r] = 0
  497. assert_array_almost_equal(M1.toarray(), expected, decimal=14)
  498. def test_against_logic_error_regression(self):
  499. # regression test for gh-5077 logic error
  500. np.random.seed(0)
  501. too_many = np.array(np.random.randn(18, 2), dtype=int)
  502. tree = self.kdtree_type(
  503. too_many, balanced_tree=False, compact_nodes=False)
  504. d = tree.sparse_distance_matrix(tree, 3).toarray()
  505. assert_array_almost_equal(d, d.T, decimal=14)
  506. def test_ckdtree_return_types(self):
  507. # brute-force reference
  508. ref = np.zeros((self.n, self.n))
  509. for i in range(self.n):
  510. for j in range(self.n):
  511. v = self.data1[i, :] - self.data2[j, :]
  512. ref[i, j] = np.dot(v, v)
  513. ref = np.sqrt(ref)
  514. ref[ref > self.r] = 0.
  515. # test return type 'dict'
  516. dist = np.zeros((self.n, self.n))
  517. r = self.T1.sparse_distance_matrix(self.T2, self.r, output_type='dict')
  518. for i, j in r.keys():
  519. dist[i, j] = r[(i, j)]
  520. assert_array_almost_equal(ref, dist, decimal=14)
  521. # test return type 'ndarray'
  522. dist = np.zeros((self.n, self.n))
  523. r = self.T1.sparse_distance_matrix(self.T2, self.r,
  524. output_type='ndarray')
  525. for k in range(r.shape[0]):
  526. i = r['i'][k]
  527. j = r['j'][k]
  528. v = r['v'][k]
  529. dist[i, j] = v
  530. assert_array_almost_equal(ref, dist, decimal=14)
  531. # test return type 'dok_matrix'
  532. r = self.T1.sparse_distance_matrix(self.T2, self.r,
  533. output_type='dok_matrix')
  534. assert_array_almost_equal(ref, r.toarray(), decimal=14)
  535. # test return type 'coo_matrix'
  536. r = self.T1.sparse_distance_matrix(self.T2, self.r,
  537. output_type='coo_matrix')
  538. assert_array_almost_equal(ref, r.toarray(), decimal=14)
  539. @KDTreeTest
  540. class _Test_sparse_distance_matrix(sparse_distance_matrix_consistency):
  541. def setup_method(self):
  542. n = 50
  543. m = 4
  544. np.random.seed(1234)
  545. data1 = np.random.randn(n, m)
  546. data2 = np.random.randn(n, m)
  547. self.T1 = self.kdtree_type(data1, leafsize=2)
  548. self.T2 = self.kdtree_type(data2, leafsize=2)
  549. self.r = 0.5
  550. self.p = 2
  551. self.data1 = data1
  552. self.data2 = data2
  553. self.n = n
  554. self.m = m
  555. def test_distance_matrix():
  556. m = 10
  557. n = 11
  558. k = 4
  559. np.random.seed(1234)
  560. xs = np.random.randn(m, k)
  561. ys = np.random.randn(n, k)
  562. ds = distance_matrix(xs, ys)
  563. assert_equal(ds.shape, (m, n))
  564. for i in range(m):
  565. for j in range(n):
  566. assert_almost_equal(minkowski_distance(xs[i], ys[j]), ds[i, j])
  567. def test_distance_matrix_looping():
  568. m = 10
  569. n = 11
  570. k = 4
  571. np.random.seed(1234)
  572. xs = np.random.randn(m, k)
  573. ys = np.random.randn(n, k)
  574. ds = distance_matrix(xs, ys)
  575. dsl = distance_matrix(xs, ys, threshold=1)
  576. assert_equal(ds, dsl)
  577. def check_onetree_query(T, d):
  578. r = T.query_ball_tree(T, d)
  579. s = set()
  580. for i, l in enumerate(r):
  581. for j in l:
  582. if i < j:
  583. s.add((i, j))
  584. assert_(s == T.query_pairs(d))
  585. def test_onetree_query(kdtree_type):
  586. np.random.seed(0)
  587. n = 50
  588. k = 4
  589. points = np.random.randn(n, k)
  590. T = kdtree_type(points)
  591. check_onetree_query(T, 0.1)
  592. points = np.random.randn(3*n, k)
  593. points[:n] *= 0.001
  594. points[n:2*n] += 2
  595. T = kdtree_type(points)
  596. check_onetree_query(T, 0.1)
  597. check_onetree_query(T, 0.001)
  598. check_onetree_query(T, 0.00001)
  599. check_onetree_query(T, 1e-6)
  600. def test_query_pairs_single_node(kdtree_type):
  601. tree = kdtree_type([[0, 1]])
  602. assert_equal(tree.query_pairs(0.5), set())
  603. def test_kdtree_query_pairs(kdtree_type):
  604. np.random.seed(0)
  605. n = 50
  606. k = 2
  607. r = 0.1
  608. r2 = r**2
  609. points = np.random.randn(n, k)
  610. T = kdtree_type(points)
  611. # brute force reference
  612. brute = set()
  613. for i in range(n):
  614. for j in range(i+1, n):
  615. v = points[i, :] - points[j, :]
  616. if np.dot(v, v) <= r2:
  617. brute.add((i, j))
  618. l0 = sorted(brute)
  619. # test default return type
  620. s = T.query_pairs(r)
  621. l1 = sorted(s)
  622. assert_array_equal(l0, l1)
  623. # test return type 'set'
  624. s = T.query_pairs(r, output_type='set')
  625. l1 = sorted(s)
  626. assert_array_equal(l0, l1)
  627. # test return type 'ndarray'
  628. s = set()
  629. arr = T.query_pairs(r, output_type='ndarray')
  630. for i in range(arr.shape[0]):
  631. s.add((int(arr[i, 0]), int(arr[i, 1])))
  632. l2 = sorted(s)
  633. assert_array_equal(l0, l2)
  634. def test_query_pairs_eps(kdtree_type):
  635. spacing = np.sqrt(2)
  636. # irrational spacing to have potential rounding errors
  637. x_range = np.linspace(0, 3 * spacing, 4)
  638. y_range = np.linspace(0, 3 * spacing, 4)
  639. xy_array = [(xi, yi) for xi in x_range for yi in y_range]
  640. tree = kdtree_type(xy_array)
  641. pairs_eps = tree.query_pairs(r=spacing, eps=.1)
  642. # result: 24 with eps, 16 without due to rounding
  643. pairs = tree.query_pairs(r=spacing * 1.01)
  644. # result: 24
  645. assert_equal(pairs, pairs_eps)
  646. def test_ball_point_ints(kdtree_type):
  647. # Regression test for #1373.
  648. x, y = np.mgrid[0:4, 0:4]
  649. points = list(zip(x.ravel(), y.ravel()))
  650. tree = kdtree_type(points)
  651. assert_equal(sorted([4, 8, 9, 12]),
  652. sorted(tree.query_ball_point((2, 0), 1)))
  653. points = np.asarray(points, dtype=float)
  654. tree = kdtree_type(points)
  655. assert_equal(sorted([4, 8, 9, 12]),
  656. sorted(tree.query_ball_point((2, 0), 1)))
  657. def test_kdtree_comparisons():
  658. # Regression test: node comparisons were done wrong in 0.12 w/Py3.
  659. nodes = [KDTree.node() for _ in range(3)]
  660. assert_equal(sorted(nodes), sorted(nodes[::-1]))
  661. def test_kdtree_build_modes(kdtree_type):
  662. # check if different build modes for KDTree give similar query results
  663. np.random.seed(0)
  664. n = 5000
  665. k = 4
  666. points = np.random.randn(n, k)
  667. T1 = kdtree_type(points).query(points, k=5)[-1]
  668. T2 = kdtree_type(points, compact_nodes=False).query(points, k=5)[-1]
  669. T3 = kdtree_type(points, balanced_tree=False).query(points, k=5)[-1]
  670. T4 = kdtree_type(points, compact_nodes=False,
  671. balanced_tree=False).query(points, k=5)[-1]
  672. assert_array_equal(T1, T2)
  673. assert_array_equal(T1, T3)
  674. assert_array_equal(T1, T4)
  675. def test_kdtree_pickle(kdtree_type):
  676. # test if it is possible to pickle a KDTree
  677. import pickle
  678. np.random.seed(0)
  679. n = 50
  680. k = 4
  681. points = np.random.randn(n, k)
  682. T1 = kdtree_type(points)
  683. tmp = pickle.dumps(T1)
  684. T2 = pickle.loads(tmp)
  685. T1 = T1.query(points, k=5)[-1]
  686. T2 = T2.query(points, k=5)[-1]
  687. assert_array_equal(T1, T2)
  688. def test_kdtree_pickle_boxsize(kdtree_type):
  689. # test if it is possible to pickle a periodic KDTree
  690. import pickle
  691. np.random.seed(0)
  692. n = 50
  693. k = 4
  694. points = np.random.uniform(size=(n, k))
  695. T1 = kdtree_type(points, boxsize=1.0)
  696. tmp = pickle.dumps(T1)
  697. T2 = pickle.loads(tmp)
  698. T1 = T1.query(points, k=5)[-1]
  699. T2 = T2.query(points, k=5)[-1]
  700. assert_array_equal(T1, T2)
  701. def test_kdtree_copy_data(kdtree_type):
  702. # check if copy_data=True makes the kd-tree
  703. # impervious to data corruption by modification of
  704. # the data arrray
  705. np.random.seed(0)
  706. n = 5000
  707. k = 4
  708. points = np.random.randn(n, k)
  709. T = kdtree_type(points, copy_data=True)
  710. q = points.copy()
  711. T1 = T.query(q, k=5)[-1]
  712. points[...] = np.random.randn(n, k)
  713. T2 = T.query(q, k=5)[-1]
  714. assert_array_equal(T1, T2)
  715. def test_ckdtree_parallel(kdtree_type, monkeypatch):
  716. # check if parallel=True also generates correct query results
  717. np.random.seed(0)
  718. n = 5000
  719. k = 4
  720. points = np.random.randn(n, k)
  721. T = kdtree_type(points)
  722. T1 = T.query(points, k=5, workers=64)[-1]
  723. T2 = T.query(points, k=5, workers=-1)[-1]
  724. T3 = T.query(points, k=5)[-1]
  725. assert_array_equal(T1, T2)
  726. assert_array_equal(T1, T3)
  727. monkeypatch.setattr(os, 'cpu_count', lambda: None)
  728. with pytest.raises(NotImplementedError, match="Cannot determine the"):
  729. T.query(points, 1, workers=-1)
  730. def test_ckdtree_view():
  731. # Check that the nodes can be correctly viewed from Python.
  732. # This test also sanity checks each node in the cKDTree, and
  733. # thus verifies the internal structure of the kd-tree.
  734. np.random.seed(0)
  735. n = 100
  736. k = 4
  737. points = np.random.randn(n, k)
  738. kdtree = cKDTree(points)
  739. # walk the whole kd-tree and sanity check each node
  740. def recurse_tree(n):
  741. assert_(isinstance(n, cKDTreeNode))
  742. if n.split_dim == -1:
  743. assert_(n.lesser is None)
  744. assert_(n.greater is None)
  745. assert_(n.indices.shape[0] <= kdtree.leafsize)
  746. else:
  747. recurse_tree(n.lesser)
  748. recurse_tree(n.greater)
  749. x = n.lesser.data_points[:, n.split_dim]
  750. y = n.greater.data_points[:, n.split_dim]
  751. assert_(x.max() < y.min())
  752. recurse_tree(kdtree.tree)
  753. # check that indices are correctly retrieved
  754. n = kdtree.tree
  755. assert_array_equal(np.sort(n.indices), range(100))
  756. # check that data_points are correctly retrieved
  757. assert_array_equal(kdtree.data[n.indices, :], n.data_points)
  758. # KDTree is specialized to type double points, so no need to make
  759. # a unit test corresponding to test_ball_point_ints()
  760. def test_kdtree_list_k(kdtree_type):
  761. # check kdtree periodic boundary
  762. n = 200
  763. m = 2
  764. klist = [1, 2, 3]
  765. kint = 3
  766. np.random.seed(1234)
  767. data = np.random.uniform(size=(n, m))
  768. kdtree = kdtree_type(data, leafsize=1)
  769. # check agreement between arange(1, k+1) and k
  770. dd, ii = kdtree.query(data, klist)
  771. dd1, ii1 = kdtree.query(data, kint)
  772. assert_equal(dd, dd1)
  773. assert_equal(ii, ii1)
  774. # now check skipping one element
  775. klist = np.array([1, 3])
  776. kint = 3
  777. dd, ii = kdtree.query(data, kint)
  778. dd1, ii1 = kdtree.query(data, klist)
  779. assert_equal(dd1, dd[..., klist - 1])
  780. assert_equal(ii1, ii[..., klist - 1])
  781. # check k == 1 special case
  782. # and k == [1] non-special case
  783. dd, ii = kdtree.query(data, 1)
  784. dd1, ii1 = kdtree.query(data, [1])
  785. assert_equal(len(dd.shape), 1)
  786. assert_equal(len(dd1.shape), 2)
  787. assert_equal(dd, np.ravel(dd1))
  788. assert_equal(ii, np.ravel(ii1))
  789. def test_kdtree_box(kdtree_type):
  790. # check ckdtree periodic boundary
  791. n = 2000
  792. m = 3
  793. k = 3
  794. np.random.seed(1234)
  795. data = np.random.uniform(size=(n, m))
  796. kdtree = kdtree_type(data, leafsize=1, boxsize=1.0)
  797. # use the standard python KDTree for the simulated periodic box
  798. kdtree2 = kdtree_type(data, leafsize=1)
  799. for p in [1, 2, 3.0, np.inf]:
  800. dd, ii = kdtree.query(data, k, p=p)
  801. dd1, ii1 = kdtree.query(data + 1.0, k, p=p)
  802. assert_almost_equal(dd, dd1)
  803. assert_equal(ii, ii1)
  804. dd1, ii1 = kdtree.query(data - 1.0, k, p=p)
  805. assert_almost_equal(dd, dd1)
  806. assert_equal(ii, ii1)
  807. dd2, ii2 = simulate_periodic_box(kdtree2, data, k, boxsize=1.0, p=p)
  808. assert_almost_equal(dd, dd2)
  809. assert_equal(ii, ii2)
  810. def test_kdtree_box_0boxsize(kdtree_type):
  811. # check ckdtree periodic boundary that mimics non-periodic
  812. n = 2000
  813. m = 2
  814. k = 3
  815. np.random.seed(1234)
  816. data = np.random.uniform(size=(n, m))
  817. kdtree = kdtree_type(data, leafsize=1, boxsize=0.0)
  818. # use the standard python KDTree for the simulated periodic box
  819. kdtree2 = kdtree_type(data, leafsize=1)
  820. for p in [1, 2, np.inf]:
  821. dd, ii = kdtree.query(data, k, p=p)
  822. dd1, ii1 = kdtree2.query(data, k, p=p)
  823. assert_almost_equal(dd, dd1)
  824. assert_equal(ii, ii1)
  825. def test_kdtree_box_upper_bounds(kdtree_type):
  826. data = np.linspace(0, 2, 10).reshape(-1, 2)
  827. data[:, 1] += 10
  828. with pytest.raises(ValueError):
  829. kdtree_type(data, leafsize=1, boxsize=1.0)
  830. with pytest.raises(ValueError):
  831. kdtree_type(data, leafsize=1, boxsize=(0.0, 2.0))
  832. # skip a dimension.
  833. kdtree_type(data, leafsize=1, boxsize=(2.0, 0.0))
  834. def test_kdtree_box_lower_bounds(kdtree_type):
  835. data = np.linspace(-1, 1, 10)
  836. assert_raises(ValueError, kdtree_type, data, leafsize=1, boxsize=1.0)
  837. def simulate_periodic_box(kdtree, data, k, boxsize, p):
  838. dd = []
  839. ii = []
  840. x = np.arange(3 ** data.shape[1])
  841. nn = np.array(np.unravel_index(x, [3] * data.shape[1])).T
  842. nn = nn - 1.0
  843. for n in nn:
  844. image = data + n * 1.0 * boxsize
  845. dd2, ii2 = kdtree.query(image, k, p=p)
  846. dd2 = dd2.reshape(-1, k)
  847. ii2 = ii2.reshape(-1, k)
  848. dd.append(dd2)
  849. ii.append(ii2)
  850. dd = np.concatenate(dd, axis=-1)
  851. ii = np.concatenate(ii, axis=-1)
  852. result = np.empty([len(data), len(nn) * k], dtype=[
  853. ('ii', 'i8'),
  854. ('dd', 'f8')])
  855. result['ii'][:] = ii
  856. result['dd'][:] = dd
  857. result.sort(order='dd')
  858. return result['dd'][:, :k], result['ii'][:, :k]
  859. @pytest.mark.skipif(python_implementation() == 'PyPy',
  860. reason="Fails on PyPy CI runs. See #9507")
  861. def test_ckdtree_memuse():
  862. # unit test adaptation of gh-5630
  863. # NOTE: this will fail when run via valgrind,
  864. # because rss is no longer a reliable memory usage indicator.
  865. try:
  866. import resource
  867. except ImportError:
  868. # resource is not available on Windows
  869. return
  870. # Make some data
  871. dx, dy = 0.05, 0.05
  872. y, x = np.mgrid[slice(1, 5 + dy, dy),
  873. slice(1, 5 + dx, dx)]
  874. z = np.sin(x)**10 + np.cos(10 + y*x) * np.cos(x)
  875. z_copy = np.empty_like(z)
  876. z_copy[:] = z
  877. # Place FILLVAL in z_copy at random number of random locations
  878. FILLVAL = 99.
  879. mask = np.random.randint(0, z.size, np.random.randint(50) + 5)
  880. z_copy.flat[mask] = FILLVAL
  881. igood = np.vstack(np.nonzero(x != FILLVAL)).T
  882. ibad = np.vstack(np.nonzero(x == FILLVAL)).T
  883. mem_use = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  884. # burn-in
  885. for i in range(10):
  886. tree = cKDTree(igood)
  887. # count memleaks while constructing and querying cKDTree
  888. num_leaks = 0
  889. for i in range(100):
  890. mem_use = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  891. tree = cKDTree(igood)
  892. dist, iquery = tree.query(ibad, k=4, p=2)
  893. new_mem_use = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
  894. if new_mem_use > mem_use:
  895. num_leaks += 1
  896. # ideally zero leaks, but errors might accidentally happen
  897. # outside cKDTree
  898. assert_(num_leaks < 10)
  899. def test_kdtree_weights(kdtree_type):
  900. data = np.linspace(0, 1, 4).reshape(-1, 1)
  901. tree1 = kdtree_type(data, leafsize=1)
  902. weights = np.ones(len(data), dtype='f4')
  903. nw = tree1._build_weights(weights)
  904. assert_array_equal(nw, [4, 2, 1, 1, 2, 1, 1])
  905. assert_raises(ValueError, tree1._build_weights, weights[:-1])
  906. for i in range(10):
  907. # since weights are uniform, these shall agree:
  908. c1 = tree1.count_neighbors(tree1, np.linspace(0, 10, i))
  909. c2 = tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  910. weights=(weights, weights))
  911. c3 = tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  912. weights=(weights, None))
  913. c4 = tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  914. weights=(None, weights))
  915. tree1.count_neighbors(tree1, np.linspace(0, 10, i),
  916. weights=weights)
  917. assert_array_equal(c1, c2)
  918. assert_array_equal(c1, c3)
  919. assert_array_equal(c1, c4)
  920. for i in range(len(data)):
  921. # this tests removal of one data point by setting weight to 0
  922. w1 = weights.copy()
  923. w1[i] = 0
  924. data2 = data[w1 != 0]
  925. tree2 = kdtree_type(data2)
  926. c1 = tree1.count_neighbors(tree1, np.linspace(0, 10, 100),
  927. weights=(w1, w1))
  928. # "c2 is correct"
  929. c2 = tree2.count_neighbors(tree2, np.linspace(0, 10, 100))
  930. assert_array_equal(c1, c2)
  931. #this asserts for two different trees, singular weights
  932. # crashes
  933. assert_raises(ValueError, tree1.count_neighbors,
  934. tree2, np.linspace(0, 10, 100), weights=w1)
  935. def test_kdtree_count_neighbous_multiple_r(kdtree_type):
  936. n = 2000
  937. m = 2
  938. np.random.seed(1234)
  939. data = np.random.normal(size=(n, m))
  940. kdtree = kdtree_type(data, leafsize=1)
  941. r0 = [0, 0.01, 0.01, 0.02, 0.05]
  942. i0 = np.arange(len(r0))
  943. n0 = kdtree.count_neighbors(kdtree, r0)
  944. nnc = kdtree.count_neighbors(kdtree, r0, cumulative=False)
  945. assert_equal(n0, nnc.cumsum())
  946. for i, r in zip(itertools.permutations(i0),
  947. itertools.permutations(r0)):
  948. # permute n0 by i and it shall agree
  949. n = kdtree.count_neighbors(kdtree, r)
  950. assert_array_equal(n, n0[list(i)])
  951. def test_len0_arrays(kdtree_type):
  952. # make sure len-0 arrays are handled correctly
  953. # in range queries (gh-5639)
  954. np.random.seed(1234)
  955. X = np.random.rand(10, 2)
  956. Y = np.random.rand(10, 2)
  957. tree = kdtree_type(X)
  958. # query_ball_point (single)
  959. d, i = tree.query([.5, .5], k=1)
  960. z = tree.query_ball_point([.5, .5], 0.1*d)
  961. assert_array_equal(z, [])
  962. # query_ball_point (multiple)
  963. d, i = tree.query(Y, k=1)
  964. mind = d.min()
  965. z = tree.query_ball_point(Y, 0.1*mind)
  966. y = np.empty(shape=(10, ), dtype=object)
  967. y.fill([])
  968. assert_array_equal(y, z)
  969. # query_ball_tree
  970. other = kdtree_type(Y)
  971. y = tree.query_ball_tree(other, 0.1*mind)
  972. assert_array_equal(10*[[]], y)
  973. # count_neighbors
  974. y = tree.count_neighbors(other, 0.1*mind)
  975. assert_(y == 0)
  976. # sparse_distance_matrix
  977. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='dok_matrix')
  978. assert_array_equal(y == np.zeros((10, 10)), True)
  979. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='coo_matrix')
  980. assert_array_equal(y == np.zeros((10, 10)), True)
  981. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='dict')
  982. assert_equal(y, {})
  983. y = tree.sparse_distance_matrix(other, 0.1*mind, output_type='ndarray')
  984. _dtype = [('i', np.intp), ('j', np.intp), ('v', np.float64)]
  985. res_dtype = np.dtype(_dtype, align=True)
  986. z = np.empty(shape=(0, ), dtype=res_dtype)
  987. assert_array_equal(y, z)
  988. # query_pairs
  989. d, i = tree.query(X, k=2)
  990. mind = d[:, -1].min()
  991. y = tree.query_pairs(0.1*mind, output_type='set')
  992. assert_equal(y, set())
  993. y = tree.query_pairs(0.1*mind, output_type='ndarray')
  994. z = np.empty(shape=(0, 2), dtype=np.intp)
  995. assert_array_equal(y, z)
  996. def test_kdtree_duplicated_inputs(kdtree_type):
  997. # check kdtree with duplicated inputs
  998. n = 1024
  999. for m in range(1, 8):
  1000. data = np.ones((n, m))
  1001. data[n//2:] = 2
  1002. for balanced, compact in itertools.product((False, True), repeat=2):
  1003. kdtree = kdtree_type(data, balanced_tree=balanced,
  1004. compact_nodes=compact, leafsize=1)
  1005. assert kdtree.size == 3
  1006. tree = (kdtree.tree if kdtree_type is cKDTree else
  1007. kdtree.tree._node)
  1008. assert_equal(
  1009. np.sort(tree.lesser.indices),
  1010. np.arange(0, n // 2))
  1011. assert_equal(
  1012. np.sort(tree.greater.indices),
  1013. np.arange(n // 2, n))
  1014. def test_kdtree_noncumulative_nondecreasing(kdtree_type):
  1015. # check kdtree with duplicated inputs
  1016. # it shall not divide more than 3 nodes.
  1017. # root left (1), and right (2)
  1018. kdtree = kdtree_type([[0]], leafsize=1)
  1019. assert_raises(ValueError, kdtree.count_neighbors,
  1020. kdtree, [0.1, 0], cumulative=False)
  1021. def test_short_knn(kdtree_type):
  1022. # The test case is based on github: #6425 by @SteveDoyle2
  1023. xyz = np.array([
  1024. [0., 0., 0.],
  1025. [1.01, 0., 0.],
  1026. [0., 1., 0.],
  1027. [0., 1.01, 0.],
  1028. [1., 0., 0.],
  1029. [1., 1., 0.]],
  1030. dtype='float64')
  1031. ckdt = kdtree_type(xyz)
  1032. deq, ieq = ckdt.query(xyz, k=4, distance_upper_bound=0.2)
  1033. assert_array_almost_equal(deq,
  1034. [[0., np.inf, np.inf, np.inf],
  1035. [0., 0.01, np.inf, np.inf],
  1036. [0., 0.01, np.inf, np.inf],
  1037. [0., 0.01, np.inf, np.inf],
  1038. [0., 0.01, np.inf, np.inf],
  1039. [0., np.inf, np.inf, np.inf]])
  1040. def test_query_ball_point_vector_r(kdtree_type):
  1041. np.random.seed(1234)
  1042. data = np.random.normal(size=(100, 3))
  1043. query = np.random.normal(size=(100, 3))
  1044. tree = kdtree_type(data)
  1045. d = np.random.uniform(0, 0.3, size=len(query))
  1046. rvector = tree.query_ball_point(query, d)
  1047. rscalar = [tree.query_ball_point(qi, di) for qi, di in zip(query, d)]
  1048. for a, b in zip(rvector, rscalar):
  1049. assert_array_equal(sorted(a), sorted(b))
  1050. def test_query_ball_point_length(kdtree_type):
  1051. np.random.seed(1234)
  1052. data = np.random.normal(size=(100, 3))
  1053. query = np.random.normal(size=(100, 3))
  1054. tree = kdtree_type(data)
  1055. d = 0.3
  1056. length = tree.query_ball_point(query, d, return_length=True)
  1057. length2 = [len(ind) for ind in tree.query_ball_point(query, d, return_length=False)]
  1058. length3 = [len(tree.query_ball_point(qi, d)) for qi in query]
  1059. length4 = [tree.query_ball_point(qi, d, return_length=True) for qi in query]
  1060. assert_array_equal(length, length2)
  1061. assert_array_equal(length, length3)
  1062. assert_array_equal(length, length4)
  1063. def test_discontiguous(kdtree_type):
  1064. np.random.seed(1234)
  1065. data = np.random.normal(size=(100, 3))
  1066. d_contiguous = np.arange(100) * 0.04
  1067. d_discontiguous = np.ascontiguousarray(
  1068. np.arange(100)[::-1] * 0.04)[::-1]
  1069. query_contiguous = np.random.normal(size=(100, 3))
  1070. query_discontiguous = np.ascontiguousarray(query_contiguous.T).T
  1071. assert query_discontiguous.strides[-1] != query_contiguous.strides[-1]
  1072. assert d_discontiguous.strides[-1] != d_contiguous.strides[-1]
  1073. tree = kdtree_type(data)
  1074. length1 = tree.query_ball_point(query_contiguous,
  1075. d_contiguous, return_length=True)
  1076. length2 = tree.query_ball_point(query_discontiguous,
  1077. d_discontiguous, return_length=True)
  1078. assert_array_equal(length1, length2)
  1079. d1, i1 = tree.query(query_contiguous, 1)
  1080. d2, i2 = tree.query(query_discontiguous, 1)
  1081. assert_array_equal(d1, d2)
  1082. assert_array_equal(i1, i2)
  1083. @pytest.mark.parametrize("balanced_tree, compact_nodes",
  1084. [(True, False),
  1085. (True, True),
  1086. (False, False),
  1087. (False, True)])
  1088. def test_kdtree_empty_input(kdtree_type, balanced_tree, compact_nodes):
  1089. # https://github.com/scipy/scipy/issues/5040
  1090. np.random.seed(1234)
  1091. empty_v3 = np.empty(shape=(0, 3))
  1092. query_v3 = np.ones(shape=(1, 3))
  1093. query_v2 = np.ones(shape=(2, 3))
  1094. tree = kdtree_type(empty_v3, balanced_tree=balanced_tree,
  1095. compact_nodes=compact_nodes)
  1096. length = tree.query_ball_point(query_v3, 0.3, return_length=True)
  1097. assert length == 0
  1098. dd, ii = tree.query(query_v2, 2)
  1099. assert ii.shape == (2, 2)
  1100. assert dd.shape == (2, 2)
  1101. assert np.isinf(dd).all()
  1102. N = tree.count_neighbors(tree, [0, 1])
  1103. assert_array_equal(N, [0, 0])
  1104. M = tree.sparse_distance_matrix(tree, 0.3)
  1105. assert M.shape == (0, 0)
  1106. @KDTreeTest
  1107. class _Test_sorted_query_ball_point:
  1108. def setup_method(self):
  1109. np.random.seed(1234)
  1110. self.x = np.random.randn(100, 1)
  1111. self.ckdt = self.kdtree_type(self.x)
  1112. def test_return_sorted_True(self):
  1113. idxs_list = self.ckdt.query_ball_point(self.x, 1., return_sorted=True)
  1114. for idxs in idxs_list:
  1115. assert_array_equal(idxs, sorted(idxs))
  1116. for xi in self.x:
  1117. idxs = self.ckdt.query_ball_point(xi, 1., return_sorted=True)
  1118. assert_array_equal(idxs, sorted(idxs))
  1119. def test_return_sorted_None(self):
  1120. """Previous behavior was to sort the returned indices if there were
  1121. multiple points per query but not sort them if there was a single point
  1122. per query."""
  1123. idxs_list = self.ckdt.query_ball_point(self.x, 1.)
  1124. for idxs in idxs_list:
  1125. assert_array_equal(idxs, sorted(idxs))
  1126. idxs_list_single = [self.ckdt.query_ball_point(xi, 1.) for xi in self.x]
  1127. idxs_list_False = self.ckdt.query_ball_point(self.x, 1., return_sorted=False)
  1128. for idxs0, idxs1 in zip(idxs_list_False, idxs_list_single):
  1129. assert_array_equal(idxs0, idxs1)
  1130. def test_kdtree_complex_data():
  1131. # Test that KDTree rejects complex input points (gh-9108)
  1132. points = np.random.rand(10, 2).view(complex)
  1133. with pytest.raises(TypeError, match="complex data"):
  1134. t = KDTree(points)
  1135. t = KDTree(points.real)
  1136. with pytest.raises(TypeError, match="complex data"):
  1137. t.query(points)
  1138. with pytest.raises(TypeError, match="complex data"):
  1139. t.query_ball_point(points, r=1)
  1140. def test_kdtree_tree_access():
  1141. # Test KDTree.tree can be used to traverse the KDTree
  1142. np.random.seed(1234)
  1143. points = np.random.rand(100, 4)
  1144. t = KDTree(points)
  1145. root = t.tree
  1146. assert isinstance(root, KDTree.innernode)
  1147. assert root.children == points.shape[0]
  1148. # Visit the tree and assert some basic properties for each node
  1149. nodes = [root]
  1150. while nodes:
  1151. n = nodes.pop(-1)
  1152. if isinstance(n, KDTree.leafnode):
  1153. assert isinstance(n.children, int)
  1154. assert n.children == len(n.idx)
  1155. assert_array_equal(points[n.idx], n._node.data_points)
  1156. else:
  1157. assert isinstance(n, KDTree.innernode)
  1158. assert isinstance(n.split_dim, int)
  1159. assert 0 <= n.split_dim < t.m
  1160. assert isinstance(n.split, float)
  1161. assert isinstance(n.children, int)
  1162. assert n.children == n.less.children + n.greater.children
  1163. nodes.append(n.greater)
  1164. nodes.append(n.less)
  1165. def test_kdtree_attributes():
  1166. # Test KDTree's attributes are available
  1167. np.random.seed(1234)
  1168. points = np.random.rand(100, 4)
  1169. t = KDTree(points)
  1170. assert isinstance(t.m, int)
  1171. assert t.n == points.shape[0]
  1172. assert isinstance(t.n, int)
  1173. assert t.m == points.shape[1]
  1174. assert isinstance(t.leafsize, int)
  1175. assert t.leafsize == 10
  1176. assert_array_equal(t.maxes, np.amax(points, axis=0))
  1177. assert_array_equal(t.mins, np.amin(points, axis=0))
  1178. assert t.data is points
  1179. @pytest.mark.parametrize("kdtree_class", [KDTree, cKDTree])
  1180. def test_kdtree_count_neighbors_weighted(kdtree_class):
  1181. np.random.seed(1234)
  1182. r = np.arange(0.05, 1, 0.05)
  1183. A = np.random.random(21).reshape((7,3))
  1184. B = np.random.random(45).reshape((15,3))
  1185. wA = np.random.random(7)
  1186. wB = np.random.random(15)
  1187. kdA = kdtree_class(A)
  1188. kdB = kdtree_class(B)
  1189. nAB = kdA.count_neighbors(kdB, r, cumulative=False, weights=(wA,wB))
  1190. # Compare against brute-force
  1191. weights = wA[None, :] * wB[:, None]
  1192. dist = np.linalg.norm(A[None, :, :] - B[:, None, :], axis=-1)
  1193. expect = [np.sum(weights[(prev_radius < dist) & (dist <= radius)])
  1194. for prev_radius, radius in zip(itertools.chain([0], r[:-1]), r)]
  1195. assert_allclose(nAB, expect)
  1196. def test_kdtree_nan():
  1197. vals = [1, 5, -10, 7, -4, -16, -6, 6, 3, -11]
  1198. n = len(vals)
  1199. data = np.concatenate([vals, np.full(n, np.nan)])[:, None]
  1200. query_with_nans = KDTree(data).query_pairs(2)
  1201. query_without_nans = KDTree(data[:n]).query_pairs(2)
  1202. assert query_with_nans == query_without_nans