Skip to content

Commit 5468f88

Browse files
authored
[mlir] update remaining transform tests to main pass (#81279)
Use the main transform interpreter pass instead of the test pass. The only tests that are not updated are specific to the operation of the test pass.
1 parent 3fa9102 commit 5468f88

16 files changed

+584
-425
lines changed

mlir/include/mlir/Dialect/Transform/Transforms/Passes.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,25 @@ def InterpreterPass : Pass<"transform-interpreter"> {
6666
let description = [{
6767
This pass runs the transform dialect interpreter and applies the named
6868
sequence transformation specified by the provided name (defaults to
69-
`TransformDialect::kTransformEntryPointSymbolName` (i.e. `__transform_main`)).
69+
`TransformDialect::kTransformEntryPointSymbolName`,
70+
i.e. `__transform_main`).
71+
72+
Additional options can be used to narrow down the pass applicability for
73+
debugging purposes:
74+
* `debugPayloadRootTag` makes the transform script apply to the payload
75+
operation that has a `transform.target_tag` string attribute with the
76+
given value, rather than to the anchor operation of the pass.
77+
* `debugBindTrailingArgs` allows one to bind values to trailing arguments
78+
of the transform entry point as follows:
79+
* arguments of `TransformHandleTypeInterface` type can be bound to all
80+
payload operations with the name provided as a simple string;
81+
* arguments of `TransformValueHandleTypeInterface` type can be bound to
82+
a flattened list of results of all operations with the name provided
83+
as a string prefixed with `^`;
84+
* arguments of `TransformParamTypeInterface` type can be bound to
85+
integer constants provided as `;`-separated list prefixed with `#`.
86+
* `entryPoint` specifies the name of the transform symbol to serve as the
87+
entry point.
7088
}];
7189
let dependentDialects = ["::mlir::transform::TransformDialect"];
7290
let options = [
@@ -83,7 +101,9 @@ def InterpreterPass : Pass<"transform-interpreter"> {
83101
"false",
84102
"Disable expensive checks in the interpreter for a faster run.">,
85103
Option<"entryPoint", "entry-point", "std::string",
86-
/*default=*/[{TransformDialect::kTransformEntryPointSymbolName.str()}],
104+
/*default=*/[{
105+
TransformDialect::kTransformEntryPointSymbolName.str()
106+
}],
87107
"Entry point of the pass pipeline.">,
88108
];
89109
}

mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,79 @@ static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
5050
return WalkResult::interrupt();
5151
});
5252

53+
if (!target) {
54+
passRoot->emitError()
55+
<< "could not find the operation with transform.target_tag=\"" << tag
56+
<< "\" attribute";
57+
return nullptr;
58+
}
59+
5360
return walkResult.wasInterrupted() ? nullptr : target;
5461
}
5562

5663
namespace {
5764
class InterpreterPass
5865
: public transform::impl::InterpreterPassBase<InterpreterPass> {
66+
// Parses the pass arguments to bind trailing arguments of the entry point.
67+
std::optional<RaggedArray<transform::MappedValue>>
68+
parseArguments(Operation *payloadRoot) {
69+
MLIRContext *context = payloadRoot->getContext();
70+
71+
SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
72+
trailingBindings.resize(debugBindTrailingArgs.size());
73+
74+
// Construct lists of op names to match.
75+
SmallVector<std::optional<OperationName>> debugBindNames;
76+
debugBindNames.reserve(debugBindTrailingArgs.size());
77+
for (auto &&[position, nameString] :
78+
llvm::enumerate(debugBindTrailingArgs)) {
79+
StringRef name = nameString;
80+
81+
// Parse the integer literals.
82+
if (name.starts_with("#")) {
83+
debugBindNames.push_back(std::nullopt);
84+
StringRef lhs = "";
85+
StringRef rhs = name.drop_front();
86+
do {
87+
std::tie(lhs, rhs) = rhs.split(';');
88+
int64_t value;
89+
if (lhs.getAsInteger(10, value)) {
90+
emitError(UnknownLoc::get(context))
91+
<< "couldn't parse integer pass argument " << name;
92+
return std::nullopt;
93+
}
94+
trailingBindings[position].push_back(
95+
Builder(context).getI64IntegerAttr(value));
96+
} while (!rhs.empty());
97+
} else if (name.starts_with("^")) {
98+
debugBindNames.emplace_back(OperationName(name.drop_front(), context));
99+
} else {
100+
debugBindNames.emplace_back(OperationName(name, context));
101+
}
102+
}
103+
104+
// Collect operations or results for extra bindings.
105+
payloadRoot->walk([&](Operation *payload) {
106+
for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
107+
if (!name || payload->getName() != *name)
108+
continue;
109+
110+
if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
111+
.starts_with("^")) {
112+
llvm::append_range(trailingBindings[position], payload->getResults());
113+
} else {
114+
trailingBindings[position].push_back(payload);
115+
}
116+
}
117+
});
118+
119+
RaggedArray<transform::MappedValue> bindings;
120+
bindings.push_back(ArrayRef<Operation *>{payloadRoot});
121+
for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
122+
bindings.push_back(std::move(trailing));
123+
return bindings;
124+
}
125+
59126
public:
60127
using Base::Base;
61128

