Skip to content

Commit ebf0599

Browse files
authored
[MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. (#106436)
Allow speculative execution and predicate results per stage.
1 parent c1667f9 commit ebf0599

File tree

3 files changed

+179
-44
lines changed

3 files changed

+179
-44
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 78 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
9494
RewriterBase &rewriter);
9595
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
9696
/// operations from stages [i; maxStage], where i is the part index.
97-
void emitEpilogue(RewriterBase &rewriter,
98-
llvm::SmallVector<Value> &returnValues);
97+
LogicalResult emitEpilogue(RewriterBase &rewriter,
98+
llvm::SmallVector<Value> &returnValues);
9999
};
100100

101101
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
133133
LDBG("--no epilogue or predicate set -> BAIL");
134134
return false;
135135
}
136-
if (dynamicLoop && peelEpilogue) {
137-
LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
138-
return false;
139-
}
140136
std::vector<std::pair<Operation *, unsigned>> schedule;
141137
options.getScheduleFn(forOp, schedule);
142138
if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
313309
});
314310
int predicateIdx = i - stages[op];
315311
if (predicates[predicateIdx]) {
312+
OpBuilder::InsertionGuard insertGuard(rewriter);
316313
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
317314
assert(newOp && "failed to predicate op.");
318315
}
319-
rewriter.setInsertionPointAfter(newOp);
320316
if (annotateFn)
321317
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
322318
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
561557
}
562558

563559
if (predicates[useStage]) {
560+
OpBuilder::InsertionGuard insertGuard(rewriter);
564561
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
565562
if (!newOp)
566563
return failure();
567564
// Remap the results to the new predicated one.
568565
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
569566
mapping.map(std::get<0>(values), std::get<1>(values));
570567
}
571-
rewriter.setInsertionPointAfter(newOp);
572568
if (annotateFn)
573569
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
574570
}
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
640636
return success();
641637
}
642638

