-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][transform] Support for multiple top-level transform ops #69615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][transform] Support for multiple top-level transform ops #69615
Conversation
@llvm/pr-subscribers-mlir Author: None (martin-luecke) ChangesThis adds a flag to the This also aligns the function Full diff: https://github.com/llvm/llvm-project/pull/69615.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 7b37245fc3d117b..60eb48a764eb4b2 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -95,11 +95,24 @@ class TransformOptions {
return *this;
}
+ // Ensures that only a single top-level transform op is present in the IR.
+ TransformOptions &
+ enableEnforceSingleToplevelTransformOp(bool enable = true) {
+ enforceSingleToplevelTransformOp = enable;
+ return *this;
+ }
+
/// Returns true if the expensive checks are requested.
bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
+ // Returns true if enforcing a single top-level transform op is requested.
+ bool getEnforceSingleToplevelTransformOp() const {
+ return enforceSingleToplevelTransformOp;
+ }
+
private:
bool expensiveChecksEnabled = true;
+ bool enforceSingleToplevelTransformOp = true;
};
/// Entry point to the Transform dialect infrastructure. Applies the
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index 538c81fe39fddb2..741456e7ebbfb86 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
/// Reports an error if there is more than one such operation and returns the
/// first one found. Reports an error returns nullptr if no such operation
/// found.
-static Operation *findTopLevelTransform(Operation *root,
- StringRef filenameOption) {
+static Operation *
+findTopLevelTransform(Operation *root, StringRef filenameOption,
+ mlir::transform::TransformOptions options) {
::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
- WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
+ root->walk<WalkOrder::PreOrder>(
[&](::mlir::transform::TransformOpInterface transformOp) {
if (!transformOp
->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
topLevelTransform = transformOp;
return WalkResult::skip();
}
- auto diag = transformOp.emitError()
- << "more than one top-level transform op";
- diag.attachNote(topLevelTransform.getLoc())
- << "previous top-level transform op";
- return WalkResult::interrupt();
+ if (options.getEnforceSingleToplevelTransformOp()) {
+ auto diag = transformOp.emitError()
+ << "more than one top-level transform op";
+ diag.attachNote(topLevelTransform.getLoc())
+ << "previous top-level transform op";
+ return WalkResult::interrupt();
+ }
+ return WalkResult::skip();
});
- if (walkResult.wasInterrupted())
- return nullptr;
if (!topLevelTransform) {
auto diag = root->emitError()
<< "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *transformRoot =
debugTransformRootTag.empty()
? findTopLevelTransform(transformContainer,
- transformFileName.getArgStr())
+ transformFileName.getArgStr(), options)
: findOpWithTag(transformContainer, kTransformDialectTagAttrName,
debugTransformRootTag);
if (!transformRoot)
diff --git a/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
new file mode 100644
index 000000000000000..db7fecdf753e984
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-multiple-top-level-ops.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+transform.sequence failures(propagate) {
+// CHECK: transform.sequence
+^bb0(%arg0: !transform.any_op):
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
+}
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ %op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ // expected-remark @below{{found get_parent_op}}
+ %1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
+}
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index c60b21c918338b4..756b7f669b0c5bf 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
}
options = options.enableExpensiveChecks(enableExpensiveChecks);
+ options = options.enableEnforceSingleToplevelTransformOp(
+ enforceSingleToplevelTransformOp);
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
*this, "enable-expensive-checks", llvm::cl::init(false),
llvm::cl::desc("perform expensive checks to better report errors in the "
"transform IR")};
+ Option<bool> enforceSingleToplevelTransformOp{
+ *this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
+ llvm::cl::desc("Ensure that only a single top-level transform op is "
+ "present in the IR.")};
Option<std::string> bindFirstExtraToOps{
*this, "bind-first-extra-to-ops",
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's okay to have, but we are generally moving away from the model where a top-level op is executed and towards a model where there is a named entry point.
7b0f4c9
to
23add3c
Compare
This adds a flag to the
TransformDialectInterpreter
that relaxes the requirement for only a single top-level transform op.This is useful for supporting transforms that take transform IR as payload.
This also aligns the function
findTopLevelTransform
here with its documentation:In the presence of multiple top-level transform ops it now correctly returns the first of them after reporting the error instead of returning a
nullptr
.