Skip to content

Commit 19b9c74

Browse files
committed
[mlir] Return new scf.forall handle in fuse_into_containing_op
Since the scf.forall is now consumed by the fuse into containing op, we need to return a handle to the new scf.forall. This patch does that and also ensures that the new bbArg added to the scf.forall is used in its body. Differential Revision: https://reviews.llvm.org/D151418
1 parent 13e3d4a commit 19b9c74

File tree

4 files changed

+101
-15
lines changed

4 files changed

+101
-15
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def FuseIntoContainingOp :
183183

184184
let arguments = (ins TransformHandleTypeInterface:$producer_op,
185185
TransformHandleTypeInterface:$containing_op);
186-
let results = (outs TransformHandleTypeInterface:$fused_op);
186+
let results = (outs TransformHandleTypeInterface:$fused_op,
187+
TransformHandleTypeInterface:$new_containing_op);
187188
let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
188189
" `:` functional-type(operands, results)";
189190

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ void transform::FuseIntoContainingOp::build(OpBuilder &builder,
344344
Value producerOp,
345345
Value containingOp) {
346346
result.addOperands({producerOp, containingOp});
347-
result.addTypes(transform::AnyOpType::get(builder.getContext()));
347+
auto resultType = transform::AnyOpType::get(builder.getContext());
348+
result.addTypes({resultType, resultType});
348349
}
349350

350351
/// Add new operands to the forall op for users of the producerOp
@@ -388,8 +389,16 @@ static Operation *replaceForAllWithNewSignature(
388389
newforallOp.getRegion().takeBody(forallOp.getRegion());
389390

390391
// Add additional block argument for new value being returned
392+
// and replaces all uses of the new output with corresponding bbArg
393+
// inside the scf.forall to enable fusion into this new scf.forall.
391394
newforallOp.getBody()->addArgument(newOuts.back().getType(),
392395
newOuts.back().getLoc());
396+
auto bbArgs = newforallOp.getBody()->getArguments();
397+
rewriter.replaceUsesWithIf(newOuts.back(), bbArgs.back(),
398+
[&](OpOperand &use) {
399+
Operation *op = use.getOwner();
400+
return newforallOp->isProperAncestor(op);
401+
});
393402

394403
// Fix terminator
395404
scf::InParallelOp terminatorOp = newforallOp.getTerminator();
@@ -749,14 +758,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
749758
}
750759

751760
results.set(cast<OpResult>(getFusedOp()), fusedOps);
761+
results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
752762
return DiagnosedSilenceableFailure::success();
753763
}
754764

755765
void transform::FuseIntoContainingOp::getEffects(
756766
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
757767
consumesHandle(getProducerOp(), effects);
758-
onlyReadsHandle(getContainingOp(), effects);
759-
producesHandle(getFusedOp(), effects);
768+
consumesHandle(getContainingOp(), effects);
769+
producesHandle(getResults(), effects);
760770
modifiesPayload(effects);
761771
}
762772

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 84 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ module {
4848

4949
// linalg.fill is tileable. The op is tiled and fused.
5050
transform.structured.fuse_into_containing_op %0 into %1
51-
: (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op
51+
: (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
5252
}
5353
}
5454

@@ -92,7 +92,7 @@ module {
9292

9393
// tensor.empty is not tileable. The op is cloned and fused.
9494
transform.structured.fuse_into_containing_op %0 into %1
95-
: (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> !transform.any_op
95+
: (!transform.op<"tensor.empty">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
9696
}
9797
}
9898

@@ -139,7 +139,7 @@ module {
139139

140140
// linalg.fill is tileable. The op is tiled and fused.
141141
transform.structured.fuse_into_containing_op %0 into %1
142-
: (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> !transform.any_op
142+
: (!transform.op<"linalg.fill">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
143143
}
144144
}
145145

@@ -188,7 +188,7 @@ module {
188188

189189
// linalg.fill is tileable. The op is tiled and fused.
190190
transform.structured.fuse_into_containing_op %0 into %1
191-
: (!transform.any_op, !transform.any_op) -> !transform.any_op
191+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
192192
}
193193
}
194194

@@ -249,7 +249,7 @@ module {
249249

250250
// linalg.generic is tileable. The op is tiled and fused.
251251
transform.structured.fuse_into_containing_op %0 into %1
252-
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
252+
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
253253
}
254254
}
255255

@@ -285,7 +285,7 @@ module {
285285
%2 = transform.merge_handles %0, %0 : !transform.any_op
286286

287287
// It shouldn't be a problem to fuse this handle.
288-
transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
288+
transform.structured.fuse_into_containing_op %2 into %1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
289289
}
290290
}
291291

