@@ -79,40 +79,6 @@ inline void assert_true(bool expr, const std::string & msg) {
79
79
}
80
80
81
81
82
- inline void set_input_array_shapes (const py::buffer_info& buffer, size_t * rows, size_t * features) {
83
- if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
84
- if (buffer.ndim == 2 ) {
85
- *rows = buffer.shape [0 ];
86
- *features = buffer.shape [1 ];
87
- } else {
88
- *rows = 1 ;
89
- *features = buffer.shape [0 ];
90
- }
91
- }
92
-
93
-
94
- inline std::vector<size_t > get_input_ids (const py::object& ids_, size_t rows) {
95
- std::vector<size_t > ids;
96
- if (!ids_.is_none ()) {
97
- py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
98
- auto ids_numpy = items.request ();
99
- if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
100
- std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
101
- for (size_t i = 0 ; i < ids1.size (); i++) {
102
- ids1[i] = items.data ()[i];
103
- }
104
- ids.swap (ids1);
105
- } else if (ids_numpy.ndim == 0 && rows == 1 ) {
106
- ids.push_back (*items.data ());
107
- } else {
108
- throw std::runtime_error (" wrong dimensionality of the labels" );
109
- }
110
- }
111
-
112
- return ids;
113
- }
114
-
115
-
116
82
template <typename dist_t , typename data_t = float >
117
83
class Index {
118
84
public:
@@ -222,7 +188,15 @@ class Index {
222
188
num_threads = num_threads_default;
223
189
224
190
size_t rows, features;
225
- set_input_array_shapes (buffer, &rows, &features);
191
+
192
+ if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
193
+ if (buffer.ndim == 2 ) {
194
+ rows = buffer.shape [0 ];
195
+ features = buffer.shape [1 ];
196
+ } else {
197
+ rows = 1 ;
198
+ features = buffer.shape [0 ];
199
+ }
226
200
227
201
if (features != dim)
228
202
throw std::runtime_error (" wrong dimensionality of the vectors" );
@@ -232,7 +206,23 @@ class Index {
232
206
num_threads = 1 ;
233
207
}
234
208
235
- std::vector<size_t > ids = get_input_ids (ids_, rows);
209
+ std::vector<size_t > ids;
210
+
211
+ if (!ids_.is_none ()) {
212
+ py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
213
+ auto ids_numpy = items.request ();
214
+ if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
215
+ std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
216
+ for (size_t i = 0 ; i < ids1.size (); i++) {
217
+ ids1[i] = items.data ()[i];
218
+ }
219
+ ids.swap (ids1);
220
+ } else if (ids_numpy.ndim == 0 && rows == 1 ) {
221
+ ids.push_back (*items.data ());
222
+ } else {
223
+ throw std::runtime_error (" wrong dimensionality of the labels" );
224
+ }
225
+ }
236
226
237
227
{
238
228
int start = 0 ;
@@ -571,7 +561,15 @@ class Index {
571
561
572
562
{
573
563
py::gil_scoped_release l;
574
- set_input_array_shapes (buffer, &rows, &features);
564
+
565
+ if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
566
+ if (buffer.ndim == 2 ) {
567
+ rows = buffer.shape [0 ];
568
+ features = buffer.shape [1 ];
569
+ } else {
570
+ rows = 1 ;
571
+ features = buffer.shape [0 ];
572
+ }
575
573
576
574
// avoid using threads when the number of searches is small:
577
575
if (rows <= num_threads * 4 ) {
@@ -727,12 +725,36 @@ class BFIndex {
727
725
py::array_t < dist_t , py::array::c_style | py::array::forcecast > items (input);
728
726
auto buffer = items.request ();
729
727
size_t rows, features;
730
- set_input_array_shapes (buffer, &rows, &features);
728
+
729
+ if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
730
+ if (buffer.ndim == 2 ) {
731
+ rows = buffer.shape [0 ];
732
+ features = buffer.shape [1 ];
733
+ } else {
734
+ rows = 1 ;
735
+ features = buffer.shape [0 ];
736
+ }
731
737
732
738
if (features != dim)
733
739
throw std::runtime_error (" wrong dimensionality of the vectors" );
734
740
735
- std::vector<size_t > ids = get_input_ids (ids_, rows);
741
+ std::vector<size_t > ids;
742
+
743
+ if (!ids_.is_none ()) {
744
+ py::array_t < size_t , py::array::c_style | py::array::forcecast > items (ids_);
745
+ auto ids_numpy = items.request ();
746
+ if (ids_numpy.ndim == 1 && ids_numpy.shape [0 ] == rows) {
747
+ std::vector<size_t > ids1 (ids_numpy.shape [0 ]);
748
+ for (size_t i = 0 ; i < ids1.size (); i++) {
749
+ ids1[i] = items.data ()[i];
750
+ }
751
+ ids.swap (ids1);
752
+ } else if (ids_numpy.ndim == 0 && rows == 1 ) {
753
+ ids.push_back (*items.data ());
754
+ } else {
755
+ throw std::runtime_error (" wrong dimensionality of the labels" );
756
+ }
757
+ }
736
758
737
759
{
738
760
for (size_t row = 0 ; row < rows; row++) {
@@ -780,7 +802,14 @@ class BFIndex {
780
802
{
781
803
py::gil_scoped_release l;
782
804
783
- set_input_array_shapes (buffer, &rows, &features);
805
+ if (buffer.ndim != 2 && buffer.ndim != 1 ) throw std::runtime_error (" data must be a 1d/2d array" );
806
+ if (buffer.ndim == 2 ) {
807
+ rows = buffer.shape [0 ];
808
+ features = buffer.shape [1 ];
809
+ } else {
810
+ rows = 1 ;
811
+ features = buffer.shape [0 ];
812
+ }
784
813
785
814
data_numpy_l = new hnswlib::labeltype[rows * k];
786
815
data_numpy_d = new dist_t [rows * k];
0 commit comments