@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
94
94
RewriterBase &rewriter);
95
95
// / Emits the epilogue, this creates `maxStage - 1` part which will contain
96
96
// / operations from stages [i; maxStage], where i is the part index.
97
- void emitEpilogue (RewriterBase &rewriter,
98
- llvm::SmallVector<Value> &returnValues);
97
+ LogicalResult emitEpilogue (RewriterBase &rewriter,
98
+ llvm::SmallVector<Value> &returnValues);
99
99
};
100
100
101
101
bool LoopPipelinerInternal::initializeLoopInfo (
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
133
133
LDBG (" --no epilogue or predicate set -> BAIL" );
134
134
return false ;
135
135
}
136
- if (dynamicLoop && peelEpilogue) {
137
- LDBG (" --dynamic loop doesn't support epilogue yet -> BAIL" );
138
- return false ;
139
- }
140
136
std::vector<std::pair<Operation *, unsigned >> schedule;
141
137
options.getScheduleFn (forOp, schedule);
142
138
if (schedule.empty ()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
313
309
});
314
310
int predicateIdx = i - stages[op];
315
311
if (predicates[predicateIdx]) {
312
+ OpBuilder::InsertionGuard insertGuard (rewriter);
316
313
newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
317
314
assert (newOp && " failed to predicate op." );
318
315
}
319
- rewriter.setInsertionPointAfter (newOp);
320
316
if (annotateFn)
321
317
annotateFn (newOp, PipeliningOption::PipelinerPart::Prologue, i);
322
318
for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
@@ -561,14 +557,14 @@ LogicalResult LoopPipelinerInternal::createKernel(
561
557
}
562
558
563
559
if (predicates[useStage]) {
560
+ OpBuilder::InsertionGuard insertGuard (rewriter);
564
561
newOp = predicateFn (rewriter, newOp, predicates[useStage]);
565
562
if (!newOp)
566
563
return failure ();
567
564
// Remap the results to the new predicated one.
568
565
for (auto values : llvm::zip (op->getResults (), newOp->getResults ()))
569
566
mapping.map (std::get<0 >(values), std::get<1 >(values));
570
567
}
571
- rewriter.setInsertionPointAfter (newOp);
572
568
if (annotateFn)
573
569
annotateFn (newOp, PipeliningOption::PipelinerPart::Kernel, 0 );
574
570
}
@@ -640,70 +636,121 @@ LogicalResult LoopPipelinerInternal::createKernel(
640
636
return success ();
641
637
}
642
638
643
- void LoopPipelinerInternal::emitEpilogue (
644
- RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
639
+ LogicalResult
640
+ LoopPipelinerInternal::emitEpilogue (RewriterBase &rewriter,
641
+ llvm::SmallVector<Value> &returnValues) {
642
+ Location loc = forOp.getLoc ();
645
643
// Emit different versions of the induction variable. They will be
646
644
// removed by dead code if not used.
645
+
646
+ // bounds_range = ub - lb
647
+ // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
648
+ Type t = lb.getType ();
649
+ Value minus1 =
650
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
651
+
652
+ Value const_0 =
653
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, 0 ));
654
+ Value const_1 =
655
+ rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, 1 ));
656
+ Value boundsRange = rewriter.create <arith::SubIOp>(loc, ub, lb);
657
+ Value boundsRem = rewriter.create <arith::RemUIOp>(loc, boundsRange, step);
658
+ Value hasRem = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
659
+ boundsRem, const_0);
660
+ Value totalIterations = rewriter.create <arith::AddIOp>(
661
+ loc, rewriter.create <arith::DivUIOp>(loc, boundsRange, step),
662
+ rewriter.create <arith::SelectOp>(loc, hasRem, const_1, const_0));
663
+
664
+ SmallVector<Value> predicates (maxStage + 1 );
647
665
for (int64_t i = 0 ; i < maxStage; i++) {
648
- Location loc = forOp.getLoc ();
649
- Type t = lb.getType ();
650
- Value minusOne =
651
- rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -1 ));
652
- // number of iterations = ((ub - 1) - lb) / step
653
- Value totalNumIteration = rewriter.create <arith::DivUIOp>(
654
- loc,
655
- rewriter.create <arith::SubIOp>(
656
- loc, rewriter.create <arith::AddIOp>(loc, ub, minusOne), lb),
657
- step);
658
- // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
666
+ // iterI = total_iters - 1 - i
667
+ // May go negative...
659
668
Value minusI =
660
669
rewriter.create <arith::ConstantOp>(loc, rewriter.getIntegerAttr (t, -i));
670
+ Value iterI = rewriter.create <arith::AddIOp>(
671
+ loc, rewriter.create <arith::AddIOp>(loc, totalIterations, minus1),
672
+ minusI);
673
+ // newLastIter = lb + step * iterI
661
674
Value newlastIter = rewriter.create <arith::AddIOp>(
662
- loc, lb,
663
- rewriter.create <arith::MulIOp>(
664
- loc, step,
665
- rewriter.create <arith::AddIOp>(loc, totalNumIteration, minusI)));
675
+ loc, lb, rewriter.create <arith::MulIOp>(loc, step, iterI));
676
+
666
677
setValueMapping (forOp.getInductionVar (), newlastIter, maxStage - i);
678
+
679
+ if (dynamicLoop) {
680
+ // pred = iterI >= lb
681
+ predicates[i + 1 ] = rewriter.create <arith::CmpIOp>(
682
+ loc, arith::CmpIPredicate::sge, iterI, lb);
683
+ }
667
684
}
685
+
668
686
// Emit `maxStage - 1` epilogue part that includes operations from stages
669
687
// [i; maxStage].
670
688
for (int64_t i = 1 ; i <= maxStage; i++) {
689
+ SmallVector<std::pair<Value, unsigned >> returnMap (returnValues.size ());
671
690
for (Operation *op : opOrder) {
672
691
if (stages[op] < i)
673
692
continue ;
693
+ unsigned currentVersion = maxStage - stages[op] + i;
694
+ unsigned nextVersion = currentVersion + 1 ;
674
695
Operation *newOp =
675
696
cloneAndUpdateOperands (rewriter, op, [&](OpOperand *newOperand) {
676
697
auto it = valueMapping.find (newOperand->get ());
677
698
if (it != valueMapping.end ()) {
678
- Value replacement = it->second [maxStage - stages[op] + i ];
699
+ Value replacement = it->second [currentVersion ];
679
700
newOperand->set (replacement);
680
701
}
681
702
});
682
703
if (annotateFn)
683
704
annotateFn (newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1 );
684
- for (unsigned destId : llvm::seq (unsigned (0 ), op->getNumResults ())) {
685
- setValueMapping (op->getResult (destId), newOp->getResult (destId),
686
- maxStage - stages[op] + i);
705
+ if (dynamicLoop) {
706
+ OpBuilder::InsertionGuard insertGuard (rewriter);
707
+ newOp = predicateFn (rewriter, newOp, predicates[currentVersion]);
708
+ if (!newOp)
709
+ return failure ();
710
+ }
711
+
712
+ for (auto [opRes, newRes] :
713
+ llvm::zip (op->getResults (), newOp->getResults ())) {
714
+ setValueMapping (opRes, newRes, currentVersion);
687
715
// If the value is a loop carried dependency update the loop argument
688
716
// mapping and keep track of the last version to replace the original
689
717
// forOp uses.
690
718
for (OpOperand &operand :
691
719
forOp.getBody ()->getTerminator ()->getOpOperands ()) {
692
- if (operand.get () != op-> getResult (destId) )
720
+ if (operand.get () != opRes )
693
721
continue ;
694
- unsigned version = maxStage - stages[op] + i + 1 ;
695
722
// If the version is greater than maxStage it means it maps to the
696
723
// original forOp returned value.
697
- if (version > maxStage) {
698
- returnValues[operand.getOperandNumber ()] = newOp->getResult (destId);
699
- continue ;
700
- }
701
- setValueMapping (forOp.getRegionIterArgs ()[operand.getOperandNumber ()],
702
- newOp->getResult (destId), version);
724
+ unsigned ri = operand.getOperandNumber ();
725
+ returnValues[ri] = newRes;
726
+ Value mapVal = forOp.getRegionIterArgs ()[ri];
727
+ returnMap[ri] = std::make_pair (mapVal, currentVersion);
728
+ if (nextVersion <= maxStage)
729
+ setValueMapping (mapVal, newRes, nextVersion);
730
+ }
731
+ }
732
+ }
733
+ if (dynamicLoop) {
734
+ // Select return values from this stage (live outs) based on predication.
735
+ // If the stage is valid select the peeled value, else use previous stage
736
+ // value.
737
+ for (auto pair : llvm::enumerate (returnValues)) {
738
+ unsigned ri = pair.index ();
739
+ auto [mapVal, currentVersion] = returnMap[ri];
740
+ if (mapVal) {
741
+ unsigned nextVersion = currentVersion + 1 ;
742
+ Value pred = predicates[currentVersion];
743
+ Value prevValue = valueMapping[mapVal][currentVersion];
744
+ auto selOp = rewriter.create <arith::SelectOp>(loc, pred, pair.value (),
745
+ prevValue);
746
+ returnValues[ri] = selOp;
747
+ if (nextVersion <= maxStage)
748
+ setValueMapping (mapVal, selOp, nextVersion);
703
749
}
704
750
}
705
751
}
706
752
}
753
+ return success ();
707
754
}
708
755
709
756
void LoopPipelinerInternal::setValueMapping (Value key, Value el, int64_t idx) {
@@ -760,7 +807,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
760
807
if (options.peelEpilogue ) {
761
808
// 4. Emit the epilogue after the new forOp.
762
809
rewriter.setInsertionPointAfter (newForOp);
763
- pipeliner.emitEpilogue (rewriter, returnValues);
810
+ if (failed (pipeliner.emitEpilogue (rewriter, returnValues)))
811
+ return failure ();
764
812
}
765
813
// 5. Erase the original loop and replace the uses with the epilogue output.
766
814
if (forOp->getNumResults () > 0 )
0 commit comments