Skip to content

Commit 34fe7f1

Browse files
author
Dmitry Yashunin
committed
Fix possible multithreading issues
1 parent b440cbd commit 34fe7f1

File tree

1 file changed

+75
-41
lines changed

1 file changed

+75
-41
lines changed

hnswlib/hnswalg.h

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
3838
// Locks to prevent race condition during update/insert of an element at same time.
3939
// Note: Locks for additions can also be used to prevent this race condition
4040
// if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
41-
std::vector<std::mutex> link_list_update_locks_;
41+
mutable std::vector<std::mutex> link_list_update_locks_;
4242

4343
std::mutex global;
44-
std::mutex cur_element_count_guard_;
4544
std::vector<std::mutex> link_list_locks_;
4645

4746
tableint enterpoint_node_{0};
@@ -57,7 +56,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
5756

5857
DISTFUNC<dist_t> fstdistfunc_;
5958
void *dist_func_param_{nullptr};
60-
std::mutex label_lookup_lock;
59+
60+
mutable std::mutex label_lookup_lock; // lock for label_lookup_
6161
std::unordered_map<labeltype, tableint> label_lookup_;
6262

6363
std::default_random_engine level_generator_;
@@ -68,7 +68,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
6868

6969
bool replace_deleted_ = false;
7070

71-
std::mutex deleted_elements_lock;
71+
std::mutex deleted_elements_lock; // lock for deleted_elements
7272
std::unordered_set<tableint> deleted_elements;
7373

7474

@@ -714,14 +714,16 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
714714

715715
template<typename data_t>
716716
std::vector<data_t> getDataByLabel(labeltype label) const {
717-
tableint label_internal;
717+
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
718718
auto search = label_lookup_.find(label);
719719
if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
720720
throw std::runtime_error("Label not found");
721721
}
722-
label_internal = search->second;
723-
724-
char* data_ptrv = getDataByInternalId(label_internal);
722+
tableint internalId = search->second;
723+
lock_table.unlock();
724+
// wait for element addition or update
725+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(internalId & (max_update_element_locks - 1))]);
726+
char* data_ptrv = getDataByInternalId(internalId);
725727
size_t dim = *((size_t *) dist_func_param_);
726728
std::vector<data_t> data;
727729
data_t* data_ptr = (data_t*) data_ptrv;
@@ -737,11 +739,15 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
737739
* Marks an element with the given label deleted, does NOT really change the current graph.
738740
*/
739741
void markDelete(labeltype label) {
742+
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
740743
auto search = label_lookup_.find(label);
741744
if (search == label_lookup_.end()) {
742745
throw std::runtime_error("Label not found");
743746
}
744747
tableint internalId = search->second;
748+
lock_table.unlock();
749+
// wait for element addition or update
750+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(internalId & (max_update_element_locks - 1))]);
745751
markDeletedInternal(internalId);
746752
}
747753

@@ -756,7 +762,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
756762
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2;
757763
*ll_cur |= DELETE_MARK;
758764
num_deleted_ += 1;
759-
if (replace_deleted_) deleted_elements.insert(internalId);
765+
if (replace_deleted_) {
766+
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
767+
deleted_elements.insert(internalId);
768+
}
760769
} else {
761770
throw std::runtime_error("The requested to delete element is already deleted");
762771
}
@@ -767,25 +776,36 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
767776
* Remove the deleted mark of the node, does NOT really change the current graph.
768777
*/
769778
void unmarkDelete(labeltype label) {
779+
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
770780
auto search = label_lookup_.find(label);
771781
if (search == label_lookup_.end()) {
772782
throw std::runtime_error("Label not found");
773783
}
774784
tableint internalId = search->second;
785+
lock_table.unlock();
786+
// wait for element addition or update
787+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(internalId & (max_update_element_locks - 1))]);
775788
unmarkDeletedInternal(internalId);
776789
}
777790

778791

