Skip to content

Filter elements with an optional filtering function #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 6, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ jobs:
run: |
cd build
./searchKnnCloserFirst_test
./searchKnnWithFilter_test
./test_updates
./test_updates update
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)
target_link_libraries(searchKnnCloserFirst_test hnswlib)

add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp)
target_link_libraries(searchKnnWithFilter_test hnswlib)

add_executable(main main.cpp sift_1b.cpp)
target_link_libraries(main hnswlib)
endif()
173 changes: 173 additions & 0 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// This is a test file for testing the filtering feature

#include "../hnswlib/hnswlib.h"

#include <assert.h>

#include <vector>
#include <iostream>

namespace
{

using idx_t = hnswlib::labeltype;

bool pickIdsDivisibleByThree(unsigned int label_id) {
return label_id % 3 == 0;
}

bool pickIdsDivisibleBySeven(unsigned int label_id) {
return label_id % 7 == 0;
}

bool pickNothing(unsigned int label_id) {
return false;
}

template<typename filter_func_t>
void test_some_filtering(filter_func_t& filter_func, size_t div_num, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; ++i) {
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
alg_brute->addPoint(data.data() + d * i, label_id_start + i);
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i);
}

// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
assert((gd.top().second % div_num) == 0);
gd.pop();
}
}

// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
assert((gd.top().second % div_num) == 0);
gd.pop();
}
}

delete alg_brute;
delete alg_hnsw;
}

template<typename filter_func_t>
void test_none_filtering(filter_func_t& filter_func, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; ++i) {
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}

hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_brute = new hnswlib::BruteforceSearch<float,filter_func_t>(&space, 2 * n);
hnswlib::AlgorithmInterface<float,filter_func_t>* alg_hnsw = new hnswlib::HierarchicalNSW<float,filter_func_t>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
alg_brute->addPoint(data.data() + d * i, label_id_start + i);
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i);
}

// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
assert(0 == gd.size());
}

// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
assert(0 == gd.size());
}

delete alg_brute;
delete alg_hnsw;
}

} // namespace

class CustomFilterFunctor: public hnswlib::FilterFunctor {
std::unordered_set<unsigned int> allowed_values;

public:
explicit CustomFilterFunctor(const std::unordered_set<unsigned int>& values) : allowed_values(values) {}

bool operator()(unsigned int id) {
return allowed_values.count(id) != 0;
}
};

int main() {
std::cout << "Testing ..." << std::endl;

// some of the elements are filtered
test_some_filtering(pickIdsDivisibleByThree, 3, 17);
test_some_filtering(pickIdsDivisibleBySeven, 7, 17);

// all of the elements are filtered
test_none_filtering(pickNothing, 17);

// functor style which can capture context
CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65});
test_some_filtering(pickIdsDivisibleByThirteen, 13, 21);

std::cout << "Test ok" << std::endl;

return 0;
}
25 changes: 16 additions & 9 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include <algorithm>

namespace hnswlib {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename filter_func_t=FilterFunctor>
class BruteforceSearch : public AlgorithmInterface<dist_t,filter_func_t> {
public:
BruteforceSearch(SpaceInterface <dist_t> *s) : data_(nullptr), maxelements_(0),
cur_element_count(0), size_per_element_(0), data_size_(0),
Expand Down Expand Up @@ -92,23 +92,30 @@ namespace hnswlib {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const {
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add the same flag as in hnswlib/hnswalg.h and check of k

assert(k <= cur_element_count);
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;

if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
Copy link
Contributor

@dyashuni dyashuni Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_filter_disabled || isIdAllowed(label)) {

topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
dist_t lastdist = topResults.top().first;
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_filter_disabled || isIdAllowed(label)) {

topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
topResults.pop();
lastdist = topResults.top().first;

if (!topResults.empty()) {
lastdist = topResults.top().first;
}
}

}
Expand Down
17 changes: 9 additions & 8 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename filter_func_t=FilterFunctor>
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> {
public:
static const tableint max_update_element_locks = 65536;
HierarchicalNSW(SpaceInterface<dist_t> *s) {
Expand Down Expand Up @@ -238,7 +238,7 @@ namespace hnswlib {

template <bool has_deletions, bool collect_metrics=false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t& isIdAllowed) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand All @@ -247,7 +247,8 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:

bool is_filter_disabled = isIdAllowed == allowAllIds;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {

if (!has_deletions || !isMarkedDeleted(ep_id)) {
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
Expand Down Expand Up @@ -307,7 +308,7 @@ namespace hnswlib {
_MM_HINT_T0);////////////////////////
#endif

if (!has_deletions || !isMarkedDeleted(candidate_id))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(candidate_id))))
top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
Expand Down Expand Up @@ -1111,7 +1112,7 @@ namespace hnswlib {
};

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down Expand Up @@ -1148,11 +1149,11 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (num_deleted_) {
top_candidates=searchBaseLayerST<true,true>(
currObj, query_data, std::max(ef_, k));
currObj, query_data, std::max(ef_, k), isIdAllowed);
}
else{
top_candidates=searchBaseLayerST<false,true>(
currObj, query_data, std::max(ef_, k));
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

while (top_candidates.size() > k) {
Expand Down
24 changes: 17 additions & 7 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ static bool AVX512Capable() {
namespace hnswlib {
typedef size_t labeltype;

// This can be extended to store state for filtering (e.g. from a std::set)
struct FilterFunctor {
template<class...Args>
bool operator()(Args&&...) { return true; }
};

static FilterFunctor allowAllIds;

template <typename T>
class pairGreater {
public:
Expand All @@ -137,7 +145,6 @@ namespace hnswlib {
template<typename MTYPE>
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);


template<typename MTYPE>
class SpaceInterface {
public:
Expand All @@ -151,28 +158,31 @@ namespace hnswlib {
virtual ~SpaceInterface() {}
};

template<typename dist_t>
template<typename dist_t, typename filter_func_t=FilterFunctor>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void*, size_t, filter_func_t& isIdAllowed=allowAllIds) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k) const;
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const;

virtual void saveIndex(const std::string &location)=0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t>
template<typename dist_t, typename filter_func_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
filter_func_t& isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k);
auto ret = searchKnn(query_data, k, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
Expand Down