Skip to content

Commit b440cbd

Browse files
authored
Merge branch 'develop' into replace_deleted
2 parents 0f3214c + 983cea9 commit b440cbd

File tree

6 files changed

+149
-61
lines changed

6 files changed

+149
-61
lines changed

examples/searchKnnWithFilter_test.cpp

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,25 @@ namespace {
1111

1212
using idx_t = hnswlib::labeltype;
1313

14-
bool pickIdsDivisibleByThree(unsigned int label_id) {
15-
return label_id % 3 == 0;
16-
}
17-
18-
bool pickIdsDivisibleBySeven(unsigned int label_id) {
19-
return label_id % 7 == 0;
20-
}
14+
class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
15+
unsigned int divisor = 1;
16+
public:
17+
PickDivisibleIds(unsigned int divisor): divisor(divisor) {
18+
assert(divisor != 0);
19+
}
20+
bool operator()(idx_t label_id) {
21+
return label_id % divisor == 0;
22+
}
23+
};
2124

22-
bool pickNothing(unsigned int label_id) {
23-
return false;
24-
}
25+
class PickNothing: public hnswlib::BaseFilterFunctor {
26+
public:
27+
bool operator()(idx_t label_id) {
28+
return false;
29+
}
30+
};
2531

26-
template<typename filter_func_t>
27-
void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) {
32+
void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) {
2833
int d = 4;
2934
idx_t n = 100;
3035
idx_t nq = 10;
@@ -45,8 +50,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
4550
}
4651

4752
hnswlib::L2Space space(d);
48-
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
49-
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
53+
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
54+
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
5055