792+
779793
/**
780-
* Remove the deleted mark of the node.
781-
*/
794+
* Remove the deleted mark of the node.
795+
*
796+
* Note: the method is not safe to use when replacement of deleted elements is enabled
797+
* bacause elements marked as deleted can be completely removed from the index
798+
*/
782799
void unmarkDeletedInternal(tableint internalId) {
783800
assert(internalId < cur_element_count);
784801
if (isMarkedDeleted(internalId)) {
785802
unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2;
786803
*ll_cur &= ~DELETE_MARK;
787804
num_deleted_ -= 1;
788-
if (replace_deleted_) deleted_elements.erase(internalId);
805+
if (replace_deleted_) {
806+
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
807+
deleted_elements.erase(internalId);
808+
}
789809
} else {
790810
throw std::runtime_error("The requested to undelete element is not deleted");
791811
}
@@ -813,42 +833,49 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
813833

814834
/**
815835
* Adds point and replaces previously deleted point if any, updating it with new point
816-
*
817-
* If deleted point was replaced returns its label, else returns label of added point
836+
* If deleted point was replaced returns its label, else returns label of added or updated point
837+
*
838+
* Note:
839+
* Methods that can work with deleted elements unmarkDelete and addPoint are not safe to use
840+
* with this method. Because addPointToVacantPlace removes deleted elements from the index.
818841
*/
819842
labeltype addPointToVacantPlace(const void* data_point, labeltype label) {
820843
if (!replace_deleted_) {
821844
throw std::runtime_error("Can't use addPointToVacantPlace when replacement of deleted elements is disabled");
822845
}
823846

824-
std::unique_lock <std::mutex> tmp_del_el_lock(deleted_elements_lock);
825-
bool is_empty = deleted_elements.empty();
826-
tmp_del_el_lock.unlock();
827-
828-
if (is_empty) {
829-
addPoint(data_point, label);
830-
return label;
847+
// check if there is vacant place
848+
tableint internal_id_replaced;
849+
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
850+
bool is_vacant_place = !deleted_elements.empty();
851+
if (is_vacant_place) {
852+
internal_id_replaced = *deleted_elements.begin();
853+
deleted_elements.erase(internal_id_replaced);
831854
}
832-
else {
833-
tmp_del_el_lock.lock();
834-
tableint id_replace = *deleted_elements.begin();
835-
deleted_elements.erase(id_replace);
836-
tmp_del_el_lock.unlock();
837-
838-
// use link list locks to not block calls for other elements
839-
std::unique_lock <std::mutex> lock_label_update(link_list_update_locks_[(id_replace & (max_update_element_locks - 1))]);
840-
labeltype label_replace = getExternalLabel(id_replace);
841-
setExternalLabel(id_replace, label);
842-
lock_label_update.unlock();
843-
844-
std::unique_lock <std::mutex> tmp_label_lookup_lock(label_lookup_lock);
845-
label_lookup_.erase(label_replace);
846-
label_lookup_[label] = id_replace;
847-
tmp_label_lookup_lock.unlock();
855+
lock_deleted_elements.unlock();
848856

857+
// if there is no vacant place then add or update point
858+
// else add point to vacant place
859+
if (!is_vacant_place) {
849860
addPoint(data_point, label);
850-
851-
return label_replace;
861+
return label;
862+
} else {
863+
// wait for element addition or update
864+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(internal_id_replaced & (max_update_element_locks - 1))]);
865+
labeltype label_replaced = getExternalLabel(internal_id_replaced);
866+
setExternalLabel(internal_id_replaced, label);
867+
lock_el_update.unlock();
868+
869+
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
870+
label_lookup_.erase(label_replaced);
871+
label_lookup_[label] = internal_id_replaced;
872+
lock_table.unlock();
873+
874+
lock_el_update.lock();
875+
unmarkDeletedInternal(internal_id_replaced);
876+
updatePoint(data_point, internal_id_replaced, 1.0);
877+
878+
return label_replaced;
852879
}
853880
}
854881

@@ -1024,11 +1051,18 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
10241051
{
10251052
// Checking if the element with the same label already exists
10261053
// if so, updating it *instead* of creating a new element.
1027-
std::unique_lock <std::mutex> templock_curr(cur_element_count_guard_);
1054+
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
10281055
auto search = label_lookup_.find(label);
10291056
if (search != label_lookup_.end()) {
10301057
tableint existingInternalId = search->second;
1031-
templock_curr.unlock();
1058+
if (replace_deleted_) {
1059+
// wait for element addition or update
1060+
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
1061+
if (isMarkedDeleted(existingInternalId)) {
1062+
throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
1063+
}
1064+
}
1065+
lock_table.unlock();
10321066

10331067
std::unique_lock <std::mutex> lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]);
10341068

0 commit comments

Comments
 (0)