Skip to content

Commit 9ccf01f

Browse files
[mlir][transform] Support for multiple top-level transform ops (llvm#69615)
This adds a flag to the `TransformDialectInterpreter` that relaxes the requirement for only a single top-level transform op. This is useful for supporting transforms that take transform IR as payload. This also aligns the function `findTopLevelTransform` [here](llvm@7b0f4c9#diff-551f92bb609487ccf981daf9571f0f1b1703ab2330560a388a5f0d133e520be4L59) with its documentation: In the presence of multiple top-level transform ops it now correctly returns the first of them after reporting the error instead of returning a `nullptr`.
1 parent e45f6e9 commit 9ccf01f

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,23 @@ class TransformOptions {
9595
return *this;
9696
}
9797

98+
// Ensures that only a single top-level transform op is present in the IR.
99+
TransformOptions &enableEnforceSingleToplevelTransformOp(bool enable = true) {
100+
enforceSingleToplevelTransformOp = enable;
101+
return *this;
102+
}
103+
98104
/// Returns true if the expensive checks are requested.
99105
bool getExpensiveChecksEnabled() const { return expensiveChecksEnabled; }
100106

107+
// Returns true if enforcing a single top-level transform op is requested.
108+
bool getEnforceSingleToplevelTransformOp() const {
109+
return enforceSingleToplevelTransformOp;
110+
}
111+
101112
private:
102113
bool expensiveChecksEnabled = true;
114+
bool enforceSingleToplevelTransformOp = true;
103115
};
104116

105117
/// Entry point to the Transform dialect infrastructure. Applies the

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,11 @@ constexpr static llvm::StringLiteral
5656
/// Reports an error if there is more than one such operation and returns the
5757
/// first one found. Reports an error returns nullptr if no such operation
5858
/// found.
59-
static Operation *findTopLevelTransform(Operation *root,
60-
StringRef filenameOption) {
59+
static Operation *
60+
findTopLevelTransform(Operation *root, StringRef filenameOption,
61+
mlir::transform::TransformOptions options) {
6162
::mlir::transform::TransformOpInterface topLevelTransform = nullptr;
62-
WalkResult walkResult = root->walk<WalkOrder::PreOrder>(
63+
root->walk<WalkOrder::PreOrder>(
6364
[&](::mlir::transform::TransformOpInterface transformOp) {
6465
if (!transformOp
6566
->hasTrait<transform::PossibleTopLevelTransformOpTrait>())
@@ -68,14 +69,15 @@ static Operation *findTopLevelTransform(Operation *root,
6869
topLevelTransform = transformOp;
6970
return WalkResult::skip();
7071
}
71-
auto diag = transformOp.emitError()
72-
<< "more than one top-level transform op";
73-
diag.attachNote(topLevelTransform.getLoc())
74-
<< "previous top-level transform op";
75-
return WalkResult::interrupt();
72+
if (options.getEnforceSingleToplevelTransformOp()) {
73+
auto diag = transformOp.emitError()
74+
<< "more than one top-level transform op";
75+
diag.attachNote(topLevelTransform.getLoc())
76+
<< "previous top-level transform op";
77+
return WalkResult::interrupt();
78+
}
79+
return WalkResult::skip();
7680
});
77-
if (walkResult.wasInterrupted())
78-
return nullptr;
7981
if (!topLevelTransform) {
8082
auto diag = root->emitError()
8183
<< "could not find a nested top-level transform op";
@@ -310,7 +312,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
310312
Operation *transformRoot =
311313
debugTransformRootTag.empty()
312314
? findTopLevelTransform(transformContainer,
313-
transformFileName.getArgStr())
315+
transformFileName.getArgStr(), options)
314316
: findOpWithTag(transformContainer, kTransformDialectTagAttrName,
315317
debugTransformRootTag);
316318
if (!transformRoot)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter='enforce-single-top-level-transform-op=0' -allow-unregistered-dialect --split-input-file --verify-diagnostics | FileCheck %s
2+
3+
transform.sequence failures(propagate) {
4+
// CHECK: transform.sequence
5+
^bb0(%arg0: !transform.any_op):
6+
}
7+
8+
transform.sequence failures(propagate) {
9+
// CHECK: transform.sequence
10+
^bb0(%arg0: !transform.any_op):
11+
}
12+
13+
// -----
14+
15+
transform.sequence failures(propagate) {
16+
^bb0(%arg0: !transform.any_op):
17+
%match = transform.structured.match ops{["transform.get_parent_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
18+
transform.test_print_remark_at_operand %match, "found get_parent_op" : !transform.any_op
19+
}
20+
21+
transform.sequence failures(propagate) {
22+
^bb0(%arg0: !transform.any_op):
23+
%op = transform.structured.match ops{[]} in %arg0 : (!transform.any_op) -> !transform.any_op
24+
// expected-remark @below{{found get_parent_op}}
25+
%1 = transform.get_parent_op %op : (!transform.any_op) -> !transform.any_op
26+
}

mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ class TestTransformDialectInterpreterPass
158158
}
159159

160160
options = options.enableExpensiveChecks(enableExpensiveChecks);
161+
options = options.enableEnforceSingleToplevelTransformOp(
162+
enforceSingleToplevelTransformOp);
161163
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
162164
getOperation(), getArgument(), getSharedTransformModule(),
163165
getTransformLibraryModule(), extraMapping, options,
@@ -170,6 +172,10 @@ class TestTransformDialectInterpreterPass
170172
*this, "enable-expensive-checks", llvm::cl::init(false),
171173
llvm::cl::desc("perform expensive checks to better report errors in the "
172174
"transform IR")};
175+
Option<bool> enforceSingleToplevelTransformOp{
176+
*this, "enforce-single-top-level-transform-op", llvm::cl::init(true),
177+
llvm::cl::desc("Ensure that only a single top-level transform op is "
178+
"present in the IR.")};
173179

174180
Option<std::string> bindFirstExtraToOps{
175181
*this, "bind-first-extra-to-ops",

0 commit comments

Comments
 (0)