Skip to content

Commit 9be66a1

Browse files
author
SJW
committed
[MLIR][SCF] Add support for loop pipeline peeling for dynamic loops.
* Allow speculative execution and predicate results per stage.
1 parent fee4836 commit 9be66a1

File tree

3 files changed

+129
-44
lines changed

3 files changed

+129
-44
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
9494
RewriterBase &rewriter);
9595
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
9696
/// operations from stages [i; maxStage], where i is the part index.
97-
void emitEpilogue(RewriterBase &rewriter,
98-
llvm::SmallVector<Value> &returnValues);
97+
LogicalResult emitEpilogue(RewriterBase &rewriter,
98+
llvm::SmallVector<Value> &returnValues);
9999
};
100100

101101
bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
133133
LDBG("--no epilogue or predicate set -> BAIL");
134134
return false;
135135
}
136-
if (dynamicLoop && peelEpilogue) {
137-
LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
138-
return false;
139-
}
140136
std::vector<std::pair<Operation *, unsigned>> schedule;
141137
options.getScheduleFn(forOp, schedule);
142138
if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
313309
});
314310
int predicateIdx = i - stages[op];
315311
if (predicates[predicateIdx]) {
312+
OpBuilder::InsertionGuard insertGuard(rewriter);
316313
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
317314
assert(newOp && "failed to predicate op.");
318315
}
319-
rewriter.setInsertionPointAfter(newOp);
320316
if (annotateFn)
321317
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
322318
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
561557
}
562558

563559
if (predicates[useStage]) {
560+
OpBuilder::InsertionGuard insertGuard(rewriter);
564561
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
565562
if (!newOp)
566563
return failure();
567564
// Remap the results to the new predicated one.
568565
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
569566
mapping.map(std::get<0>(values), std::get<1>(values));
570567
}
571-
rewriter.setInsertionPointAfter(newOp);
572568
if (annotateFn)
573569
annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
574570
}
@@ -640,70 +636,123 @@ LogicalResult LoopPipelinerInternal::createKernel(
640636
return success();
641637
}
642638

643-
void LoopPipelinerInternal::emitEpilogue(
644-
RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
639+
LogicalResult
640+
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
641+
llvm::SmallVector<Value> &returnValues) {
642+
Location loc = forOp.getLoc();
645643
// Emit different versions of the induction variable. They will be
646644
// removed by dead code if not used.
645+
646+
// bounds_range = ub - lb
647+
// total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
648+
Type t = lb.getType();
649+
Value minus1 =
650+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
651+
652+
Value const_0 =
653+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
654+
Value const_1 =
655+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
656+
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
657+
Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
658+
Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
659+
boundsRem, const_0);
660+
Value selRem =
661+
rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0);
662+
Value boundsDiv = rewriter.create<arith::DivUIOp>(loc, boundsRange, step);
663+
Value totalIterations =
664+
rewriter.create<arith::AddIOp>(loc, boundsDiv, selRem);
665+
666+
SmallVector<Value> predicates(maxStage + 1);
647667
for (int64_t i = 0; i < maxStage; i++) {
648-
Location loc = forOp.getLoc();
649-
Type t = lb.getType();
650-
Value minusOne =
651-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
652-
// number of iterations = ((ub - 1) - lb) / step
653-
Value totalNumIteration = rewriter.create<arith::DivUIOp>(
654-
loc,
655-
rewriter.create<arith::SubIOp>(
656-
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
657-
step);
658-
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
668+
// iterI = total_iters - 1 - i
669+
// May go negative...
659670
Value minusI =
660671
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
672+
Value iterI = rewriter.create<arith::AddIOp>(
673+
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
674+
minusI);
675+
// newLastIter = lb + step * iterI
661676
Value newlastIter = rewriter.create<arith::AddIOp>(
662-
loc, lb,
663-
rewriter.create<arith::MulIOp>(
664-
loc, step,
665-
rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
677+
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
678+
666679
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
680+
681+
if (dynamicLoop) {
682+
// pred = iterI >= lb
683+
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
684+
loc, arith::CmpIPredicate::sge, iterI, lb);
685+
}
667686
}
687+
668688
// Emit `maxStage - 1` epilogue part that includes operations from stages
669689
// [i; maxStage].
670690
for (int64_t i = 1; i <= maxStage; i++) {
691+
SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
671692
for (Operation *op : opOrder) {
672693
if (stages[op] < i)
673694
continue;
695+
unsigned currentVersion = maxStage - stages[op] + i;
696+
unsigned nextVersion = currentVersion + 1;
674697
Operation *newOp =
675698
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
676699
auto it = valueMapping.find(newOperand->get());
677700
if (it != valueMapping.end()) {
678-
Value replacement = it->second[maxStage - stages[op] + i];
701+
Value replacement = it->second[currentVersion];
679702
newOperand->set(replacement);
680703
}
681704
});
682705
if (annotateFn)
683706
annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
684-
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
685-
setValueMapping(op->getResult(destId), newOp->getResult(destId),
686-
maxStage - stages[op] + i);
707+
if (dynamicLoop) {
708+
OpBuilder::InsertionGuard insertGuard(rewriter);
709+
newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
710+
if (!newOp)
711+
return failure();
712+
}
713+
714+
for (auto [opRes, newRes] :
715+
llvm::zip(op->getResults(), newOp->getResults())) {
716+
setValueMapping(opRes, newRes, currentVersion);
687717
// If the value is a loop carried dependency update the loop argument
688718
// mapping and keep track of the last version to replace the original
689719
// forOp uses.
690720
for (OpOperand &operand :
691721
forOp.getBody()->getTerminator()->getOpOperands()) {
692-
if (operand.get() != op->getResult(destId))
722+
if (operand.get() != opRes)
693723
continue;
694-
unsigned version = maxStage - stages[op] + i + 1;
695724
// If the version is greater than maxStage it means it maps to the
696725
// original forOp returned value.
697-
if (version > maxStage) {
698-
returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
699-
continue;
700-
}
701-
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
702-
newOp->getResult(destId), version);
726+
unsigned ri = operand.getOperandNumber();
727+
returnValues[ri] = newRes;
728+
Value mapVal = forOp.getRegionIterArgs()[ri];
729+
returnMap[ri] = std::make_pair(mapVal, currentVersion);
730+
if (nextVersion <= maxStage)
731+
setValueMapping(mapVal, newRes, nextVersion);
732+
}
733+
}
734+
}
735+
if (dynamicLoop) {
736+
// Select return values from this stage (live outs) based on predication.
737+
// If the stage is valid select the peeled value, else use previous stage
738+
// value.
739+
for (auto pair : llvm::enumerate(returnValues)) {
740+
unsigned ri = pair.index();
741+
auto [mapVal, currentVersion] = returnMap[ri];
742+
if (mapVal) {
743+
unsigned nextVersion = currentVersion + 1;
744+
Value pred = predicates[currentVersion];
745+
Value prevValue = valueMapping[mapVal][currentVersion];
746+
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
747+
prevValue);
748+
returnValues[ri] = selOp;
749+
if (nextVersion <= maxStage)
750+
setValueMapping(mapVal, selOp, nextVersion);
703751
}
704752
}
705753
}
706754
}
755+
return success();
707756
}
708757

