Skip to content

[mlir][python] Expose transform param types #67421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2023

Conversation

martin-luecke
Copy link
Contributor

This exposes the Transform dialect types AnyParamType and ParamType via the Python bindings.

@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2023

@llvm/pr-subscribers-mlir

Changes

This exposes the Transform dialect types AnyParamType and ParamType via the Python bindings.


Full diff: https://github.com/llvm/llvm-project/pull/67421.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/Dialect/Transform.h (+19)
  • (modified) mlir/lib/Bindings/Python/DialectTransform.cpp (+36)
  • (modified) mlir/lib/CAPI/Dialect/Transform.cpp (+28)
  • (modified) mlir/test/python/dialects/transform.py (+10)
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
 
 MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
 MLIR_CAPI_EXPORTED MlirStringRef
 mlirTransformOperationTypeGetOperationName(MlirType type);
 
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+                                                      MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..e7d73c12d3db3d5 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
       "Get an instance of AnyOpType in the given context.", py::arg("cls"),
       py::arg("context") = py::none());
 
+  //===-------------------------------------------------------------------===//
+  // AnyParamType
+  //===-------------------------------------------------------------------===//
+
+  auto anyParamType =
+      mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+  anyParamType.def_classmethod(
+      "get",
+      [](py::object cls, MlirContext ctx) {
+        return cls(mlirTransformAnyParamTypeGet(ctx));
+      },
+      "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+      py::arg("context") = py::none());
+
   //===-------------------------------------------------------------------===//
   // AnyValueType
   //===-------------------------------------------------------------------===//
@@ -71,6 +85,28 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
         return py::str(operationName.data, operationName.length);
       },
       "Get the name of the payload operation accepted by the handle.");
+
+  //===-------------------------------------------------------------------===//
+  // ParamType
+  //===-------------------------------------------------------------------===//
+
+  auto paramType =
+      mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+  paramType.def_classmethod(
+      "get",
+      [](py::object cls, MlirType type, MlirContext ctx) {
+        return cls(mlirTransformParamTypeGet(ctx, type));
+      },
+      "Get an instance of ParamType for the given type in the given context.",
+      py::arg("cls"), py::arg("type") = py::none(),
+      py::arg("context") = py::none());
+  paramType.def_property_readonly(
+      "type",
+      [](MlirType type) {
+        MlirType paramType = mlirTransformParamTypeGetType(type);
+        return paramType;
+      },
+      "Get the type this ParamType is associated with.");
 }
 
 PYBIND11_MODULE(_mlirDialectsTransform, m) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
   return wrap(transform::AnyOpType::get(unwrap(ctx)));
 }
 
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+  return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+  return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
 //===---------------------------------------------------------------------===//
 // AnyValueType
 //===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
 MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
   return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
 }
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+  return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+  return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+  return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
     any_op = transform.AnyOpType.get()
     print(any_op)
 
+    # CHECK: !transform.any_param
+    any_param = transform.AnyParamType.get()
+    print(any_param)
+
     # CHECK: !transform.any_value
     any_value = transform.AnyValueType.get()
     print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
     print(concrete_op)
     print(concrete_op.operation_name)
 
+    # CHECK: !transform.param<i32>
+    # CHECK: i32
+    param = transform.ParamType.get(IntegerType.get_signless(32))
+    print(param)
+    print(param.type)
+
 
 @run
 def testSequenceOp():

return cls(mlirTransformParamTypeGet(ctx, type));
},
"Get an instance of ParamType for the given type in the given context.",
py::arg("cls"), py::arg("type") = py::none(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if "type" is None here? I don't think we can construct a parameter type around a null type, can we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the bindings throw a TypeError Exception. I removed the default argument for type here as it is always required.

Copy link
Contributor

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo the question @ftynse mentioned. Thanks!

@martin-luecke martin-luecke force-pushed the expose_transform_param_type branch from 943a599 to 19c5585 Compare September 26, 2023 13:06
@martin-luecke martin-luecke merged commit 97f9f1a into llvm:main Sep 26, 2023
@martin-luecke martin-luecke deleted the expose_transform_param_type branch September 26, 2023 14:10
legrosbuffle pushed a commit to legrosbuffle/llvm-project that referenced this pull request Sep 29, 2023
This exposes the Transform dialect types `AnyParamType` and `ParamType`
via the Python bindings.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants