@@ -61,6 +61,8 @@ namespace hnswlib {
61
61
maxlevel_ = -1 ;
62
62
63
63
linkLists_ = (char **) malloc (sizeof (void *) * max_elements_);
64
+ if (linkLists_ == nullptr )
65
+ throw std::runtime_error (" Not enough memory: HierarchicalNSW failed to allocate linklists" );
64
66
size_links_per_element_ = maxM_ * sizeof (tableint) + sizeof (linklistsizeint);
65
67
mult_ = 1 / log (1.0 * M_);
66
68
revSize_ = 1.0 / mult_;
@@ -150,7 +152,7 @@ namespace hnswlib {
150
152
}
151
153
152
154
std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst>
153
- searchBaseLayer (tableint ep_id, void *data_point, int layer) {
155
+ searchBaseLayer (tableint ep_id, const void *data_point, int layer) {
154
156
VisitedList *vl = visited_list_pool_->getFreeVisitedList ();
155
157
vl_type *visited_array = vl->mass ;
156
158
vl_type visited_array_tag = vl->curV ;
@@ -371,7 +373,7 @@ namespace hnswlib {
371
373
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1 ) * size_links_per_element_);
372
374
};
373
375
374
- void mutuallyConnectNewElement (void *data_point, tableint cur_c,
376
+ void mutuallyConnectNewElement (const void *data_point, tableint cur_c,
375
377
std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> top_candidates,
376
378
int level) {
377
379
@@ -484,6 +486,8 @@ namespace hnswlib {
484
486
485
487
486
488
std::priority_queue<std::pair<dist_t , tableint>> searchKnnInternal (void *query_data, int k) {
489
+ std::priority_queue<std::pair<dist_t , tableint >> top_candidates;
490
+ if (cur_element_count == 0 ) return top_candidates;
487
491
tableint currObj = enterpoint_node_;
488
492
dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
489
493
@@ -510,8 +514,6 @@ namespace hnswlib {
510
514
}
511
515
}
512
516
513
-
514
- std::priority_queue<std::pair<dist_t , tableint >> top_candidates;
515
517
if (has_deletions_) {
516
518
std::priority_queue<std::pair<dist_t , tableint >> top_candidates1=searchBaseLayerST<true >(currObj, query_data,
517
519
ef_);
@@ -546,12 +548,16 @@ namespace hnswlib {
546
548
547
549
// Reallocate base layer
548
550
char * data_level0_memory_new = (char *) malloc (new_max_elements * size_data_per_element_);
551
+ if (data_level0_memory_new == nullptr )
552
+ throw std::runtime_error (" Not enough memory: resizeIndex failed to allocate base layer" );
549
553
memcpy (data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_);
550
554
free (data_level0_memory_);
551
555
data_level0_memory_=data_level0_memory_new;
552
556
553
557
// Reallocate all other layers
554
558
char ** linkLists_new = (char **) malloc (sizeof (void *) * new_max_elements);
559
+ if (linkLists_new == nullptr )
560
+ throw std::runtime_error (" Not enough memory: resizeIndex failed to allocate other layers" );
555
561
memcpy (linkLists_new, linkLists_,cur_element_count * sizeof (void *));
556
562
free (linkLists_);
557
563
linkLists_=linkLists_new;
@@ -659,6 +665,8 @@ namespace hnswlib {
659
665
660
666
661
667
data_level0_memory_ = (char *) malloc (max_elements * size_data_per_element_);
668
+ if (data_level0_memory_ == nullptr )
669
+ throw std::runtime_error (" Not enough memory: loadIndex failed to allocate level0" );
662
670
input.read (data_level0_memory_, cur_element_count * size_data_per_element_);
663
671
664
672
@@ -675,6 +683,8 @@ namespace hnswlib {
675
683
676
684
677
685
linkLists_ = (char **) malloc (sizeof (void *) * max_elements);
686
+ if (linkLists_ == nullptr )
687
+ throw std::runtime_error (" Not enough memory: loadIndex failed to allocate linklists" );
678
688
element_levels_ = std::vector<int >(max_elements);
679
689
revSize_ = 1.0 / mult_;
680
690
ef_ = 10 ;
@@ -689,6 +699,8 @@ namespace hnswlib {
689
699
} else {
690
700
element_levels_[i] = linkListSize / size_links_per_element_;
691
701
linkLists_[i] = (char *) malloc (linkListSize);
702
+ if (linkLists_[i] == nullptr )
703
+ throw std::runtime_error (" Not enough memory: loadIndex failed to allocate linklist" );
692
704
input.read (linkLists_[i], linkListSize);
693
705
}
694
706
}
@@ -779,11 +791,11 @@ namespace hnswlib {
779
791
*((unsigned short int *)(ptr))=*((unsigned short int *)&size);
780
792
}
781
793
782
- void addPoint (void *data_point, labeltype label) {
794
+ void addPoint (const void *data_point, labeltype label) {
783
795
addPoint (data_point, label,-1 );
784
796
}
785
797
786
- tableint addPoint (void *data_point, labeltype label, int level) {
798
+ tableint addPoint (const void *data_point, labeltype label, int level) {
787
799
tableint cur_c = 0 ;
788
800
{
789
801
std::unique_lock <std::mutex> lock (cur_element_count_guard_);
@@ -797,6 +809,7 @@ namespace hnswlib {
797
809
auto search = label_lookup_.find (label);
798
810
if (search != label_lookup_.end ()) {
799
811
std::unique_lock <std::mutex> lock_el (link_list_locks_[search->second ]);
812
+ has_deletions_ = true ;
800
813
markDeletedInternal (search->second );
801
814
}
802
815
label_lookup_[label] = cur_c;
@@ -827,6 +840,8 @@ namespace hnswlib {
827
840
828
841
if (curlevel) {
829
842
linkLists_[cur_c] = (char *) malloc (size_links_per_element_ * curlevel + 1 );
843
+ if (linkLists_[cur_c] == nullptr )
844
+ throw std::runtime_error (" Not enough memory: addPoint failed to allocate linklist" );
830
845
memset (linkLists_[cur_c], 0 , size_links_per_element_ * curlevel + 1 );
831
846
}
832
847
@@ -895,7 +910,11 @@ namespace hnswlib {
895
910
return cur_c;
896
911
};
897
912
898
- std::priority_queue<std::pair<dist_t , labeltype >> searchKnn (const void *query_data, size_t k) const {
913
+ std::priority_queue<std::pair<dist_t , labeltype >>
914
+ searchKnn (const void *query_data, size_t k) const {
915
+ std::priority_queue<std::pair<dist_t , labeltype >> result;
916
+ if (cur_element_count == 0 ) return result;
917
+
899
918
tableint currObj = enterpoint_node_;
900
919
dist_t curdist = fstdistfunc_ (query_data, getDataByInternalId (enterpoint_node_), dist_func_param_);
901
920
@@ -934,18 +953,34 @@ namespace hnswlib {
934
953
currObj, query_data, std::max (ef_, k));
935
954
top_candidates.swap (top_candidates1);
936
955
}
937
- std::priority_queue<std::pair<dist_t , labeltype >> results;
938
956
while (top_candidates.size () > k) {
939
957
top_candidates.pop ();
940
958
}
941
959
while (top_candidates.size () > 0 ) {
942
960
std::pair<dist_t , tableint> rez = top_candidates.top ();
943
- results .push (std::pair<dist_t , labeltype>(rez.first , getExternalLabel (rez.second )));
961
+ result .push (std::pair<dist_t , labeltype>(rez.first , getExternalLabel (rez.second )));
944
962
top_candidates.pop ();
945
963
}
946
- return results ;
964
+ return result ;
947
965
};
948
966
967
+ template <typename Comp>
968
+ std::vector<std::pair<dist_t , labeltype>>
969
+ searchKnn (const void * query_data, size_t k, Comp comp) {
970
+ std::vector<std::pair<dist_t , labeltype>> result;
971
+ if (cur_element_count == 0 ) return result;
972
+
973
+ auto ret = searchKnn (query_data, k);
974
+
975
+ while (!ret.empty ()) {
976
+ result.push_back (ret.top ());
977
+ ret.pop ();
978
+ }
979
+
980
+ std::sort (result.begin (), result.end (), comp);
981
+
982
+ return result;
983
+ }
949
984
950
985
};
951
986
0 commit comments