Skip to content

Commit 9315645

Browse files
authored
[mlir][python] auto attribute casting (#97786)
1 parent 3bb2563 commit 9315645

File tree

7 files changed

+51
-10
lines changed

7 files changed

+51
-10
lines changed

mlir/include/mlir/Bindings/Python/PybindAdaptors.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,21 +406,25 @@ class pure_subclass {
406406
class mlir_attribute_subclass : public pure_subclass {
407407
public:
408408
using IsAFunctionTy = bool (*)(MlirAttribute);
409+
using GetTypeIDFunctionTy = MlirTypeID (*)();
409410

410411
/// Subclasses by looking up the super-class dynamically.
411412
mlir_attribute_subclass(py::handle scope, const char *attrClassName,
412-
IsAFunctionTy isaFunction)
413+
IsAFunctionTy isaFunction,
414+
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
413415
: mlir_attribute_subclass(
414416
scope, attrClassName, isaFunction,
415417
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
416-
.attr("Attribute")) {}
418+
.attr("Attribute"),
419+
getTypeIDFunction) {}
417420

418421
/// Subclasses with a provided mlir.ir.Attribute super-class. This must
419422
/// be used if the subclass is being defined in the same extension module
420423
/// as the mlir.ir class (otherwise, it will trigger a recursive
421424
/// initialization).
422425
mlir_attribute_subclass(py::handle scope, const char *typeClassName,
423-
IsAFunctionTy isaFunction, const py::object &superCls)
426+
IsAFunctionTy isaFunction, const py::object &superCls,
427+
GetTypeIDFunctionTy getTypeIDFunction = nullptr)
424428
: pure_subclass(scope, typeClassName, superCls) {
425429
// Casting constructor. Note that it hard, if not impossible, to properly
426430
// call chain to parent `__init__` in pybind11 due to its special handling
@@ -454,6 +458,20 @@ class mlir_attribute_subclass : public pure_subclass {
454458
"isinstance",
455459
[isaFunction](MlirAttribute other) { return isaFunction(other); },
456460
py::arg("other_attribute"));
461+
def("__repr__", [superCls, captureTypeName](py::object self) {
462+
return py::repr(superCls(self))
463+
.attr("replace")(superCls.attr("__name__"), captureTypeName);
464+
});
465+
if (getTypeIDFunction) {
466+
def_staticmethod("get_static_typeid",
467+
[getTypeIDFunction]() { return getTypeIDFunction(); });
468+
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
469+
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
470+
getTypeIDFunction())(pybind11::cpp_function(
471+
[thisClass = thisClass](const py::object &mlirAttribute) {
472+
return thisClass(mlirAttribute);
473+
}));
474+
}
457475
}
458476
};
459477

mlir/test/python/dialects/python_test.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,11 +307,23 @@ def testOptionalOperandOp():
307307
# CHECK-LABEL: TEST: testCustomAttribute
308308
@run
309309
def testCustomAttribute():
310-
with Context() as ctx:
310+
with Context() as ctx, Location.unknown():
311311
a = test.TestAttr.get()
312312
# CHECK: #python_test.test_attr
313313
print(a)
314314

315+
# CHECK: python_test.custom_attributed_op {
316+
# CHECK: #python_test.test_attr
317+
# CHECK: }
318+
op2 = test.CustomAttributedOp(a)
319+
print(f"{op2}")
320+
321+
# CHECK: #python_test.test_attr
322+
print(f"{op2.test_attr}")
323+
324+
# CHECK: TestAttr(#python_test.test_attr)
325+
print(repr(op2.test_attr))
326+
315327
# The following cast must not assert.
316328
b = test.TestAttr(a)
317329

mlir/test/python/lib/PythonTestCAPI.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
2323
return wrap(python_test::TestAttrAttr::get(unwrap(context)));
2424
}
2525

26+
MlirTypeID mlirPythonTestTestAttributeGetTypeID(void) {
27+
return wrap(python_test::TestAttrAttr::getTypeID());
28+
}
29+
2630
bool mlirTypeIsAPythonTestTestType(MlirType type) {
2731
return llvm::isa<python_test::TestTypeType>(unwrap(type));
2832
}

mlir/test/python/lib/PythonTestCAPI.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
2323
MLIR_CAPI_EXPORTED MlirAttribute
2424
mlirPythonTestTestAttributeGet(MlirContext context);
2525

26+
MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void);
27+
2628
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
2729

2830
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);

mlir/test/python/lib/PythonTestDialect.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
#include "PythonTestDialect.h.inc"
1818

19-
#define GET_OP_CLASSES
20-
#include "PythonTestOps.h.inc"
21-
2219
#define GET_ATTRDEF_CLASSES
2320
#include "PythonTestAttributes.h.inc"
2421

2522
#define GET_TYPEDEF_CLASSES
2623
#include "PythonTestTypes.h.inc"
2724

25+
#define GET_OP_CLASSES
26+
#include "PythonTestOps.h.inc"
27+
2828
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H

mlir/test/python/lib/PythonTestModule.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
4444
py::arg("registry"));
4545

4646
mlir_attribute_subclass(m, "TestAttr",
47-
mlirAttributeIsAPythonTestTestAttribute)
47+
mlirAttributeIsAPythonTestTestAttribute,
48+
mlirPythonTestTestAttributeGetTypeID)
4849
.def_classmethod(
4950
"get",
50-
[](py::object cls, MlirContext ctx) {
51+
[](const py::object &cls, MlirContext ctx) {
5152
return cls(mlirPythonTestTestAttributeGet(ctx));
5253
},
5354
py::arg("cls"), py::arg("context") = py::none());
@@ -56,7 +57,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
5657
mlirPythonTestTestTypeGetTypeID)
5758
.def_classmethod(
5859
"get",
59-
[](py::object cls, MlirContext ctx) {
60+
[](const py::object &cls, MlirContext ctx) {
6061
return cls(mlirPythonTestTestTypeGet(ctx));
6162
},
6263
py::arg("cls"), py::arg("context") = py::none());

mlir/test/python/python_test_ops.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ def AttributedOp : TestOp<"attributed_op"> {
5858
UnitAttr:$unit);
5959
}
6060

61+
def CustomAttributedOp : TestOp<"custom_attributed_op"> {
62+
let arguments = (ins TestAttr:$test_attr);
63+
}
64+
6165
def AttributesOp : TestOp<"attributes_op"> {
6266
let arguments = (ins
6367
AffineMapArrayAttr:$x_affinemaparr,

0 commit comments

Comments
 (0)