Skip to content

Commit 9fe639d

Browse files
committed
fix interface
1 parent 4cf279b commit 9fe639d

File tree

4 files changed

+107
-33
lines changed

4 files changed

+107
-33
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// This is a test file for testing the interface
2+
// >>> virtual std::vector<std::pair<dist_t, labeltype>>
3+
// >>> searchKnnCloserFirst(const void* query_data, size_t k) const;
4+
// of class AlgorithmInterface
5+
6+
#include "../hnswlib/hnswlib.h"
7+
8+
#include <assert.h>
9+
10+
#include <vector>
11+
#include <iostream>
12+
13+
namespace
14+
{
15+
16+
using idx_t = hnswlib::labeltype;
17+
18+
void test() {
19+
int d = 4;
20+
idx_t n = 100;
21+
idx_t nq = 10;
22+
size_t k = 10;
23+
24+
std::vector<float> data(n * d);
25+
std::vector<float> query(nq * d);
26+
27+
std::mt19937 rng;
28+
rng.seed(47);
29+
std::uniform_real_distribution<> distrib;
30+
31+
for (idx_t i = 0; i < n * d; ++i) {
32+
data[i] = distrib(rng);
33+
}
34+
for (idx_t i = 0; i < nq * d; ++i) {
35+
query[i] = distrib(rng);
36+
}
37+
38+
39+
hnswlib::L2Space space(d);
40+
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
41+
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
42+
43+
for (size_t i = 0; i < n; ++i) {
44+
alg_brute->addPoint(data.data() + d * i, i);
45+
alg_hnsw->addPoint(data.data() + d * i, i);
46+
}
47+
48+
// test searchKnnCloserFirst of BruteforceSearch
49+
for (size_t j = 0; j < nq; ++j) {
50+
const void* p = query.data() + j * d;
51+
auto gd = alg_brute->searchKnn(p, k);
52+
auto res = alg_brute->searchKnnCloserFirst(p, k);
53+
assert(gd.size() == res.size());
54+
size_t t = gd.size();
55+
while (!gd.empty()) {
56+
assert(gd.top() == res[--t]);
57+
gd.pop();
58+
}
59+
}
60+
for (size_t j = 0; j < nq; ++j) {
61+
const void* p = query.data() + j * d;
62+
auto gd = alg_hnsw->searchKnn(p, k);
63+
auto res = alg_hnsw->searchKnnCloserFirst(p, k);
64+
assert(gd.size() == res.size());
65+
size_t t = gd.size();
66+
while (!gd.empty()) {
67+
assert(gd.top() == res[--t]);
68+
gd.pop();
69+
}
70+
}
71+
72+
delete alg_brute;
73+
delete alg_hnsw;
74+
}
75+
76+
} // namespace
77+
78+
int main() {
79+
std::cout << "Testing ..." << std::endl;
80+
test();
81+
std::cout << "Test ok" << std::endl;
82+
83+
return 0;
84+
}

hnswlib/bruteforce.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,24 +111,6 @@ namespace hnswlib {
111111
return topResults;
112112
};
113113

114-
template <typename Comp>
115-
std::vector<std::pair<dist_t, labeltype>>
116-
searchKnn(const void* query_data, size_t k, Comp comp) {
117-
std::vector<std::pair<dist_t, labeltype>> result;
118-
if (cur_element_count == 0) return result;
119-
120-
auto ret = searchKnn(query_data, k);
121-
122-
while (!ret.empty()) {
123-
result.push_back(ret.top());
124-
ret.pop();
125-
}
126-
127-
std::sort(result.begin(), result.end(), comp);
128-
129-
return result;
130-
}
131-
132114
void saveIndex(const std::string &location) {
133115
std::ofstream output(location, std::ios::binary);
134116
std::streampos position;

hnswlib/hnswalg.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,19 +1157,6 @@ namespace hnswlib {
11571157
return result;
11581158
};
11591159

1160-
int searchKnn(const void* x,
1161-
int k, labeltype* labels, dist_t* dists = nullptr) const override {
1162-
if (labels == nullptr) return -1;
1163-
auto ret = searchKnn(x, k);
1164-
for (int i = k - 1; i >= 0; --i) {
1165-
if (dists)
1166-
dists[i] = ret.top().first;
1167-
labels[i] = ret.top().second;
1168-
}
1169-
return 0;
1170-
}
1171-
1172-
11731160
void checkIntegrity(){
11741161
int connections_checked=0;
11751162
std::vector <int > inbound_connections_num(cur_element_count,0);

hnswlib/hnswlib.h

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,34 @@ namespace hnswlib {
7171
public:
7272
virtual void addPoint(const void *datapoint, labeltype label)=0;
7373
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
74-
virtual int searchKnn(const void* x,
75-
int k, labeltype* labels, dist_t* dists) const = 0;
74+
75+
// Return k nearest neighbor in the order of closer fist
76+
virtual std::vector<std::pair<dist_t, labeltype>>
77+
searchKnnCloserFirst(const void* query_data, size_t k) const;
78+
7679
virtual void saveIndex(const std::string &location)=0;
7780
virtual ~AlgorithmInterface(){
7881
}
7982
};
8083

84+
template<typename dist_t>
85+
std::vector<std::pair<dist_t, labeltype>>
86+
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
87+
std::vector<std::pair<dist_t, labeltype>> result;
88+
89+
// here searchKnn returns the result in the order of further first
90+
auto ret = searchKnn(query_data, k);
91+
{
92+
size_t sz = ret.size();
93+
result.resize(sz);
94+
while (!ret.empty()) {
95+
result[--sz] = ret.top();
96+
ret.pop();
97+
}
98+
}
99+
100+
return result;
101+
}
81102

82103
}
83104

0 commit comments

Comments
 (0)