result_set.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. #ifndef OPENCV_FLANN_RESULTSET_H
  31. #define OPENCV_FLANN_RESULTSET_H
  32. //! @cond IGNORED
  33. #include <algorithm>
  34. #include <cstring>
  35. #include <iostream>
  36. #include <limits>
  37. #include <set>
  38. #include <vector>
  39. #include "opencv2/core/base.hpp"
  40. #include "opencv2/core/cvdef.h"
  41. namespace cvflann
  42. {
  43. /* This record represents a branch point when finding neighbors in
  44. the tree. It contains a record of the minimum distance to the query
  45. point, as well as the node at which the search resumes.
  46. */
  47. template <typename T, typename DistanceType>
  48. struct BranchStruct
  49. {
  50. T node; /* Tree node at which search resumes */
  51. DistanceType mindist; /* Minimum distance to query for all nodes below. */
  52. BranchStruct() {}
  53. BranchStruct(const T& aNode, DistanceType dist) : node(aNode), mindist(dist) {}
  54. bool operator<(const BranchStruct<T, DistanceType>& rhs) const
  55. {
  56. return mindist<rhs.mindist;
  57. }
  58. };
  59. template <typename DistanceType>
  60. class ResultSet
  61. {
  62. public:
  63. virtual ~ResultSet() {}
  64. virtual bool full() const = 0;
  65. virtual void addPoint(DistanceType dist, int index) = 0;
  66. virtual DistanceType worstDist() const = 0;
  67. };
  68. /**
  69. * KNNSimpleResultSet does not ensure that the element it holds are unique.
  70. * Is used in those cases where the nearest neighbour algorithm used does not
  71. * attempt to insert the same element multiple times.
  72. */
  73. template <typename DistanceType>
  74. class KNNSimpleResultSet : public ResultSet<DistanceType>
  75. {
  76. int* indices;
  77. DistanceType* dists;
  78. int capacity;
  79. int count;
  80. DistanceType worst_distance_;
  81. public:
  82. KNNSimpleResultSet(int capacity_) : capacity(capacity_), count(0)
  83. {
  84. }
  85. void init(int* indices_, DistanceType* dists_)
  86. {
  87. indices = indices_;
  88. dists = dists_;
  89. count = 0;
  90. worst_distance_ = (std::numeric_limits<DistanceType>::max)();
  91. dists[capacity-1] = worst_distance_;
  92. }
  93. size_t size() const
  94. {
  95. return count;
  96. }
  97. bool full() const CV_OVERRIDE
  98. {
  99. return count == capacity;
  100. }
  101. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  102. {
  103. if (dist >= worst_distance_) return;
  104. int i;
  105. for (i=count; i>0; --i) {
  106. #ifdef FLANN_FIRST_MATCH
  107. if ( (dists[i-1]>dist) || ((dist==dists[i-1])&&(indices[i-1]>index)) )
  108. #else
  109. if (dists[i-1]>dist)
  110. #endif
  111. {
  112. if (i<capacity) {
  113. dists[i] = dists[i-1];
  114. indices[i] = indices[i-1];
  115. }
  116. }
  117. else break;
  118. }
  119. if (count < capacity) ++count;
  120. dists[i] = dist;
  121. indices[i] = index;
  122. worst_distance_ = dists[capacity-1];
  123. }
  124. DistanceType worstDist() const CV_OVERRIDE
  125. {
  126. return worst_distance_;
  127. }
  128. };
  129. /**
  130. * K-Nearest neighbour result set. Ensures that the elements inserted are unique
  131. */
  132. template <typename DistanceType>
  133. class KNNResultSet : public ResultSet<DistanceType>
  134. {
  135. int* indices;
  136. DistanceType* dists;
  137. int capacity;
  138. int count;
  139. DistanceType worst_distance_;
  140. public:
  141. KNNResultSet(int capacity_)
  142. : indices(NULL), dists(NULL), capacity(capacity_), count(0), worst_distance_(0)
  143. {
  144. }
  145. void init(int* indices_, DistanceType* dists_)
  146. {
  147. indices = indices_;
  148. dists = dists_;
  149. count = 0;
  150. worst_distance_ = (std::numeric_limits<DistanceType>::max)();
  151. dists[capacity-1] = worst_distance_;
  152. }
  153. size_t size() const
  154. {
  155. return count;
  156. }
  157. bool full() const CV_OVERRIDE
  158. {
  159. return count == capacity;
  160. }
  161. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  162. {
  163. CV_DbgAssert(indices);
  164. CV_DbgAssert(dists);
  165. if (dist >= worst_distance_) return;
  166. int i;
  167. for (i = count; i > 0; --i) {
  168. #ifdef FLANN_FIRST_MATCH
  169. if ( (dists[i-1]<=dist) && ((dist!=dists[i-1])||(indices[i-1]<=index)) )
  170. #else
  171. if (dists[i-1]<=dist)
  172. #endif
  173. {
  174. // Check for duplicate indices
  175. for (int j = i; dists[j] == dist && j--;) {
  176. if (indices[j] == index) {
  177. return;
  178. }
  179. }
  180. break;
  181. }
  182. }
  183. if (count < capacity) ++count;
  184. for (int j = count-1; j > i; --j) {
  185. dists[j] = dists[j-1];
  186. indices[j] = indices[j-1];
  187. }
  188. dists[i] = dist;
  189. indices[i] = index;
  190. worst_distance_ = dists[capacity-1];
  191. }
  192. DistanceType worstDist() const CV_OVERRIDE
  193. {
  194. return worst_distance_;
  195. }
  196. };
  197. /**
  198. * A result-set class used when performing a radius based search.
  199. */
  200. template <typename DistanceType>
  201. class RadiusResultSet : public ResultSet<DistanceType>
  202. {
  203. DistanceType radius;
  204. int* indices;
  205. DistanceType* dists;
  206. size_t capacity;
  207. size_t count;
  208. public:
  209. RadiusResultSet(DistanceType radius_, int* indices_, DistanceType* dists_, int capacity_) :
  210. radius(radius_), indices(indices_), dists(dists_), capacity(capacity_)
  211. {
  212. init();
  213. }
  214. ~RadiusResultSet()
  215. {
  216. }
  217. void init()
  218. {
  219. count = 0;
  220. }
  221. size_t size() const
  222. {
  223. return count;
  224. }
  225. bool full() const
  226. {
  227. return true;
  228. }
  229. void addPoint(DistanceType dist, int index)
  230. {
  231. if (dist<radius) {
  232. if ((capacity>0)&&(count < capacity)) {
  233. dists[count] = dist;
  234. indices[count] = index;
  235. }
  236. count++;
  237. }
  238. }
  239. DistanceType worstDist() const
  240. {
  241. return radius;
  242. }
  243. };
  244. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  245. /** Class that holds the k NN neighbors
  246. * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
  247. */
  248. template<typename DistanceType>
  249. class UniqueResultSet : public ResultSet<DistanceType>
  250. {
  251. public:
  252. struct DistIndex
  253. {
  254. DistIndex(DistanceType dist, unsigned int index) :
  255. dist_(dist), index_(index)
  256. {
  257. }
  258. bool operator<(const DistIndex dist_index) const
  259. {
  260. return (dist_ < dist_index.dist_) || ((dist_ == dist_index.dist_) && index_ < dist_index.index_);
  261. }
  262. DistanceType dist_;
  263. unsigned int index_;
  264. };
  265. /** Default constructor */
  266. UniqueResultSet() :
  267. is_full_(false), worst_distance_(std::numeric_limits<DistanceType>::max())
  268. {
  269. }
  270. /** Check the status of the set
  271. * @return true if we have k NN
  272. */
  273. inline bool full() const CV_OVERRIDE
  274. {
  275. return is_full_;
  276. }
  277. /** Remove all elements in the set
  278. */
  279. virtual void clear() = 0;
  280. /** Copy the set to two C arrays
  281. * @param indices pointer to a C array of indices
  282. * @param dist pointer to a C array of distances
  283. * @param n_neighbors the number of neighbors to copy
  284. */
  285. virtual void copy(int* indices, DistanceType* dist, int n_neighbors = -1) const
  286. {
  287. if (n_neighbors < 0) {
  288. for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
  289. dist_indices_.end(); dist_index != dist_index_end; ++dist_index, ++indices, ++dist) {
  290. *indices = dist_index->index_;
  291. *dist = dist_index->dist_;
  292. }
  293. }
  294. else {
  295. int i = 0;
  296. for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
  297. dist_indices_.end(); (dist_index != dist_index_end) && (i < n_neighbors); ++dist_index, ++indices, ++dist, ++i) {
  298. *indices = dist_index->index_;
  299. *dist = dist_index->dist_;
  300. }
  301. }
  302. }
  303. /** Copy the set to two C arrays but sort it according to the distance first
  304. * @param indices pointer to a C array of indices
  305. * @param dist pointer to a C array of distances
  306. * @param n_neighbors the number of neighbors to copy
  307. */
  308. virtual void sortAndCopy(int* indices, DistanceType* dist, int n_neighbors = -1) const
  309. {
  310. copy(indices, dist, n_neighbors);
  311. }
  312. /** The number of neighbors in the set
  313. */
  314. size_t size() const
  315. {
  316. return dist_indices_.size();
  317. }
  318. /** The distance of the furthest neighbor
  319. * If we don't have enough neighbors, it returns the max possible value
  320. */
  321. inline DistanceType worstDist() const CV_OVERRIDE
  322. {
  323. return worst_distance_;
  324. }
  325. protected:
  326. /** Flag to say if the set is full */
  327. bool is_full_;
  328. /** The worst distance found so far */
  329. DistanceType worst_distance_;
  330. /** The best candidates so far */
  331. std::set<DistIndex> dist_indices_;
  332. };
  333. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  334. /** Class that holds the k NN neighbors
  335. * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
  336. */
  337. template<typename DistanceType>
  338. class KNNUniqueResultSet : public UniqueResultSet<DistanceType>
  339. {
  340. public:
  341. /** Constructor
  342. * @param capacity the number of neighbors to store at max
  343. */
  344. KNNUniqueResultSet(unsigned int capacity) : capacity_(capacity)
  345. {
  346. this->is_full_ = false;
  347. this->clear();
  348. }
  349. /** Add a possible candidate to the best neighbors
  350. * @param dist distance for that neighbor
  351. * @param index index of that neighbor
  352. */
  353. inline void addPoint(DistanceType dist, int index) CV_OVERRIDE
  354. {
  355. // Don't do anything if we are worse than the worst
  356. if (dist >= worst_distance_) return;
  357. dist_indices_.insert(DistIndex(dist, index));
  358. if (is_full_) {
  359. if (dist_indices_.size() > capacity_) {
  360. dist_indices_.erase(*dist_indices_.rbegin());
  361. worst_distance_ = dist_indices_.rbegin()->dist_;
  362. }
  363. }
  364. else if (dist_indices_.size() == capacity_) {
  365. is_full_ = true;
  366. worst_distance_ = dist_indices_.rbegin()->dist_;
  367. }
  368. }
  369. /** Remove all elements in the set
  370. */
  371. void clear() CV_OVERRIDE
  372. {
  373. dist_indices_.clear();
  374. worst_distance_ = std::numeric_limits<DistanceType>::max();
  375. is_full_ = false;
  376. }
  377. protected:
  378. typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
  379. using UniqueResultSet<DistanceType>::is_full_;
  380. using UniqueResultSet<DistanceType>::worst_distance_;
  381. using UniqueResultSet<DistanceType>::dist_indices_;
  382. /** The number of neighbors to keep */
  383. unsigned int capacity_;
  384. };
  385. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  386. /** Class that holds the radius nearest neighbors
  387. * It is more accurate than RadiusResult as it is not limited in the number of neighbors
  388. */
  389. template<typename DistanceType>
  390. class RadiusUniqueResultSet : public UniqueResultSet<DistanceType>
  391. {
  392. public:
  393. /** Constructor
  394. * @param radius the maximum distance of a neighbor
  395. */
  396. RadiusUniqueResultSet(DistanceType radius) :
  397. radius_(radius)
  398. {
  399. is_full_ = true;
  400. }
  401. /** Add a possible candidate to the best neighbors
  402. * @param dist distance for that neighbor
  403. * @param index index of that neighbor
  404. */
  405. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  406. {
  407. if (dist <= radius_) dist_indices_.insert(DistIndex(dist, index));
  408. }
  409. /** Remove all elements in the set
  410. */
  411. inline void clear() CV_OVERRIDE
  412. {
  413. dist_indices_.clear();
  414. }
  415. /** Check the status of the set
  416. * @return alwys false
  417. */
  418. inline bool full() const CV_OVERRIDE
  419. {
  420. return true;
  421. }
  422. /** The distance of the furthest neighbor
  423. * If we don't have enough neighbors, it returns the max possible value
  424. */
  425. inline DistanceType worstDist() const CV_OVERRIDE
  426. {
  427. return radius_;
  428. }
  429. private:
  430. typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
  431. using UniqueResultSet<DistanceType>::dist_indices_;
  432. using UniqueResultSet<DistanceType>::is_full_;
  433. /** The furthest distance a neighbor can be */
  434. DistanceType radius_;
  435. };
  436. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  437. /** Class that holds the k NN neighbors within a radius distance
  438. */
  439. template<typename DistanceType>
  440. class KNNRadiusUniqueResultSet : public KNNUniqueResultSet<DistanceType>
  441. {
  442. public:
  443. /** Constructor
  444. * @param capacity the number of neighbors to store at max
  445. * @param radius the maximum distance of a neighbor
  446. */
  447. KNNRadiusUniqueResultSet(unsigned int capacity, DistanceType radius)
  448. {
  449. this->capacity_ = capacity;
  450. this->radius_ = radius;
  451. this->dist_indices_.reserve(capacity_);
  452. this->clear();
  453. }
  454. /** Remove all elements in the set
  455. */
  456. void clear()
  457. {
  458. dist_indices_.clear();
  459. worst_distance_ = radius_;
  460. is_full_ = false;
  461. }
  462. private:
  463. using KNNUniqueResultSet<DistanceType>::dist_indices_;
  464. using KNNUniqueResultSet<DistanceType>::is_full_;
  465. using KNNUniqueResultSet<DistanceType>::worst_distance_;
  466. /** The maximum number of neighbors to consider */
  467. unsigned int capacity_;
  468. /** The maximum distance of a neighbor */
  469. DistanceType radius_;
  470. };
  471. }
  472. //! @endcond
  473. #endif //OPENCV_FLANN_RESULTSET_H