@@ -11,20 +11,25 @@ namespace {
11
11
12
12
using idx_t = hnswlib::labeltype;
13
13
14
- bool pickIdsDivisibleByThree (unsigned int label_id) {
15
- return label_id % 3 == 0 ;
16
- }
17
-
18
- bool pickIdsDivisibleBySeven (unsigned int label_id) {
19
- return label_id % 7 == 0 ;
20
- }
14
+ class PickDivisibleIds : public hnswlib ::BaseFilterFunctor {
15
+ unsigned int divisor = 1 ;
16
+ public:
17
+ PickDivisibleIds (unsigned int divisor): divisor(divisor) {
18
+ assert (divisor != 0 );
19
+ }
20
+ bool operator ()(idx_t label_id) {
21
+ return label_id % divisor == 0 ;
22
+ }
23
+ };
21
24
22
- bool pickNothing (unsigned int label_id) {
23
- return false ;
24
- }
25
+ class PickNothing : public hnswlib ::BaseFilterFunctor {
26
+ public:
27
+ bool operator ()(idx_t label_id) {
28
+ return false ;
29
+ }
30
+ };
25
31
26
- template <typename filter_func_t >
27
- void test_some_filtering (filter_func_t & filter_func, size_t div_num, size_t label_id_start) {
32
+ void test_some_filtering (hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) {
28
33
int d = 4 ;
29
34
idx_t n = 100 ;
30
35
idx_t nq = 10 ;
@@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
45
50
}
46
51
47
52
hnswlib::L2Space space (d);
48
- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_brute = new hnswlib::BruteforceSearch<float , filter_func_t >(&space, 2 * n);
49
- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_hnsw = new hnswlib::HierarchicalNSW<float , filter_func_t >(&space, 2 * n);
53
+ hnswlib::AlgorithmInterface<float >* alg_brute = new hnswlib::BruteforceSearch<float >(&space, 2 * n);
54
+ hnswlib::AlgorithmInterface<float >* alg_hnsw = new hnswlib::HierarchicalNSW<float >(&space, 2 * n);
50
55
51
56
for (size_t i = 0 ; i < n; ++i) {
52
57
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
57
62
// test searchKnnCloserFirst of BruteforceSearch with filtering
58
63
for (size_t j = 0 ; j < nq; ++j) {
59
64
const void * p = query.data () + j * d;
60
- auto gd = alg_brute->searchKnn (p, k, filter_func);
61
- auto res = alg_brute->searchKnnCloserFirst (p, k, filter_func);
65
+ auto gd = alg_brute->searchKnn (p, k, & filter_func);
66
+ auto res = alg_brute->searchKnnCloserFirst (p, k, & filter_func);
62
67
assert (gd.size () == res.size ());
63
68
size_t t = gd.size ();
64
69
while (!gd.empty ()) {
@@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
71
76
// test searchKnnCloserFirst of hnsw with filtering
72
77
for (size_t j = 0 ; j < nq; ++j) {
73
78
const void * p = query.data () + j * d;
74
- auto gd = alg_hnsw->searchKnn (p, k, filter_func);
75
- auto res = alg_hnsw->searchKnnCloserFirst (p, k, filter_func);
79
+ auto gd = alg_hnsw->searchKnn (p, k, & filter_func);
80
+ auto res = alg_hnsw->searchKnnCloserFirst (p, k, & filter_func);
76
81
assert (gd.size () == res.size ());
77
82
size_t t = gd.size ();
78
83
while (!gd.empty ()) {
@@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
86
91
delete alg_hnsw;
87
92
}
88
93
89
- template <typename filter_func_t >
90
- void test_none_filtering (filter_func_t & filter_func, size_t label_id_start) {
94
+ void test_none_filtering (hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) {
91
95
int d = 4 ;
92
96
idx_t n = 100 ;
93
97
idx_t nq = 10 ;
@@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
108
112
}
109
113
110
114
hnswlib::L2Space space (d);
111
- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_brute = new hnswlib::BruteforceSearch<float , filter_func_t >(&space, 2 * n);
112
- hnswlib::AlgorithmInterface<float , filter_func_t >* alg_hnsw = new hnswlib::HierarchicalNSW<float , filter_func_t >(&space, 2 * n);
115
+ hnswlib::AlgorithmInterface<float >* alg_brute = new hnswlib::BruteforceSearch<float >(&space, 2 * n);
116
+ hnswlib::AlgorithmInterface<float >* alg_hnsw = new hnswlib::HierarchicalNSW<float >(&space, 2 * n);
113
117
114
118
for (size_t i = 0 ; i < n; ++i) {
115
119
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -120,17 +124,17 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
120
124
// test searchKnnCloserFirst of BruteforceSearch with filtering
121
125
for (size_t j = 0 ; j < nq; ++j) {
122
126
const void * p = query.data () + j * d;
123
- auto gd = alg_brute->searchKnn (p, k, filter_func);
124
- auto res = alg_brute->searchKnnCloserFirst (p, k, filter_func);
127
+ auto gd = alg_brute->searchKnn (p, k, & filter_func);
128
+ auto res = alg_brute->searchKnnCloserFirst (p, k, & filter_func);
125
129
assert (gd.size () == res.size ());
126
130
assert (0 == gd.size ());
127
131
}
128
132
129
133
// test searchKnnCloserFirst of hnsw with filtering
130
134
for (size_t j = 0 ; j < nq; ++j) {
131
135
const void * p = query.data () + j * d;
132
- auto gd = alg_hnsw->searchKnn (p, k, filter_func);
133
- auto res = alg_hnsw->searchKnnCloserFirst (p, k, filter_func);
136
+ auto gd = alg_hnsw->searchKnn (p, k, & filter_func);
137
+ auto res = alg_hnsw->searchKnnCloserFirst (p, k, & filter_func);
134
138
assert (gd.size () == res.size ());
135
139
assert (0 == gd.size ());
136
140
}
@@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
141
145
142
146
} // namespace
143
147
144
- class CustomFilterFunctor : public hnswlib ::FilterFunctor {
145
- std::unordered_set<unsigned int > allowed_values;
148
+ class CustomFilterFunctor : public hnswlib ::BaseFilterFunctor {
149
+ std::unordered_set<idx_t > allowed_values;
146
150
147
151
public:
148
- explicit CustomFilterFunctor (const std::unordered_set<unsigned int >& values) : allowed_values(values) {}
152
+ explicit CustomFilterFunctor (const std::unordered_set<idx_t >& values) : allowed_values(values) {}
149
153
150
- bool operator ()(unsigned int id) {
154
+ bool operator ()(idx_t id) {
151
155
return allowed_values.count (id) != 0 ;
152
156
}
153
157
};
@@ -156,10 +160,13 @@ int main() {
156
160
std::cout << " Testing ..." << std::endl;
157
161
158
162
// some of the elements are filtered
163
+ PickDivisibleIds pickIdsDivisibleByThree (3 );
159
164
test_some_filtering (pickIdsDivisibleByThree, 3 , 17 );
165
+ PickDivisibleIds pickIdsDivisibleBySeven (7 );
160
166
test_some_filtering (pickIdsDivisibleBySeven, 7 , 17 );
161
167
162
168
// all of the elements are filtered
169
+ PickNothing pickNothing;
163
170
test_none_filtering (pickNothing, 17 );
164
171
165
172
// functor style which can capture context
0 commit comments