@@ -67,34 +134,18 @@ class InterpreterPass
67134
findPayloadRoot(getOperation(), debugPayloadRootTag);
68135
if (!payloadRoot)
69136
return signalPassFailure();
70-
auto debugBindNames = llvm::map_to_vector(
71-
debugBindTrailingArgs,
72-
[&](const std::string &name) { return OperationName(name, context); });
73-
SmallVector<SmallVector<Operation *>, 2> trailingBindings;
74-
trailingBindings.resize(debugBindNames.size());
75-
payloadRoot->walk([&](Operation *payload) {
76-
for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
77-
if (payload->getName() == name)
78-
trailingBindings[position].push_back(payload);
79-
}
80-
});
81137

82138
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
83139
getOperation(), transformModule, entryPoint);
84-
if (!transformEntryPoint) {
85-
getOperation()->emitError()
86-
<< "could not find transform entry point: " << entryPoint
87-
<< " in either payload or transform module";
140+
if (!transformEntryPoint)
88141
return signalPassFailure();
89-
}
90-
91-
RaggedArray<transform::MappedValue> bindings;
92-
bindings.push_back(ArrayRef<Operation *>{payloadRoot});
93-
for (SmallVector<Operation *> &trailing : trailingBindings)
94-
bindings.push_back(std::move(trailing));
95142

