@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
94
94
RewriterBase &rewriter);
95
95
// / Emits the epilogue, this creates `maxStage - 1` part which will contain
96
96
// / operations from stages [i; maxStage], where i is the part index.
97
- void emitEpilogue (RewriterBase &rewriter,
98
- llvm::SmallVector<Value> &returnValues);
97
+ LogicalResult emitEpilogue (RewriterBase &rewriter,
98
+ llvm::SmallVector<Value> &returnValues);
99
99
};
100
100
101
101
bool LoopPipelinerInternal::initializeLoopInfo (
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
133
133
LDBG (" --no epilogue or predicate set -> BAIL" );
134
134
return false ;
135
135
}
136
- if (dynamicLoop && peelEpilogue) {
137
- LDBG (" --dynamic loop doesn't support epilogue yet -> BAIL" );
138
- return false ;
139
- }
140
136
std::vector<std::pair<Operation *, unsigned >> schedule;
141
137
options.getScheduleFn (forOp, schedule);
142
138
if (schedule.empty ()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
313
309
});
314
310
int predicateIdx = i - stages[op];
315
311
if (predicates[predicateIdx]) {
312
+ OpBuilder::InsertionGuard insertGuard (rewriter);
316
313
newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
317
314
assert (newOp && " failed to predicate op." );
318
315
}
319
- rewriter.setInsertionPointAfter (newOp);
320
316
if (annotateFn)
321
317
annotateFn (newOp, PipeliningOption::PipelinerPart::Prologue, i);
322
318
for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
561
557
}
562
558
563
559
if (predicates[useStage]) {
560
+ OpBuilder::InsertionGuard insertGuard (rewriter);
564
561
newOp = predicateFn (rewriter, newOp, predicates[useStage]);
565
562
if (!newOp)
566
563
return failure ();
567
564
// Remap the results to the new predicated one.
568
565
for (auto values : llvm::zip (op->getResults (), newOp->getResults ()))
569
566
mapping.map (std::get<0 >(values), std::get<1 >(values));
570
567
}
571
- rewriter.setInsertionPointAfter (newOp);
572
568
if (annotateFn)
573
569
annotateFn (newOp, PipeliningOption::PipelinerPart::Kernel, 0 );
574
570
}
@@ -640,70 +636,113 @@ LogicalResult LoopPipelinerInternal::createKernel(
640
636
return success ();
641
637
}
642
638
643
- void LoopPipelinerInternal::emitEpilogue (
644
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
639
+ LogicalResult
640
+ LoopPipelinerInternal::emitEpilogue (RewriterBase &rewriter,
641
+ llvm::SmallVector<Value> &returnValues) {
642
+ Location loc = forOp.getLoc ();
645
643
// Emit different versions of the induction variable. They will be
646
644
// removed by dead code if not used.
645
+
646
+ // bounds_range = ub - lb
647
+ // total_iterations = (bounds_range + step - 1) / step
648
+ Type t = lb.getType ();
649
+ Value minus1 =
650
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
651
+ Value boundsRange = rewriter.create <arith::SubIOp>(loc, ub, lb);
652
+ Value rangeIncr = rewriter.create <arith::AddIOp>(loc, boundsRange, step);
653
+ Value rangeDecr = rewriter.create <arith::AddIOp>(loc, rangeIncr, minus1);
654
+ Value totalIterations = rewriter.create <arith::DivUIOp>(loc, rangeDecr, step);
655
+
656
+ SmallVector<Value> predicates (maxStage + 1 );
647
657
for (int64_t i = 0 ; i < maxStage; i++) {
648
- Location loc = forOp.getLoc ();
649
- Type t = lb.getType ();
650
- Value minusOne =
651
- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
652
- // number of iterations = ((ub - 1) - lb) / step
653
- Value totalNumIteration = rewriter.create <arith::DivUIOp>(
654
- loc,
655
- rewriter.create <arith::SubIOp>(
656
- loc, rewriter.create <arith::AddIOp>(loc, ub, minusOne), lb),
657
- step);
658
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
658
+ // iterI = total_iters - 1 - i
659
+ // May go negative...
659
660
Value minusI =
660
661
rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -i));
662
+ Value iterI = rewriter.create <arith::AddIOp>(
663
+ loc, rewriter.create <arith::AddIOp>(loc, totalIterations, minus1),
664
+ minusI);
665
+ // newLastIter = lb + step * iterI
661
666
Value newlastIter = rewriter.create <arith::AddIOp>(
662
- loc, lb,
663
- rewriter.create <arith::MulIOp>(
664
- loc, step,
665
- rewriter.create <arith::AddIOp>(loc, totalNumIteration, minusI)));
667
+ loc, lb, rewriter.create <arith::MulIOp>(loc, step, iterI));
668
+
666
669
setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
670
+
671
+ if (dynamicLoop) {
672
+ // pred = iterI >= lb
673
+ predicates[i + 1 ] = rewriter.create <arith::CmpIOp>(
674
+ loc, arith::CmpIPredicate::sge, iterI, lb);
675
+ }
667
676
}
677
+
668
678
// Emit `maxStage - 1` epilogue part that includes operations from stages
669
679
// [i; maxStage].
670
680
for (int64_t i = 1 ; i <= maxStage; i++) {
681
+ SmallVector<std::pair<Value, unsigned >> returnMap (returnValues.size ());
671
682
for (Operation *op : opOrder) {
672
683
if (stages[op] < i)
673
684
continue ;
685
+ unsigned currentVersion = maxStage - stages[op] + i;
686
+ unsigned nextVersion = currentVersion + 1 ;
674
687
Operation *newOp =
675
688
cloneAndUpdateOperands (rewriter, op, [&](OpOperand *newOperand) {
676
689
auto it = valueMapping.find (newOperand->get ());
677
690
if (it != valueMapping.end ()) {
678
- Value replacement = it->second [maxStage - stages[op] + i ];
691
+ Value replacement = it->second [currentVersion ];
679
692
newOperand->set (replacement);
680
693
}
681
694
});
695
+ if (dynamicLoop) {
696
+ OpBuilder::InsertionGuard insertGuard (rewriter);
697
+ newOp = predicateFn (rewriter, newOp, predicates[currentVersion]);
698
+ if (!newOp)
699
+ return failure ();
700
+ }
682
701
if (annotateFn)
683
702
annotateFn (newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1 );
684
- for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
685
- setValueMapping (op->getResult (destId), newOp->getResult (destId),
686
- maxStage - stages[op] + i);
703
+
704
+ for (auto [opRes, newRes] :
705
+ llvm::zip (op->getResults (), newOp->getResults ())) {
706
+ setValueMapping (opRes, newRes, currentVersion);
687
707
// If the value is a loop carried dependency update the loop argument
688
708
// mapping and keep track of the last version to replace the original
689
709
// forOp uses.
690
710
for (OpOperand &operand :
691
711
forOp.getBody ()->getTerminator ()->getOpOperands ()) {
692
- if (operand.get () != op-> getResult (destId) )
712
+ if (operand.get () != opRes )
693
713
continue ;
694
- unsigned version = maxStage - stages[op] + i + 1 ;
695
714
// If the version is greater than maxStage it means it maps to the
696
715
// original forOp returned value.
697
- if (version > maxStage) {
698
- returnValues[operand.getOperandNumber ()] = newOp->getResult (destId);
699
- continue ;
700
- }
701
- setValueMapping (forOp.getRegionIterArgs ()[operand.getOperandNumber ()],
702
- newOp->getResult (destId), version);
716
+ unsigned ri = operand.getOperandNumber ();
717
+ returnValues[ri] = newRes;
718
+ Value mapVal = forOp.getRegionIterArgs ()[ri];
719
+ returnMap[ri] = std::make_pair (mapVal, currentVersion);
720
+ if (nextVersion <= maxStage)
721
+ setValueMapping (mapVal, newRes, nextVersion);
722
+ }
723
+ }
724
+ }
725
+ if (dynamicLoop) {
726
+ // Select return values from this stage (live outs) based on predication.
727
+ // If the stage is valid select the peeled value, else use previous stage
728
+ // value.
729
+ for (auto pair : llvm::enumerate (returnValues)) {
730
+ unsigned ri = pair.index ();
731
+ auto [mapVal, currentVersion] = returnMap[ri];
732
+ if (mapVal) {
733
+ unsigned nextVersion = currentVersion + 1 ;
734
+ Value pred = predicates[currentVersion];
735
+ Value prevValue = valueMapping[mapVal][currentVersion];
736
+ auto selOp = rewriter.create <arith::SelectOp>(loc, pred, pair.value (),
737
+ prevValue);
738
+ returnValues[ri] = selOp;
739
+ if (nextVersion <= maxStage)
740
+ setValueMapping (mapVal, selOp, nextVersion);
703
741
}
704
742
}
705
743
}
706
744
}
745
+ return success ();
707
746
}
708
747
709
748
void LoopPipelinerInternal::setValueMapping (Value key, Value el, int64_t idx) {
@@ -760,7 +799,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
760
799
if (options.peelEpilogue ) {
761
800
// 4. Emit the epilogue after the new forOp.
762
801
rewriter.setInsertionPointAfter (newForOp);
763
- pipeliner.emitEpilogue (rewriter, returnValues);
802
+ if (failed (pipeliner.emitEpilogue (rewriter, returnValues)))
803
+ return failure ();
764
804
}
765
805
// 5. Erase the original loop and replace the uses with the epilogue output.
766
806
if (forOp->getNumResults () > 0 )
0 commit comments