Skip to content

Commit 9aaf007

Browse files
committed
[SCF][Transform] Add transform.loop.fuse_sibling
This patch adds a new transform operation `transform.loop.fuse_sibling`, which given two loops, fuses them, assuming that they are independent. The transform operation itself performs very basic checks to ensure IR legality, and leaves the responsibility of ensuring independence on the user. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D157069
1 parent d163ae8 commit 9aaf007

File tree

5 files changed

+365
-0
lines changed

5 files changed

+365
-0
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,39 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
310310
}];
311311
}
312312

313+
def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
314+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
315+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
316+
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
317+
318+
let description = [{
319+
Fuses the `target` loop into the `source` loop assuming they are
320+
independent of each other. It is the responsibility of the user to ensure
321+
that the given two loops are independent of each other, this operation will
322+
not performa any legality checks and will simply fuse the two given loops.
323+
324+
Currently, the only fusion supported is when both `target` and `source`
325+
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
326+
mapping must match, otherwise a silencable failure is produced.
327+
328+
The input handles `target` and `source` must map to exactly one operation,
329+
a definite failure is produced otherwise.
330+
331+
#### Return modes
332+
333+
This operation consumes the `target` and `source` handles and produces the
334+
`fused_loop` handle, which points to the fused loop.
335+
}];
336+
337+
let arguments = (ins TransformHandleTypeInterface:$target,
338+
TransformHandleTypeInterface:$source);
339+
let results = (outs TransformHandleTypeInterface:$fused_loop);
340+
let assemblyFormat = "$target `into` $source attr-dict "
341+
" `:` functional-type(operands, results)";
342+
343+
let builders = [
344+
OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
345+
];
346+
}
347+
313348
#endif // SCF_TRANSFORM_OPS

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,17 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
185185
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
186186
scf::ForOp root);
187187

188+
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
189+
/// `source`. Assumes that the given loops are siblings and are independent of
190+
/// each other.
191+
///
192+
/// This function does not perform any legality checks and simply fuses the
193+
/// loops. The caller is responsible for ensuring that the loops are legal to
194+
/// fuse.
195+
scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
196+
scf::ForallOp source,
197+
RewriterBase &rewriter);
198+
188199
} // namespace mlir
189200

190201
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1919
#include "mlir/Dialect/Transform/IR/TransformOps.h"
2020
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/IR/Dominance.h"
2122

2223
using namespace mlir;
2324
using namespace mlir::affine;
@@ -318,6 +319,146 @@ void transform::TakeAssumedBranchOp::getEffects(
318319
modifiesPayload(effects);
319320
}
320321

