@@ -61,6 +61,11 @@ struct LoopPipelinerInternal {
61
61
// / `idx` of `key` in the epilogue.
62
62
void setValueMapping (Value key, Value el, int64_t idx);
63
63
64
+ // / Return the defining op of the given value, if the Value is an argument of
65
+ // / the loop return the associated defining op in the loop and its distance to
66
+ // / the Value.
67
+ std::pair<Operation *, int64_t > getDefiningOpAndDistance (Value value);
68
+
64
69
public:
65
70
// / Initalize the information for the given `op`, return true if it
66
71
// / satisfies the pre-condition to apply pipelining.
@@ -240,11 +245,12 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
240
245
unsigned stage = stages[op];
241
246
242
247
auto analyzeOperand = [&](OpOperand &operand) {
243
- Operation * def = operand.get (). getDefiningOp ( );
248
+ auto [ def, distance] = getDefiningOpAndDistance ( operand.get ());
244
249
if (!def)
245
250
return ;
246
251
auto defStage = stages.find (def);
247
- if (defStage == stages.end () || defStage->second == stage)
252
+ if (defStage == stages.end () || defStage->second == stage ||
253
+ defStage->second == stage + distance)
248
254
return ;
249
255
assert (stage > defStage->second );
250
256
LiverangeInfo &info = crossStageValues[operand.get ()];
@@ -261,6 +267,25 @@ LoopPipelinerInternal::analyzeCrossStageValues() {
261
267
return crossStageValues;
262
268
}
263
269
270
+ std::pair<Operation *, int64_t >
271
+ LoopPipelinerInternal::getDefiningOpAndDistance (Value value) {
272
+ int64_t distance = 0 ;
273
+ if (auto arg = dyn_cast<BlockArgument>(value)) {
274
+ if (arg.getOwner () != forOp.getBody ())
275
+ return {nullptr , 0 };
276
+ // Ignore induction variable.
277
+ if (arg.getArgNumber () == 0 )
278
+ return {nullptr , 0 };
279
+ distance++;
280
+ value =
281
+ forOp.getBody ()->getTerminator ()->getOperand (arg.getArgNumber () - 1 );
282
+ }
283
+ Operation *def = value.getDefiningOp ();
284
+ if (!def)
285
+ return {nullptr , 0 };
286
+ return {def, distance};
287
+ }
288
+
264
289
scf::ForOp LoopPipelinerInternal::createKernelLoop (
265
290
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
266
291
&crossStageValues,
@@ -366,10 +391,9 @@ LogicalResult LoopPipelinerInternal::createKernel(
366
391
rewriter.setInsertionPointAfter (newOp);
367
392
continue ;
368
393
}
369
- auto arg = dyn_cast<BlockArgument>(operand->get ());
394
+ Value source = operand->get ();
395
+ auto arg = dyn_cast<BlockArgument>(source);
370
396
if (arg && arg.getOwner () == forOp.getBody ()) {
371
- // If the value is a loop carried value coming from stage N + 1 remap,
372
- // it will become a direct use.
373
397
Value ret = forOp.getBody ()->getTerminator ()->getOperand (
374
398
arg.getArgNumber () - 1 );
375
399
Operation *dep = ret.getDefiningOp ();
@@ -378,15 +402,19 @@ LogicalResult LoopPipelinerInternal::createKernel(
378
402
auto stageDep = stages.find (dep);
379
403
if (stageDep == stages.end () || stageDep->second == useStage)
380
404
continue ;
381
- assert (stageDep->second == useStage + 1 );
382
- nestedNewOp->setOperand (operand->getOperandNumber (),
383
- mapping.lookupOrDefault (ret));
384
- continue ;
405
+ // If the value is a loop carried value coming from stage N + 1 remap,
406
+ // it will become a direct use.
407
+ if (stageDep->second == useStage + 1 ) {
408
+ nestedNewOp->setOperand (operand->getOperandNumber (),
409
+ mapping.lookupOrDefault (ret));
410
+ continue ;
411
+ }
412
+ source = ret;
385
413
}
386
414
// For operands defined in a previous stage we need to remap it to use
387
415
// the correct region argument. We look for the right version of the
388
416
// Value based on the stage where it is used.
389
- Operation *def = operand-> get () .getDefiningOp ();
417
+ Operation *def = source .getDefiningOp ();
390
418
if (!def)
391
419
continue ;
392
420
auto stageDef = stages.find (def);
@@ -418,9 +446,29 @@ LogicalResult LoopPipelinerInternal::createKernel(
418
446
// We create a mapping between original values and the associated loop
419
447
// returned values that will be needed by the epilogue.
420
448
llvm::SmallVector<Value> yieldOperands;
421
- for (Value retVal : forOp.getBody ()->getTerminator ()->getOperands ()) {
422
- yieldOperands.push_back (mapping.lookupOrDefault (retVal));
449
+ for (OpOperand &yieldOperand :
450
+ forOp.getBody ()->getTerminator ()->getOpOperands ()) {
451
+ Value source = mapping.lookupOrDefault (yieldOperand.get ());
452
+ // When we don't peel the epilogue and the yield value is used outside the
453
+ // loop we need to make sure we return the version from numStages -
454
+ // defStage.
455
+ if (!peelEpilogue &&
456
+ !forOp.getResult (yieldOperand.getOperandNumber ()).use_empty ()) {
457
+ Operation *def = getDefiningOpAndDistance (yieldOperand.get ()).first ;
458
+ if (def) {
459
+ auto defStage = stages.find (def);
460
+ if (defStage != stages.end () && defStage->second < maxStage) {
461
+ Value pred = predicates[defStage->second ];
462
+ source = rewriter.create <arith::SelectOp>(
463
+ pred.getLoc (), pred, source,
464
+ newForOp.getBody ()
465
+ ->getArguments ()[yieldOperand.getOperandNumber () + 1 ]);
466
+ }
467
+ }
468
+ }
469
+ yieldOperands.push_back (source);
423
470
}
471
+
424
472
for (auto &it : crossStageValues) {
425
473
int64_t version = maxStage - it.second .lastUseStage + 1 ;
426
474
unsigned numVersionReturned = it.second .lastUseStage - it.second .defStage ;
@@ -444,9 +492,11 @@ LogicalResult LoopPipelinerInternal::createKernel(
444
492
Operation *def = retVal.value ().getDefiningOp ();
445
493
assert (def && " Only support loop carried dependencies of distance 1" );
446
494
unsigned defStage = stages[def];
447
- setValueMapping (forOp.getRegionIterArgs ()[retVal.index ()],
448
- newForOp->getResult (retVal.index ()),
449
- maxStage - defStage + 1 );
495
+ if (defStage > 0 ) {
496
+ setValueMapping (forOp.getRegionIterArgs ()[retVal.index ()],
497
+ newForOp->getResult (retVal.index ()),
498
+ maxStage - defStage + 1 );
499
+ }
450
500
}
451
501
rewriter.create <scf::YieldOp>(forOp.getLoc (), yieldOperands);
452
502
return success ();
0 commit comments