-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][python] Expose transform param types #67421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][python] Expose transform param types #67421
Conversation
@llvm/pr-subscribers-mlir ChangesThis exposes the Transform dialect types Full diff: https://github.com/llvm/llvm-project/pull/67421.diff 4 Files Affected:
diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h
index 954575925cc5c45..91c99b1f869f22c 100644
--- a/mlir/include/mlir-c/Dialect/Transform.h
+++ b/mlir/include/mlir-c/Dialect/Transform.h
@@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx);
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx);
+
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName);
MLIR_CAPI_EXPORTED MlirStringRef
mlirTransformOperationTypeGetOperationName(MlirType type);
+//===---------------------------------------------------------------------===//
+// ParamType
+//===---------------------------------------------------------------------===//
+
+MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx,
+ MlirType type);
+
+MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type);
+
#ifdef __cplusplus
}
#endif
diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp
index 932e40220057c13..e7d73c12d3db3d5 100644
--- a/mlir/lib/Bindings/Python/DialectTransform.cpp
+++ b/mlir/lib/Bindings/Python/DialectTransform.cpp
@@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
"Get an instance of AnyOpType in the given context.", py::arg("cls"),
py::arg("context") = py::none());
+ //===-------------------------------------------------------------------===//
+ // AnyParamType
+ //===-------------------------------------------------------------------===//
+
+ auto anyParamType =
+ mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType);
+ anyParamType.def_classmethod(
+ "get",
+ [](py::object cls, MlirContext ctx) {
+ return cls(mlirTransformAnyParamTypeGet(ctx));
+ },
+ "Get an instance of AnyParamType in the given context.", py::arg("cls"),
+ py::arg("context") = py::none());
+
//===-------------------------------------------------------------------===//
// AnyValueType
//===-------------------------------------------------------------------===//
@@ -71,6 +85,28 @@ void populateDialectTransformSubmodule(const pybind11::module &m) {
return py::str(operationName.data, operationName.length);
},
"Get the name of the payload operation accepted by the handle.");
+
+ //===-------------------------------------------------------------------===//
+ // ParamType
+ //===-------------------------------------------------------------------===//
+
+ auto paramType =
+ mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType);
+ paramType.def_classmethod(
+ "get",
+ [](py::object cls, MlirType type, MlirContext ctx) {
+ return cls(mlirTransformParamTypeGet(ctx, type));
+ },
+ "Get an instance of ParamType for the given type in the given context.",
+ py::arg("cls"), py::arg("type") = py::none(),
+ py::arg("context") = py::none());
+ paramType.def_property_readonly(
+ "type",
+ [](MlirType type) {
+ MlirType paramType = mlirTransformParamTypeGetType(type);
+ return paramType;
+ },
+ "Get the type this ParamType is associated with.");
}
PYBIND11_MODULE(_mlirDialectsTransform, m) {
diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp
index 5841f6783ad5f1d..3f7f8b8e2113fe4 100644
--- a/mlir/lib/CAPI/Dialect/Transform.cpp
+++ b/mlir/lib/CAPI/Dialect/Transform.cpp
@@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) {
return wrap(transform::AnyOpType::get(unwrap(ctx)));
}
+//===---------------------------------------------------------------------===//
+// AnyParamType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformAnyParamType(MlirType type) {
+ return isa<transform::AnyParamType>(unwrap(type));
+}
+
+MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) {
+ return wrap(transform::AnyParamType::get(unwrap(ctx)));
+}
+
//===---------------------------------------------------------------------===//
// AnyValueType
//===---------------------------------------------------------------------===//
@@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx,
MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) {
return wrap(cast<transform::OperationType>(unwrap(type)).getOperationName());
}
+
+//===---------------------------------------------------------------------===//
+// AnyOpType
+//===---------------------------------------------------------------------===//
+
+bool mlirTypeIsATransformParamType(MlirType type) {
+ return isa<transform::ParamType>(unwrap(type));
+}
+
+MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) {
+ return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type)));
+}
+
+MlirType mlirTransformParamTypeGetType(MlirType type) {
+ return wrap(cast<transform::ParamType>(unwrap(type)).getType());
+}
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 5df125694256a4e..481d7745720101d 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -22,6 +22,10 @@ def testTypes():
any_op = transform.AnyOpType.get()
print(any_op)
+ # CHECK: !transform.any_param
+ any_param = transform.AnyParamType.get()
+ print(any_param)
+
# CHECK: !transform.any_value
any_value = transform.AnyValueType.get()
print(any_value)
@@ -32,6 +36,12 @@ def testTypes():
print(concrete_op)
print(concrete_op.operation_name)
+ # CHECK: !transform.param<i32>
+ # CHECK: i32
+ param = transform.ParamType.get(IntegerType.get_signless(32))
+ print(param)
+ print(param.type)
+
@run
def testSequenceOp():
|
return cls(mlirTransformParamTypeGet(ctx, type)); | ||
}, | ||
"Get an instance of ParamType for the given type in the given context.", | ||
py::arg("cls"), py::arg("type") = py::none(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if "type" is None here? I don't think we can construct a parameter type around a null type, can we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case the bindings throw a TypeError
Exception. I removed the default argument for type
here as it is always required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo the question @ftynse mentioned. Thanks!
943a599
to
19c5585
Compare
This exposes the Transform dialect types `AnyParamType` and `ParamType` via the Python bindings.
This exposes the Transform dialect types
AnyParamType
andParamType
via the Python bindings.