Skip to content

Commit 51d524a

Browse files
author
Rolf Morel
committed
[SCF][Transform] Add support for scf.for in LoopFuseSibling op
Adds support for fusing two scf.for loops occurring in the same block. Implementation mirrors that of LoopFuseSibling's support for scf.forall, including only rudimentary checks, like the target loop's operands being dominated by the source loop. Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than (the ops) where these values originate. Adds tests for using LoopFuseSibling on scf.for loops, including one which fails without the fix for the dominance check.
1 parent a8ab830 commit 51d524a

File tree

5 files changed

+300
-18
lines changed

5 files changed

+300
-18
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,11 +342,13 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
342342
Fuses the `target` loop into the `source` loop assuming they are
343343
independent of each other. It is the responsibility of the user to ensure
344344
that the given two loops are independent of each other, this operation will
345-
not performa any legality checks and will simply fuse the two given loops.
345+
not perform any legality checks and will simply fuse the two given loops.
346346

347-
Currently, the only fusion supported is when both `target` and `source`
348-
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
349-
mapping must match, otherwise a silencable failure is produced.
347+
Currently, fusion is only supported in case both `target` and `source` are
348+
`scf.for` operations or both are `scf.forall` operations. For `scf.for`
349+
fusion the bounds and step size must match. For `scf.forall` fusion the
350+
bounds and the mapping must match. Otherwise a silencable failure is
351+
produced.
350352

351353
The input handles `target` and `source` must map to exactly one operation,
352354
a definite failure is produced otherwise.

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
162162
scf::ForallOp source,
163163
RewriterBase &rewriter);
164164

165+
/// Given two scf.for loops, `target` and `source`, fuses `target` into
166+
/// `source`. Assumes that the given loops are siblings and are independent of
167+
/// each other.
168+
///
169+
/// This function does not perform any legality checks and simply fuses the
170+
/// loops. The caller is responsible for ensuring that the loops are legal to
171+
/// fuse.
172+
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
173+
RewriterBase &rewriter);
174+
165175
} // namespace mlir
166176

167177
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
441441
bool failed = false;
442442
OpOperand *failedValue = nullptr;
443443
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
444-
if (!domInfo.properlyDominates(operand->getOwner(), source,
445-
/*enclosingOpOk=*/false)) {
444+
Operation *operandOp = operand->get().getDefiningOp();
445+
if (operandOp && !domInfo.properlyDominates(operandOp, source,
446+
/*enclosingOpOk=*/false)) {
447+
// `operand` is not a block argument and its defining op does not
448+
// dominate `source`
446449
failed = true;
447450
failedValue = operand;
448451
}
@@ -476,15 +479,34 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
476479
targetOp.getMapping() == sourceOp.getMapping();
477480
}
478481

479-
/// Fuse `target` into `source` assuming they are siblings and indepndent.
480-
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
482+
static bool isForWithIdenticalConfiguration(Operation *target,
483+
Operation *source) {
484+
auto targetOp = dyn_cast<scf::ForOp>(target);
485+
auto sourceOp = dyn_cast<scf::ForOp>(source);
486+
if (!targetOp || !sourceOp)
487+
return false;
488+
489+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
490+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
491+
targetOp.getStep() == sourceOp.getStep();
492+
}
493+
494+
/// Fuse `target` into `source` assuming they are siblings and independent.
495+
/// TODO: Support fusion for operations besides scf.for and scf.forall.
481496
static Operation *fuseSiblings(Operation *target, Operation *source,
482497
RewriterBase &rewriter) {
483-
auto targetOp = dyn_cast<scf::ForallOp>(target);
484-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
485-
if (!targetOp || !sourceOp)
486-
return nullptr;
487-
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
498+
auto targetForOp = dyn_cast<scf::ForOp>(target);
499+
auto sourceForOp = dyn_cast<scf::ForOp>(source);
500+
if (targetForOp && sourceForOp)
501+
return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
502+
503+
auto targetForallOp = dyn_cast<scf::ForallOp>(target);
504+
auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
505+
if (targetForallOp && sourceForallOp)
506+
return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
507+
rewriter);
508+
509+
return nullptr;
488510
}
489511

