-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] auto attribute casting #97786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesThis PR implements auto attribute casting for downstream attributes just like we have for downstream types. Use case: https://github.com/openxla/shardy Full diff: https://github.com/llvm/llvm-project/pull/97786.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ebf50109f72f23..67cc48277efcbe 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -406,21 +406,25 @@ class pure_subclass {
class mlir_attribute_subclass : public pure_subclass {
public:
using IsAFunctionTy = bool (*)(MlirAttribute);
+ using GetTypeIDFunctionTy = MlirTypeID (*)();
/// Subclasses by looking up the super-class dynamically.
mlir_attribute_subclass(py::handle scope, const char *attrClassName,
- IsAFunctionTy isaFunction)
+ IsAFunctionTy isaFunction,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: mlir_attribute_subclass(
scope, attrClassName, isaFunction,
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
- .attr("Attribute")) {}
+ .attr("Attribute"),
+ getTypeIDFunction) {}
/// Subclasses with a provided mlir.ir.Attribute super-class. This must
/// be used if the subclass is being defined in the same extension module
/// as the mlir.ir class (otherwise, it will trigger a recursive
/// initialization).
mlir_attribute_subclass(py::handle scope, const char *typeClassName,
- IsAFunctionTy isaFunction, const py::object &superCls)
+ IsAFunctionTy isaFunction, const py::object &superCls,
+ GetTypeIDFunctionTy getTypeIDFunction = nullptr)
: pure_subclass(scope, typeClassName, superCls) {
// Casting constructor. Note that it hard, if not impossible, to properly
// call chain to parent `__init__` in pybind11 due to its special handling
@@ -454,6 +458,25 @@ class mlir_attribute_subclass : public pure_subclass {
"isinstance",
[isaFunction](MlirAttribute other) { return isaFunction(other); },
py::arg("other_attribute"));
+ def("__repr__", [superCls, captureTypeName](py::object self) {
+ return py::repr(superCls(self))
+ .attr("replace")(superCls.attr("__name__"), captureTypeName);
+ });
+ if (getTypeIDFunction) {
+ // 'get_static_typeid' method.
+ // This is modeled as a static method instead of a static property because
+ // `def_property_readonly_static` is not available in `pure_subclass` and
+ // we do not want to introduce the complexity that pybind uses to
+ // implement it.
+ def_staticmethod("get_static_typeid",
+ [getTypeIDFunction]() { return getTypeIDFunction(); });
+ py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+ .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+ getTypeIDFunction())(pybind11::cpp_function(
+ [thisClass = thisClass](const py::object &mlirAttribute) {
+ return thisClass(mlirAttribute);
+ }));
+ }
}
};
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 70927b22d4749c..a76f3f2b5e4583 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -307,11 +307,23 @@ def testOptionalOperandOp():
# CHECK-LABEL: TEST: testCustomAttribute
@run
def testCustomAttribute():
- with Context() as ctx:
+ with Context() as ctx, Location.unknown():
a = test.TestAttr.get()
# CHECK: #python_test.test_attr
print(a)
+ # CHECK: python_test.custom_attributed_op {
+ # CHECK: #python_test.test_attr
+ # CHECK: }
+ op2 = test.CustomAttributedOp(a)
+ print(f"{op2}")
+
+ # CHECK: #python_test.test_attr
+ print(f"{op2.test_attr}")
+
+ # CHECK: TestAttr(#python_test.test_attr)
+ print(repr(op2.test_attr))
+
# The following cast must not assert.
b = test.TestAttr(a)
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 71778a97d83a41..cb7d7677714fe6 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -23,6 +23,10 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
return wrap(python_test::TestAttrAttr::get(unwrap(context)));
}
+MlirTypeID mlirPythonTestTestAttributeGetTypeID(void) {
+ return wrap(python_test::TestAttrAttr::getTypeID());
+}
+
bool mlirTypeIsAPythonTestTestType(MlirType type) {
return llvm::isa<python_test::TestTypeType>(unwrap(type));
}
diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 5f1ed3a5b2ad66..43f8fdcbfae125 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -23,6 +23,8 @@ mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirAttribute
mlirPythonTestTestAttributeGet(MlirContext context);
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void);
+
MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728d..889365e1136b4e 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -16,13 +16,13 @@
#include "PythonTestDialect.h.inc"
-#define GET_OP_CLASSES
-#include "PythonTestOps.h.inc"
-
#define GET_ATTRDEF_CLASSES
#include "PythonTestAttributes.h.inc"
#define GET_TYPEDEF_CLASSES
#include "PythonTestTypes.h.inc"
+#define GET_OP_CLASSES
+#include "PythonTestOps.h.inc"
+
#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f81b851f8759bf..a4f538dcb55944 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -44,10 +44,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
py::arg("registry"));
mlir_attribute_subclass(m, "TestAttr",
- mlirAttributeIsAPythonTestTestAttribute)
+ mlirAttributeIsAPythonTestTestAttribute,
+ mlirPythonTestTestAttributeGetTypeID)
.def_classmethod(
"get",
- [](py::object cls, MlirContext ctx) {
+ [](const py::object &cls, MlirContext ctx) {
return cls(mlirPythonTestTestAttributeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
@@ -56,7 +57,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
mlirPythonTestTestTypeGetTypeID)
.def_classmethod(
"get",
- [](py::object cls, MlirContext ctx) {
+ [](const py::object &cls, MlirContext ctx) {
return cls(mlirPythonTestTestTypeGet(ctx));
},
py::arg("cls"), py::arg("context") = py::none());
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 95301985e3fde0..5a82c00ae60802 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -58,6 +58,10 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unit);
}
+def CustomAttributedOp : TestOp<"custom_attributed_op"> {
+ let arguments = (ins TestAttr:$test_attr);
+}
+
def AttributesOp : TestOp<"attributes_op"> {
let arguments = (ins
AffineMapArrayAttr:$x_affinemaparr,
|
3e5d810
to
3fcd41e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again thanks for making this change!
This PR implements auto attribute casting for downstream attributes just like we have for downstream types.
Use case: https://github.com/openxla/shardy