Skip to content

Commit 98e2ab4

Browse files
author
Rolf Morel
committed
[SCF][Transform] Add support for scf.for in LoopFuseSibling op
Adds support for fusing two scf.for loops occurring in the same block. Uses the rudimentary checks already in place for scf.for_all (like the target loop's operands being dominated by the source loop). - Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than the ops that define these operands. - Renames the LoopFuseSibling op to LoopFuseSiblingOp. - Updates the LoopFuseSiblingOp's description. - Adds tests for using LoopFuseSiblingOp on scf.for loops, including one which fails without the fix for the dominance check. - Adds tests checking the different failure modes of the dominance checker. - Adds test for case whereby scf.yield is automatically generated when there are no loop-carried variables.
1 parent 8ecc377 commit 98e2ab4

File tree

5 files changed

+359
-84
lines changed

5 files changed

+359
-84
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -333,23 +333,24 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
333333
}];
334334
}
335335

336-
def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
336+
def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
337337
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
338338
DeclareOpInterfaceMethods<TransformOpInterface>]> {
339339
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
340340

341341
let description = [{
342342
Fuses the `target` loop into the `source` loop assuming they are
343-
independent of each other. It is the responsibility of the user to ensure
344-
that the given two loops are independent of each other, this operation will
345-
not performa any legality checks and will simply fuse the two given loops.
343+
independent of each other. In the fused loop, the arguments, body and
344+
results of `target` are placed _before_ those of `source`.
346345

347-
Currently, the only fusion supported is when both `target` and `source`
348-
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
349-
mapping must match, otherwise a silencable failure is produced.
346+
For fusion of two `scf.for` loops, the bounds and step size must match. For
347+
fusion of two `scf.forall` loops, the bounds and the mapping must match.
348+
Otherwise a silencable failure is produced.
350349

351-
The input handles `target` and `source` must map to exactly one operation,
352-
a definite failure is produced otherwise.
350+
The `target` and `source` handles must refer to exactly one operation,
351+
otherwise a definite failure is produced. It is the responsibility of the
352+
user to ensure that the `target` and `source` loops are independent of each
353+
other -- this op will only perform rudimentary legality checks.
353354

354355
#### Return modes
355356

@@ -362,10 +363,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
362363
let results = (outs TransformHandleTypeInterface:$fused_loop);
363364
let assemblyFormat = "$target `into` $source attr-dict "
364365
" `:` functional-type(operands, results)";
365-
366-
let builders = [
367-
OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
368-
];
369366
}
370367

371368
#endif // SCF_TRANSFORM_OPS

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
162162
scf::ForallOp source,
163163
RewriterBase &rewriter);
164164

165+
/// Given two scf.for loops, `target` and `source`, fuses `target` into
166+
/// `source`. Assumes that the given loops are siblings and are independent of
167+
/// each other.
168+
///
169+
/// This function does not perform any legality checks and simply fuses the
170+
/// loops. The caller is responsible for ensuring that the loops are legal to
171+
/// fuse.
172+
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
173+
RewriterBase &rewriter);
174+
165175
} // namespace mlir
166176

167177
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
384384
}
385385

386386
//===----------------------------------------------------------------------===//
387-
// LoopFuseSibling
387+
// LoopFuseSiblingOp
388388
//===----------------------------------------------------------------------===//
389389

