Skip to content

[mlir][python] auto attribute casting #97786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 5, 2024

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Jul 5, 2024

This PR implements auto attribute casting for downstream attributes just like we have for downstream types.

Use case: https://github.com/openxla/shardy

@llvmbot
Copy link
Member

llvmbot commented Jul 5, 2024

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

This PR implements auto attribute casting for downstream attributes just like we have for downstream types.

Use case: https://github.com/openxla/shardy


Full diff: https://github.com/llvm/llvm-project/pull/97786.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Bindings/Python/PybindAdaptors.h (+26-3)
  • (modified) mlir/test/python/dialects/python_test.py (+13-1)
  • (modified) mlir/test/python/lib/PythonTestCAPI.cpp (+4)
  • (modified) mlir/test/python/lib/PythonTestCAPI.h (+2)
  • (modified) mlir/test/python/lib/PythonTestDialect.h (+3-3)
  • (modified) mlir/test/python/lib/PythonTestModule.cpp (+4-3)
  • (modified) mlir/test/python/python_test_ops.td (+4)
diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
index ebf50109f72f23..67cc48277efcbe 100644
--- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h
@@ -406,21 +406,25 @@ class pure_subclass {
 class mlir_attribute_subclass : public pure_subclass {
 public:
   using IsAFunctionTy = bool (*)(MlirAttribute);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
 
   /// Subclasses by looking up the super-class dynamically.
   mlir_attribute_subclass(py::handle scope, const char *attrClassName,
-                          IsAFunctionTy isaFunction)
+                          IsAFunctionTy isaFunction,
+                          GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : mlir_attribute_subclass(
             scope, attrClassName, isaFunction,
             py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                .attr("Attribute")) {}
+                .attr("Attribute"),
+            getTypeIDFunction) {}
 
   /// Subclasses with a provided mlir.ir.Attribute super-class. This must
   /// be used if the subclass is being defined in the same extension module
   /// as the mlir.ir class (otherwise, it will trigger a recursive
   /// initialization).
   mlir_attribute_subclass(py::handle scope, const char *typeClassName,
-                          IsAFunctionTy isaFunction, const py::object &superCls)
+                          IsAFunctionTy isaFunction, const py::object &superCls,
+                          GetTypeIDFunctionTy getTypeIDFunction = nullptr)
       : pure_subclass(scope, typeClassName, superCls) {
     // Casting constructor. Note that it hard, if not impossible, to properly
     // call chain to parent `__init__` in pybind11 due to its special handling
@@ -454,6 +458,25 @@ class mlir_attribute_subclass : public pure_subclass {
         "isinstance",
         [isaFunction](MlirAttribute other) { return isaFunction(other); },
         py::arg("other_attribute"));
+    def("__repr__", [superCls, captureTypeName](py::object self) {
+      return py::repr(superCls(self))
+          .attr("replace")(superCls.attr("__name__"), captureTypeName);
+    });
+    if (getTypeIDFunction) {
+      // 'get_static_typeid' method.
+      // This is modeled as a static method instead of a static property because
+      // `def_property_readonly_static` is not available in `pure_subclass` and
+      // we do not want to introduce the complexity that pybind uses to
+      // implement it.
+      def_staticmethod("get_static_typeid",
+                       [getTypeIDFunction]() { return getTypeIDFunction(); });
+      py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
+          .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
+              getTypeIDFunction())(pybind11::cpp_function(
+              [thisClass = thisClass](const py::object &mlirAttribute) {
+                return thisClass(mlirAttribute);
+              }));
+    }
   }
 };
 
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 70927b22d4749c..a76f3f2b5e4583 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -307,11 +307,23 @@ def testOptionalOperandOp():
 # CHECK-LABEL: TEST: testCustomAttribute
 @run
 def testCustomAttribute():
-    with Context() as ctx:
+    with Context() as ctx, Location.unknown():
         a = test.TestAttr.get()
         # CHECK: #python_test.test_attr
         print(a)
 
