Skip to content

Commit 7af386e

Browse files
committed
Revert code duplication changes
1 parent 39b3cab commit 7af386e

File tree

1 file changed

+69
-40
lines changed

1 file changed

+69
-40
lines changed

python_bindings/bindings.cpp

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -79,40 +79,6 @@ inline void assert_true(bool expr, const std::string & msg) {
7979
}
8080

8181

82-
inline void set_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) {
83-
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
84-
if (buffer.ndim == 2) {
85-
*rows = buffer.shape[0];
86-
*features = buffer.shape[1];
87-
} else {
88-
*rows = 1;
89-
*features = buffer.shape[0];
90-
}
91-
}
92-
93-
94-
inline std::vector<size_t> get_input_ids(const py::object& ids_, size_t rows) {
95-
std::vector<size_t> ids;
96-
if (!ids_.is_none()) {
97-
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
98-
auto ids_numpy = items.request();
99-
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
100-
std::vector<size_t> ids1(ids_numpy.shape[0]);
101-
for (size_t i = 0; i < ids1.size(); i++) {
102-
ids1[i] = items.data()[i];
103-
}
104-
ids.swap(ids1);
105-
} else if (ids_numpy.ndim == 0 && rows == 1) {
106-
ids.push_back(*items.data());
107-
} else {
108-
throw std::runtime_error("wrong dimensionality of the labels");
109-
}
110-
}
111-
112-
return ids;
113-
}
114-
115-
11682
template<typename dist_t, typename data_t = float>
11783
class Index {
11884
public:
@@ -222,7 +188,15 @@ class Index {
222188
num_threads = num_threads_default;
223189

224190
size_t rows, features;
225-
set_input_array_shapes(buffer, &rows, &features);
191+
192+
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
193+
if (buffer.ndim == 2) {
194+
rows = buffer.shape[0];
195+
features = buffer.shape[1];
196+
} else {
197+
rows = 1;
198+
features = buffer.shape[0];
199+
}
226200

227201
if (features != dim)
228202
throw std::runtime_error("wrong dimensionality of the vectors");
@@ -232,7 +206,23 @@ class Index {
232206
num_threads = 1;
233207
}
234208

235-
std::vector<size_t> ids = get_input_ids(ids_, rows);
209+
std::vector<size_t> ids;
210+
211+
if (!ids_.is_none()) {
212+
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
213+
auto ids_numpy = items.request();
214+
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
215+
std::vector<size_t> ids1(ids_numpy.shape[0]);
216+
for (size_t i = 0; i < ids1.size(); i++) {
217+
ids1[i] = items.data()[i];
218+
}
219+
ids.swap(ids1);
220+
} else if (ids_numpy.ndim == 0 && rows == 1) {
221+
ids.push_back(*items.data());
222+
} else {
223+
throw std::runtime_error("wrong dimensionality of the labels");
224+
}
225+
}
236226

237227
{
238228
int start = 0;
@@ -571,7 +561,15 @@ class Index {
571561

572562
{
573563
py::gil_scoped_release l;
574-
set_input_array_shapes(buffer, &rows, &features);
564+
565+
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
566+
if (buffer.ndim == 2) {
567+
rows = buffer.shape[0];
568+
features = buffer.shape[1];
569+
} else {
570+
rows = 1;
571+
features = buffer.shape[0];
572+
}
575573

576574
// avoid using threads when the number of searches is small:
577575
if (rows <= num_threads * 4) {
@@ -727,12 +725,36 @@ class BFIndex {
727725
py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input);
728726
auto buffer = items.request();
729727
size_t rows, features;
730-
set_input_array_shapes(buffer, &rows, &features);
728+
729+
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
730+
if (buffer.ndim == 2) {
731+
rows = buffer.shape[0];
732+
features = buffer.shape[1];
733+
} else {
734+
rows = 1;
735+
features = buffer.shape[0];
736+
}
731737

732738
if (features != dim)
733739
throw std::runtime_error("wrong dimensionality of the vectors");
734740

735-
std::vector<size_t> ids = get_input_ids(ids_, rows);
741+
std::vector<size_t> ids;
742+
743+
if (!ids_.is_none()) {
744+
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
745+
auto ids_numpy = items.request();
746+
if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) {
747+
std::vector<size_t> ids1(ids_numpy.shape[0]);
748+
for (size_t i = 0; i < ids1.size(); i++) {
749+
ids1[i] = items.data()[i];
750+
}
751+
ids.swap(ids1);
752+
} else if (ids_numpy.ndim == 0 && rows == 1) {
753+
ids.push_back(*items.data());
754+
} else {
755+
throw std::runtime_error("wrong dimensionality of the labels");
756+
}
757+
}
736758

737759
{
738760
for (size_t row = 0; row < rows; row++) {
@@ -780,7 +802,14 @@ class BFIndex {
780802
{
781803
py::gil_scoped_release l;
782804

783-
set_input_array_shapes(buffer, &rows, &features);
805+
if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array");
806+
if (buffer.ndim == 2) {
807+
rows = buffer.shape[0];
808+
features = buffer.shape[1];
809+
} else {
810+
rows = 1;
811+
features = buffer.shape[0];
812+
}
784813

785814
data_numpy_l = new hnswlib::labeltype[rows * k];
786815
data_numpy_d = new dist_t[rows * k];

0 commit comments

Comments
 (0)