@@ -90,7 +90,8 @@ struct LoopPipelinerInternal {
90
90
RewriterBase &rewriter);
91
91
// / Emits the epilogue, this creates `maxStage - 1` part which will contain
92
92
// / operations from stages [i; maxStage], where i is the part index.
93
- llvm::SmallVector<Value> emitEpilogue (RewriterBase &rewriter);
93
+ void emitEpilogue (RewriterBase &rewriter,
94
+ llvm::SmallVector<Value> &returnValues);
94
95
};
95
96
96
97
bool LoopPipelinerInternal::initializeLoopInfo (
@@ -175,15 +176,18 @@ bool LoopPipelinerInternal::initializeLoopInfo(
175
176
}
176
177
}
177
178
178
- // Only support loop carried dependency with a distance of 1. This means the
179
- // source of all the scf.yield operands needs to be defined by operations in
180
- // the loop.
179
+ // Support only loop-carried dependencies with a distance of one iteration or
180
+ // those defined outside of the loop. This means that any dependency within a
181
+ // loop should either be on the immediately preceding iteration, the current
182
+ // iteration, or on variables whose values are set before entering the loop.
181
183
if (llvm::any_of (forOp.getBody ()->getTerminator ()->getOperands (),
182
184
[this ](Value operand) {
183
185
Operation *def = operand.getDefiningOp ();
184
- return !def || !stages.contains (def);
186
+ return !def ||
187
+ (!stages.contains (def) && forOp->isAncestor (def));
185
188
})) {
186
- LDBG (" --only support loop carried dependency with a distance of 1 -> BAIL" );
189
+ LDBG (" --only support loop carried dependency with a distance of 1 or "
190
+ " defined outside of the loop -> BAIL" );
187
191
return false ;
188
192
}
189
193
annotateFn = options.annotateFn ;
@@ -341,12 +345,17 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
341
345
for (const auto &retVal :
342
346
llvm::enumerate (forOp.getBody ()->getTerminator ()->getOperands ())) {
343
347
Operation *def = retVal.value ().getDefiningOp ();
344
- assert (def && " Only support loop carried dependencies of distance 1" );
345
- unsigned defStage = stages[def];
346
- Value valueVersion = valueMapping[forOp.getRegionIterArgs ()[retVal.index ()]]
347
- [maxStage - defStage];
348
- assert (valueVersion);
349
- newLoopArg.push_back (valueVersion);
348
+ assert (def && " Only support loop carried dependencies of distance of 1 or "
349
+ " outside the loop" );
350
+ auto defStage = stages.find (def);
351
+ if (defStage != stages.end ()) {
352
+ Value valueVersion =
353
+ valueMapping[forOp.getRegionIterArgs ()[retVal.index ()]]
354
+ [maxStage - defStage->second ];
355
+ assert (valueVersion);
356
+ newLoopArg.push_back (valueVersion);
357
+ } else
358
+ newLoopArg.push_back (forOp.getInitArgs ()[retVal.index ()]);
350
359
}
351
360
for (auto escape : crossStageValues) {
352
361
LiverangeInfo &info = escape.second ;
@@ -551,21 +560,25 @@ LogicalResult LoopPipelinerInternal::createKernel(
551
560
for (const auto &retVal :
552
561
llvm::enumerate (forOp.getBody ()->getTerminator ()->getOperands ())) {
553
562
Operation *def = retVal.value ().getDefiningOp ();
554
- assert (def && " Only support loop carried dependencies of distance 1" );
555
- unsigned defStage = stages[def];
556
- if (defStage > 0 ) {
563
+ assert (def && " Only support loop carried dependencies of distance of 1 or "
564
+ " defined outside the loop" );
565
+ auto defStage = stages.find (def);
566
+ if (defStage == stages.end ()) {
567
+ for (unsigned int stage = 1 ; stage <= maxStage; stage++)
568
+ setValueMapping (forOp.getRegionIterArgs ()[retVal.index ()],
569
+ retVal.value (), stage);
570
+ } else if (defStage->second > 0 ) {
557
571
setValueMapping (forOp.getRegionIterArgs ()[retVal.index ()],
558
572
newForOp->getResult (retVal.index ()),
559
- maxStage - defStage + 1 );
573
+ maxStage - defStage-> second + 1 );
560
574
}
561
575
}
562
576
rewriter.create <scf::YieldOp>(forOp.getLoc (), yieldOperands);
563
577
return success ();
564
578
}
565
579
566
- llvm::SmallVector<Value>
567
- LoopPipelinerInternal::emitEpilogue (RewriterBase &rewriter) {
568
- llvm::SmallVector<Value> returnValues (forOp->getNumResults ());
580
+ void LoopPipelinerInternal::emitEpilogue (
581
+ RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
569
582
// Emit different versions of the induction variable. They will be
570
583
// removed by dead code if not used.
571
584
for (int64_t i = 0 ; i < maxStage; i++) {
@@ -628,7 +641,6 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
628
641
}
629
642
}
630
643
}
631
- return returnValues;
632
644
}
633
645
634
646
void LoopPipelinerInternal::setValueMapping (Value key, Value el, int64_t idx) {
@@ -685,7 +697,7 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
685
697
if (options.peelEpilogue ) {
686
698
// 4. Emit the epilogue after the new forOp.
687
699
rewriter.setInsertionPointAfter (newForOp);
688
- returnValues = pipeliner.emitEpilogue (rewriter);
700
+ pipeliner.emitEpilogue (rewriter, returnValues );
689
701
}
690
702
// 5. Erase the original loop and replace the uses with the epilogue output.
691
703
if (forOp->getNumResults () > 0 )
0 commit comments