+        # CHECK: python_test.custom_attributed_op  {
+        # CHECK: #python_test.test_attr
+        # CHECK: }
+        op2 = test.CustomAttributedOp(a)
+        print(f"{op2}")
+
+        # CHECK: #python_test.test_attr
+        print(f"{op2.test_attr}")
+
+        # CHECK: TestAttr(#python_test.test_attr)
+        print(repr(op2.test_attr))
+
         # The following cast must not assert.
         b = test.TestAttr(a)
 
diff --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
index 71778a97d83a41..cb7d7677714fe6 100644
--- a/mlir/test/python/lib/PythonTestCAPI.cpp
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -23,6 +23,10 @@ MlirAttribute mlirPythonTestTestAttributeGet(MlirContext context) {
   return wrap(python_test::TestAttrAttr::get(unwrap(context)));
 }
 
+MlirTypeID mlirPythonTestTestAttributeGetTypeID(void) {
+  return wrap(python_test::TestAttrAttr::getTypeID());
+}
+
 bool mlirTypeIsAPythonTestTestType(MlirType type) {
   return llvm::isa<python_test::TestTypeType>(unwrap(type));
 }
diff --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
index 5f1ed3a5b2ad66..43f8fdcbfae125 100644
--- a/mlir/test/python/lib/PythonTestCAPI.h
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -23,6 +23,8 @@ mlirAttributeIsAPythonTestTestAttribute(MlirAttribute attr);
 MLIR_CAPI_EXPORTED MlirAttribute
 mlirPythonTestTestAttributeGet(MlirContext context);
 
+MLIR_CAPI_EXPORTED MlirTypeID mlirPythonTestTestAttributeGetTypeID(void);
+
 MLIR_CAPI_EXPORTED bool mlirTypeIsAPythonTestTestType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirPythonTestTestTypeGet(MlirContext context);
diff --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
index 044381fcd4728d..889365e1136b4e 100644
--- a/mlir/test/python/lib/PythonTestDialect.h
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -16,13 +16,13 @@
 
 #include "PythonTestDialect.h.inc"
 
-#define GET_OP_CLASSES
-#include "PythonTestOps.h.inc"
-
 #define GET_ATTRDEF_CLASSES
 #include "PythonTestAttributes.h.inc"
 
 #define GET_TYPEDEF_CLASSES
 #include "PythonTestTypes.h.inc"
 
+#define GET_OP_CLASSES
+#include "PythonTestOps.h.inc"
+
 #endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
index f81b851f8759bf..a4f538dcb55944 100644
--- a/mlir/test/python/lib/PythonTestModule.cpp
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -44,10 +44,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
       py::arg("registry"));
 
   mlir_attribute_subclass(m, "TestAttr",
-                          mlirAttributeIsAPythonTestTestAttribute)
+                          mlirAttributeIsAPythonTestTestAttribute,
+                          mlirPythonTestTestAttributeGetTypeID)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirContext ctx) {
+          [](const py::object &cls, MlirContext ctx) {
             return cls(mlirPythonTestTestAttributeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
@@ -56,7 +57,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) {
                      mlirPythonTestTestTypeGetTypeID)
       .def_classmethod(
           "get",
-          [](py::object cls, MlirContext ctx) {
+          [](const py::object &cls, MlirContext ctx) {
             return cls(mlirPythonTestTestTypeGet(ctx));
           },
           py::arg("cls"), py::arg("context") = py::none());
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 95301985e3fde0..5a82c00ae60802 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -58,6 +58,10 @@ def AttributedOp : TestOp<"attributed_op"> {
                    UnitAttr:$unit);
 }
 
+def CustomAttributedOp : TestOp<"custom_attributed_op"> {
+  let arguments = (ins TestAttr:$test_attr);
+}
+
 def AttributesOp : TestOp<"attributes_op"> {
   let arguments = (ins
                    AffineMapArrayAttr:$x_affinemaparr,

Copy link
Contributor

@bartchr808 bartchr808 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again thanks for making this change!

@makslevental makslevental merged commit 9315645 into llvm:main Jul 5, 2024
7 checks passed
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants