@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
94
94
RewriterBase &rewriter);
95
95
// / Emits the epilogue, this creates `maxStage - 1` part which will contain
96
96
// / operations from stages [i; maxStage], where i is the part index.
97
- void emitEpilogue (RewriterBase &rewriter,
98
- llvm::SmallVector<Value> &returnValues);
97
+ LogicalResult emitEpilogue (RewriterBase &rewriter,
98
+ llvm::SmallVector<Value> &returnValues);
99
99
};
100
100
101
101
bool LoopPipelinerInternal::initializeLoopInfo (
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
133
133
LDBG (" --no epilogue or predicate set -> BAIL" );
134
134
return false ;
135
135
}
136
- if (dynamicLoop && peelEpilogue) {
137
- LDBG (" --dynamic loop doesn't support epilogue yet -> BAIL" );
138
- return false ;
139
- }
140
136
std::vector<std::pair<Operation *, unsigned >> schedule;
141
137
options.getScheduleFn (forOp, schedule);
142
138
if (schedule.empty ()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
313
309
});
314
310
int predicateIdx = i - stages[op];
315
311
if (predicates[predicateIdx]) {
312
+ OpBuilder::InsertionGuard insertGuard (rewriter);
316
313
newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
317
314
assert (newOp && " failed to predicate op." );
318
315
}
319
- rewriter.setInsertionPointAfter (newOp);
320
316
if (annotateFn)
321
317
annotateFn (newOp, PipeliningOption::PipelinerPart::Prologue, i);
322
318
for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
561
557
}
562
558
563
559
if (predicates[useStage]) {
560
+ OpBuilder::InsertionGuard insertGuard (rewriter);
564
561
newOp = predicateFn (rewriter, newOp, predicates[useStage]);
565
562
if (!newOp)
566
563
return failure ();
567
564
// Remap the results to the new predicated one.
568
565
for (auto values : llvm::zip (op->getResults (), newOp->getResults ()))
569
566
mapping.map (std::get<0 >(values), std::get<1 >(values));
570
567
}
571
- rewriter.setInsertionPointAfter (newOp);
572
568
if (annotateFn)
573
569
annotateFn (newOp, PipeliningOption::PipelinerPart::Kernel, 0 );
574
570
}
@@ -640,70 +636,123 @@ LogicalResult LoopPipelinerInternal::createKernel(
640
636
return success ();
641
637
}
642
638
643
- void LoopPipelinerInternal::emitEpilogue (
644
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
639
+ LogicalResult
640
+ LoopPipelinerInternal::emitEpilogue (RewriterBase &rewriter,
641
+ llvm::SmallVector<Value> &returnValues) {
642
+ Location loc = forOp.getLoc ();
645
643
// Emit different versions of the induction variable. They will be
646
644
// removed by dead code if not used.
645
+
646
+ // bounds_range = ub - lb
647
+ // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
648
+ Type t = lb.getType ();
649
+ Value minus1 =
650
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
651
+
652
+ Value const_0 =
653
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, 0 ));
654
+ Value const_1 =
655
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, 1 ));
656
+ Value boundsRange = rewriter.create <arith::SubIOp>(loc, ub, lb);
657
+ Value boundsRem = rewriter.create <arith::RemUIOp>(loc, boundsRange, step);
658
+ Value hasRem = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
659
+ boundsRem, const_0);
660
+ Value selRem =
661
+ rewriter.create <arith::SelectOp>(loc, hasRem, const_1, const_0);
662
+ Value boundsDiv = rewriter.create <arith::DivUIOp>(loc, boundsRange, step);
663
+ Value totalIterations =
664
+ rewriter.create <arith::AddIOp>(loc, boundsDiv, selRem);
665
+
666
+ SmallVector<Value> predicates (maxStage + 1 );
647
667
for (int64_t i = 0 ; i < maxStage; i++) {
648
- Location loc = forOp.getLoc ();
649
- Type t = lb.getType ();
650
- Value minusOne =
651
- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
652
- // number of iterations = ((ub - 1) - lb) / step
653
- Value totalNumIteration = rewriter.create <arith::DivUIOp>(
654
- loc,
655
- rewriter.create <arith::SubIOp>(
656
- loc, rewriter.create <arith::AddIOp>(loc, ub, minusOne), lb),
657
- step);
658
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
668
+ // iterI = total_iters - 1 - i
669
+ // May go negative...
659
670
Value minusI =
660
671
rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -i));
672
+ Value iterI = rewriter.create <arith::AddIOp>(
673
+ loc, rewriter.create <arith::AddIOp>(loc, totalIterations, minus1),
674
+ minusI);
675
+ // newLastIter = lb + step * iterI
661
676
Value newlastIter = rewriter.create <arith::AddIOp>(
662
- loc, lb,
663
- rewriter.create <arith::MulIOp>(
664
- loc, step,
665
- rewriter.create <arith::AddIOp>(loc, totalNumIteration, minusI)));
677
+ loc, lb, rewriter.create <arith::MulIOp>(loc, step, iterI));
678
+
666
679
setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
680
+
681
+ if (dynamicLoop) {
682
+ // pred = iterI >= lb
683
+ predicates[i + 1 ] = rewriter.create <arith::CmpIOp>(
684
+ loc, arith::CmpIPredicate::sge, iterI, lb);
685
+ }
667
686
}
687
+
668
688
// Emit `maxStage - 1` epilogue part that includes operations from stages
669
689
// [i; maxStage].
670
690
for (int64_t i = 1 ; i <= maxStage; i++) {
691
+ SmallVector<std::pair<Value, unsigned >> returnMap (returnValues.size ());
671
692
for (Operation *op : opOrder) {
672
693
if (stages[op] < i)
673
694
continue ;
695
+ unsigned currentVersion = maxStage - stages[op] + i;
696
+ unsigned nextVersion = currentVersion + 1 ;
674
697
Operation *newOp =
675
698
cloneAndUpdateOperands (rewriter, op, [&](OpOperand *newOperand) {
676
699
auto it = valueMapping.find (newOperand->get ());
677
700
if (it != valueMapping.end ()) {
678
- Value replacement = it->second [maxStage - stages[op] + i ];
701
+ Value replacement = it->second [currentVersion ];
679
702
newOperand->set (replacement);
680
703
}
681
704
});
682
705
if (annotateFn)
683
706
annotateFn (newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1 );
684
- for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
685
- setValueMapping (op->getResult (destId), newOp->getResult (destId),
686
- maxStage - stages[op] + i);
707
+ if (dynamicLoop) {
708
+ OpBuilder::InsertionGuard insertGuard (rewriter);
709
+ newOp = predicateFn (rewriter, newOp, predicates[currentVersion]);
710
+ if (!newOp)
711
+ return failure ();
712
+ }
713
+
714
+ for (auto [opRes, newRes] :
715
+ llvm::zip (op->getResults (), newOp->getResults ())) {
716
+ setValueMapping (opRes, newRes, currentVersion);
687
717
// If the value is a loop carried dependency update the loop argument
688
718
// mapping and keep track of the last version to replace the original
689
719
// forOp uses.
690
720
for (OpOperand &operand :
691
721
forOp.getBody ()->getTerminator ()->getOpOperands ()) {
692
- if (operand.get () != op-> getResult (destId) )
722
+ if (operand.get () != opRes )
693
723
continue ;
694
- unsigned version = maxStage - stages[op] + i + 1 ;
695
724
// If the version is greater than maxStage it means it maps to the
696
725
// original forOp returned value.
697
- if (version > maxStage) {
698
- returnValues[operand.getOperandNumber ()] = newOp->getResult (destId);
699
- continue ;
700
- }
701
- setValueMapping (forOp.getRegionIterArgs ()[operand.getOperandNumber ()],
702
- newOp->getResult (destId), version);
726
+ unsigned ri = operand.getOperandNumber ();
727
+ returnValues[ri] = newRes;
728
+ Value mapVal = forOp.getRegionIterArgs ()[ri];
729
+ returnMap[ri] = std::make_pair (mapVal, currentVersion);
730
+ if (nextVersion <= maxStage)
731
+ setValueMapping (mapVal, newRes, nextVersion);
732
+ }
733
+ }
734
+ }
735
+ if (dynamicLoop) {
736
+ // Select return values from this stage (live outs) based on predication.
737
+ // If the stage is valid select the peeled value, else use previous stage
738
+ // value.
739
+ for (auto pair : llvm::enumerate (returnValues)) {
740
+ unsigned ri = pair.index ();
741
+ auto [mapVal, currentVersion] = returnMap[ri];
742
+ if (mapVal) {
743
+ unsigned nextVersion = currentVersion + 1 ;
744
+ Value pred = predicates[currentVersion];
745
+ Value prevValue = valueMapping[mapVal][currentVersion];
746
+ auto selOp = rewriter.create <arith::SelectOp>(loc, pred, pair.value (),
747
+ prevValue);
748
+ returnValues[ri] = selOp;
749
+ if (nextVersion <= maxStage)
750
+ setValueMapping (mapVal, selOp, nextVersion);
703
751
}
704
752
}
705
753
}
706
754
}
755
+ return success ();
707
756
}
708
757
709
758
void LoopPipelinerInternal::setValueMapping (Value key, Value el, int64_t idx) {
@@ -760,7 +809,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
760
809
if (options.peelEpilogue ) {
761
810
// 4. Emit the epilogue after the new forOp.
762
811
rewriter.setInsertionPointAfter (newForOp);
763
- pipeliner.emitEpilogue (rewriter, returnValues);
812
+ if (failed (pipeliner.emitEpilogue (rewriter, returnValues)))
813
+ return failure ();
764
814
}
765
815
// 5. Erase the original loop and replace the uses with the epilogue output.
766
816
if (forOp->getNumResults () > 0 )
0 commit comments