Skip to content

Commit dff787e

Browse files
authored
Merge pull request #362 from psobot/psobot/fix-crash-on-get-items
Throw exception instead of segfaulting when passing a scalar to get_items.
2 parents 41998c9 + 18fe0c7 commit dff787e

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

python_bindings/bindings.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,16 @@ class Index {
260260
if (!ids_.is_none()) {
261261
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
262262
auto ids_numpy = items.request();
263-
std::vector<size_t> ids1(ids_numpy.shape[0]);
264-
for (size_t i = 0; i < ids1.size(); i++) {
265-
ids1[i] = items.data()[i];
263+
264+
if (ids_numpy.ndim == 0) {
265+
throw std::invalid_argument("get_items accepts a list of indices and returns a list of vectors");
266+
} else {
267+
std::vector<size_t> ids1(ids_numpy.shape[0]);
268+
for (size_t i = 0; i < ids1.size(); i++) {
269+
ids1[i] = items.data()[i];
270+
}
271+
ids.swap(ids1);
266272
}
267-
ids.swap(ids1);
268273
}
269274

270275
std::vector<std::vector<data_t>> data;

python_bindings/tests/bindings_test_getdata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def testGettingItems(self):
4141
print("Adding all elements (%d)" % (len(data)))
4242
p.add_items(data, labels)
4343

44+
# Getting data by label should raise an exception if a scalar is passed:
45+
self.assertRaises(ValueError, lambda: p.get_items(labels[0]))
46+
4447
# After adding them, all labels should be retrievable
4548
returned_items = p.get_items(labels)
4649
self.assertSequenceEqual(data.tolist(), returned_items)

0 commit comments

Comments
 (0)