Skip to content

Commit 49b688c

Browse files
committed
Let transform.structured.convert_to_loops return handles to loops
This lets `transform.structured.convert_to_loops` return handles to the generated loops, making this transformation more useful to use for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops.
1 parent ff66e9b commit 49b688c

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,14 +1281,14 @@ def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
12811281
let description = [{
12821282
For operations that implement the `TilingInterface`, and implement
12831283
the `generateScalarImplementation` method, lowers the operation to
1284-
loops. This operation does not return any handles.
1284+
loops. The return handles point to the generated loops.
12851285
}];
12861286

12871287
let arguments = (ins TransformHandleTypeInterface:$target);
1288-
let results = (outs);
1288+
let results = (outs Variadic<TransformHandleTypeInterface>:$result);
12891289

12901290
let assemblyFormat = [{
1291-
$target attr-dict `:` type($target)
1291+
$target attr-dict `:` functional-type(operands, results)
12921292
}];
12931293

12941294
let extraClassDeclaration = [{

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,6 +2122,9 @@ DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
21222122
if (failed(loops))
21232123
return emitDefaultDefiniteFailure(target);
21242124
rewriter.eraseOp(target);
2125+
for (auto &loop : *loops) {
2126+
results.push_back(loop);
2127+
}
21252128
return DiagnosedSilenceableFailure::success();
21262129
}
21272130

mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
1111
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
1212
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
1313
: (!transform.any_op) -> !transform.any_op
14-
transform.structured.convert_to_loops %matmul : !transform.any_op
14+
%0:3 = transform.structured.convert_to_loops %matmul
15+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
1516
transform.yield
1617
}
1718
}
@@ -66,7 +67,8 @@ module attributes {transform.with_named_sequence} {
6667
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
6768
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
6869
: (!transform.any_op) -> !transform.any_op
69-
transform.structured.convert_to_loops %generic : !transform.any_op
70+
%0:2 = transform.structured.convert_to_loops %generic
71+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
7072
transform.yield
7173
}
7274
}
@@ -111,7 +113,8 @@ module attributes {transform.with_named_sequence} {
111113
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
112114
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
113115
: (!transform.any_op) -> !transform.any_op
114-
transform.structured.convert_to_loops %conv : !transform.any_op
116+
%0:7 = transform.structured.convert_to_loops %conv
117+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
115118
transform.yield
116119
}
117120
}
@@ -165,7 +168,8 @@ module attributes {transform.with_named_sequence} {
165168
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
166169
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
167170
: (!transform.any_op) -> !transform.any_op
168-
transform.structured.convert_to_loops %pool : !transform.any_op
171+
%0:6 = transform.structured.convert_to_loops %pool
172+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
169173
transform.yield
170174
}
171175
}
@@ -216,7 +220,8 @@ module attributes {transform.with_named_sequence} {
216220
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
217221
%map = transform.structured.match ops{["linalg.map"]} in %arg1
218222
: (!transform.any_op) -> !transform.any_op
219-
transform.structured.convert_to_loops %map : !transform.any_op
223+
%0 = transform.structured.convert_to_loops %map
224+
: (!transform.any_op) -> (!transform.any_op)
220225
transform.yield
221226
}
222227
}
@@ -248,7 +253,8 @@ module attributes {transform.with_named_sequence} {
248253
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
249254
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
250255
: (!transform.any_op) -> !transform.any_op
251-
transform.structured.convert_to_loops %transpose : !transform.any_op
256+
%0:3 = transform.structured.convert_to_loops %transpose
257+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
252258
transform.yield
253259
}
254260
}
@@ -285,7 +291,8 @@ module attributes {transform.with_named_sequence} {
285291
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
286292
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
287293
: (!transform.any_op) -> !transform.any_op
288-
transform.structured.convert_to_loops %reduce : !transform.any_op
294+
%0:3 = transform.structured.convert_to_loops %reduce
295+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
289296
transform.yield
290297
}
291298
}
@@ -322,7 +329,8 @@ module attributes {transform.with_named_sequence} {
322329
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
323330
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
324331
: (!transform.any_op) -> !transform.any_op
325-
transform.structured.convert_to_loops %broadcast : !transform.any_op
332+
%0:3 = transform.structured.convert_to_loops %broadcast
333+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
326334
transform.yield
327335
}
328336
}

0 commit comments

Comments
 (0)