Skip to content

Commit a7ac3cf

Browse files
committed
use detection idiom inside of getItem instead of virtual member
1 parent 69ae05c commit a7ac3cf

File tree

5 files changed

+21
-26
lines changed

5 files changed

+21
-26
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,13 +2218,6 @@ class PyBlockArgumentList
22182218
step),
22192219
operation(std::move(operation)), block(block) {}
22202220

2221-
pybind11::object getItem(intptr_t index) override {
2222-
auto item = this->SliceableT::getItem(index);
2223-
if (item.ptr() != nullptr)
2224-
return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
2225-
return item;
2226-
}
2227-
22282221
static void bindDerived(ClassTy &c) {
22292222
c.def_property_readonly("types", [](PyBlockArgumentList &self) {
22302223
return getValueTypes(self, self.operation->getContext());
@@ -2274,13 +2267,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
22742267
step),
22752268
operation(operation) {}
22762269

2277-
pybind11::object getItem(intptr_t index) override {
2278-
auto item = this->SliceableT::getItem(index);
2279-
if (item.ptr() != nullptr)
2280-
return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
2281-
return item;
2282-
}
2283-
22842270
void dunderSetItem(intptr_t index, PyValue value) {
22852271
index = wrapIndex(index);
22862272
mlirOperationSetOperand(operation->get(), index, value.get());
@@ -2337,13 +2323,6 @@ class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
23372323
step),
23382324
operation(std::move(operation)) {}
23392325

2340-
pybind11::object getItem(intptr_t index) override {
2341-
auto item = this->SliceableT::getItem(index);
2342-
if (item.ptr() != nullptr)
2343-
return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)();
2344-
return item;
2345-
}
2346-
23472326
static void bindDerived(ClassTy &c) {
23482327
c.def_property_readonly("types", [](PyOpResultList &self) {
23492328
return getValueTypes(self, self.operation->getContext());

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,7 @@ class PyConcreteAttribute : public BaseTy {
11101110
/// bindings so such operation always exists).
11111111
class PyValue {
11121112
public:
1113+
virtual ~PyValue() = default;
11131114
PyValue(PyOperationRef parentOperation, MlirValue value)
11141115
: parentOperation(std::move(parentOperation)), value(value) {}
11151116
operator MlirValue() const { return value; }
@@ -1122,7 +1123,7 @@ class PyValue {
11221123
/// Gets a capsule wrapping the void* within the MlirValue.
11231124
pybind11::object getCapsule();
11241125

1125-
virtual pybind11::object maybeDownCast();
1126+
pybind11::object maybeDownCast();
11261127

11271128
/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
11281129
/// the underlying MlirValue is still tied to the owning operation.

mlir/lib/Bindings/Python/PybindUtils.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
1111

1212
#include "mlir-c/Support.h"
13+
#include "llvm/ADT/STLExtras.h"
1314
#include "llvm/ADT/Twine.h"
1415
#include "llvm/Support/DataTypes.h"
1516

@@ -228,19 +229,29 @@ class Sliceable {
228229
return linearIndex;
229230
}
230231

232+
/// Trait to check if T provides a `maybeDownCast` method.
233+
/// Note, you need the & to detect inherited members.
234+
template <typename T, typename... Args>
235+
using has_maybe_downcast = decltype(&T::maybeDownCast);
236+
231237
/// Returns the element at the given slice index. Supports negative indices
232238
/// by taking elements in inverse order. Returns a nullptr object if out
233239
/// of bounds.
234-
virtual pybind11::object getItem(intptr_t index) {
240+
pybind11::object getItem(intptr_t index) {
235241
// Negative indices mean we count from the end.
236242
index = wrapIndex(index);
237243
if (index < 0) {
238244
PyErr_SetString(PyExc_IndexError, "index out of range");
239245
return {};
240246
}
241247

242-
return pybind11::cast(
243-
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
248+
if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
249+
return static_cast<Derived *>(this)
250+
->getRawElement(linearizeIndex(index))
251+
.maybeDownCast();
252+
else
253+
return pybind11::cast(
254+
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
244255
}
245256

246257
/// Returns a new instance of the pseudo-container restricted to the given

mlir/test/python/dialects/arith_dialect.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ def __str__(self):
7474

7575
with InsertionPoint(module.body):
7676
a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
77+
# CHECK: ArithValue(%cst = arith.constant 4.240
78+
print(a)
79+
7780
b = a + a
7881
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
7982
print(b)

mlir/test/python/ir/value.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,8 @@ def dont_cast_int_shouldnt_register(v):
370370
print(e)
371371

372372
@register_value_caster(IntegerType.static_typeid, replace=True)
373-
def dont_cast_int(v) -> Value:
373+
def dont_cast_int(v) -> OpResult:
374+
assert isinstance(v, OpResult)
374375
print("don't cast", v.result_number, v)
375376
return v
376377

0 commit comments

Comments
 (0)