Skip to content

Commit 0dcfb91

Browse files
authored
Merge pull request #166 from nmslib/develop
Merge develop into master
2 parents bbddf19 + 38482db commit 0dcfb91

File tree

6 files changed

+99
-15
lines changed

6 files changed

+99
-15
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ Index methods:
6767

6868
* `get_ids_list()` - returns a list of all elements' ids.
6969

70+
* `get_max_elements()` - returns the current capacity of the index
7071

72+
* `get_current_count()` - returns the current number of element stored in the index
7173

7274

7375

hnswlib/bruteforce.h

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <unordered_map>
33
#include <fstream>
44
#include <mutex>
5+
#include <algorithm>
56

67
namespace hnswlib {
78
template<typename dist_t>
@@ -21,6 +22,8 @@ namespace hnswlib {
2122
dist_func_param_ = s->get_dist_func_param();
2223
size_per_element_ = data_size_ + sizeof(labeltype);
2324
data_ = (char *) malloc(maxElements * size_per_element_);
25+
if (data_ == nullptr)
26+
std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
2427
cur_element_count = 0;
2528
}
2629

@@ -40,7 +43,7 @@ namespace hnswlib {
4043

4144
std::unordered_map<labeltype,size_t > dict_external_to_internal;
4245

43-
void addPoint(void *datapoint, labeltype label) {
46+
void addPoint(const void *datapoint, labeltype label) {
4447

4548
int idx;
4649
{
@@ -84,8 +87,10 @@ namespace hnswlib {
8487
}
8588

8689

87-
std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
90+
std::priority_queue<std::pair<dist_t, labeltype >>
91+
searchKnn(const void *query_data, size_t k) const {
8892
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
93+
if (cur_element_count == 0) return topResults;
8994
for (int i = 0; i < k; i++) {
9095
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
9196
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
@@ -106,6 +111,24 @@ namespace hnswlib {
106111
return topResults;
107112
};
108113

114+
template <typename Comp>
115+
std::vector<std::pair<dist_t, labeltype>>
116+
searchKnn(const void* query_data, size_t k, Comp comp) {
117+
std::vector<std::pair<dist_t, labeltype>> result;
118+
if (cur_element_count == 0) return result;
119+
120+
auto ret = searchKnn(query_data, k);
121+
122+
while (!ret.empty()) {
123+
result.push_back(ret.top());
124+
ret.pop();
125+
}
126+
127+
std::sort(result.begin(), result.end(), comp);
128+
129+
return result;
130+
}
131+
109132
void saveIndex(const std::string &location) {
110133
std::ofstream output(location, std::ios::binary);
111134
std::streampos position;
@@ -134,6 +157,8 @@ namespace hnswlib {
134157
dist_func_param_ = s->get_dist_func_param();
135158
size_per_element_ = data_size_ + sizeof(labeltype);
136159
data_ = (char *) malloc(maxelements_ * size_per_element_);
160+
if (data_ == nullptr)
161+
std::runtime_error("Not enough memory: loadIndex failed to allocate data");
137162

138163
input.read(data_, maxelements_ * size_per_element_);
139164

hnswlib/hnswalg.h

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ namespace hnswlib {
6161
maxlevel_ = -1;
6262

6363
linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
64+
if (linkLists_ == nullptr)
65+
throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
6466
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
6567
mult_ = 1 / log(1.0 * M_);
6668
revSize_ = 1.0 / mult_;
@@ -150,7 +152,7 @@ namespace hnswlib {
150152
}
151153

152154
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
153-
searchBaseLayer(tableint ep_id, void *data_point, int layer) {
155+
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
154156
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
155157
vl_type *visited_array = vl->mass;
156158
vl_type visited_array_tag = vl->curV;
@@ -371,7 +373,7 @@ namespace hnswlib {
371373
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
372374
};
373375

374-
void mutuallyConnectNewElement(void *data_point, tableint cur_c,
376+
void mutuallyConnectNewElement(const void *data_point, tableint cur_c,
375377
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates,
376378
int level) {
377379

@@ -484,6 +486,8 @@ namespace hnswlib {
484486

485487

486488
std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
489+
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
490+
if (cur_element_count == 0) return top_candidates;
487491
tableint currObj = enterpoint_node_;
488492
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
489493

@@ -510,8 +514,6 @@ namespace hnswlib {
510514
}
511515
}
512516

513-
514-
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
515517
if (has_deletions_) {
516518
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
517519
ef_);
@@ -546,12 +548,16 @@ namespace hnswlib {
546548

547549
// Reallocate base layer
548550
char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_);
551+
if (data_level0_memory_new == nullptr)
552+
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
549553
memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_);
550554
free(data_level0_memory_);
551555
data_level0_memory_=data_level0_memory_new;
552556

553557
// Reallocate all other layers
554558
char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements);
559+
if (linkLists_new == nullptr)
560+
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
555561
memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *));
556562
free(linkLists_);
557563
linkLists_=linkLists_new;
@@ -659,6 +665,8 @@ namespace hnswlib {
659665

660666

661667
data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
668+
if (data_level0_memory_ == nullptr)
669+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
662670
input.read(data_level0_memory_, cur_element_count * size_data_per_element_);
663671

664672

@@ -675,6 +683,8 @@ namespace hnswlib {
675683

676684

677685
linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
686+
if (linkLists_ == nullptr)
687+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
678688
element_levels_ = std::vector<int>(max_elements);
679689
revSize_ = 1.0 / mult_;
680690
ef_ = 10;
@@ -689,6 +699,8 @@ namespace hnswlib {
689699
} else {
690700
element_levels_[i] = linkListSize / size_links_per_element_;
691701
linkLists_[i] = (char *) malloc(linkListSize);
702+
if (linkLists_[i] == nullptr)
703+
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
692704
input.read(linkLists_[i], linkListSize);
693705
}
694706
}
@@ -779,11 +791,11 @@ namespace hnswlib {
779791
*((unsigned short int*)(ptr))=*((unsigned short int *)&size);
780792
}
781793

782-
void addPoint(void *data_point, labeltype label) {
794+
void addPoint(const void *data_point, labeltype label) {
783795
addPoint(data_point, label,-1);
784796
}
785797

786-
tableint addPoint(void *data_point, labeltype label, int level) {
798+
tableint addPoint(const void *data_point, labeltype label, int level) {
787799
tableint cur_c = 0;
788800
{
789801
std::unique_lock <std::mutex> lock(cur_element_count_guard_);
@@ -797,6 +809,7 @@ namespace hnswlib {
797809
auto search = label_lookup_.find(label);
798810
if (search != label_lookup_.end()) {
799811
std::unique_lock <std::mutex> lock_el(link_list_locks_[search->second]);
812+
has_deletions_ = true;
800813
markDeletedInternal(search->second);
801814
}
802815
label_lookup_[label] = cur_c;
@@ -827,6 +840,8 @@ namespace hnswlib {
827840

828841
if (curlevel) {
829842
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
843+
if (linkLists_[cur_c] == nullptr)
844+
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
830845
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
831846
}
832847

@@ -895,7 +910,11 @@ namespace hnswlib {
895910
return cur_c;
896911
};
897912

898-
std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
913+
std::priority_queue<std::pair<dist_t, labeltype >>
914+
searchKnn(const void *query_data, size_t k) const {
915+
std::priority_queue<std::pair<dist_t, labeltype >> result;
916+
if (cur_element_count == 0) return result;
917+
899918
tableint currObj = enterpoint_node_;
900919
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);
901920

@@ -934,18 +953,34 @@ namespace hnswlib {
934953
currObj, query_data, std::max(ef_, k));
935954
top_candidates.swap(top_candidates1);
936955
}
937-
std::priority_queue<std::pair<dist_t, labeltype >> results;
938956
while (top_candidates.size() > k) {
939957
top_candidates.pop();
940958
}
941959
while (top_candidates.size() > 0) {
942960
std::pair<dist_t, tableint> rez = top_candidates.top();
943-
results.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
961+
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
944962
top_candidates.pop();
945963
}
946-
return results;
964+
return result;
947965
};
948966

967+
template <typename Comp>
968+
std::vector<std::pair<dist_t, labeltype>>
969+
searchKnn(const void* query_data, size_t k, Comp comp) {
970+
std::vector<std::pair<dist_t, labeltype>> result;
971+
if (cur_element_count == 0) return result;
972+
973+
auto ret = searchKnn(query_data, k);
974+
975+
while (!ret.empty()) {
976+
result.push_back(ret.top());
977+
ret.pop();
978+
}
979+
980+
std::sort(result.begin(), result.end(), comp);
981+
982+
return result;
983+
}
949984

950985
};
951986

hnswlib/hnswlib.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,21 @@
2424
#endif
2525

2626
#include <queue>
27+
#include <vector>
2728

2829
#include <string.h>
2930

3031
namespace hnswlib {
3132
typedef size_t labeltype;
3233

34+
template <typename T>
35+
class pairGreater {
36+
public:
37+
bool operator()(const T& p1, const T& p2) {
38+
return p1.first > p2.first;
39+
}
40+
};
41+
3342
template<typename T>
3443
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
3544
out.write((char *) &podRef, sizeof(T));
@@ -60,8 +69,11 @@ namespace hnswlib {
6069
template<typename dist_t>
6170
class AlgorithmInterface {
6271
public:
63-
virtual void addPoint(void *datapoint, labeltype label)=0;
72+
virtual void addPoint(const void *datapoint, labeltype label)=0;
6473
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
74+
template <typename Comp>
75+
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
76+
}
6577
virtual void saveIndex(const std::string &location)=0;
6678
virtual ~AlgorithmInterface(){
6779
}

python_bindings/bindings.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,14 @@ class Index {
359359
appr_alg->resizeIndex(new_size);
360360
}
361361

362+
size_t getMaxElements() const {
363+
return appr_alg->max_elements_;
364+
}
365+
366+
size_t getCurrentCount() const {
367+
return appr_alg->cur_element_count;
368+
}
369+
362370
std::string space_name;
363371
int dim;
364372

@@ -397,6 +405,8 @@ PYBIND11_PLUGIN(hnswlib) {
397405
.def("load_index", &Index<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0)
398406
.def("mark_deleted", &Index<float>::markDeleted, py::arg("label"))
399407
.def("resize_index", &Index<float>::resizeIndex, py::arg("new_size"))
408+
.def("get_max_elements", &Index<float>::getMaxElements)
409+
.def("get_current_count", &Index<float>::getCurrentCount)
400410
.def("__repr__",
401411
[](const Index<float> &a) {
402412
return "<HNSW-lib index>";

python_bindings/tests/bindings_test_labels.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def testRandomSelf(self):
5050
self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data1))),1.0,3)
5151

5252
# Check that the returned element data is correct:
53-
diff_with_gt_labels=np.max(np.abs(data1-items))
53+
diff_with_gt_labels=np.mean(np.abs(data1-items))
5454
self.assertAlmostEqual(diff_with_gt_labels, 0, delta = 1e-4)
5555

5656
# Serializing and deleting the index.
@@ -83,7 +83,7 @@ def testRandomSelf(self):
8383
self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))),1.0,3)
8484

8585
# Check that the returned element data is correct:
86-
diff_with_gt_labels=np.max(np.abs(data-items))
86+
diff_with_gt_labels=np.mean(np.abs(data-items))
8787
self.assertAlmostEqual(diff_with_gt_labels, 0, delta = 1e-4) # deleting index.
8888

8989
# Checking that all labels are returned correctly:

0 commit comments

Comments
 (0)