lsh_index.h 15 KB


  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. /***********************************************************************
  31. * Author: Vincent Rabaud
  32. *************************************************************************/
  33. #ifndef OPENCV_FLANN_LSH_INDEX_H_
  34. #define OPENCV_FLANN_LSH_INDEX_H_
  35. #include <algorithm>
  36. #include <cassert>
  37. #include <cstring>
  38. #include <map>
  39. #include <vector>
  40. #include "general.h"
  41. #include "nn_index.h"
  42. #include "matrix.h"
  43. #include "result_set.h"
  44. #include "heap.h"
  45. #include "lsh_table.h"
  46. #include "allocator.h"
  47. #include "random.h"
  48. #include "saving.h"
  49. namespace cvflann
  50. {
  51. struct LshIndexParams : public IndexParams
  52. {
  53. LshIndexParams(unsigned int table_number = 12, unsigned int key_size = 20, unsigned int multi_probe_level = 2)
  54. {
  55. (* this)["algorithm"] = FLANN_INDEX_LSH;
  56. // The number of hash tables to use
  57. (*this)["table_number"] = table_number;
  58. // The length of the key in the hash tables
  59. (*this)["key_size"] = key_size;
  60. // Number of levels to use in multi-probe (0 for standard LSH)
  61. (*this)["multi_probe_level"] = multi_probe_level;
  62. }
  63. };
  64. /**
  65. * Randomized kd-tree index
  66. *
  67. * Contains the k-d trees and other information for indexing a set of points
  68. * for nearest-neighbor matching.
  69. */
  70. template<typename Distance>
  71. class LshIndex : public NNIndex<Distance>
  72. {
  73. public:
  74. typedef typename Distance::ElementType ElementType;
  75. typedef typename Distance::ResultType DistanceType;
  76. /** Constructor
  77. * @param input_data dataset with the input features
  78. * @param params parameters passed to the LSH algorithm
  79. * @param d the distance used
  80. */
  81. LshIndex(const Matrix<ElementType>& input_data, const IndexParams& params = LshIndexParams(),
  82. Distance d = Distance()) :
  83. dataset_(input_data), index_params_(params), distance_(d)
  84. {
  85. // cv::flann::IndexParams sets integer params as 'int', so it is used with get_param
  86. // in place of 'unsigned int'
  87. table_number_ = (unsigned int)get_param<int>(index_params_,"table_number",12);
  88. key_size_ = (unsigned int)get_param<int>(index_params_,"key_size",20);
  89. multi_probe_level_ = (unsigned int)get_param<int>(index_params_,"multi_probe_level",2);
  90. feature_size_ = (unsigned)dataset_.cols;
  91. fill_xor_mask(0, key_size_, multi_probe_level_, xor_masks_);
  92. }
  93. LshIndex(const LshIndex&);
  94. LshIndex& operator=(const LshIndex&);
  95. /**
  96. * Builds the index
  97. */
  98. void buildIndex() CV_OVERRIDE
  99. {
  100. tables_.resize(table_number_);
  101. for (unsigned int i = 0; i < table_number_; ++i) {
  102. lsh::LshTable<ElementType>& table = tables_[i];
  103. table = lsh::LshTable<ElementType>(feature_size_, key_size_);
  104. // Add the features to the table
  105. table.add(dataset_);
  106. }
  107. }
  108. flann_algorithm_t getType() const CV_OVERRIDE
  109. {
  110. return FLANN_INDEX_LSH;
  111. }
  112. void saveIndex(FILE* stream) CV_OVERRIDE
  113. {
  114. save_value(stream,table_number_);
  115. save_value(stream,key_size_);
  116. save_value(stream,multi_probe_level_);
  117. save_value(stream, dataset_);
  118. }
  119. void loadIndex(FILE* stream) CV_OVERRIDE
  120. {
  121. load_value(stream, table_number_);
  122. load_value(stream, key_size_);
  123. load_value(stream, multi_probe_level_);
  124. load_value(stream, dataset_);
  125. // Building the index is so fast we can afford not storing it
  126. buildIndex();
  127. index_params_["algorithm"] = getType();
  128. index_params_["table_number"] = table_number_;
  129. index_params_["key_size"] = key_size_;
  130. index_params_["multi_probe_level"] = multi_probe_level_;
  131. }
  132. /**
  133. * Returns size of index.
  134. */
  135. size_t size() const CV_OVERRIDE
  136. {
  137. return dataset_.rows;
  138. }
  139. /**
  140. * Returns the length of an index feature.
  141. */
  142. size_t veclen() const CV_OVERRIDE
  143. {
  144. return feature_size_;
  145. }
  146. /**
  147. * Computes the index memory usage
  148. * Returns: memory used by the index
  149. */
  150. int usedMemory() const CV_OVERRIDE
  151. {
  152. return (int)(dataset_.rows * sizeof(int));
  153. }
  154. IndexParams getParameters() const CV_OVERRIDE
  155. {
  156. return index_params_;
  157. }
  158. /**
  159. * \brief Perform k-nearest neighbor search
  160. * \param[in] queries The query points for which to find the nearest neighbors
  161. * \param[out] indices The indices of the nearest neighbors found
  162. * \param[out] dists Distances to the nearest neighbors found
  163. * \param[in] knn Number of nearest neighbors to return
  164. * \param[in] params Search parameters
  165. */
  166. virtual void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params) CV_OVERRIDE
  167. {
  168. assert(queries.cols == veclen());
  169. assert(indices.rows >= queries.rows);
  170. assert(dists.rows >= queries.rows);
  171. assert(int(indices.cols) >= knn);
  172. assert(int(dists.cols) >= knn);
  173. KNNUniqueResultSet<DistanceType> resultSet(knn);
  174. for (size_t i = 0; i < queries.rows; i++) {
  175. resultSet.clear();
  176. std::fill_n(indices[i], knn, -1);
  177. std::fill_n(dists[i], knn, std::numeric_limits<DistanceType>::max());
  178. findNeighbors(resultSet, queries[i], params);
  179. if (get_param(params,"sorted",true)) resultSet.sortAndCopy(indices[i], dists[i], knn);
  180. else resultSet.copy(indices[i], dists[i], knn);
  181. }
  182. }
  183. /**
  184. * Find set of nearest neighbors to vec. Their indices are stored inside
  185. * the result object.
  186. *
  187. * Params:
  188. * result = the result object in which the indices of the nearest-neighbors are stored
  189. * vec = the vector for which to search the nearest neighbors
  190. * maxCheck = the maximum number of restarts (in a best-bin-first manner)
  191. */
  192. void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& /*searchParams*/) CV_OVERRIDE
  193. {
  194. getNeighbors(vec, result);
  195. }
  196. private:
  197. /** Defines the comparator on score and index
  198. */
  199. typedef std::pair<float, unsigned int> ScoreIndexPair;
  200. struct SortScoreIndexPairOnSecond
  201. {
  202. bool operator()(const ScoreIndexPair& left, const ScoreIndexPair& right) const
  203. {
  204. return left.second < right.second;
  205. }
  206. };
  207. /** Fills the different xor masks to use when getting the neighbors in multi-probe LSH
  208. * @param key the key we build neighbors from
  209. * @param lowest_index the lowest index of the bit set
  210. * @param level the multi-probe level we are at
  211. * @param xor_masks all the xor mask
  212. */
  213. void fill_xor_mask(lsh::BucketKey key, int lowest_index, unsigned int level,
  214. std::vector<lsh::BucketKey>& xor_masks)
  215. {
  216. xor_masks.push_back(key);
  217. if (level == 0) return;
  218. for (int index = lowest_index - 1; index >= 0; --index) {
  219. // Create a new key
  220. lsh::BucketKey new_key = key | (1 << index);
  221. fill_xor_mask(new_key, index, level - 1, xor_masks);
  222. }
  223. }
  224. /** Performs the approximate nearest-neighbor search.
  225. * @param vec the feature to analyze
  226. * @param do_radius flag indicating if we check the radius too
  227. * @param radius the radius if it is a radius search
  228. * @param do_k flag indicating if we limit the number of nn
  229. * @param k_nn the number of nearest neighbors
  230. * @param checked_average used for debugging
  231. */
  232. void getNeighbors(const ElementType* vec, bool /*do_radius*/, float radius, bool do_k, unsigned int k_nn,
  233. float& /*checked_average*/)
  234. {
  235. static std::vector<ScoreIndexPair> score_index_heap;
  236. if (do_k) {
  237. unsigned int worst_score = std::numeric_limits<unsigned int>::max();
  238. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
  239. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
  240. for (; table != table_end; ++table) {
  241. size_t key = table->getKey(vec);
  242. std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
  243. std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
  244. for (; xor_mask != xor_mask_end; ++xor_mask) {
  245. size_t sub_key = key ^ (*xor_mask);
  246. const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
  247. if (bucket == 0) continue;
  248. // Go over each descriptor index
  249. std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
  250. std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
  251. DistanceType hamming_distance;
  252. // Process the rest of the candidates
  253. for (; training_index < last_training_index; ++training_index) {
  254. hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols);
  255. if (hamming_distance < worst_score) {
  256. // Insert the new element
  257. score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
  258. std::push_heap(score_index_heap.begin(), score_index_heap.end());
  259. if (score_index_heap.size() > (unsigned int)k_nn) {
  260. // Remove the highest distance value as we have too many elements
  261. std::pop_heap(score_index_heap.begin(), score_index_heap.end());
  262. score_index_heap.pop_back();
  263. // Keep track of the worst score
  264. worst_score = score_index_heap.front().first;
  265. }
  266. }
  267. }
  268. }
  269. }
  270. }
  271. else {
  272. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
  273. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
  274. for (; table != table_end; ++table) {
  275. size_t key = table->getKey(vec);
  276. std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
  277. std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
  278. for (; xor_mask != xor_mask_end; ++xor_mask) {
  279. size_t sub_key = key ^ (*xor_mask);
  280. const lsh::Bucket* bucket = table->getBucketFromKey(sub_key);
  281. if (bucket == 0) continue;
  282. // Go over each descriptor index
  283. std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
  284. std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
  285. DistanceType hamming_distance;
  286. // Process the rest of the candidates
  287. for (; training_index < last_training_index; ++training_index) {
  288. // Compute the Hamming distance
  289. hamming_distance = distance_(vec, dataset_[*training_index], dataset_.cols);
  290. if (hamming_distance < radius) score_index_heap.push_back(ScoreIndexPair(hamming_distance, training_index));
  291. }
  292. }
  293. }
  294. }
  295. }
  296. /** Performs the approximate nearest-neighbor search.
  297. * This is a slower version than the above as it uses the ResultSet
  298. * @param vec the feature to analyze
  299. */
  300. void getNeighbors(const ElementType* vec, ResultSet<DistanceType>& result)
  301. {
  302. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table = tables_.begin();
  303. typename std::vector<lsh::LshTable<ElementType> >::const_iterator table_end = tables_.end();
  304. for (; table != table_end; ++table) {
  305. size_t key = table->getKey(vec);
  306. std::vector<lsh::BucketKey>::const_iterator xor_mask = xor_masks_.begin();
  307. std::vector<lsh::BucketKey>::const_iterator xor_mask_end = xor_masks_.end();
  308. for (; xor_mask != xor_mask_end; ++xor_mask) {
  309. size_t sub_key = key ^ (*xor_mask);
  310. const lsh::Bucket* bucket = table->getBucketFromKey((lsh::BucketKey)sub_key);
  311. if (bucket == 0) continue;
  312. // Go over each descriptor index
  313. std::vector<lsh::FeatureIndex>::const_iterator training_index = bucket->begin();
  314. std::vector<lsh::FeatureIndex>::const_iterator last_training_index = bucket->end();
  315. DistanceType hamming_distance;
  316. // Process the rest of the candidates
  317. for (; training_index < last_training_index; ++training_index) {
  318. // Compute the Hamming distance
  319. hamming_distance = distance_(vec, dataset_[*training_index], (int)dataset_.cols);
  320. result.addPoint(hamming_distance, *training_index);
  321. }
  322. }
  323. }
  324. }
  325. /** The different hash tables */
  326. std::vector<lsh::LshTable<ElementType> > tables_;
  327. /** The data the LSH tables where built from */
  328. Matrix<ElementType> dataset_;
  329. /** The size of the features (as ElementType[]) */
  330. unsigned int feature_size_;
  331. IndexParams index_params_;
  332. /** table number */
  333. unsigned int table_number_;
  334. /** key size */
  335. unsigned int key_size_;
  336. /** How far should we look for neighbors in multi-probe LSH */
  337. unsigned int multi_probe_level_;
  338. /** The XOR masks to apply to a key to get the neighboring buckets */
  339. std::vector<lsh::BucketKey> xor_masks_;
  340. Distance distance_;
  341. };
  342. }
  343. #endif //OPENCV_FLANN_LSH_INDEX_H_