390390
/// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
408408
// Check if fusion will violate dominance.
409409
DominanceInfo domInfo(source);
410410
if (target->isBeforeInBlock(source)) {
411-
// Since, `target` is before `source`, all users of results of `target`
411+
// Since `target` is before `source`, all users of results of `target`
412412
// need to be dominated by `source`.
413413
for (Operation *user : target->getUsers()) {
414414
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
424424
// Check if operands of `target` are dominated by `source`.
425425
for (Value operand : target->getOperands()) {
426426
Operation *operandOp = operand.getDefiningOp();
427-
// If operand does not have a defining operation, it is a block arguement,
428-
// which will always dominate `source`, since `target` and `source` are in
429-
// the same block and the operand dominated `source` before.
427+
// Operands without defining operations are block arguments. When `target`
428+
// and `source` occur in the same block, these operands dominate `source`.
430429
if (!operandOp)
431430
continue;
432431

@@ -441,8 +440,11 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
441440
bool failed = false;
442441
OpOperand *failedValue = nullptr;
443442
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
444-
if (!domInfo.properlyDominates(operand->getOwner(), source,
445-
/*enclosingOpOk=*/false)) {
443+
Operation *operandOp = operand->get().getDefiningOp();
444+
if (operandOp && !domInfo.properlyDominates(operandOp, source,
445+
/*enclosingOpOk=*/false)) {
446+
// `operand` is not an argument of an enclosing block and the defining
447+
// op of `operand` is outside `target` but does not dominate `source`.
446448
failed = true;
447449
failedValue = operand;
448450
}
@@ -476,21 +478,40 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
476478
targetOp.getMapping() == sourceOp.getMapping();
477479
}
478480

479-
/// Fuse `target` into `source` assuming they are siblings and indepndent.
480-
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
481+
static bool isForWithIdenticalConfiguration(Operation *target,
482+
Operation *source) {
483+
auto targetOp = dyn_cast<scf::ForOp>(target);
484+
auto sourceOp = dyn_cast<scf::ForOp>(source);
485+
if (!targetOp || !sourceOp)
486+
return false;
487+
488+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
489+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
490+
targetOp.getStep() == sourceOp.getStep();
491+
}
492+
493+
/// Fuse `target` into `source` assuming they are siblings and independent.
494+
/// TODO: Support fusion for operations besides scf.for and scf.forall.
481495
static Operation *fuseSiblings(Operation *target, Operation *source,
482496
RewriterBase &rewriter) {
483-
auto targetOp = dyn_cast<scf::ForallOp>(target);
484-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
485-
if (!targetOp || !sourceOp)
486-
return nullptr;
487-
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
497+
auto targetForOp = dyn_cast<scf::ForOp>(target);
498+
auto sourceForOp = dyn_cast<scf::ForOp>(source);
499+
if (targetForOp && sourceForOp)
500+
return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
501+
502+
auto targetForallOp = dyn_cast<scf::ForallOp>(target);
503+
auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
504+
if (targetForallOp && sourceForallOp)
505+
return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
506+
rewriter);
507+
508+
return nullptr;
488509
}
489510

490511
DiagnosedSilenceableFailure
491-
transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
492-
transform::TransformResults &results,
493-
transform::TransformState &state) {
512+
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
513+
transform::TransformResults &results,
514+
transform::TransformState &state) {
494515
auto targetOps = state.getPayloadOps(getTarget());
495516
auto sourceOps = state.getPayloadOps(getSource());
496517

@@ -511,7 +532,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
511532
return diag;
512533

513534
// Check if the target can be fused into source.
514-
if (!isForallWithIdenticalConfiguration(target, source)) {
535+
if (!isForallWithIdenticalConfiguration(target, source) &&
536+
!isForWithIdenticalConfiguration(target, source)) {
515537
return emitSilenceableFailure(target->getLoc())
516538
<< "operations cannot be fused";
517539
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -910,61 +910,98 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
910910
unsigned numTargetOuts = target.getNumResults();
911911
unsigned numSourceOuts = source.getNumResults();
912912

913-
OperandRange targetOuts = target.getOutputs();
914-
OperandRange sourceOuts = source.getOutputs();
915-
916913
// Create fused shared_outs.
917914
SmallVector<Value> fusedOuts;
918-
fusedOuts.reserve(numTargetOuts + numSourceOuts);
919-
fusedOuts.append(targetOuts.begin(), targetOuts.end());
920-
fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
915+
llvm::append_range(fusedOuts, target.getOutputs());
916+
llvm::append_range(fusedOuts, source.getOutputs());
921917

922-
// Create a new scf::forall op after the source loop.
918+
// Create a new scf.forall op after the source loop.
923919
rewriter.setInsertionPointAfter(source);
924920
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
925921
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
926922
source.getMixedStep(), fusedOuts, source.getMapping());
927923

928924
// Map control operands.
929-
IRMapping fusedMapping;
930-
fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
931-
fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
925+
IRMapping mapping;
926+
mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
927+
mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
932928

933929
// Map shared outs.
934-
fusedMapping.map(target.getRegionIterArgs(),
935-
fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
936-
fusedMapping.map(
937-
source.getRegionIterArgs(),
938-
fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
930+
mapping.map(target.getRegionIterArgs(),
931+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
932+
mapping.map(source.getRegionIterArgs(),
933+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
939934

940935
// Append everything except the terminator into the fused operation.
941936
rewriter.setInsertionPointToStart(fusedLoop.getBody());
942937
for (Operation &op : target.getBody()->without_terminator())
943-
rewriter.clone(op, fusedMapping);
938+
rewriter.clone(op, mapping);
944939
for (Operation &op : source.getBody()->without_terminator())
945-
rewriter.clone(op, fusedMapping);
940+
rewriter.clone(op, mapping);
946941

947942
// Fuse the old terminator in_parallel ops into the new one.
948943
scf::InParallelOp targetTerm = target.getTerminator();
949944
scf::InParallelOp sourceTerm = source.getTerminator();
950945
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
951-
952946
rewriter.setInsertionPointToStart(fusedTerm.getBody());
953947
for (Operation &op : targetTerm.getYieldingOps())
954-
rewriter.clone(op, fusedMapping);
948+
rewriter.clone(op, mapping);
955949
for (Operation &op : sourceTerm.getYieldingOps())
956-
rewriter.clone(op, fusedMapping);
957-
958-
// Replace all uses of the old loops with the fused loop.
959-
rewriter.replaceAllUsesWith(target.getResults(),
960-
fusedLoop.getResults().slice(0, numTargetOuts));
961-
rewriter.replaceAllUsesWith(
962-
source.getResults(),
963-
fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
964-
965-
// Erase the old loops.
966-
rewriter.eraseOp(target);
967-
rewriter.eraseOp(source);
950+
rewriter.clone(op, mapping);
951+
952+
// Replace old loops by substituting their uses by results of the fused loop.
953+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
954+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
955+
956+
return fusedLoop;
957+
}
958+
959+
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
960+
scf::ForOp source,
961+
RewriterBase &rewriter) {
962+
unsigned numTargetOuts = target.getNumResults();
963+
unsigned numSourceOuts = source.getNumResults();
964+
965+
// Create fused init_args, with target's init_args before source's init_args.
966+
SmallVector<Value> fusedInitArgs;
967+
llvm::append_range(fusedInitArgs, target.getInitArgs());
968+
llvm::append_range(fusedInitArgs, source.getInitArgs());
969+
970+
// Create a new scf.for op after the source loop (with scf.yield terminator
971+
// (without arguments) only in case its init_args is empty).
972+
rewriter.setInsertionPointAfter(source);
973+
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
974+
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
975+
source.getStep(), fusedInitArgs);
976+
977+
// Map original induction variables and operands to those of the fused loop.
978+
IRMapping mapping;
979+
mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
980+
mapping.map(target.getRegionIterArgs(),
981+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
982+
mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
983+
mapping.map(source.getRegionIterArgs(),
984+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
985+
986+
// Merge target's body into the new (fused) for loop and then source's body.
987+
rewriter.setInsertionPointToStart(fusedLoop.getBody());
988+
for (Operation &op : target.getBody()->without_terminator())
989+
rewriter.clone(op, mapping);
990+
for (Operation &op : source.getBody()->without_terminator())
991+
rewriter.clone(op, mapping);
992+
993+
// Build fused yield results by appropriately mapping original yield operands.
994+
SmallVector<Value> yieldResults;
995+
for (Value operand : target.getBody()->getTerminator()->getOperands())
996+
yieldResults.push_back(mapping.lookupOrDefault(operand));
997+
for (Value operand : source.getBody()->getTerminator()->getOperands())
998+
yieldResults.push_back(mapping.lookupOrDefault(operand));
999+
if (yieldResults.size())
1000+
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
1001+
1002+
// Replace old loops by substituting their uses by results of the fused loop.
1003+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1004+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
9681005

9691006
return fusedLoop;
9701007
}

0 commit comments

Comments
 (0)