Skip to content

Commit 116563b

Browse files
committed
[mlir] update remaining transform tests to main pass
Use the main transform interpreter pass instead of the test pass. The only tests that are not updated are specific to the operation of the test pass.
1 parent 6e20cb5 commit 116563b

15 files changed

+562
-423
lines changed

mlir/lib/Dialect/Transform/Transforms/InterpreterPass.cpp

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,79 @@ static Operation *findPayloadRoot(Operation *passRoot, StringRef tag) {
5050
return WalkResult::interrupt();
5151
});
5252

53+
if (!target) {
54+
passRoot->emitError()
55+
<< "could not find the operation with transform.target_tag=\"" << tag
56+
<< "\" attribute";
57+
return nullptr;
58+
}
59+
5360
return walkResult.wasInterrupted() ? nullptr : target;
5461
}
5562

5663
namespace {
5764
class InterpreterPass
5865
: public transform::impl::InterpreterPassBase<InterpreterPass> {
66+
// Parses the pass arguments to bind trailing arguments of the entry point.
67+
std::optional<RaggedArray<transform::MappedValue>>
68+
parseArguments(Operation *payloadRoot) {
69+
MLIRContext *context = payloadRoot->getContext();
70+
71+
SmallVector<SmallVector<transform::MappedValue>, 2> trailingBindings;
72+
trailingBindings.resize(debugBindTrailingArgs.size());
73+
74+
// Construct lists of op names to match.
75+
SmallVector<std::optional<OperationName>> debugBindNames;
76+
debugBindNames.reserve(debugBindTrailingArgs.size());
77+
for (auto &&[position, nameString] :
78+
llvm::enumerate(debugBindTrailingArgs)) {
79+
StringRef name = nameString;
80+
81+
// Parse the integer literals.
82+
if (name.starts_with("#")) {
83+
debugBindNames.push_back(std::nullopt);
84+
StringRef lhs = "";
85+
StringRef rhs = name.drop_front();
86+
do {
87+
std::tie(lhs, rhs) = rhs.split(';');
88+
int64_t value;
89+
if (lhs.getAsInteger(10, value)) {
90+
emitError(UnknownLoc::get(context))
91+
<< "couldn't parse integer pass argument " << name;
92+
return std::nullopt;
93+
}
94+
trailingBindings[position].push_back(
95+
Builder(context).getI64IntegerAttr(value));
96+
} while (!rhs.empty());
97+
} else if (name.starts_with("^")) {
98+
debugBindNames.emplace_back(OperationName(name.drop_front(), context));
99+
} else {
100+
debugBindNames.emplace_back(OperationName(name, context));
101+
}
102+
}
103+
104+
// Collect operations or results for extra bindings.
105+
payloadRoot->walk([&](Operation *payload) {
106+
for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
107+
if (!name || payload->getName() != *name)
108+
continue;
109+
110+
if (StringRef(*std::next(debugBindTrailingArgs.begin(), position))
111+
.starts_with("^")) {
112+
llvm::append_range(trailingBindings[position], payload->getResults());
113+
} else {
114+
trailingBindings[position].push_back(payload);
115+
}
116+
}
117+
});
118+
119+
RaggedArray<transform::MappedValue> bindings;
120+
bindings.push_back(ArrayRef<Operation *>{payloadRoot});
121+
for (SmallVector<transform::MappedValue> &trailing : trailingBindings)
122+
bindings.push_back(std::move(trailing));
123+
return bindings;
124+
}
125+
59126
public:
60127
using Base::Base;
61128

@@ -67,34 +134,18 @@ class InterpreterPass
67134
findPayloadRoot(getOperation(), debugPayloadRootTag);
68135
if (!payloadRoot)
69136
return signalPassFailure();
70-
auto debugBindNames = llvm::map_to_vector(
71-
debugBindTrailingArgs,
72-
[&](const std::string &name) { return OperationName(name, context); });
73-
SmallVector<SmallVector<Operation *>, 2> trailingBindings;
74-
trailingBindings.resize(debugBindNames.size());
75-
payloadRoot->walk([&](Operation *payload) {
76-
for (auto &&[position, name] : llvm::enumerate(debugBindNames)) {
77-
if (payload->getName() == name)
78-
trailingBindings[position].push_back(payload);
79-
}
80-
});
81137

82138
Operation *transformEntryPoint = transform::detail::findTransformEntryPoint(
83139
getOperation(), transformModule, entryPoint);
84-
if (!transformEntryPoint) {
85-
getOperation()->emitError()
86-
<< "could not find transform entry point: " << entryPoint
87-
<< " in either payload or transform module";
140+
if (!transformEntryPoint)
88141
return signalPassFailure();
89-
}
90-
91-
RaggedArray<transform::MappedValue> bindings;
92-
bindings.push_back(ArrayRef<Operation *>{payloadRoot});
93-
for (SmallVector<Operation *> &trailing : trailingBindings)
94-
bindings.push_back(std::move(trailing));
95142

143+
std::optional<RaggedArray<transform::MappedValue>> bindings =
144+
parseArguments(payloadRoot);
145+
if (!bindings)
146+
return signalPassFailure();
96147
if (failed(transform::applyTransformNamedSequence(
97-
bindings,
148+
*bindings,
98149
cast<transform::TransformOpInterface>(transformEntryPoint),
99150
transformModule,
100151
options.enableExpensiveChecks(!disableExpensiveChecks)))) {
Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
// RUN: mlir-opt %s
22
// No need to check anything else than parsing here, this is being used by another test as data.
33

4-
transform.with_pdl_patterns {
5-
^bb0(%arg0: !transform.any_op):
6-
pdl.pattern @func_return : benefit(1) {
7-
%0 = pdl.operation "func.return"
8-
pdl.rewrite %0 with "transform.dialect"
9-
}
4+
module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @__transform_main(%root: !transform.any_op) {
6+
transform.with_pdl_patterns %root : !transform.any_op {
7+
^bb0(%arg0: !transform.any_op):
8+
pdl.pattern @func_return : benefit(1) {
9+
%0 = pdl.operation "func.return"
10+
pdl.rewrite %0 with "transform.dialect"
11+
}
1012

11-
sequence %arg0 : !transform.any_op failures(propagate) {
12-
^bb1(%arg1: !transform.any_op):
13-
%0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
14-
transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
13+
sequence %arg0 : !transform.any_op failures(propagate) {
14+
^bb1(%arg1: !transform.any_op):
15+
%0 = pdl_match @func_return in %arg1 : (!transform.any_op) -> !transform.op<"func.return">
16+
transform.debug.emit_remark_at %0, "matched" : !transform.op<"func.return">
17+
}
18+
}
19+
transform.yield
1520
}
1621
}
Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
// RUN: mlir-opt %s
22
// No need to check anything else than parsing here, this is being used by another test as data.
33

4-
transform.sequence failures(propagate) {
5-
^bb0(%arg0: !transform.any_op):
6-
transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
7-
transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
8-
^bb1(%arg1: !transform.any_op):
9-
transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
4+
module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
6+
transform.debug.emit_remark_at %arg0, "outer" : !transform.any_op
7+
transform.sequence %arg0 : !transform.any_op failures(propagate) attributes {transform.target_tag="transform"} {
8+
^bb1(%arg1: !transform.any_op):
9+
transform.debug.emit_remark_at %arg1, "inner" : !transform.any_op
10+
}
11+
transform.yield
1012
}
1113
}

mlir/test/Dialect/Transform/multi-arg-top-level-ops.mlir

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-ops=func.func bind-second-extra-to-ops=func.return})' \
2-
// RUN: --split-input-file --verify-diagnostics
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(transform-interpreter{\
2+
// RUN: debug-bind-trailing-args=func.func,func.return})" \
3+
// RUN: --split-input-file --verify-diagnostics
34

4-
transform.sequence failures(propagate) {
5-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
6-
transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
7-
transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
5+
module attributes {transform.with_named_sequence} {
6+
transform.named_sequence @__transform_main(
7+
%arg0: !transform.any_op, %arg1: !transform.any_op,
8+
%arg2: !transform.any_op) {
9+
transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_op
10+
transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_op
11+
transform.yield
12+
}
813
}
914

1015
// expected-remark @below {{first extra}}
@@ -26,9 +31,13 @@ func.func @bar(%arg0: i1) {
2631

2732
// -----
2833

29-
transform.sequence failures(propagate) {
30-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
31-
// expected-error @above {{wrong kind of value provided for top-level parameter}}
34+
module attributes {transform.with_named_sequence} {
35+
transform.named_sequence @__transform_main(
36+
%arg0: !transform.any_op, %arg1: !transform.any_op,
37+
%arg2: !transform.param<i64>) {
38+
// expected-error @above {{wrong kind of value provided for top-level parameter}}
39+
transform.yield
40+
}
3241
}
3342

3443
func.func @foo() {
@@ -37,9 +46,13 @@ func.func @foo() {
3746

3847
// -----
3948

40-
transform.sequence failures(propagate) {
41-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
42-
// expected-error @above {{wrong kind of value provided for the top-level value handle}}
49+
module attributes {transform.with_named_sequence} {
50+
transform.named_sequence @__transform_main(
51+
%arg0: !transform.any_op, %arg1: !transform.any_op,
52+
%arg2: !transform.any_value) {
53+
// expected-error @above {{wrong kind of value provided for the top-level value handle}}
54+
transform.yield
55+
}
4356
}
4457

4558
func.func @foo() {
@@ -48,19 +61,27 @@ func.func @foo() {
4861

4962
// -----
5063

51-
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
52-
transform.sequence failures(propagate) {
53-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op):
64+
65+
module attributes {transform.with_named_sequence} {
66+
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
67+
transform.named_sequence @__transform_main(
68+
%arg0: !transform.any_op, %arg1: !transform.any_op) {
69+
transform.yield
70+
}
5471
}
5572

5673
// -----
5774

58-
transform.sequence failures(propagate) {
59-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_op):
60-
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
61-
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
62-
transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
63-
transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
75+
module attributes {transform.with_named_sequence} {
76+
transform.named_sequence @__transform_main(
77+
%arg0: !transform.any_op, %arg1: !transform.any_op,
78+
%arg2: !transform.any_op) {
79+
transform.sequence %arg0, %arg1, %arg2 : !transform.any_op, !transform.any_op, !transform.any_op failures(propagate) {
80+
^bb0(%arg3: !transform.any_op, %arg4: !transform.any_op, %arg5: !transform.any_op):
81+
transform.debug.emit_remark_at %arg4, "first extra" : !transform.any_op
82+
transform.debug.emit_remark_at %arg5, "second extra" : !transform.any_op
83+
}
84+
transform.yield
6485
}
6586
}
6687

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,37 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-params=1,2,3 bind-second-extra-to-params=42,45})' \
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
2+
// RUN: debug-bind-trailing-args=#1;2;3,#42;45})' \
23
// RUN: --split-input-file --verify-diagnostics
34

4-
transform.sequence failures(propagate) {
5-
^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>):
6-
// expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
7-
transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
8-
// expected-remark @below {{42 : i64, 45 : i64}}
9-
transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
5+
module attributes {transform.with_named_sequence} {
6+
transform.named_sequence @__transform_main(
7+
%arg0: !transform.any_op, %arg1: !transform.param<i64>,
8+
%arg2: !transform.param<i64>) {
9+
// expected-remark @below {{1 : i64, 2 : i64, 3 : i64}}
10+
transform.debug.emit_param_as_remark %arg1 : !transform.param<i64>
11+
// expected-remark @below {{42 : i64, 45 : i64}}
12+
transform.debug.emit_param_as_remark %arg2 : !transform.param<i64>
13+
transform.yield
14+
}
1015
}
1116

1217
// -----
1318

14-
transform.sequence failures(propagate) {
15-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.param<i64>):
16-
// expected-error @above {{wrong kind of value provided for top-level operation handle}}
19+
module attributes {transform.with_named_sequence} {
20+
transform.named_sequence @__transform_main(
21+
%arg0: !transform.any_op, %arg1: !transform.any_op,
22+
// expected-error @above {{wrong kind of value provided for top-level operation handle}}
23+
%arg2: !transform.param<i64>) {
24+
transform.yield
25+
}
1726
}
1827

1928
// -----
2029

21-
// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
22-
transform.sequence failures(propagate) {
23-
^bb0(%arg0: !transform.any_op, %arg1: !transform.param<i64>, %arg2: !transform.param<i64>, %arg3: !transform.param<i64>):
30+
module attributes {transform.with_named_sequence} {
31+
// expected-error @below {{operation expects 3 extra value bindings, but 2 were provided to the interpreter}}
32+
transform.named_sequence @__transform_main(
33+
%arg0: !transform.any_op, %arg1: !transform.param<i64>,
34+
%arg2: !transform.param<i64>, %arg3: !transform.param<i64>) {
35+
transform.yield
36+
}
2437
}

mlir/test/Dialect/Transform/multi-arg-top-level-values.mlir

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(test-transform-dialect-interpreter{bind-first-extra-to-results-of-ops=test.some_returning_op bind-second-extra-to-results-of-ops=test.some_other_returning_op})' \
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(transform-interpreter{\
2+
// RUN: debug-bind-trailing-args=^test.some_returning_op,^test.some_other_returning_op})' \
23
// RUN: --split-input-file --verify-diagnostics
34

45
// Note that diagnostic checker will merge two diagnostics with the same message
@@ -21,25 +22,34 @@
2122
// expected-note @below {{value handle points to an op result #1}}
2223
%2:2 = "test.some_other_returning_op"() : () -> (f32, f64)
2324

24-
transform.sequence failures(propagate) {
25-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value, %arg2: !transform.any_value):
26-
transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_value
27-
transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_value
25+
module attributes {transform.with_named_sequence} {
26+
transform.named_sequence @__transform_main(
27+
%arg0: !transform.any_op, %arg1: !transform.any_value,
28+
%arg2: !transform.any_value) {
29+
transform.debug.emit_remark_at %arg1, "first extra" : !transform.any_value
30+
transform.debug.emit_remark_at %arg2, "second extra" : !transform.any_value
31+
transform.yield
32+
}
2833
}
2934

3035
// -----
3136

3237
%0:2 = "test.some_returning_op"() : () -> (i32, i64)
3338
%1 = "test.some_returning_op"() : () -> index
3439

35-
transform.sequence failures(propagate) {
36-
// expected-error @below {{wrong kind of value provided for top-level operation handle}}
37-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value):
40+
module attributes {transform.with_named_sequence} {
41+
transform.named_sequence @__transform_main(
42+
// expected-error @below {{wrong kind of value provided for top-level operation handle}}
43+
%arg0: !transform.any_op, %arg1: !transform.any_op, %arg2: !transform.any_value) {
44+
transform.yield
45+
}
3846
}
3947

4048
// -----
4149

42-
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
43-
transform.sequence failures(propagate) {
44-
^bb0(%arg0: !transform.any_op, %arg1: !transform.any_value):
50+
module attributes {transform.with_named_sequence} {
51+
// expected-error @below {{operation expects 1 extra value bindings, but 2 were provided to the interpreter}}
52+
transform.named_sequence @__transform_main(%arg0: !transform.any_op, %arg1: !transform.any_value) {
53+
transform.yield
54+
}
4555
}

0 commit comments

Comments
 (0)