Skip to content

Commit 765c4ab

Browse files
committed
Filter elements with an optional filtering function.
1 parent fdb1632 commit 765c4ab

File tree

5 files changed

+131
-21
lines changed

5 files changed

+131
-21
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
2222
add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)
2323
target_link_libraries(searchKnnCloserFirst_test hnswlib)
2424

25+
add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp)
26+
target_link_libraries(searchKnnWithFilter_test hnswlib)
27+
2528
add_executable(main main.cpp sift_1b.cpp)
2629
target_link_libraries(main hnswlib)
2730
endif()

examples/searchKnnWithFilter_test.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// This is a test file for testing the filtering feature
2+
3+
#include "../hnswlib/hnswlib.h"
4+
5+
#include <assert.h>
6+
7+
#include <vector>
8+
#include <iostream>
9+
10+
namespace
11+
{
12+
13+
using idx_t = hnswlib::labeltype;
14+
15+
bool pickIdsDivisibleByThree(unsigned int ep_id) {
16+
return ep_id % 3 == 0;
17+
}
18+
19+
bool pickIdsDivisibleBySeven(unsigned int ep_id) {
20+
return ep_id % 7 == 0;
21+
}
22+
23+
template<typename filter_func_t>
24+
void test(filter_func_t filter_func, size_t div_num) {
25+
int d = 4;
26+
idx_t n = 100;
27+
idx_t nq = 10;
28+
size_t k = 10;
29+
30+
std::vector<float> data(n * d);
31+
std::vector<float> query(nq * d);
32+
33+
std::mt19937 rng;
34+
rng.seed(47);
35+
std::uniform_real_distribution<> distrib;
36+
37+
for (idx_t i = 0; i < n * d; ++i) {
38+
data[i] = distrib(rng);
39+
}
40+
for (idx_t i = 0; i < nq * d; ++i) {
41+
query[i] = distrib(rng);
42+
}
43+
44+
45+
hnswlib::L2Space space(d);
46+
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float,hnswlib::FILTERFUNC>(&space, 2 * n);
47+
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float,hnswlib::FILTERFUNC>(&space, 2 * n);
48+
49+
for (size_t i = 0; i < n; ++i) {
50+
alg_brute->addPoint(data.data() + d * i, i);
51+
alg_hnsw->addPoint(data.data() + d * i, i);
52+
}
53+
54+
// test searchKnnCloserFirst of BruteforceSearch with filtering
55+
for (size_t j = 0; j < nq; ++j) {
56+
const void* p = query.data() + j * d;
57+
auto gd = alg_brute->searchKnn(p, k, filter_func);
58+
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
59+
assert(gd.size() == res.size());
60+
size_t t = gd.size();
61+
while (!gd.empty()) {
62+
assert(gd.top() == res[--t]);
63+
assert((gd.top().second % div_num) == 0);
64+
gd.pop();
65+
}
66+
}
67+
68+
// test searchKnnCloserFirst of hnsw with filtering
69+
for (size_t j = 0; j < nq; ++j) {
70+
const void* p = query.data() + j * d;
71+
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
72+
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
73+
assert(gd.size() == res.size());
74+
size_t t = gd.size();
75+
while (!gd.empty()) {
76+
assert(gd.top() == res[--t]);
77+
assert((gd.top().second % div_num) == 0);
78+
gd.pop();
79+
}
80+
}
81+
82+
delete alg_brute;
83+
delete alg_hnsw;
84+
}
85+
86+
} // namespace
87+
88+
int main() {
89+
std::cout << "Testing ..." << std::endl;
90+
test(pickIdsDivisibleByThree, 3);
91+
test(pickIdsDivisibleBySeven, 7);
92+
std::cout << "Test ok" << std::endl;
93+
94+
return 0;
95+
}

hnswlib/bruteforce.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
#include <algorithm>
66

77
namespace hnswlib {
8-
template<typename dist_t>
9-
class BruteforceSearch : public AlgorithmInterface<dist_t> {
8+
template<typename dist_t, typename filter_func_t=FILTERFUNC>
9+
class BruteforceSearch : public AlgorithmInterface<dist_t,filter_func_t> {
1010
public:
1111
BruteforceSearch(SpaceInterface <dist_t> *s) : data_(nullptr), maxelements_(0),
1212
cur_element_count(0), size_per_element_(0), data_size_(0),
@@ -92,20 +92,24 @@ namespace hnswlib {
9292

9393

9494
std::priority_queue<std::pair<dist_t, labeltype >>
95-
searchKnn(const void *query_data, size_t k) const {
95+
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
9696
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
9797
if (cur_element_count == 0) return topResults;
9898
for (int i = 0; i < k; i++) {
9999
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
100-
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
101-
data_size_))));
100+
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
101+
if(isIdAllowed(label)) {
102+
topResults.push(std::pair<dist_t, labeltype>(dist, label));
103+
}
102104
}
103105
dist_t lastdist = topResults.top().first;
104106
for (int i = k; i < cur_element_count; i++) {
105107
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
106108
if (dist <= lastdist) {
107-
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
108-
data_size_))));
109+
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
110+
if(isIdAllowed(label)) {
111+
topResults.push(std::pair<dist_t, labeltype>(dist, label));
112+
}
109113
if (topResults.size() > k)
110114
topResults.pop();
111115
lastdist = topResults.top().first;

