Skip to content

Commit 4e04706

Browse files
author
Rolf Morel
committed
[SCF][Transform] Add support for scf.for in LoopFuseSibling op
Adds support for fusing two scf.for loops occurring in the same block. Implementation mirrors that of LoopFuseSibling's support for scf.forall, including only rudimentary checks, like the target loop's operands being dominated by the source loop. Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than (the ops) where these values originate. Adds tests for using LoopFuseSibling on scf.for loops, including one which fails without the fix for the dominance check.
1 parent 8ecc377 commit 4e04706

File tree

5 files changed

+337
-51
lines changed

5 files changed

+337
-51
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
333333
}];
334334
}
335335

336-
def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
336+
def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
337337
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
338338
DeclareOpInterfaceMethods<TransformOpInterface>]> {
339339
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
@@ -342,11 +342,13 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
342342
Fuses the `target` loop into the `source` loop assuming they are
343343
independent of each other. It is the responsibility of the user to ensure
344344
that the given two loops are independent of each other, this operation will
345-
not performa any legality checks and will simply fuse the two given loops.
345+
not perform any legality checks and will simply fuse the two given loops.
346346

347-
Currently, the only fusion supported is when both `target` and `source`
348-
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
349-
mapping must match, otherwise a silencable failure is produced.
347+
Currently, fusion is only supported in case both `target` and `source` are
348+
`scf.for` operations or both are `scf.forall` operations. For `scf.for`
349+
fusion the bounds and step size must match. For `scf.forall` fusion the
350+
bounds and the mapping must match. Otherwise a silencable failure is
351+
produced.
350352

351353
The input handles `target` and `source` must map to exactly one operation,
352354
a definite failure is produced otherwise.

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
162162
scf::ForallOp source,
163163
RewriterBase &rewriter);
164164

165+
/// Given two scf.for loops, `target` and `source`, fuses `target` into
166+
/// `source`. Assumes that the given loops are siblings and are independent of
167+
/// each other.
168+
///
169+
/// This function does not perform any legality checks and simply fuses the
170+
/// loops. The caller is responsible for ensuring that the loops are legal to
171+
/// fuse.
172+
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
173+
RewriterBase &rewriter);
174+
165175
} // namespace mlir
166176

167177
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
384384
}
385385

386386
//===----------------------------------------------------------------------===//
387-
// LoopFuseSibling
387+
// LoopFuseSiblingOp
388388
//===----------------------------------------------------------------------===//
389389

390390
/// Check if `target` and `source` are siblings, in the context that `target`
@@ -441,8 +441,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
441441
bool failed = false;
442442
OpOperand *failedValue = nullptr;
443443
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
444-
if (!domInfo.properlyDominates(operand->getOwner(), source,
445-
/*enclosingOpOk=*/false)) {
444+
Operation *operandOp = operand->get().getDefiningOp();
445+
if (operandOp && !domInfo.properlyDominates(operandOp, source,
446+
/*enclosingOpOk=*/false)) {
447+
// `operand` is not a block argument and its defining op does not
448+
// dominate `source`
446449
failed = true;
447450
failedValue = operand;
448451
}
@@ -476,21 +479,40 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
476479
targetOp.getMapping() == sourceOp.getMapping();
477480
}
478481

479-
/// Fuse `target` into `source` assuming they are siblings and indepndent.
480-
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
482+
static bool isForWithIdenticalConfiguration(Operation *target,
483+
Operation *source) {
484+
auto targetOp = dyn_cast<scf::ForOp>(target);
485+
auto sourceOp = dyn_cast<scf::ForOp>(source);
486+
if (!targetOp || !sourceOp)
487+
return false;
488+
489+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
490+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
491+
targetOp.getStep() == sourceOp.getStep();
492+
}
493+
494+
/// Fuse `target` into `source` assuming they are siblings and independent.
495+
/// TODO: Support fusion for operations besides scf.for and scf.forall.
481496
static Operation *fuseSiblings(Operation *target, Operation *source,
482497
RewriterBase &rewriter) {
483-
auto targetOp = dyn_cast<scf::ForallOp>(target);
484-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
485-
if (!targetOp || !sourceOp)
486-
return nullptr;
487-
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
498+
auto targetForOp = dyn_cast<scf::ForOp>(target);
499+
auto sourceForOp = dyn_cast<scf::ForOp>(source);
500+
if (targetForOp && sourceForOp)
501+
return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
502+
503+
auto targetForallOp = dyn_cast<scf::ForallOp>(target);
504+
auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
505+
if (targetForallOp && sourceForallOp)
506+
return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
507+
rewriter);
508+
509+
return nullptr;
488510
}
489511