709758
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +809,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
760809
if (options.peelEpilogue) {
761810
// 4. Emit the epilogue after the new forOp.
762811
rewriter.setInsertionPointAfter(newForOp);
763-
pipeliner.emitEpilogue(rewriter, returnValues);
812+
if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
813+
return failure();
764814
}
765815
// 5. Erase the original loop and replace the uses with the epilogue output.
766816
if (forOp->getNumResults() > 0)

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
764764
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
765765
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32
766766

767-
// In case dynamic loop pipelining is off check that the transformation didn't
768-
// apply.
767+
// Check for predicated epilogue for dynamic loop.
769768
// CHECK-LABEL: dynamic_loop(
770-
// CHECK-NOT: memref.load
771-
// CHECK: scf.for
769+
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
770+
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
771+
// CHECK: %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}}
772+
// CHECK: %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}}
773+
// CHECK: %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]]
774+
// CHECK: %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]]
775+
// CHECK: scf.yield %[[ADDF_26]], %[[LOAD_29]]
776+
// CHECK: }
777+
// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
778+
// CHECK: %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}}
779+
// CHECK: %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}}
780+
// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}}
781+
// CHECK: %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}}
782+
// CHECK: %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]]
783+
// CHECK: %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
784+
// CHECK: %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]]
785+
// CHECK: %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]]
786+
// CHECK: %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}}
787+
// CHECK: %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1
788+
// CHECK: %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1
789+
// CHECK: %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]]
790+
// CHECK: %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]]
791+
// CHECK: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}}
792+
// CHECK: scf.if %[[CMPI_19]] {
793+
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]]
794+
// CHECK: } else {
795+
// CHECK: }
796+
// CHECK: %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
797+
// CHECK: %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}}
798+
// CHECK: scf.yield %[[ADDF_26]]
799+
// CHECK: } else {
800+
// CHECK: scf.yield %{{.*}}
801+
// CHECK: }
802+
// CHECK: scf.if %[[CMPI_24]] {
803+
// CHECK: memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]]
804+
// CHECK: } else {
805+
// CHECK: }
806+
// CHECK: return
772807
func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
773808
%cf = arith.constant 1.0 : f32
774809
scf.for %i0 = %lb to %ub step %step {

mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
214214
RewritePatternSet patterns(&getContext());
215215
mlir::scf::PipeliningOption options;
216216
options.getScheduleFn = getSchedule;
217+
options.supportDynamicLoops = true;
218+
options.predicateFn = predicateOp;
217219
if (annotatePipeline)
218220
options.annotateFn = annotate;
219221
if (noEpiloguePeeling) {
220-
options.supportDynamicLoops = true;
221222
options.peelEpilogue = false;
222-
options.predicateFn = predicateOp;
223223
}
224224
scf::populateSCFLoopPipeliningPatterns(patterns, options);
225225
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

0 commit comments

Comments
 (0)