322+
//===----------------------------------------------------------------------===//
323+
// LoopFuseSibling
324+
//===----------------------------------------------------------------------===//
325+
326+
/// Check if `target` and `source` are siblings, in the context that `target`
327+
/// is being fused into `source`.
328+
///
329+
/// This is a simple check that just checks if both operations are in the same
330+
/// block and some checks to ensure that the fused IR does not violate
331+
/// dominance.
332+
static DiagnosedSilenceableFailure isOpSibling(Operation *target,
333+
Operation *source) {
334+
// Check if both operations are same.
335+
if (target == source)
336+
return emitSilenceableFailure(source)
337+
<< "target and source need to be different loops";
338+
339+
// Check if both operations are in the same block.
340+
if (target->getBlock() != source->getBlock())
341+
return emitSilenceableFailure(source)
342+
<< "target and source are not in the same block";
343+
344+
// Check if fusion will violate dominance.
345+
DominanceInfo domInfo(source);
346+
if (target->isBeforeInBlock(source)) {
347+
// Since, `target` is before `source`, all users of results of `target`
348+
// need to be dominated by `source`.
349+
for (Operation *user : target->getUsers()) {
350+
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
351+
return emitSilenceableFailure(target)
352+
<< "user of results of target should be properly dominated by "
353+
"source";
354+
}
355+
}
356+
} else {
357+
// Since `target` is after `source`, all values used by `target` need
358+
// to dominate `source`.
359+
360+
// Check if operands of `target` are dominated by `source`.
361+
for (Value operand : target->getOperands()) {
362+
Operation *operandOp = operand.getDefiningOp();
363+
// If operand does not have a defining operation, it is a block arguement,
364+
// which will always dominate `source`, since `target` and `source` are in
365+
// the same block and the operand dominated `source` before.
366+
if (!operandOp)
367+
continue;
368+
369+
// Operand's defining operation should properly dominate `source`.
370+
if (!domInfo.properlyDominates(operandOp, source,
371+
/*enclosingOpOk=*/false))
372+
return emitSilenceableFailure(target)
373+
<< "operands of target should be properly dominated by source";
374+
}
375+
376+
// Check if values used by `target` are dominated by `source`.
377+
bool failed = false;
378+
OpOperand *failedValue = nullptr;
379+
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
380+
if (!domInfo.properlyDominates(operand->getOwner(), source,
381+
/*enclosingOpOk=*/false)) {
382+
failed = true;
383+
failedValue = operand;
384+
}
385+
});
386+
387+
if (failed)
388+
return emitSilenceableFailure(failedValue->getOwner())
389+
<< "values used inside regions of target should be properly "
390+
"dominated by source";
391+
}
392+
393+
return DiagnosedSilenceableFailure::success();
394+
}
395+
396+
/// Check if `target` can be fused into `source`.
397+
///
398+
/// This is a simple check that just checks if both loops have same
399+
/// bounds, steps and mapping. This check does not ensure that the side effects
400+
/// of `target` are independent of `source` or vice-versa. It is the
401+
/// responsibility of the caller to ensure that.
402+
static bool isForallWithIdenticalConfiguration(Operation *target,
403+
Operation *source) {
404+
auto targetOp = dyn_cast<scf::ForallOp>(target);
405+
auto sourceOp = dyn_cast<scf::ForallOp>(source);
406+
if (!targetOp || !sourceOp)
407+
return false;
408+
409+
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
410+
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
411+
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
412+
targetOp.getMapping() == sourceOp.getMapping();
413+
}
414+
415+
/// Fuse `target` into `source` assuming they are siblings and indepndent.
416+
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
417+
static Operation *fuseSiblings(Operation *target, Operation *source,
418+
RewriterBase &rewriter) {
419+
auto targetOp = dyn_cast<scf::ForallOp>(target);
420+
auto sourceOp = dyn_cast<scf::ForallOp>(source);
421+
if (!targetOp || !sourceOp)
422+
return nullptr;
423+
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
424+
}
425+
426+
DiagnosedSilenceableFailure
427+
transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
428+
transform::TransformResults &results,
429+
transform::TransformState &state) {
430+
auto targetOps = state.getPayloadOps(getTarget());
431+
auto sourceOps = state.getPayloadOps(getSource());
432+
433+
if (!llvm::hasSingleElement(targetOps) ||
434+
!llvm::hasSingleElement(sourceOps)) {
435+
return emitDefiniteFailure()
436+
<< "requires exactly one target handle (got "
437+
<< llvm::range_size(targetOps) << ") and exactly one "
438+
<< "source handle (got " << llvm::range_size(sourceOps) << ")";
439+
}
440+
441+
Operation *target = *targetOps.begin();
442+
Operation *source = *sourceOps.begin();
443+
444+
// Check if the target and source are siblings.
445+
DiagnosedSilenceableFailure diag = isOpSibling(target, source);
446+
if (!diag.succeeded())
447+
return diag;
448+
449+
// Check if the target can be fused into source.
450+
if (!isForallWithIdenticalConfiguration(target, source)) {
451+
return emitSilenceableFailure(target->getLoc())
452+
<< "operations cannot be fused";
453+
}
454+
455+
Operation *fusedLoop = fuseSiblings(target, source, rewriter);
456+
assert(fusedLoop && "failed to fuse operations");
457+
458+
results.set(cast<OpResult>(getFusedLoop()), {fusedLoop});
459+
return DiagnosedSilenceableFailure::success();
460+
}
461+
321462
//===----------------------------------------------------------------------===//
322463
// Transform op registration
323464
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,68 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp,
970970

