Skip to content

Commit 47bc565

Browse files
authored
[MLIR] [Transforms] Let transform.structured.convert_to_loops return handles to loops (#83984)
This lets `transform.structured.convert_to_loops` return handles to the generated loops, making this transformation more useful to use for (transformation-)nesting purposes. This is modelled after SCFs `transform.loop.forall_to_for` which returns handles to loops. Introduced in commit aa2a96a, with a note that they might move out of the `Linalg`-Dialect, but no reason given for the non-return of handles. As far as I can see, this transform always returns loops.
1 parent f0eb0c5 commit 47bc565

File tree

3 files changed

+101
-31
lines changed

3 files changed

+101
-31
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,33 +1274,29 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
12741274
}];
12751275
}
12761276

1277+
//===----------------------------------------------------------------------===//
1278+
// ConvertToLoopsOp
1279+
//===----------------------------------------------------------------------===//
1280+
12771281
def ConvertToLoopsOp : Op<Transform_Dialect, "structured.convert_to_loops",
12781282
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
1279-
TransformOpInterface, TransformEachOpTrait,
1283+
DeclareOpInterfaceMethods<TransformOpInterface>,
12801284
ReportTrackingListenerFailuresOpTrait]> {
12811285
let description = [{
12821286
For operations that implement the `TilingInterface`, and implement
12831287
the `generateScalarImplementation` method, lowers the operation to
1284-
loops. This operation does not return any handles.
1288+
loops. The return handle points to all generated loops.
1289+
Fails if the payload ops cannot be lowered to loops.
12851290
}];
12861291

12871292
let arguments = (ins TransformHandleTypeInterface:$target);
1288-
let results = (outs);
1293+
let results = (outs TransformHandleTypeInterface:$result);
12891294

12901295
let assemblyFormat = [{
1291-
$target attr-dict `:` type($target)
1292-
}];
1293-
1294-
let extraClassDeclaration = [{
1295-
::mlir::DiagnosedSilenceableFailure applyToOne(
1296-
::mlir::transform::TransformRewriter &rewriter,
1297-
::mlir::TilingInterface target,
1298-
::mlir::transform::ApplyToEachResultList &results,
1299-
::mlir::transform::TransformState &state);
1296+
$target attr-dict `:` functional-type(operands, results)
13001297
}];
13011298
}
13021299

1303-
13041300
//===----------------------------------------------------------------------===//
13051301
// DecomposeInterfaceOp
13061302
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,16 +2112,31 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
21122112
// ConvertToLoopsOp
21132113
//===----------------------------------------------------------------------===//
21142114