490512
DiagnosedSilenceableFailure
491-
transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
492-
transform::TransformResults &results,
493-
transform::TransformState &state) {
513+
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
514+
transform::TransformResults &results,
515+
transform::TransformState &state) {
494516
auto targetOps = state.getPayloadOps(getTarget());
495517
auto sourceOps = state.getPayloadOps(getSource());
496518

@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
511533
return diag;
512534

513535
// Check if the target can be fused into source.
514-
if (!isForallWithIdenticalConfiguration(target, source)) {
536+
if (!isForallWithIdenticalConfiguration(target, source) &&
537+
!isForWithIdenticalConfiguration(target, source)) {
515538
return emitSilenceableFailure(target->getLoc())
516539
<< "operations cannot be fused";
517540
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -910,39 +910,34 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
910910
unsigned numTargetOuts = target.getNumResults();
911911
unsigned numSourceOuts = source.getNumResults();
912912

913-
OperandRange targetOuts = target.getOutputs();
914-
OperandRange sourceOuts = source.getOutputs();
915-
916913
// Create fused shared_outs.
917914
SmallVector<Value> fusedOuts;
918-
fusedOuts.reserve(numTargetOuts + numSourceOuts);
919-
fusedOuts.append(targetOuts.begin(), targetOuts.end());
920-
fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
915+
llvm::append_range(fusedOuts, target.getOutputs());
916+
llvm::append_range(fusedOuts, source.getOutputs());
921917

922-
// Create a new scf::forall op after the source loop.
918+
// Create a new scf.forall op after the source loop.
923919
rewriter.setInsertionPointAfter(source);
924920
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
925921
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
926922
source.getMixedStep(), fusedOuts, source.getMapping());
927923

928924
// Map control operands.
929-
IRMapping fusedMapping;
930-
fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
931-
fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
925+
IRMapping mapping;
926+
mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
927+
mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
932928

933929
// Map shared outs.
934-
fusedMapping.map(target.getRegionIterArgs(),
935-
fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
936-
fusedMapping.map(
937-
source.getRegionIterArgs(),
938-
fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
930+
mapping.map(target.getRegionIterArgs(),
931+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
932+
mapping.map(source.getRegionIterArgs(),
933+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
939934

940935
// Append everything except the terminator into the fused operation.
941936
rewriter.setInsertionPointToStart(fusedLoop.getBody());
942937
for (Operation &op : target.getBody()->without_terminator())
943-
rewriter.clone(op, fusedMapping);
938+
rewriter.clone(op, mapping);
944939
for (Operation &op : source.getBody()->without_terminator())
945-
rewriter.clone(op, fusedMapping);
940+
rewriter.clone(op, mapping);
946941

947942
// Fuse the old terminator in_parallel ops into the new one.
948943
scf::InParallelOp targetTerm = target.getTerminator();
@@ -951,20 +946,62 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
951946

952947
rewriter.setInsertionPointToStart(fusedTerm.getBody());
953948
for (Operation &op : targetTerm.getYieldingOps())
954-
rewriter.clone(op, fusedMapping);
949+
rewriter.clone(op, mapping);
955950
for (Operation &op : sourceTerm.getYieldingOps())
956-
rewriter.clone(op, fusedMapping);
951+
rewriter.clone(op, mapping);
957952

958953
// Replace all uses of the old loops with the fused loop.
959-
rewriter.replaceAllUsesWith(target.getResults(),
960-
fusedLoop.getResults().slice(0, numTargetOuts));
961-
rewriter.replaceAllUsesWith(
962-
source.getResults(),
963-
fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
964-
965-
// Erase the old loops.
966-
rewriter.eraseOp(target);
967-
rewriter.eraseOp(source);
954+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
955+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
956+
957+
return fusedLoop;
958+
}
959+
960+
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
961+
scf::ForOp source,
962+
RewriterBase &rewriter) {
963+
unsigned numTargetOuts = target.getNumResults();
964+
unsigned numSourceOuts = source.getNumResults();
965+
966+
// Create fused init_args, with target's init_args before source's init_args.
967+
SmallVector<Value> fusedInitArgs;
968+
llvm::append_range(fusedInitArgs, target.getInitArgs());
969+
llvm::append_range(fusedInitArgs, source.getInitArgs());
970+
971+
// Create a new scf.for op after the source loop.
972+
rewriter.setInsertionPointAfter(source);
973+
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
974+
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
975+
source.getStep(), fusedInitArgs);
976+
977+
// Map original induction variables and operands to those of the fused loop.
978+
IRMapping mapping;
979+
mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
980+
mapping.map(target.getRegionIterArgs(),
981+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
982+
983+
mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
984+
mapping.map(source.getRegionIterArgs(),
985+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
986+
987+
// Merge target's body into the new (fused) for loop and then source's body.
988+
rewriter.setInsertionPointToStart(fusedLoop.getBody());
989+
for (Operation &op : target.getBody()->without_terminator())
990+
rewriter.clone(op, mapping);
991+
for (Operation &op : source.getBody()->without_terminator())
992+
rewriter.clone(op, mapping);
993+
994+
// Build fused yield results by appropriately mapping original yield operands.
995+
SmallVector<Value> yieldResults;
996+
for (Value operand : target.getBody()->getTerminator()->getOperands())
997+
yieldResults.push_back(mapping.lookupOrDefault(operand));
998+
for (Value operand : source.getBody()->getTerminator()->getOperands())
999+
yieldResults.push_back(mapping.lookupOrDefault(operand));
1000+
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1001+
1002+
// Replace old loops by substituting their uses by results of the fused loop.
1003+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1004+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
9681005

9691006
return fusedLoop;
9701007
}

0 commit comments

Comments
 (0)