643-
void LoopPipelinerInternal::emitEpilogue(
644-
RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
639+
LogicalResult
640+
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
641+
llvm::SmallVector<Value> &returnValues) {
642+
Location loc = forOp.getLoc();
645643
// Emit different versions of the induction variable. They will be
646644
// removed by dead code if not used.
645+
646+
// bounds_range = ub - lb
647+
// total_iterations = (bounds_range + step - 1) / step
648+
Type t = lb.getType();
649+
Value minus1 =
650+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
651+
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
652+
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
653+
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
654+
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
655+
656+
SmallVector<Value> predicates(maxStage + 1);
647657
for (int64_t i = 0; i < maxStage; i++) {
648-
Location loc = forOp.getLoc();
649-
Type t = lb.getType();
650-
Value minusOne =
651-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
652-
// number of iterations = ((ub - 1) - lb) / step
653-
Value totalNumIteration = rewriter.create<arith::DivUIOp>(
654-
loc,
655-
rewriter.create<arith::SubIOp>(
656-
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
657-
step);
658-
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
658+
// iterI = total_iters - 1 - i
659+
// May go negative...
659660
Value minusI =
660661
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
662+
Value iterI = rewriter.create<arith::AddIOp>(
663+
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
664+
minusI);
665+
// newLastIter = lb + step * iterI
661666
Value newlastIter = rewriter.create<arith::AddIOp>(
662-
loc, lb,
663-
rewriter.create<arith::MulIOp>(
664-
loc, step,
665-
rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
667+
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
668+
666669
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
670+
671+
if (dynamicLoop) {
672+
// pred = iterI >= lb
673+
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
674+
loc, arith::CmpIPredicate::sge, iterI, lb);
675+
}
667676
}
677+
668678
// Emit `maxStage - 1` epilogue part that includes operations from stages
669679
// [i; maxStage].
670680
for (int64_t i = 1; i <= maxStage; i++) {
681+
SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
671682
for (Operation *op : opOrder) {
672683
if (stages[op] < i)
673684
continue;
685+
unsigned currentVersion = maxStage - stages[op] + i;
686+
unsigned nextVersion = currentVersion + 1;
674687
Operation *newOp =
675688
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
676689
auto it = valueMapping.find(newOperand->get());
677690
if (it != valueMapping.end()) {
678-
Value replacement = it->second[maxStage - stages[op] + i];
691+
Value replacement = it->second[currentVersion];
679692
newOperand->set(replacement);
680693
}
681694
});
695+
if (dynamicLoop) {
696+
OpBuilder::InsertionGuard insertGuard(rewriter);
697+
newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
698+
if (!newOp)
699+
return failure();
700+
}
682701
if (annotateFn)
683702
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
684-
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
685-
setValueMapping(op->getResult(destId), newOp->getResult(destId),
686-
maxStage - stages[op] + i);
703+
704+
for (auto [opRes, newRes] :
705+
llvm::zip(op->getResults(), newOp->getResults())) {
706+
setValueMapping(opRes, newRes, currentVersion);
687707
// If the value is a loop carried dependency update the loop argument
688708
// mapping and keep track of the last version to replace the original
689709
// forOp uses.
690710
for (OpOperand &operand :
691711
forOp.getBody()->getTerminator()->getOpOperands()) {
692-
if (operand.get() != op->getResult(destId))
712+
if (operand.get() != opRes)
693713
continue;
694-
unsigned version = maxStage - stages[op] + i + 1;
695714
// If the version is greater than maxStage it means it maps to the
696715
// original forOp returned value.
697-
if (version > maxStage) {
698-
returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
699-
continue;
700-
}
701-
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
702-
newOp->getResult(destId), version);
716+
unsigned ri = operand.getOperandNumber();
717+
returnValues[ri] = newRes;
718+
Value mapVal = forOp.getRegionIterArgs()[ri];
719+
returnMap[ri] = std::make_pair(mapVal, currentVersion);
720+
if (nextVersion <= maxStage)
721+
setValueMapping(mapVal, newRes, nextVersion);
722+
}
723+
}
724+
}
725+
if (dynamicLoop) {
726+
// Select return values from this stage (live outs) based on predication.
727+
// If the stage is valid select the peeled value, else use previous stage
728+
// value.
729+
for (auto pair : llvm::enumerate(returnValues)) {
730+
unsigned ri = pair.index();
731+
auto [mapVal, currentVersion] = returnMap[ri];
732+
if (mapVal) {
733+
unsigned nextVersion = currentVersion + 1;
734+
Value pred = predicates[currentVersion];
735+
Value prevValue = valueMapping[mapVal][currentVersion];
736+
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
737+
prevValue);
738+
returnValues[ri] = selOp;
739+
if (nextVersion <= maxStage)
740+
setValueMapping(mapVal, selOp, nextVersion);
703741
}
704742
}
705743
}
706744
}
745+
return success();
707746
}
708747

709748
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +799,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
760799
if (options.peelEpilogue) {
761800
// 4. Emit the epilogue after the new forOp.
762801
rewriter.setInsertionPointAfter(newForOp);
763-
pipeliner.emitEpilogue(rewriter, returnValues);
802+
if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
803+
return failure();
764804
}
765805
// 5. Erase the original loop and replace the uses with the epilogue output.
766806
if (forOp->getNumResults() > 0)

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,11 +764,44 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
764764
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
765765
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
766766