490512
DiagnosedSilenceableFailure
@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
511533
return diag;
512534

513535
// Check if the target can be fused into source.
514-
if (!isForallWithIdenticalConfiguration(target, source)) {
536+
if (!isForallWithIdenticalConfiguration(target, source) &&
537+
!isForWithIdenticalConfiguration(target, source)) {
515538
return emitSilenceableFailure(target->getLoc())
516539
<< "operations cannot be fused";
517540
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,69 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
970970

971971
return fusedLoop;
972972
}
973+
974+
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
975+
scf::ForOp source,
976+
RewriterBase &rewriter) {
977+
// Create fused init_args.
978+
auto targetInitArgs = target.getInitArgs();
979+
auto sourceInitArgs = source.getInitArgs();
980+
SmallVector<Value> fusedInitArgs;
981+
fusedInitArgs.reserve(targetInitArgs.size() + sourceInitArgs.size());
982+
fusedInitArgs.append(sourceInitArgs.begin(), sourceInitArgs.end());
983+
fusedInitArgs.append(targetInitArgs.begin(), targetInitArgs.end());
984+
985+
// Create a new scf::for op after the source loop.
986+
rewriter.setInsertionPointAfter(source);
987+
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
988+
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
989+
source.getStep(), fusedInitArgs);
990+
991+
SmallVector<Value> yieldResults;
992+
993+
// First merge source loop into the new (fused) for loop and then target loop.
994+
rewriter.setInsertionPointToEnd(fusedLoop.getBody());
995+
for (auto loopAndInitArgsBegin :
996+
{std::pair(source, (unsigned int)0),
997+
std::pair(target, source.getNumRegionIterArgs())}) {
998+
auto origLoop = loopAndInitArgsBegin.first;
999+
IRMapping mapping;
1000+
1001+
mapping.map(origLoop.getInductionVar(), fusedLoop.getInductionVar());
1002+
for (size_t i = 0; i < origLoop.getNumRegionIterArgs(); ++i) {
1003+
mapping.map(
1004+
origLoop.getRegionIterArgs()[i],
1005+
fusedLoop.getRegionIterArgs()[loopAndInitArgsBegin.second + i]);
1006+
}
1007+
1008+
for (Operation &op : origLoop.getBody()->getOperations()) {
1009+
rewriter.clone(op, mapping);
1010+
}
1011+
1012+
if (origLoop.getNumResults() > 0) {
1013+
scf::YieldOp yieldFromOrigLoop =
1014+
cast<scf::YieldOp>(fusedLoop.getBody()->getTerminator());
1015+
yieldResults.append(yieldFromOrigLoop.getOperands().begin(),
1016+
yieldFromOrigLoop.getOperands().end());
1017+
rewriter.eraseOp(yieldFromOrigLoop);
1018+
}
1019+
}
1020+
1021+
// Construct combined YieldOp
1022+
rewriter.setInsertionPointToEnd(fusedLoop.getBody());
1023+
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1024+
1025+
// Replace all uses of the old loops with the fused loop.
1026+
unsigned numSourceOuts = source.getNumResults();
1027+
rewriter.replaceAllUsesWith(source.getResults(),
1028+
fusedLoop.getResults().slice(0, numSourceOuts));
1029+
rewriter.replaceAllUsesWith(
1030+
target.getResults(),
1031+
fusedLoop.getResults().slice(numSourceOuts, target.getNumResults()));
1032+
1033+
// Erase the old loops.
1034+
rewriter.eraseOp(target);
1035+
rewriter.eraseOp(source);
1036+
1037+
return fusedLoop;
1038+
}

0 commit comments

Comments
 (0)