2115-
DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne(
2116-
transform::TransformRewriter &rewriter, TilingInterface target,
2117-
transform::ApplyToEachResultList &results,
2118-
transform::TransformState &state) {
2119-
rewriter.setInsertionPoint(target);
2120-
FailureOr<SmallVector<scf::ForOp>> loops =
2121-
scf::lowerToLoopsUsingSCFForOp(rewriter, target);
2122-
if (failed(loops))
2123-
return emitDefaultDefiniteFailure(target);
2124-
rewriter.eraseOp(target);
2115+
DiagnosedSilenceableFailure
2116+
transform::ConvertToLoopsOp::apply(transform::TransformRewriter &rewriter,
2117+
transform::TransformResults &results,
2118+
transform::TransformState &state) {
2119+
SmallVector<Operation *> loops;
2120+
for (Operation *target : state.getPayloadOps(getTarget())) {
2121+
auto tilingOp = dyn_cast<TilingInterface>(*target);
2122+
if (!target) {
2123+
DiagnosedSilenceableFailure diag =
2124+
emitSilenceableError()
2125+
<< "expected the payload to implement TilingInterface";
2126+
diag.attachNote(target->getLoc()) << "payload op";
2127+
return diag;
2128+
}
2129+
rewriter.setInsertionPoint(target);
2130+
FailureOr<SmallVector<scf::ForOp>> generatedLoops =
2131+
scf::lowerToLoopsUsingSCFForOp(rewriter, tilingOp);
2132+
if (failed(generatedLoops))
2133+
return emitDefaultDefiniteFailure(target);
2134+
for (scf::ForOp &loop : *generatedLoops) {
2135+
loops.push_back(loop.getOperation());
2136+
}
2137+
rewriter.eraseOp(target);
2138+
}
2139+
results.set(cast<OpResult>(getResult()), loops);
21252140
return DiagnosedSilenceableFailure::success();
21262141
}
21272142

mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ module attributes {transform.with_named_sequence} {
1111
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
1212
%matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
1313
: (!transform.any_op) -> !transform.any_op
14-
transform.structured.convert_to_loops %matmul : !transform.any_op
14+
%0 = transform.structured.convert_to_loops %matmul
15+
: (!transform.any_op) -> (!transform.any_op)
1516
transform.yield
1617
}
1718
}
@@ -37,6 +38,57 @@ module attributes {transform.with_named_sequence} {
3738

3839
// -----
3940

41+
func.func @gemm(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
42+
%arg2 : memref<?x?xf32>, %arg3 : memref<?xf32>, %arg4 : memref<?xf32>) {
43+
linalg.matmul ins(%arg0, %arg1 : memref<?x?xf32>, memref<?x?xf32>)
44+
outs(%arg2 : memref<?x?xf32>)
45+
linalg.matvec ins(%arg0, %arg3 : memref<?x?xf32>, memref<?xf32>)
46+
outs(%arg4 : memref<?xf32>)
47+
return
48+
}
49+
50+
module attributes {transform.with_named_sequence} {
51+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
52+
%linalg_ops = transform.structured.match interface{TilingInterface} in %arg1
53+
: (!transform.any_op) -> !transform.any_op
54+
%0 = transform.structured.convert_to_loops %linalg_ops
55+
: (!transform.any_op) -> (!transform.any_op)
56+
%1:5 = transform.split_handle %0
57+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
58+
transform.yield
59+
}
60+
}
61+
// CHECK-LABEL: func @gemm
62+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
63+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
64+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
65+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: memref<?xf32>
66+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: memref<?xf32>
67+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
68+
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
69+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
70+
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
71+
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
72+
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
73+
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C1]]
74+
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
75+
// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[IV2]]]
76+
// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG1]][%[[IV2]], %[[IV1]]]
77+
// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]]]
78+
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
79+
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
80+
// CHECK: memref.store %[[ADDF]], %[[ARG2]][%[[IV0]], %[[IV1]]]
81+
// CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]]
82+
// CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C1]]
83+
// CHECK-DAG: %[[LHS:.+]] = memref.load %[[ARG0]][%[[IV3]], %[[IV4]]]
84+
// CHECK-DAG: %[[RHS:.+]] = memref.load %[[ARG3]][%[[IV4]]]
85+
// CHECK-DAG: %[[OUT:.+]] = memref.load %[[ARG4]][%[[IV3]]]
86+
// CHECK: %[[MULF:.+]] = arith.mulf %[[LHS]], %[[RHS]]
87+
// CHECK: %[[ADDF:.+]] = arith.addf %[[OUT]], %[[MULF]]
88+
// CHECK: memref.store %[[ADDF]], %[[ARG4]][%[[IV3]]]
89+
90+
// -----
91+
4092
func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>,
4193
%arg2 : memref<200xi8>, %arg3 : memref<300x200xi64>) {
4294
linalg.generic {
@@ -66,7 +118,8 @@ module attributes {transform.with_named_sequence} {
66118
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
67119
%generic = transform.structured.match ops{["linalg.generic"]} in %arg1
68120
: (!transform.any_op) -> !transform.any_op
69-
transform.structured.convert_to_loops %generic : !transform.any_op
121+
%0 = transform.structured.convert_to_loops %generic
122+
: (!transform.any_op) -> (!transform.any_op)
70123
transform.yield
71124
}
72125
}
@@ -111,7 +164,8 @@ module attributes {transform.with_named_sequence} {
111164
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
112165
%conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1
113166
: (!transform.any_op) -> !transform.any_op
114-
transform.structured.convert_to_loops %conv : !transform.any_op
167+
%0 = transform.structured.convert_to_loops %conv
168+
: (!transform.any_op) -> (!transform.any_op)
115169
transform.yield
116170
}
117171
}
@@ -165,7 +219,8 @@ module attributes {transform.with_named_sequence} {
165219
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
166220
%pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1
167221
: (!transform.any_op) -> !transform.any_op
168-
transform.structured.convert_to_loops %pool : !transform.any_op
222+
%0 = transform.structured.convert_to_loops %pool
223+
: (!transform.any_op) -> (!transform.any_op)
169224
transform.yield
170225
}
171226
}
@@ -216,7 +271,8 @@ module attributes {transform.with_named_sequence} {
216271
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
217272
%map = transform.structured.match ops{["linalg.map"]} in %arg1
218273
: (!transform.any_op) -> !transform.any_op
219-
transform.structured.convert_to_loops %map : !transform.any_op
274+
%0 = transform.structured.convert_to_loops %map
275+
: (!transform.any_op) -> (!transform.any_op)
220276
transform.yield
221277
}
222278
}
@@ -248,7 +304,8 @@ module attributes {transform.with_named_sequence} {
248304
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
249305
%transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1
250306
: (!transform.any_op) -> !transform.any_op
251-
transform.structured.convert_to_loops %transpose : !transform.any_op
307+
%0 = transform.structured.convert_to_loops %transpose
308+
: (!transform.any_op) -> (!transform.any_op)
252309
transform.yield
253310
}
254311
}
@@ -285,7 +342,8 @@ module attributes {transform.with_named_sequence} {
285342
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
286343
%reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1
287344
: (!transform.any_op) -> !transform.any_op
288-
transform.structured.convert_to_loops %reduce : !transform.any_op
345+
%0 = transform.structured.convert_to_loops %reduce
346+
: (!transform.any_op) -> (!transform.any_op)
289347
transform.yield
290348
}
291349
}
@@ -322,7 +380,8 @@ module attributes {transform.with_named_sequence} {
322380
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
323381
%broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1
324382
: (!transform.any_op) -> !transform.any_op
325-
transform.structured.convert_to_loops %broadcast : !transform.any_op
383+
%0 = transform.structured.convert_to_loops %broadcast
384+
: (!transform.any_op) -> (!transform.any_op)
326385
transform.yield
327386
}
328387
}

0 commit comments

Comments
 (0)