767-
// In case dynamic loop pipelining is off check that the transformation didn't
768-
// apply.
767+
// Check for predicated epilogue for dynamic loop.
769768
// CHECK-LABEL: dynamic_loop(
770-
// CHECK-NOT: memref.load
771-
// CHECK: scf.for
769+
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
770+
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
771+
// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
772+
// CHECK: %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
773+
// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
774+
// CHECK: %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
775+
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
776+
// CHECK: }
777+
// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
778+
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
779+
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
780+
// CHECK: %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
781+
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
782+
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
783+
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
784+
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %{{.*}}
785+
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
786+
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
787+
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
788+
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
789+
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %{{.*}}
790+
// CHECK: scf.if %[[CMPI_17]] {
791+
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
792+
// CHECK: } else {
793+
// CHECK: }
794+
// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
795+
// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
796+
// CHECK: scf.yield %[[ADDF_24]]
797+
// CHECK: } else {
798+
// CHECK: scf.yield %{{.*}}
799+
// CHECK: }
800+
// CHECK: scf.if %[[CMPI_22]] {
801+
// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
802+
// CHECK: } else {
803+
// CHECK: }
804+
// CHECK: return
772805
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
773806
%cf = arith.constant 1.0 : f32
774807
scf.for %i0 = %lb to %ub step %step {
@@ -781,6 +814,68 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
781814

782815
// -----
783816

817+
// NOEPILOGUE-LABEL: func.func @dynamic_loop_result
818+
// NOEPILOGUE: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
819+
// NOEPILOGUE: %[[SUBI_3:.*]] = arith.subi %{{.*}}, %{{.*}}
820+
// NOEPILOGUE: %[[CMPI_4:.*]] = arith.cmpi slt, %[[ARG5]], %[[SUBI_3]]
821+
// NOEPILOGUE: %[[ADDF_5:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
822+
// NOEPILOGUE: %[[MULF_6:.*]] = arith.mulf %[[ADDF_5]], %{{.*}}
823+
// NOEPILOGUE: %[[ADDI_7:.*]] = arith.addi %[[ARG5]], %{{.*}}
824+
// NOEPILOGUE: %[[IF_8:.*]] = scf.if %[[CMPI_4]]
825+
// NOEPILOGUE: %[[LOAD_9:.*]] = memref.load %{{.*}}[%[[ADDI_7]]]
826+
// NOEPILOGUE: scf.yield %[[LOAD_9]]
827+
// NOEPILOGUE: } else {
828+
// NOEPILOGUE: scf.yield %{{.*}}
829+
// NOEPILOGUE: }
830+
// NOEPILOGUE: scf.yield %[[MULF_6]], %[[IF_8]]
831+
// NOEPILOGUE: }
832+
// NOEPILOGUE: memref.store %{{.*}}#0, %{{.*}}[%{{.*}}]
833+
834+
// Check for predicated epilogue for dynamic loop.
835+
// CHECK-LABEL: func.func @dynamic_loop_result
836+
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
837+
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
838+
// CHECK: %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
839+
// CHECK: %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
840+
// CHECK: %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
841+
// CHECK: scf.yield %[[MULF_14]], %[[LOAD_16]]
842+
// CHECK: }
843+
// CHECK: %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
844+
// CHECK: %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
845+
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
846+
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
847+
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
848+
// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
849+
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
850+
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
851+
// CHECK: scf.yield %[[ADDF_13]]
852+
// CHECK: } else {
853+
// CHECK: scf.yield %{{.*}}
854+
// CHECK: }
855+
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_9]]
856+
// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
857+
// CHECK: scf.yield %[[MULF_13]]
858+
// CHECK: } else {
859+
// CHECK: scf.yield %{{.*}}
860+
// CHECK: }
861+
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
862+
// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
863+
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
864+
%cf0 = arith.constant 1.0 : f32
865+
%cf1 = arith.constant 33.0 : f32
866+
%cst = arith.constant 0 : index
867+
%res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
868+
%A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
869+
%A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
870+
%A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
871+
scf.yield %A2_elem : f32
872+
} { __test_pipelining_loop__ }
873+
memref.store %res#0, %result[%cst] : memref<?xf32>
874+
return
875+
}
876+
877+
// -----
878+
784879
// CHECK-LABEL: yield_constant_loop(
785880
// CHECK-SAME: %[[A:.*]]: memref<?xf32>) -> f32 {
786881
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
214214
RewritePatternSet patterns(&getContext());
215215
mlir::scf::PipeliningOption options;
216216
options.getScheduleFn = getSchedule;
217+
options.supportDynamicLoops = true;
218+
options.predicateFn = predicateOp;
217219
if (annotatePipeline)
218220
options.annotateFn = annotate;
219221
if (noEpiloguePeeling) {
220-
options.supportDynamicLoops = true;
221222
options.peelEpilogue = false;
222-
options.predicateFn = predicateOp;
223223
}
224224
scf::populateSCFLoopPipeliningPatterns(patterns, options);
225225
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)