Skip to content

Commit 429b0cf

Browse files
[mlir][python] Directly implement sequence protocol on Sliceable.
* While annoying, this is the only way to get C++ exception handling out of the happy path for normal iteration. * Implements sq_length and sq_item for the sequence protocol (used for iteration, including list() construction). * Implements mp_subscript for general use (i.e. foo[1] and foo[1:1]). * For constructing a `list(op.results)`, this reduces the time from ~4-5us to ~1.5us on my machine (give or take measurement overhead) and eliminates C++ exceptions, which is a worthy goal in itself. * Compared to a baseline of similar construction of a three-integer list, which takes 450ns (might just be measuring function call overhead). * See issue discussed on the pybind side: pybind/pybind11#2842 Differential Revision: https://reviews.llvm.org/D119691
1 parent e404e22 commit 429b0cf

File tree

2 files changed

+85
-31
lines changed

2 files changed

+85
-31
lines changed

mlir/lib/Bindings/Python/PybindUtils.h

Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ struct PySinglePartStringAccumulator {
207207
/// constructs a new instance of the derived pseudo-container with the
208208
/// given slice parameters (to be forwarded to the Sliceable constructor).
209209
///
210+
/// The getNumElements() and getElement(intptr_t) callbacks must not throw.
211+
///
210212
/// A derived class may additionally define:
211213
/// - a `static void bindDerived(ClassTy &)` method to bind additional methods
212214
/// the python class.
@@ -215,49 +217,53 @@ class Sliceable {
215217
protected:
216218
using ClassTy = pybind11::class_<Derived>;
217219

220+
// Transforms `index` into a legal value to access the underlying sequence.
221+
// Returns <0 on failure.
218222
intptr_t wrapIndex(intptr_t index) {
219223
if (index < 0)
220224
index = length + index;
221-
if (index < 0 || index >= length) {
222-
throw python::SetPyError(PyExc_IndexError,
223-
"attempt to access out of bounds");
224-
}
225+
if (index < 0 || index >= length)
226+
return -1;
225227
return index;
226228
}
227229

228-
public:
229-
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
230-
: startIndex(startIndex), length(length), step(step) {
231-
assert(length >= 0 && "expected non-negative slice length");
232-
}
233-
234-
/// Returns the length of the slice.
235-
intptr_t dunderLen() const { return length; }
236-
237230
/// Returns the element at the given slice index. Supports negative indices
238-
/// by taking elements in inverse order. Throws if the index is out of bounds.
239-
ElementTy dunderGetItem(intptr_t index) {
231+
/// by taking elements in inverse order. Returns a nullptr object if out
232+
/// of bounds.
233+
pybind11::object getItem(intptr_t index) {
240234
// Negative indices mean we count from the end.
241235
index = wrapIndex(index);
236+
if (index < 0) {
237+
PyErr_SetString(PyExc_IndexError, "index out of range");
238+
return {};
239+
}
242240

243241
// Compute the linear index given the current slice properties.
244242
int linearIndex = index * step + startIndex;
245243
assert(linearIndex >= 0 &&
246244
linearIndex < static_cast<Derived *>(this)->getNumElements() &&
247245
"linear index out of bounds, the slice is ill-formed");
248-
return static_cast<Derived *>(this)->getElement(linearIndex);
246+
return pybind11::cast(
247+
static_cast<Derived *>(this)->getElement(linearIndex));
249248
}
250249

251250
/// Returns a new instance of the pseudo-container restricted to the given
252-
/// slice.
253-
Derived dunderGetItemSlice(pybind11::slice slice) {
251+
/// slice. Returns a nullptr object on failure.
252+
pybind11::object getItemSlice(PyObject *slice) {
254253
ssize_t start, stop, extraStep, sliceLength;
255-
if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) {
256-
throw python::SetPyError(PyExc_IndexError,
257-
"attempt to access out of bounds");
254+
if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep,
255+
&sliceLength) != 0) {
256+
PyErr_SetString(PyExc_IndexError, "index out of range");
257+
return {};
258258
}
259-
return static_cast<Derived *>(this)->slice(startIndex + start * step,
260-
sliceLength, step * extraStep);
259+
return pybind11::cast(static_cast<Derived *>(this)->slice(
260+
startIndex + start * step, sliceLength, step * extraStep));
261+
}
262+
263+
public:
264+
explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step)
265+
: startIndex(startIndex), length(length), step(step) {
266+
assert(length >= 0 && "expected non-negative slice length");
261267
}
262268

263269
/// Returns a new vector (mapped to Python list) containing elements from two
@@ -267,10 +273,10 @@ class Sliceable {
267273
std::vector<ElementTy> elements;
268274
elements.reserve(length + other.length);
269275
for (intptr_t i = 0; i < length; ++i) {
270-
elements.push_back(dunderGetItem(i));
276+
elements.push_back(static_cast<Derived *>(this)->getElement(i));
271277
}
272278
for (intptr_t i = 0; i < other.length; ++i) {
273-
elements.push_back(other.dunderGetItem(i));
279+
elements.push_back(static_cast<Derived *>(this)->getElement(i));
274280
}
275281
return elements;
276282
}
@@ -279,11 +285,51 @@ class Sliceable {
279285
static void bind(pybind11::module &m) {
280286
auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName,
281287
pybind11::module_local())
282-
.def("__len__", &Sliceable::dunderLen)
283-
.def("__getitem__", &Sliceable::dunderGetItem)
284-
.def("__getitem__", &Sliceable::dunderGetItemSlice)
285288
.def("__add__", &Sliceable::dunderAdd);
286289
Derived::bindDerived(clazz);
290+
291+
// Manually implement the sequence protocol via the C API. We do this
292+
// because it is approx 4x faster than via pybind11, largely because that
293+
// formulation requires a C++ exception to be thrown to detect end of
294+
// sequence.
295+
// Since we are in a C-context, any C++ exception that happens here
296+
// will terminate the program. There is nothing in this implementation
297+
// that should throw in a non-terminal way, so we forgo further
298+
// exception marshalling.
299+
// See: https://github.com/pybind/pybind11/issues/2842
300+
auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr());
301+
assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE &&
302+
"must be heap type");
303+
heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t {
304+
auto self = pybind11::cast<Derived *>(rawSelf);
305+
return self->length;
306+
};
307+
// sq_item is called as part of the sequence protocol for iteration,
308+
// list construction, etc.
309+
heap_type->as_sequence.sq_item =
310+
+[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * {
311+
auto self = pybind11::cast<Derived *>(rawSelf);
312+
return self->getItem(index).release().ptr();
313+
};
314+
// mp_subscript is used for both slices and integer lookups.
315+
heap_type->as_mapping.mp_subscript =
316+
+[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * {
317+
auto self = pybind11::cast<Derived *>(rawSelf);
318+
Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError);
319+
if (!PyErr_Occurred()) {
320+
// Integer indexing.
321+
return self->getItem(index).release().ptr();
322+
}
323+
PyErr_Clear();
324+
325+
// Assume slice-based indexing.
326+
if (PySlice_Check(rawSubscript)) {
327+
return self->getItemSlice(rawSubscript).release().ptr();
328+
}
329+
330+
PyErr_SetString(PyExc_ValueError, "expected integer or slice");
331+
return nullptr;
332+
};
287333
}
288334

289335
/// Hook for derived classes willing to bind more methods.

mlir/test/python/ir/operation.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ def run(f):
1414
return f
1515

1616

17+
def expect_index_error(callback):
18+
try:
19+
_ = callback()
20+
raise RuntimeError("Expected IndexError")
21+
except IndexError:
22+
pass
23+
24+
1725
# Verify iterator based traversal of the op/region/block hierarchy.
1826
# CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators
1927
@run
@@ -418,7 +426,9 @@ def testOperationResultList():
418426
for t in call.results.types:
419427
print(f"Result type {t}")
420428

421-
429+
# Out of range
430+
expect_index_error(lambda: call.results[3])
431+
expect_index_error(lambda: call.results[-4])
422432

423433

424434
# CHECK-LABEL: TEST: testOperationResultListSlice
@@ -470,8 +480,6 @@ def testOperationResultListSlice():
470480
print(f"Result {res.result_number}, type {res.type}")
471481

472482

473-
474-
475483
# CHECK-LABEL: TEST: testOperationAttributes
476484
@run
477485
def testOperationAttributes():

0 commit comments

Comments
 (0)