Skip to content

Commit 69e5620

Browse files
author
Rolf Morel
committed
[SCF][Transform] Add support for scf.for in LoopFuseSibling op
Adds support for fusing two scf.for loops occurring in the same block. Implementation mirrors that of LoopFuseSibling's support for scf.forall, including only rudimentary checks, like the target loop's operands being dominated by the source loop. Fixes a bug in the dominance check whereby it was checked that values in the target loop themselves dominated the source loop rather than (the ops) where these values originate. Adds tests for using LoopFuseSibling on scf.for loops, including one which fails without the fix for the dominance check.
1 parent 8ecc377 commit 69e5620

File tree

5 files changed

+324
-84
lines changed

5 files changed

+324
-84
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -333,23 +333,25 @@ def TakeAssumedBranchOp : Op<Transform_Dialect, "scf.take_assumed_branch", [
333333
}];
334334
}
335335

336-
def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
336+
def LoopFuseSiblingOp : Op<Transform_Dialect, "loop.fuse_sibling",
337337
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
338338
DeclareOpInterfaceMethods<TransformOpInterface>]> {
339339
let summary = "Fuse a loop into another loop, assuming the fusion is legal.";
340340

341341
let description = [{
342342
Fuses the `target` loop into the `source` loop assuming they are
343-
independent of each other. It is the responsibility of the user to ensure
344-
that the given two loops are independent of each other, this operation will
345-
not performa any legality checks and will simply fuse the two given loops.
343+
independent of each other. In the fused loop, the arguments, body and
344+
results of `target` are placed _before_ those of `source`.
346345

347-
Currently, the only fusion supported is when both `target` and `source`
348-
are `scf.forall` operations. For `scf.forall` fusion, the bounds and the
349-
mapping must match, otherwise a silencable failure is produced.
346+
For fusion of two `scf.for` loops, the bounds and step size must match. For
347+
fusion of two `scf.forall` loops, the bounds and the mapping must match.
348+
Otherwise a silencable failure is produced. Attempting to fuse any other kinds
349+
of loops/ops will produce a definite failure.
350350

351-
The input handles `target` and `source` must map to exactly one operation,
352-
a definite failure is produced otherwise.
351+
The `target` and `source` handles must refer to exactly one operation,
352+
otherwise a definite failure is produced. It is the responsibility of the
353+
user to ensure that the `target` and `source` loops are independent of each
354+
other -- this op will not perform any legality checks.
353355

354356
#### Return modes
355357

@@ -362,10 +364,6 @@ def LoopFuseSibling : Op<Transform_Dialect, "loop.fuse_sibling",
362364
let results = (outs TransformHandleTypeInterface:$fused_loop);
363365
let assemblyFormat = "$target `into` $source attr-dict "
364366
" `:` functional-type(operands, results)";
365-
366-
let builders = [
367-
OpBuilder<(ins "Value":$loop, "Value":$fused_loop)>
368-
];
369367
}
370368

371369
#endif // SCF_TRANSFORM_OPS

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
162162
scf::ForallOp source,
163163
RewriterBase &rewriter);
164164

165+
/// Given two scf.for loops, `target` and `source`, fuses `target` into
166+
/// `source`. Assumes that the given loops are siblings and are independent of
167+
/// each other.
168+
///
169+
/// This function does not perform any legality checks and simply fuses the
170+
/// loops. The caller is responsible for ensuring that the loops are legal to
171+
/// fuse.
172+
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
173+
RewriterBase &rewriter);
174+
165175
} // namespace mlir
166176

167177
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ void transform::TakeAssumedBranchOp::getEffects(
384384
}
385385

386386
//===----------------------------------------------------------------------===//
387-
// LoopFuseSibling
387+
// LoopFuseSiblingOp
388388
//===----------------------------------------------------------------------===//
389389