143+
std::optional<RaggedArray<transform::MappedValue>> bindings =
144+
parseArguments(payloadRoot);
145+
if (!bindings)
146+
return signalPassFailure();
96147
if (failed(transform::applyTransformNamedSequence(
97-
bindings,
148+
*bindings,
98149
cast<transform::TransformOpInterface>(transformEntryPoint),
99150
transformModule,
100151
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
// RUN: mlir-opt %s
22
// No need to check anything else than parsing here, this is being used by another test as data.
33

4-
transform.with_pdl_patterns {
5-
^bb0(%arg0: !transform.any_op):
6-
pdl.pattern @func_return : benefit(1) {
7-
%0 = pdl.operation "func.return"
8-
pdl.rewrite %0 with "transform.dialect"
9-
}
4+
module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @__transform_main(%root: !transform.any_op) {
6+
transform.with_pdl_patterns %root : !transform.any_op {
7+
^bb0(%arg0: !transform.any_op):
8+
pdl.pattern @func_return : benefit(1) {
9+
%0 = pdl.operation "func.return"
10+
pdl.rewrite %0 with "transform.dialect"
11+
}
1012

11-
sequence %arg0 : !transform.any_op failures(propagate) {
12-
^bb1(%arg1: !transform.any_op):
13-
%0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
14-
transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
13+
sequence %arg0 : !transform.any_op failures(propagate) {
14+
^bb1(%arg1: !transform.any_op):
15+
%0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
16+
transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
17+
}
18+
}
19+
transform.yield
1520
}
1621
}
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
// RUN: mlir-opt %s
22
// No need to check anything else than parsing here, this is being used by another test as data.
33

4-
transform.sequence failures(propagate) {
5-
^bb0(%arg0: !transform.any_op):
6-
transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
7-
transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
8-
^bb1(%arg1: !transform.any_op):
9-
transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
4+
module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
6+
transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
7+
transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
8+
^bb1(%arg1: !transform.any_op):
9+
transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
10+
}
11+
transform.yield
1012
}
1113
}

mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
2-
// RUN: --split-input-file --verify-diagnostics
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{\
2+
// RUN: debug-bind-trailing-args=func.func,func.return})" \
3+
// RUN: --split-input-file --verify-diagnostics
34

4-
transform.sequence failures(propagate) {
5-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
6-
transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
7-
transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
5+
module attributes {transform.with_named_sequence} {
6+
transform.named_sequence @__transform_main(
7+
%arg0: !transform.any_op, %arg1: !transform.any_op,
8+
%arg2: !transform.any_op) {
9+
transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
10+
transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
11+
transform.yield
12+
}
813
}
914

1015
// expected-remark @below {{first extra}}
@@ -26,9 +31,13 @@ func.func @bar(%arg0: i1) {
2631

2732
// -----
2833

29-
transform.sequence failures(propagate) {
30-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
31-
// expected-error @above {{wrong kind of value provided for top-level parameter}}
34+
module attributes {transform.with_named_sequence} {
35+
transform.named_sequence @__transform_main(
36+
%arg0: !transform.any_op, %arg1: !transform.any_op,
37+
%arg2: !transform.param<i64>) {
38+
// expected-error @above {{wrong kind of value provided for top-level parameter}}
39+
transform.yield
40+
}
3241
}
3342

3443
func.func @foo() {
@@ -37,9 +46,13 @@ func.func @foo() {
3746

3847
// -----
3948

40-
transform.sequence failures(propagate) {
41-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
42-
// expected-error @above {{wrong kind of value provided for the top-level value handle}}
49+
module attributes {transform.with_named_sequence} {
50+
transform.named_sequence @__transform_main(
51+
%arg0: !transform.any_op, %arg1: !transform.any_op,
52+
%arg2: !transform.any_value) {
53+
// expected-error @above {{wrong kind of value provided for the top-level value handle}}
54+
transform.yield
55+
}
4356
}
4457

4558
func.func @foo() {
@@ -48,19 +61,27 @@ func.func @foo() {
4861

4962
// -----
5063

51-
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
52-
transform.sequence failures(propagate) {
53-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
64+
65+
module attributes {transform.with_named_sequence} {
66+
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
67+
transform.named_sequence @__transform_main(
68+
%arg0: !transform.any_op, %arg1: !transform.any_op) {
69+
transform.yield
70+
}
5471
}
5572

5673
// -----
5774

58-
transform.sequence failures(propagate) {
59-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
60-
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
61-
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
62-
transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
63-
transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
75+
module attributes {transform.with_named_sequence} {
76+
transform.named_sequence @__transform_main(
77+
%arg0: !transform.any_op, %arg1: !transform.any_op,
78+
%arg2: !transform.any_op) {
79+
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
80+
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
81+
transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
82+
transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
83+
}
84+
transform.yield
6485
}
6586
}
6687

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
2+
// RUN: debug-bind-trailing-args=#1;2;3,#42;45})' \
23
// RUN: --split-input-file --verify-diagnostics
34

4-
transform.sequence failures(propagate) {
5-
^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
6-
// expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
7-
transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
8-
// expected-remark @below {{42 : i64, 45 : i64}}
9-
transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
5+
module attributes {transform.with_named_sequence} {
6+
transform.named_sequence @__transform_main(
7+
%arg0: !transform.any_op, %arg1: !transform.param<i64>,
8+
%arg2: !transform.param<i64>) {
9+
// expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
10+
transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
11+
// expected-remark @below {{42 : i64, 45 : i64}}
12+
transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
13+
transform.yield
14+
}
1015
}
1116

1217
// -----
1318

14-
transform.sequence failures(propagate) {
15-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
16-
// expected-error @above {{wrong kind of value provided for top-level operation handle}}
19+
module attributes {transform.with_named_sequence} {
20+
transform.named_sequence @__transform_main(
21+
%arg0: !transform.any_op, %arg1: !transform.any_op,
22+
// expected-error @above {{wrong kind of value provided for top-level operation handle}}
23+
%arg2: !transform.param<i64>) {
24+
transform.yield
25+
}
1726
}
1827

1928
// -----
2029

21-
// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
22-
transform.sequence failures(propagate) {
23-
^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
30+
module attributes {transform.with_named_sequence} {
31+
// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
32+
transform.named_sequence @__transform_main(
33+
%arg0: !transform.any_op, %arg1: !transform.param<i64>,
34+
%arg2: !transform.param<i64>, %arg3: !transform.param<i64>) {
35+
transform.yield
36+
}
2437
}

0 commit comments

Comments
 (0)