hnswlib/hnswalg.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace hnswlib {
1313
typedef unsigned int tableint;
1414
typedef unsigned int linklistsizeint;
1515

16-
template<typename dist_t>
17-
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
16+
template<typename dist_t, typename filter_func_t=FILTERFUNC>
17+
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> {
1818
public:
1919
static const tableint max_update_element_locks = 65536;
2020
HierarchicalNSW(SpaceInterface<dist_t> *s) {
@@ -238,7 +238,7 @@ namespace hnswlib {
238238

239239
template <bool has_deletions, bool collect_metrics=false>
240240
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
241-
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
241+
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const {
242242
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
243243
vl_type *visited_array = vl->mass;
244244
vl_type visited_array_tag = vl->curV;
@@ -247,7 +247,7 @@ namespace hnswlib {
247247
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
248248

249249
dist_t lowerBound;
250-
if (!has_deletions || !isMarkedDeleted(ep_id)) {
250+
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) {
251251
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
252252
lowerBound = dist;
253253
top_candidates.emplace(dist, ep_id);
@@ -307,7 +307,7 @@ namespace hnswlib {
307307
_MM_HINT_T0);////////////////////////
308308
#endif
309309

310-
if (!has_deletions || !isMarkedDeleted(candidate_id))
310+
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id))
311311
top_candidates.emplace(dist, candidate_id);
312312

313313
if (top_candidates.size() > ef)
@@ -1111,7 +1111,7 @@ namespace hnswlib {
11111111
};
11121112

11131113
std::priority_queue<std::pair<dist_t, labeltype >>
1114-
searchKnn(const void *query_data, size_t k) const {
1114+
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
11151115
std::priority_queue<std::pair<dist_t, labeltype >> result;
11161116
if (cur_element_count == 0) return result;
11171117

@@ -1148,11 +1148,11 @@ namespace hnswlib {
11481148
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
11491149
if (num_deleted_) {
11501150
top_candidates=searchBaseLayerST<true,true>(
1151-
currObj, query_data, std::max(ef_, k));
1151+
currObj, query_data, std::max(ef_, k), isIdAllowed);
11521152
}
11531153
else{
11541154
top_candidates=searchBaseLayerST<false,true>(
1155-
currObj, query_data, std::max(ef_, k));
1155+
currObj, query_data, std::max(ef_, k), isIdAllowed);
11561156
}
11571157

11581158
while (top_candidates.size() > k) {

hnswlib/hnswlib.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,10 @@ static bool AVX512Capable() {
116116
namespace hnswlib {
117117
typedef size_t labeltype;
118118

119+
bool allowAllIds(unsigned int ep_id) {
120+
return true;
121+
}
122+
119123
template <typename T>
120124
class pairGreater {
121125
public:
@@ -137,6 +141,7 @@ namespace hnswlib {
137141
template<typename MTYPE>
138142
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
139143

144+
using FILTERFUNC = bool(*)(unsigned int);
140145

141146
template<typename MTYPE>
142147
class SpaceInterface {
@@ -151,28 +156,31 @@ namespace hnswlib {
151156
virtual ~SpaceInterface() {}
152157
};
153158

154-
template<typename dist_t>
159+
template<typename dist_t, typename filter_func_t=FILTERFUNC>
155160
class AlgorithmInterface {
156161
public:
157162
virtual void addPoint(const void *datapoint, labeltype label)=0;
158-
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
163+
164+
virtual std::priority_queue<std::pair<dist_t, labeltype >>
165+
searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0;
159166

160167
// Return k nearest neighbor in the order of closer fist
161168
virtual std::vector<std::pair<dist_t, labeltype>>
162-
searchKnnCloserFirst(const void* query_data, size_t k) const;
169+
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const;
163170

164171
virtual void saveIndex(const std::string &location)=0;
165172
virtual ~AlgorithmInterface(){
166173
}
167174
};
168175

169-
template<typename dist_t>
176+
template<typename dist_t, typename filter_func_t>
170177
std::vector<std::pair<dist_t, labeltype>>
171-
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
178+
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
179+
filter_func_t isIdAllowed) const {
172180
std::vector<std::pair<dist_t, labeltype>> result;
173181

174182
// here searchKnn returns the result in the order of further first
175-
auto ret = searchKnn(query_data, k);
183+
auto ret = searchKnn(query_data, k, isIdAllowed);
176184
{
177185
size_t sz = ret.size();
178186
result.resize(sz);

0 commit comments

Comments
 (0)