@@ -351,7 +351,7 @@ module {
351351

352352
// linalg.generic is tileable. The op is tiled and fused.
353353
transform.structured.fuse_into_containing_op %0 into %1
354-
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
354+
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
355355
}
356356
}
357357

@@ -417,7 +417,7 @@ module {
417417

418418
// linalg.generic is tileable. The op is tiled and fused.
419419
transform.structured.fuse_into_containing_op %0 into %1
420-
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
420+
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
421421
}
422422
}
423423

@@ -482,6 +482,81 @@ module {
482482

483483
// linalg.generic is tileable. The op is tiled and fused.
484484
transform.structured.fuse_into_containing_op %0 into %1
485-
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> !transform.any_op
485+
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.any_op)
486+
}
487+
}
488+
489+
// -----
490+
491+
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
492+
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
493+
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
494+
#map3 = affine_map<(d0) -> (d0)>
495+
496+
module {
497+
// CHECK-LABEL: func.func @fuse_tileable_using_new_handle
498+
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
499+
// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
500+
// CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
501+
// CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
502+
// CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
503+
func.func @fuse_tileable_using_new_handle(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>)
504+
-> (tensor<?xf32>, tensor<?xf32>) {
505+
%cst = arith.constant 4.200000e+01 : f32
506+
%c0 = arith.constant 0 : index
507+
508+
%0 = linalg.generic {
509+
indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
510+
} ins(%in : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
511+
^bb0(%a: f32, %b: f32):
512+
%d = arith.addf %a, %b : f32
513+
linalg.yield %d : f32
514+
} -> tensor<?xf32>
515+
516+
%1 = linalg.generic {
517+
indexing_maps = [#map3, #map3], iterator_types = ["parallel"]
518+
} ins(%0 : tensor<?xf32>) outs(%out_1 : tensor<?xf32>) {
519+
^bb0(%a: f32, %b: f32):
520+
%d = arith.mulf %a, %b : f32
521+
linalg.yield %d : f32
522+
} -> tensor<?xf32>
523+
%d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
524+
525+
%2 = affine.apply #map0()[%d0, %idx]
526+
527+
// CHECK: %[[R0:.*]]:2 = scf.forall (%[[ARG5:.*]]) in (%{{.*}}) shared_outs(%[[ARG6:.*]] = %[[OUT_2]], %[[ARG7:.*]] = %[[OUT_1]])
528+
// CHECK-SAME: -> (tensor<?xf32>, tensor<?xf32>) {
529+
%3 = scf.forall (%i) in (%2) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
530+
// CHECK: %[[I0:.*]] = affine.apply {{.*}}
531+
%4 = affine.apply #map1(%i)[%idx]
532+
// CHECK: %[[I1:.*]] = affine.min {{.*}}
533+
%5 = affine.min #map2(%i)[%d0, %idx]
534+
%6 = tensor.extract_slice %o[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
535+
536+
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
537+
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
538+
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
539+
540+
%8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
541+
scf.forall.in_parallel {
542+
// CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
543+
tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
544+
}
545+
}
546+
// CHECK: return %[[R0]]#0, %[[R0]]#1
547+
func.return %3, %1 : tensor<?xf32>, tensor<?xf32>
548+
// CHECK: }
549+
}
550+
551+
transform.sequence failures(propagate) {
552+
^bb1(%arg1: !transform.any_op):
553+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.op<"linalg.generic">
554+
%add, %reduce = transform.split_handle %0 : (!transform.op<"linalg.generic">) -> (!transform.op<"linalg.generic">, !transform.op<"linalg.generic">)
555+
%1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.op<"scf.forall">
556+
557+
%fused_ops, %new_forall = transform.structured.fuse_into_containing_op %reduce into %1
558+
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
559+
%fused_ops_2, %new_forall_2 = transform.structured.fuse_into_containing_op %add into %new_forall
560+
: (!transform.op<"linalg.generic">, !transform.op<"scf.forall">) -> (!transform.any_op, !transform.op<"scf.forall">)
486561
}
487562
}

mlir/test/Dialect/Linalg/transform-tile-and-fuse.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ module {
5252

5353
// Fuse all producers.
5454
transform.structured.fuse_into_containing_op %producers into %forall_op
55-
: (!transform.any_op, !transform.any_op) -> !transform.any_op
55+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
5656
}
5757
}
5858

@@ -112,6 +112,6 @@ module {
112112

113113
// Fuse all producers.
114114
transform.structured.fuse_into_containing_op %reversed_producers into %forall_op
115-
: (!transform.any_op, !transform.any_op) -> !transform.any_op
115+
: (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
116116
}
117117
}

0 commit comments

Comments
 (0)