390390
/// Check if `target` and `source` are siblings, in the context that `target`
@@ -408,7 +408,7 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
408408
// Check if fusion will violate dominance.
409409
DominanceInfo domInfo(source);
410410
if (target->isBeforeInBlock(source)) {
411-
// Since, `target` is before `source`, all users of results of `target`
411+
// Since `target` is before `source`, all users of results of `target`
412412
// need to be dominated by `source`.
413413
for (Operation *user : target->getUsers()) {
414414
if (!domInfo.properlyDominates(source, user, /*enclosingOpOk=*/false)) {
@@ -424,9 +424,8 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
424424
// Check if operands of `target` are dominated by `source`.
425425
for (Value operand : target->getOperands()) {
426426
Operation *operandOp = operand.getDefiningOp();
427-
// If operand does not have a defining operation, it is a block arguement,
428-
// which will always dominate `source`, since `target` and `source` are in
429-
// the same block and the operand dominated `source` before.
427+
// Operands without defining operations are block arguments. When `target`
428+
// and `source` occur in the same block, these operands dominate `source`.
430429
if (!operandOp)
431430
continue;
432431

@@ -441,8 +440,12 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
441440
bool failed = false;
442441
OpOperand *failedValue = nullptr;
443442
visitUsedValuesDefinedAbove(target->getRegions(), [&](OpOperand *operand) {
444-
if (!domInfo.properlyDominates(operand->getOwner(), source,
445-
/*enclosingOpOk=*/false)) {
443+
Operation *operandOp = operand->get().getDefiningOp();
444+
if (!operandOp && !domInfo.properlyDominates(operandOp, source,
445+
/*enclosingOpOk=*/false)) {
446+
// `operand` is not a block argument of an enclosing block or otherwise
447+
// `operand`'s defining op is outside `target` but does not dominate
448+
// `source`.
446449
failed = true;
447450
failedValue = operand;
448451
}
@@ -476,21 +479,40 @@ static bool isForallWithIdenticalConfiguration(Operation *target,
476479
targetOp.getMapping() == sourceOp.getMapping();
477480
}
478481

479-
/// Fuse `target` into `source` assuming they are siblings and indepndent.
480-
/// TODO: Add fusion for more operations. Currently, we handle only scf.forall.
482+
static bool isForWithIdenticalConfiguration(Operation *target,
483+
Operation *source) {
484+
auto targetOp = dyn_cast<scf::ForOp>(target);
485+
auto sourceOp = dyn_cast<scf::ForOp>(source);
486+
if (!targetOp || !sourceOp)
487+
return false;
488+
489+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
490+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
491+
targetOp.getStep() == sourceOp.getStep();
492+
}
493+
494+
/// Fuse `target` into `source` assuming they are siblings and independent.
495+
/// TODO: Support fusion for operations besides scf.for and scf.forall.
481496
static Operation *fuseSiblings(Operation *target, Operation *source,
482497
RewriterBase &rewriter) {
483-
auto targetOp = dyn_cast<scf::ForallOp>(target);
484-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
485-
if (!targetOp || !sourceOp)
486-
return nullptr;
487-
return fuseIndependentSiblingForallLoops(targetOp, sourceOp, rewriter);
498+
auto targetForOp = dyn_cast<scf::ForOp>(target);
499+
auto sourceForOp = dyn_cast<scf::ForOp>(source);
500+
if (targetForOp && sourceForOp)
501+
return fuseIndependentSiblingForLoops(targetForOp, sourceForOp, rewriter);
502+
503+
auto targetForallOp = dyn_cast<scf::ForallOp>(target);
504+
auto sourceForallOp = dyn_cast<scf::ForallOp>(source);
505+
if (targetForallOp && sourceForallOp)
506+
return fuseIndependentSiblingForallLoops(targetForallOp, sourceForallOp,
507+
rewriter);
508+
509+
return nullptr;
488510
}
489511

490512
DiagnosedSilenceableFailure
491-
transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
492-
transform::TransformResults &results,
493-
transform::TransformState &state) {
513+
transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
514+
transform::TransformResults &results,
515+
transform::TransformState &state) {
494516
auto targetOps = state.getPayloadOps(getTarget());
495517
auto sourceOps = state.getPayloadOps(getSource());
496518

@@ -511,7 +533,8 @@ transform::LoopFuseSibling::apply(transform::TransformRewriter &rewriter,
511533
return diag;
512534

513535
// Check if the target can be fused into source.
514-
if (!isForallWithIdenticalConfiguration(target, source)) {
536+
if (!isForallWithIdenticalConfiguration(target, source) &&
537+
!isForWithIdenticalConfiguration(target, source)) {
515538
return emitSilenceableFailure(target->getLoc())
516539
<< "operations cannot be fused";
517540
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -910,61 +910,96 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
910910
unsigned numTargetOuts = target.getNumResults();
911911
unsigned numSourceOuts = source.getNumResults();
912912

913-
OperandRange targetOuts = target.getOutputs();
914-
OperandRange sourceOuts = source.getOutputs();
915-
916913
// Create fused shared_outs.
917914
SmallVector<Value> fusedOuts;
918-
fusedOuts.reserve(numTargetOuts + numSourceOuts);
919-
fusedOuts.append(targetOuts.begin(), targetOuts.end());
920-
fusedOuts.append(sourceOuts.begin(), sourceOuts.end());
915+
llvm::append_range(fusedOuts, target.getOutputs());
916+
llvm::append_range(fusedOuts, source.getOutputs());
921917

922-
// Create a new scf::forall op after the source loop.
918+
// Create a new scf.forall op after the source loop.
923919
rewriter.setInsertionPointAfter(source);
924920
scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>(
925921
source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(),
926922
source.getMixedStep(), fusedOuts, source.getMapping());
927923

928924
// Map control operands.
929-
IRMapping fusedMapping;
930-
fusedMapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
931-
fusedMapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
925+
IRMapping mapping;
926+
mapping.map(target.getInductionVars(), fusedLoop.getInductionVars());
927+
mapping.map(source.getInductionVars(), fusedLoop.getInductionVars());
932928

933929
// Map shared outs.
934-
fusedMapping.map(target.getRegionIterArgs(),
935-
fusedLoop.getRegionIterArgs().slice(0, numTargetOuts));
936-
fusedMapping.map(
937-
source.getRegionIterArgs(),
938-
fusedLoop.getRegionIterArgs().slice(numTargetOuts, numSourceOuts));
930+
mapping.map(target.getRegionIterArgs(),
931+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
932+
mapping.map(source.getRegionIterArgs(),
933+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
939934

940935
// Append everything except the terminator into the fused operation.
941936
rewriter.setInsertionPointToStart(fusedLoop.getBody());
942937
for (Operation &op : target.getBody()->without_terminator())
943-
rewriter.clone(op, fusedMapping);
938+
rewriter.clone(op, mapping);
944939
for (Operation &op : source.getBody()->without_terminator())
945-
rewriter.clone(op, fusedMapping);
940+
rewriter.clone(op, mapping);
946941

947942
// Fuse the old terminator in_parallel ops into the new one.
948943
scf::InParallelOp targetTerm = target.getTerminator();
949944
scf::InParallelOp sourceTerm = source.getTerminator();
950945
scf::InParallelOp fusedTerm = fusedLoop.getTerminator();
951-
952946
rewriter.setInsertionPointToStart(fusedTerm.getBody());
953947
for (Operation &op : targetTerm.getYieldingOps())
954-
rewriter.clone(op, fusedMapping);
948+
rewriter.clone(op, mapping);
955949
for (Operation &op : sourceTerm.getYieldingOps())
956-
rewriter.clone(op, fusedMapping);
957-
958-
// Replace all uses of the old loops with the fused loop.
959-
rewriter.replaceAllUsesWith(target.getResults(),
960-
fusedLoop.getResults().slice(0, numTargetOuts));
961-
rewriter.replaceAllUsesWith(
962-
source.getResults(),
963-
fusedLoop.getResults().slice(numTargetOuts, numSourceOuts));
964-
965-
// Erase the old loops.
966-
rewriter.eraseOp(target);
967-
rewriter.eraseOp(source);
950+
rewriter.clone(op, mapping);
951+
952+
// Replace old loops by substituting their uses by results of the fused loop.
953+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
954+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
955+
956+
return fusedLoop;
957+
}
958+
959+
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
960+
scf::ForOp source,
961+
RewriterBase &rewriter) {
962+
unsigned numTargetOuts = target.getNumResults();
963+
unsigned numSourceOuts = source.getNumResults();
964+
965+
// Create fused init_args, with target's init_args before source's init_args.
966+
SmallVector<Value> fusedInitArgs;
967+
llvm::append_range(fusedInitArgs, target.getInitArgs());
968+
llvm::append_range(fusedInitArgs, source.getInitArgs());
969+
970+
// Create a new scf.for op after the source loop.
971+
rewriter.setInsertionPointAfter(source);
972+
scf::ForOp fusedLoop = rewriter.create<scf::ForOp>(
973+
source.getLoc(), source.getLowerBound(), source.getUpperBound(),
974+
source.getStep(), fusedInitArgs);
975+
976+
// Map original induction variables and operands to those of the fused loop.
977+
IRMapping mapping;
978+
mapping.map(target.getInductionVar(), fusedLoop.getInductionVar());
979+
mapping.map(target.getRegionIterArgs(),
980+
fusedLoop.getRegionIterArgs().take_front(numTargetOuts));
981+
mapping.map(source.getInductionVar(), fusedLoop.getInductionVar());
982+
mapping.map(source.getRegionIterArgs(),
983+
fusedLoop.getRegionIterArgs().take_back(numSourceOuts));
984+
985+
// Merge target's body into the new (fused) for loop and then source's body.
986+
rewriter.setInsertionPointToStart(fusedLoop.getBody());
987+
for (Operation &op : target.getBody()->without_terminator())
988+
rewriter.clone(op, mapping);
989+
for (Operation &op : source.getBody()->without_terminator())
990+
rewriter.clone(op, mapping);
991+
992+
// Build fused yield results by appropriately mapping original yield operands.
993+
SmallVector<Value> yieldResults;
994+
for (Value operand : target.getBody()->getTerminator()->getOperands())
995+
yieldResults.push_back(mapping.lookupOrDefault(operand));
996+
for (Value operand : source.getBody()->getTerminator()->getOperands())
997+
yieldResults.push_back(mapping.lookupOrDefault(operand));
998+
rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults);
999+
1000+
// Replace old loops by substituting their uses by results of the fused loop.
1001+
rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts));
1002+
rewriter.replaceOp(source, fusedLoop.getResults().take_back(numSourceOuts));
9681003

9691004
return fusedLoop;
9701005
}

0 commit comments

Comments
 (0)