-
Notifications
You must be signed in to change notification settings - Fork 711
Filter elements with an optional filtering function #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
765c4ab
ad3440c
4f6dcc3
1c833a7
aaee13a
b87f623
f0dedf3
de22860
e4705fd
7f419ea
1fe7baf
c9897b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
// This is a test file for testing the filtering feature | ||
|
||
#include "../hnswlib/hnswlib.h" | ||
|
||
#include <assert.h> | ||
|
||
#include <vector> | ||
#include <iostream> | ||
|
||
namespace | ||
{ | ||
|
||
using idx_t = hnswlib::labeltype; | ||
|
||
bool pickIdsDivisibleByThree(unsigned int label_id) { | ||
return label_id % 3 == 0; | ||
} | ||
|
||
bool pickIdsDivisibleBySeven(unsigned int label_id) { | ||
return label_id % 7 == 0; | ||
} | ||
|
||
bool pickNothing(unsigned int label_id) { | ||
return false; | ||
} | ||
|
||
template<typename filter_func_t> | ||
void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) { | ||
int d = 4; | ||
idx_t n = 100; | ||
idx_t nq = 10; | ||
size_t k = 10; | ||
|
||
std::vector<float> data(n * d); | ||
std::vector<float> query(nq * d); | ||
|
||
std::mt19937 rng; | ||
rng.seed(47); | ||
std::uniform_real_distribution<> distrib; | ||
|
||
for (idx_t i = 0; i < n * d; ++i) { | ||
data[i] = distrib(rng); | ||
} | ||
for (idx_t i = 0; i < nq * d; ++i) { | ||
query[i] = distrib(rng); | ||
} | ||
|
||
hnswlib::L2Space space(d); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n); | ||
|
||
for (size_t i = 0; i < n; ++i) { | ||
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs | ||
alg_brute->addPoint(data.data() + d * i, label_id_start + i); | ||
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); | ||
} | ||
|
||
// test searchKnnCloserFirst of BruteforceSearch with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_brute->searchKnn(p, k, filter_func); | ||
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
size_t t = gd.size(); | ||
while (!gd.empty()) { | ||
assert(gd.top() == res[--t]); | ||
assert((gd.top().second % div_num) == 0); | ||
gd.pop(); | ||
} | ||
} | ||
|
||
// test searchKnnCloserFirst of hnsw with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_hnsw->searchKnn(p, k, filter_func); | ||
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
size_t t = gd.size(); | ||
while (!gd.empty()) { | ||
assert(gd.top() == res[--t]); | ||
assert((gd.top().second % div_num) == 0); | ||
gd.pop(); | ||
} | ||
} | ||
|
||
delete alg_brute; | ||
delete alg_hnsw; | ||
} | ||
|
||
template<typename filter_func_t> | ||
void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) { | ||
int d = 4; | ||
idx_t n = 100; | ||
idx_t nq = 10; | ||
size_t k = 10; | ||
|
||
std::vector<float> data(n * d); | ||
std::vector<float> query(nq * d); | ||
|
||
std::mt19937 rng; | ||
rng.seed(47); | ||
std::uniform_real_distribution<> distrib; | ||
|
||
for (idx_t i = 0; i < n * d; ++i) { | ||
data[i] = distrib(rng); | ||
} | ||
for (idx_t i = 0; i < nq * d; ++i) { | ||
query[i] = distrib(rng); | ||
} | ||
|
||
hnswlib::L2Space space(d); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n); | ||
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n); | ||
|
||
for (size_t i = 0; i < n; ++i) { | ||
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs | ||
alg_brute->addPoint(data.data() + d * i, label_id_start + i); | ||
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); | ||
} | ||
|
||
// test searchKnnCloserFirst of BruteforceSearch with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_brute->searchKnn(p, k, filter_func); | ||
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
assert(0 == gd.size()); | ||
} | ||
|
||
// test searchKnnCloserFirst of hnsw with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_hnsw->searchKnn(p, k, filter_func); | ||
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
assert(0 == gd.size()); | ||
} | ||
|
||
delete alg_brute; | ||
delete alg_hnsw; | ||
} | ||
|
||
} // namespace | ||
|
||
class CustomFilterFunctor: public hnswlib::FilterFunctor { | ||
std::unordered_set<unsigned int> allowed_values; | ||
|
||
public: | ||
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {} | ||
|
||
bool operator()(unsigned int id) { | ||
return allowed_values.count(id) != 0; | ||
} | ||
}; | ||
|
||
int main() { | ||
std::cout << "Testing ..." << std::endl; | ||
|
||
// some of the elements are filtered | ||
test_some_filtering(pickIdsDivisibleByThree, 3, 17); | ||
test_some_filtering(pickIdsDivisibleBySeven, 7, 17); | ||
|
||
// all of the elements are filtered | ||
test_none_filtering(pickNothing, 17); | ||
|
||
// functor style which can capture context | ||
CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65}); | ||
test_some_filtering(pickIdsDivisibleByThirteen, 13, 21); | ||
|
||
std::cout << "Test ok" << std::endl; | ||
|
||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,8 @@ | |
#include <algorithm> | ||
|
||
namespace hnswlib { | ||
template<typename dist_t> | ||
class BruteforceSearch : public AlgorithmInterface<dist_t> { | ||
template<typename dist_t, typename filter_func_t=FilterFunctor> | ||
class BruteforceSearch : public AlgorithmInterface<dist_t,filter_func_t> { | ||
public: | ||
BruteforceSearch(SpaceInterface <dist_t> *s) : data_(nullptr), maxelements_(0), | ||
cur_element_count(0), size_per_element_(0), data_size_(0), | ||
|
@@ -92,23 +92,30 @@ namespace hnswlib { | |
|
||
|
||
std::priority_queue<std::pair<dist_t, labeltype >> | ||
searchKnn(const void *query_data, size_t k) const { | ||
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { | ||
std::priority_queue<std::pair<dist_t, labeltype >> topResults; | ||
if (cur_element_count == 0) return topResults; | ||
for (int i = 0; i < k; i++) { | ||
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); | ||
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i + | ||
data_size_)))); | ||
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
topResults.push(std::pair<dist_t, labeltype>(dist, label)); | ||
} | ||
} | ||
dist_t lastdist = topResults.top().first; | ||
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first; | ||
for (int i = k; i < cur_element_count; i++) { | ||
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); | ||
if (dist <= lastdist) { | ||
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i + | ||
data_size_)))); | ||
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
topResults.push(std::pair<dist_t, labeltype>(dist, label)); | ||
} | ||
if (topResults.size() > k) | ||
topResults.pop(); | ||
lastdist = topResults.top().first; | ||
|
||
if (!topResults.empty()) { | ||
lastdist = topResults.top().first; | ||
} | ||
} | ||
|
||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,8 @@ namespace hnswlib { | |
typedef unsigned int tableint; | ||
typedef unsigned int linklistsizeint; | ||
|
||
template<typename dist_t> | ||
class HierarchicalNSW : public AlgorithmInterface<dist_t> { | ||
template<typename dist_t, typename filter_func_t=FilterFunctor> | ||
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> { | ||
public: | ||
static const tableint max_update_element_locks = 65536; | ||
HierarchicalNSW(SpaceInterface<dist_t> *s) { | ||
|
@@ -238,7 +238,7 @@ namespace hnswlib { | |
|
||
template <bool has_deletions, bool collect_metrics=false> | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const { | ||
VisitedList *vl = visited_list_pool_->getFreeVisitedList(); | ||
vl_type *visited_array = vl->mass; | ||
vl_type visited_array_tag = vl->curV; | ||
|
@@ -247,7 +247,8 @@ namespace hnswlib { | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; | ||
|
||
dist_t lowerBound; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:
|
||
if (!has_deletions || !isMarkedDeleted(ep_id)) { | ||
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value; | ||
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) { | ||
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); | ||
lowerBound = dist; | ||
top_candidates.emplace(dist, ep_id); | ||
|
@@ -307,7 +308,7 @@ namespace hnswlib { | |
_MM_HINT_T0);//////////////////////// | ||
#endif | ||
|
||
if (!has_deletions || !isMarkedDeleted(candidate_id)) | ||
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id)))) | ||
top_candidates.emplace(dist, candidate_id); | ||
|
||
if (top_candidates.size() > ef) | ||
|
@@ -1111,7 +1112,7 @@ namespace hnswlib { | |
}; | ||
|
||
std::priority_queue<std::pair<dist_t, labeltype >> | ||
searchKnn(const void *query_data, size_t k) const { | ||
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { | ||
std::priority_queue<std::pair<dist_t, labeltype >> result; | ||
if (cur_element_count == 0) return result; | ||
|
||
|
@@ -1148,11 +1149,11 @@ namespace hnswlib { | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; | ||
if (num_deleted_) { | ||
top_candidates=searchBaseLayerST<true,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
currObj, query_data, std::max(ef_, k), isIdAllowed); | ||
} | ||
else{ | ||
top_candidates=searchBaseLayerST<false,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
currObj, query_data, std::max(ef_, k), isIdAllowed); | ||
} | ||
|
||
while (top_candidates.size() > k) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add the same flag as in
hnswlib/hnswalg.h
and check ofk