Skip to content

Commit 52da3d2

Browse files
authored
Merge pull request #225 from uestc-lfs/fix-interface
Remove temlate interface searchKnn
2 parents b4b7b86 + 21c1ad7 commit 52da3d2

File tree

5 files changed

+109
-40
lines changed

5 files changed

+109
-40
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ endif()
2323

2424
add_executable(test_updates examples/updates_test.cpp)
2525

26+
add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)
27+
2628
target_link_libraries(main sift_test)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// This is a test file for testing the interface
2+
// >>> virtual std::vector<std::pair<dist_t, labeltype>>
3+
// >>> searchKnnCloserFirst(const void* query_data, size_t k) const;
4+
// of class AlgorithmInterface
5+
6+
#include "../hnswlib/hnswlib.h"
7+
8+
#include <assert.h>
9+
10+
#include <vector>
11+
#include <iostream>
12+
13+
namespace
14+
{
15+
16+
using idx_t = hnswlib::labeltype;
17+
18+
void test() {
19+
int d = 4;
20+
idx_t n = 100;
21+
idx_t nq = 10;
22+
size_t k = 10;
23+
24+
std::vector<float> data(n * d);
25+
std::vector<float> query(nq * d);
26+
27+
std::mt19937 rng;
28+
rng.seed(47);
29+
std::uniform_real_distribution<> distrib;
30+
31+
for (idx_t i = 0; i < n * d; ++i) {
32+
data[i] = distrib(rng);
33+
}
34+
for (idx_t i = 0; i < nq * d; ++i) {
35+
query[i] = distrib(rng);
36+
}
37+
38+
39+
hnswlib::L2Space space(d);
40+
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
41+
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
42+
43+
for (size_t i = 0; i < n; ++i) {
44+
alg_brute->addPoint(data.data() + d * i, i);
45+
alg_hnsw->addPoint(data.data() + d * i, i);
46+
}
47+
48+
// test searchKnnCloserFirst of BruteforceSearch
49+
for (size_t j = 0; j < nq; ++j) {
50+
const void* p = query.data() + j * d;
51+
auto gd = alg_brute->searchKnn(p, k);
52+
auto res = alg_brute->searchKnnCloserFirst(p, k);
53+
assert(gd.size() == res.size());
54+
size_t t = gd.size();
55+
while (!gd.empty()) {
56+
assert(gd.top() == res[--t]);
57+
gd.pop();
58+
}
59+
}
60+
for (size_t j = 0; j < nq; ++j) {
61+
const void* p = query.data() + j * d;
62+
auto gd = alg_hnsw->searchKnn(p, k);
63+
auto res = alg_hnsw->searchKnnCloserFirst(p, k);
64+
assert(gd.size() == res.size());
65+
size_t t = gd.size();
66+
while (!gd.empty()) {
67+
assert(gd.top() == res[--t]);
68+
gd.pop();
69+
}
70+
}
71+
72+
delete alg_brute;
73+
delete alg_hnsw;
74+
}
75+
76+
} // namespace
77+
78+
int main() {
79+
std::cout << "Testing ..." << std::endl;
80+
test();
81+
std::cout << "Test ok" << std::endl;
82+
83+
return 0;
84+
}

hnswlib/bruteforce.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,6 @@ namespace hnswlib {
111111
return topResults;
112112
};
113113

114-
template <typename Comp>
115-
std::vector<std::pair<dist_t, labeltype>>
116-
searchKnn(const void* query_data, size_t k, Comp comp) {
117-
std::vector<std::pair<dist_t, labeltype>> result;
118-
if (cur_element_count == 0) return result;
119-
120-
auto ret = searchKnn(query_data, k);
121-
122-
while (!ret.empty()) {
123-
result.push_back(ret.top());
124-
ret.pop();
125-
}
126-
127-
std::sort(result.begin(), result.end(), comp);
128-
129-
return result;
130-
}
131-
132114
void saveIndex(const std::string &location) {
133115
std::ofstream output(location, std::ios::binary);
134116
std::streampos position;

hnswlib/hnswalg.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <unordered_set>
1010
#include <list>
1111

12-
1312
namespace hnswlib {
1413
typedef unsigned int tableint;
1514
typedef unsigned int linklistsizeint;
@@ -1156,24 +1155,6 @@ namespace hnswlib {
11561155
return result;
11571156
};
11581157

1159-
template <typename Comp>
1160-
std::vector<std::pair<dist_t, labeltype>>
1161-
searchKnn(const void* query_data, size_t k, Comp comp) {
1162-
std::vector<std::pair<dist_t, labeltype>> result;
1163-
if (cur_element_count == 0) return result;
1164-
1165-
auto ret = searchKnn(query_data, k);
1166-
1167-
while (!ret.empty()) {
1168-
result.push_back(ret.top());
1169-
ret.pop();
1170-
}
1171-
1172-
std::sort(result.begin(), result.end(), comp);
1173-
1174-
return result;
1175-
}
1176-
11771158
void checkIntegrity(){
11781159
int connections_checked=0;
11791160
std::vector <int > inbound_connections_num(cur_element_count,0);

hnswlib/hnswlib.h

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,34 @@ namespace hnswlib {
7171
public:
7272
virtual void addPoint(const void *datapoint, labeltype label)=0;
7373
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
74-
template <typename Comp>
75-
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
76-
}
74+
75+
// Return k nearest neighbor in the order of closer fist
76+
virtual std::vector<std::pair<dist_t, labeltype>>
77+
searchKnnCloserFirst(const void* query_data, size_t k) const;
78+
7779
virtual void saveIndex(const std::string &location)=0;
7880
virtual ~AlgorithmInterface(){
7981
}
8082
};
8183

84+
template<typename dist_t>
85+
std::vector<std::pair<dist_t, labeltype>>
86+
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
87+
std::vector<std::pair<dist_t, labeltype>> result;
88+
89+
// here searchKnn returns the result in the order of further first
90+
auto ret = searchKnn(query_data, k);
91+
{
92+
size_t sz = ret.size();
93+
result.resize(sz);
94+
while (!ret.empty()) {
95+
result[--sz] = ret.top();
96+
ret.pop();
97+
}
98+
}
99+
100+
return result;
101+
}
82102

83103
}
84104

0 commit comments

Comments
 (0)