5156
for (size_t i = 0; i < n; ++i) {
5257
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -57,8 +62,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
5762
// test searchKnnCloserFirst of BruteforceSearch with filtering
5863
for (size_t j = 0; j < nq; ++j) {
5964
const void* p = query.data() + j * d;
60-
auto gd = alg_brute->searchKnn(p, k, filter_func);
61-
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
65+
auto gd = alg_brute->searchKnn(p, k, &filter_func);
66+
auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
6267
assert(gd.size() == res.size());
6368
size_t t = gd.size();
6469
while (!gd.empty()) {
@@ -71,8 +76,8 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
7176
// test searchKnnCloserFirst of hnsw with filtering
7277
for (size_t j = 0; j < nq; ++j) {
7378
const void* p = query.data() + j * d;
74-
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
75-
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
79+
auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
80+
auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
7681
assert(gd.size() == res.size());
7782
size_t t = gd.size();
7883
while (!gd.empty()) {
@@ -86,8 +91,7 @@ void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t labe
8691
delete alg_hnsw;
8792
}
8893

89-
template<typename filter_func_t>
90-
void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
94+
void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) {
9195
int d = 4;
9296
idx_t n = 100;
9397
idx_t nq = 10;
@@ -108,8 +112,8 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
108112
}
109113

110114
hnswlib::L2Space space(d);
111-
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float, filter_func_t>(&space, 2 * n);
112-
hnswlib::AlgorithmInterface<float, filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float, filter_func_t>(&space, 2 * n);
115+
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
116+
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
113117

114118
for (size_t i = 0; i < n; ++i) {
115119
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
@@ -120,17 +124,17 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
120124
// test searchKnnCloserFirst of BruteforceSearch with filtering
121125
for (size_t j = 0; j < nq; ++j) {
122126
const void* p = query.data() + j * d;
123-
auto gd = alg_brute->searchKnn(p, k, filter_func);
124-
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
127+
auto gd = alg_brute->searchKnn(p, k, &filter_func);
128+
auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
125129
assert(gd.size() == res.size());
126130
assert(0 == gd.size());
127131
}
128132

129133
// test searchKnnCloserFirst of hnsw with filtering
130134
for (size_t j = 0; j < nq; ++j) {
131135
const void* p = query.data() + j * d;
132-
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
133-
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
136+
auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
137+
auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
134138
assert(gd.size() == res.size());
135139
assert(0 == gd.size());
136140
}
@@ -141,13 +145,13 @@ void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
141145

142146
} // namespace
143147

144-
class CustomFilterFunctor: public hnswlib::FilterFunctor {
145-
std::unordered_set<unsigned int> allowed_values;
148+
class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
149+
std::unordered_set<idx_t> allowed_values;
146150

147151
public:
148-
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {}
152+
explicit CustomFilterFunctor(const std::unordered_set<idx_t>& values) : allowed_values(values) {}
149153

150-
bool operator()(unsigned int id) {
154+
bool operator()(idx_t id) {
151155
return allowed_values.count(id) != 0;
152156
}
153157
};
@@ -156,10 +160,13 @@ int main() {
156160
std::cout << "Testing ..." << std::endl;
157161

158162
// some of the elements are filtered
163+
PickDivisibleIds pickIdsDivisibleByThree(3);
159164
test_some_filtering(pickIdsDivisibleByThree, 3, 17);
165+
PickDivisibleIds pickIdsDivisibleBySeven(7);
160166
test_some_filtering(pickIdsDivisibleBySeven, 7, 17);
161167

162168
// all of the elements are filtered
169+
PickNothing pickNothing;
163170
test_none_filtering(pickNothing, 17);
164171

165172
// functor style which can capture context

hnswlib/bruteforce.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include <assert.h>
77

88
namespace hnswlib {
9-
template<typename dist_t, typename filter_func_t = FilterFunctor>
10-
class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
9+
template<typename dist_t>
10+
class BruteforceSearch : public AlgorithmInterface<dist_t> {
1111
public:
1212
char *data_;
1313
size_t maxelements_;
@@ -98,15 +98,14 @@ class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
9898

9999

100100
std::priority_queue<std::pair<dist_t, labeltype >>
101-
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const {
101+
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
102102
assert(k <= cur_element_count);
103103
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
104104
if (cur_element_count == 0) return topResults;
105-
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
106105
for (int i = 0; i < k; i++) {
107106
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
108107
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
109-
if (is_filter_disabled || isIdAllowed(label)) {
108+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
110109
topResults.push(std::pair<dist_t, labeltype>(dist, label));
111110
}
112111
}
@@ -115,7 +114,7 @@ class BruteforceSearch : public AlgorithmInterface<dist_t, filter_func_t> {
115114
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
116115
if (dist <= lastdist) {
117116
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
118-
if (is_filter_disabled || isIdAllowed(label)) {
117+
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
119118
topResults.push(std::pair<dist_t, labeltype>(dist, label));
120119
}
121120
if (topResults.size() > k)

hnswlib/hnswalg.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace hnswlib {
1313
typedef unsigned int tableint;
1414
typedef unsigned int linklistsizeint;
1515

16-
template<typename dist_t, typename filter_func_t = FilterFunctor>
17-
class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
16+
template<typename dist_t>
17+
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
1818
public:
1919
static const tableint max_update_element_locks = 65536;
2020
static const unsigned char DELETE_MARK = 0x01;
@@ -277,7 +277,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
277277

278278
template <bool has_deletions, bool collect_metrics = false>
279279
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
280-
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const {
280+
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const {
281281
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
282282
vl_type *visited_array = vl->mass;
283283
vl_type visited_array_tag = vl->curV;
@@ -286,8 +286,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
286286
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
287287

288288
dist_t lowerBound;
289-
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
290-
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {
289+
if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) {
291290
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
292291
lowerBound = dist;
293292
top_candidates.emplace(dist, ep_id);
@@ -345,7 +344,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
345344
_MM_HINT_T0); ////////////////////////
346345
#endif
347346

348-
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id))))
347+
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id))))
349348
top_candidates.emplace(dist, candidate_id);
350349

351350
if (top_candidates.size() > ef)
@@ -1137,7 +1136,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t, filter_func_t> {
11371136

11381137

11391138
std::priority_queue<std::pair<dist_t, labeltype >>
1140-
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const {
1139+
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
11411140
std::priority_queue<std::pair<dist_t, labeltype >> result;
11421141
if (cur_element_count == 0) return result;
11431142

hnswlib/hnswlib.h

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,11 @@ namespace hnswlib {
116116
typedef size_t labeltype;
117117

118118
// This can be extended to store state for filtering (e.g. from a std::set)
119-
struct FilterFunctor {
120-
template<class...Args>
121-
bool operator()(Args&&...) { return true; }
119+
class BaseFilterFunctor {
120+
public:
121+
virtual bool operator()(hnswlib::labeltype id) { return true; }
122122
};
123123

124-
static FilterFunctor allowAllIds;
125-
126124
template <typename T>
127125
class pairGreater {
128126
public:
@@ -157,27 +155,27 @@ class SpaceInterface {
157155
virtual ~SpaceInterface() {}
158156
};
159157

160-
template<typename dist_t, typename filter_func_t = FilterFunctor>
158+
template<typename dist_t>
161159
class AlgorithmInterface {
162160
public:
163161
virtual void addPoint(const void *datapoint, labeltype label) = 0;
164162

165163
virtual std::priority_queue<std::pair<dist_t, labeltype>>
166-
searchKnn(const void*, size_t, filter_func_t& isIdAllowed = allowAllIds) const = 0;
164+
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
167165

168166
// Return k nearest neighbor in the order of closer fist
169167
virtual std::vector<std::pair<dist_t, labeltype>>
170-
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed = allowAllIds) const;
168+
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
171169

172170
virtual void saveIndex(const std::string &location) = 0;
173171
virtual ~AlgorithmInterface(){
174172
}
175173
};
176174

177-
template<typename dist_t, typename filter_func_t>
175+
template<typename dist_t>
178176
std::vector<std::pair<dist_t, labeltype>>
179-
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
180-
filter_func_t& isIdAllowed) const {
177+
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
178+
BaseFilterFunctor* isIdAllowed) const {
181179
std::vector<std::pair<dist_t, labeltype>> result;
182180

183181
// here searchKnn returns the result in the order of further first

python_bindings/bindings.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <iostream>
2+
#include <pybind11/functional.h>
23
#include <pybind11/pybind11.h>
34
#include <pybind11/numpy.h>
45
#include <pybind11/stl.h>
@@ -79,6 +80,20 @@ inline void assert_true(bool expr, const std::string & msg) {
7980
}
8081

8182

83+
class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
84+
std::function<bool(hnswlib::labeltype)> filter;
85+
86+
public:
87+
explicit CustomFilterFunctor(const std::function<bool(hnswlib::labeltype)>& f) {
88+
filter = f;
89+
}
90+
91+
bool operator()(hnswlib::labeltype id) {
92+
return filter(id);
93+
}
94+
};
95+
96+
8297
inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
8398
if (buffer.ndim != 2 && buffer.ndim != 1) {
8499
char msg[256];
@@ -654,7 +669,11 @@ class Index {
654669
}
655670

656671

657-
py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) {
672+
py::object knnQuery_return_numpy(
673+
py::object input,
674+
size_t k = 1,
675+
int num_threads = -1,
676+
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
658677
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
659678
auto buffer = items.request();
660679
hnswlib::labeltype* data_numpy_l;
@@ -676,10 +695,13 @@ class Index {
676695
data_numpy_l = new hnswlib::labeltype[rows * k];
677696
data_numpy_d = new dist_t[rows * k];
678697

698+
CustomFilterFunctor idFilter(filter);
699+
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;
700+
679701
if (normalize == false) {
680702
ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) {
681703
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
682-
(void*)items.data(row), k);
704+
(void*)items.data(row), k, p_idFilter);
683705
if (result.size() != k)
684706
throw std::runtime_error(
685707
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
@@ -699,7 +721,7 @@ class Index {
699721
normalize_vector((float*)items.data(row), (norm_array.data() + start_idx));
700722

701723
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = appr_alg->searchKnn(
702-
(void*)(norm_array.data() + start_idx), k);
724+
(void*)(norm_array.data() + start_idx), k, p_idFilter);
703725
if (result.size() != k)
704726
throw std::runtime_error(
705727
"Cannot return the results in a contigious 2D array. Probably ef or M is too small");
@@ -866,7 +888,10 @@ class BFIndex {
866888
}
867889

868890

869-
py::object knnQuery_return_numpy(py::object input, size_t k = 1) {
891+
py::object knnQuery_return_numpy(
892+
py::object input,
893+
size_t k = 1,
894+
const std::function<bool(hnswlib::labeltype)>& filter = nullptr) {
870895
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
871896
auto buffer = items.request();
872897
hnswlib::labeltype *data_numpy_l;
@@ -880,9 +905,12 @@ class BFIndex {
880905
data_numpy_l = new hnswlib::labeltype[rows * k];
881906
data_numpy_d = new dist_t[rows * k];
882907

908+
CustomFilterFunctor idFilter(filter);
909+
CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr;
910+
883911
for (size_t row = 0; row < rows; row++) {
884912
std::priority_queue<std::pair<dist_t, hnswlib::labeltype >> result = alg->searchKnn(
885-
(void *) items.data(row), k);
913+
(void *) items.data(row), k, p_idFilter);
886914
for (int i = k - 1; i >= 0; i--) {
887915
auto &result_tuple = result.top();
888916
data_numpy_d[row * k + i] = result_tuple.first;
@@ -935,7 +963,8 @@ PYBIND11_PLUGIN(hnswlib) {
935963
&Index<float>::knnQuery_return_numpy,
936964
py::arg("data"),
937965
py::arg("k") = 1,
938-
py::arg("num_threads") = -1)
966+
py::arg("num_threads") = -1,
967+
py::arg("filter") = py::none())
939968
.def("add_items",
940969
&Index<float>::addItems,
941970
py::arg("data"),
@@ -1003,7 +1032,7 @@ PYBIND11_PLUGIN(hnswlib) {
10031032
py::class_<BFIndex<float>>(m, "BFIndex")
10041033
.def(py::init<const std::string &, const int>(), py::arg("space"), py::arg("dim"))
10051034
.def("init_index", &BFIndex<float>::init_new_index, py::arg("max_elements"))
1006-
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1)
1035+
.def("knn_query", &BFIndex<float>::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none())
10071036
.def("add_items", &BFIndex<float>::addItems, py::arg("data"), py::arg("ids") = py::none())
10081037
.def("delete_vector", &BFIndex<float>::deleteVector, py::arg("label"))
10091038
.def("save_index", &BFIndex<float>::saveIndex, py::arg("path_to_index"))

0 commit comments

Comments
 (0)