971971
return tileLoops;
972972
}
973+
974+
scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
975+
scf::ForallOp source,
976+
RewriterBase &rewriter) {
977+
unsigned numTargetOuts = target.getNumResults();
978+
unsigned numSourceOuts = source.getNumResults();
979+
980+
OperandRange targetOuts = target.getOutputs();
981+
OperandRange sourceOuts = source.getOutputs();
982+
983+
// Create fused shared_outs.
984+
SmallVector<Value> fusedOuts;
985+
fusedOuts.reserve(numTargetOuts + numSourceOuts);
986+
fusedOuts.append(targetOuts.begin(), targetOuts.end());
987+
fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
988+
989+
// Create a new scf::forall op after the source loop.
990+
rewriter.setInsertionPointAfter(source);
991+
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
992+
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
993+
source.getMixedStep(), fusedOuts, source.getMapping());
994+
995+
// Map control operands.
996+
IRMapping fusedMapping;
997+
fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
998+
fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
999+
1000+
// Map shared outs.
1001+
fusedMapping.map(target.getOutputBlockArguments(),
1002+
fusedLoop.getOutputBlockArguments().slice(0, numTargetOuts));
1003+
fusedMapping.map(
1004+
source.getOutputBlockArguments(),
1005+
fusedLoop.getOutputBlockArguments().slice(numTargetOuts, numSourceOuts));
1006+
1007+
// Append everything except the terminator into the fused operation.
1008+
rewriter.setInsertionPointToStart(fusedLoop.getBody());
1009+
for (Operation &op : target.getLoopBody().begin()->without_terminator())
1010+
rewriter.clone(op, fusedMapping);
1011+
for (Operation &op : source.getLoopBody().begin()->without_terminator())
1012+
rewriter.clone(op, fusedMapping);
1013+
1014+
// Fuse the old terminator in_parallel ops into the new one.
1015+
scf::InParallelOp targetTerm = target.getTerminator();
1016+
scf::InParallelOp sourceTerm = source.getTerminator();
1017+
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
1018+
1019+
rewriter.setInsertionPointToStart(fusedTerm.getBody());
1020+
for (Operation &op : targetTerm.getYieldingOps())
1021+
rewriter.clone(op, fusedMapping);
1022+
for (Operation &op : sourceTerm.getYieldingOps())
1023+
rewriter.clone(op, fusedMapping);
1024+
1025+
// Replace all uses of the old loops with the fused loop.
1026+
rewriter.replaceAllUsesWith(target.getResults(),
1027+
fusedLoop.getResults().slice(0, numTargetOuts));
1028+
rewriter.replaceAllUsesWith(
1029+
source.getResults(),
1030+
fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
1031+
1032+
// Erase the old loops.
1033+
rewriter.eraseOp(target);
1034+
rewriter.eraseOp(source);
1035+
1036+
return fusedLoop;
1037+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// RUN: mlir-opt %s -test-transform-dialect-interpreter --cse --canonicalize -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
4+
%zero = arith.constant 0.0 : f32
5+
%out_alloc = tensor.empty() : tensor<128x128xf32>
6+
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
7+
8+
// CHECK: scf.forall ([[I:%.*]]) in (4) shared_outs([[S1:%.*]] = [[IN1:%.*]], [[S2:%.*]] = [[IN2:%.*]]) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
9+
// CHECK: [[T:%.*]] = affine.apply
10+
// CHECK: tensor.extract_slice [[S1]][[[T]], 0] [32, 128] [1, 1]
11+
// CHECK: [[OUT1:%.*]] = linalg.matmul
12+
// CHECK: tensor.extract_slice [[S2]][[[T]], 0] [32, 128] [1, 1]
13+
// CHECK: [[OUT2:%.*]] = linalg.matmul
14+
// CHECK: scf.forall.in_parallel {
15+
// CHECK: tensor.parallel_insert_slice [[OUT1]] into [[S1]][[[T]], 0] [32, 128] [1, 1]
16+
// CHECK: tensor.parallel_insert_slice [[OUT2]] into [[S2]][[[T]], 0] [32, 128] [1, 1]
17+
// CHECK: }
18+
// CHECK: }
19+
%out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
20+
%out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
21+
22+
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
23+
}
24+
25+
transform.sequence failures(propagate) {
26+
^bb0(%variant_op : !transform.any_op):
27+
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
28+
29+
%mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
30+
31+
%loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
32+
%loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
33+
34+
%fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
35+
}
36+
37+
// -----
38+
39+
func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
40+
%zero = arith.constant 0.0 : f32
41+
%out_alloc = tensor.empty() : tensor<128x128xf32>
42+
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
43+
44+
// expected-error @below {{user of results of target should be properly dominated by source}}
45+
%out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
46+
%out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
47+
48+
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
49+
}
50+
51+
transform.sequence failures(propagate) {
52+
^bb0(%variant_op : !transform.any_op):
53+
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
54+
55+
%mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
56+
57+
%loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
58+
%loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
59+
60+
%fused_loop = transform.loop.fuse_sibling %loop1 into %loop2 : (!transform.any_op, !transform.any_op) -> !transform.any_op
61+
}
62+
63+
// -----
64+
65+
func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
66+
%zero = arith.constant 0.0 : f32
67+
%out_alloc = tensor.empty() : tensor<128x128xf32>
68+
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
69+
70+
%out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
71+
// expected-error @below {{values used inside regions of target should be properly dominated by source}}
72+
%out2 = linalg.matmul ins(%A, %out1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
73+
74+
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
75+
}
76+
77+
transform.sequence failures(propagate) {
78+
^bb0(%variant_op : !transform.any_op):
79+
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
80+
81+
%mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
82+
83+
%loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
84+
%loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
85+
86+
%fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
87+
}
88+
89+
// -----
90+
91+
func.func @test(%A : tensor<128x128xf32>, %B1 : tensor<128x128xf32>, %B2 : tensor<128x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
92+
%zero = arith.constant 0.0 : f32
93+
%out_alloc = tensor.empty() : tensor<128x128xf32>
94+
%out = linalg.fill ins(%zero : f32) outs(%out_alloc : tensor<128x128xf32>) -> tensor<128x128xf32>
95+
96+
%out1 = linalg.matmul ins(%A, %B1 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out : tensor<128x128xf32>) -> tensor<128x128xf32>
97+
// expected-error @below {{operands of target should be properly dominated by source}}
98+
%out2 = linalg.matmul ins(%A, %B2 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%out1 : tensor<128x128xf32>) -> tensor<128x128xf32>
99+
100+
func.return %out1, %out2 : tensor<128x128xf32>, tensor<128x128xf32>
101+
}
102+
103+
transform.sequence failures(propagate) {
104+
^bb0(%variant_op : !transform.any_op):
105+
%matched = transform.structured.match ops{["linalg.matmul"]} in %variant_op : (!transform.any_op) -> (!transform.any_op)
106+
107+
%mm1, %mm2 = transform.split_handle %matched : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
108+
109+
%loop1, %tiled_mm1 = transform.structured.tile_to_forall_op %mm1 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
110+
%loop2, %tiled_mm2 = transform.structured.tile_to_forall_op %mm2 tile_sizes [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
111+
112+
%fused_loop = transform.loop.fuse_sibling %loop2 into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op
113+
}

0 commit comments

Comments
 (0)