@@ -372,15 +372,16 @@ static void generateUnrolledLoop(
372
372
loopBodyBlock->getTerminator ()->setOperands (lastYielded);
373
373
}
374
374
375
- // / Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
376
- LogicalResult mlir::loopUnrollByFactor (
375
+ // / Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
376
+ // / eplilog loop, if the loop is unrolled.
377
+ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor (
377
378
scf::ForOp forOp, uint64_t unrollFactor,
378
379
function_ref<void (unsigned , Operation *, OpBuilder)> annotateFn) {
379
380
assert (unrollFactor > 0 && " expected positive unroll factor" );
380
381
381
382
// Return if the loop body is empty.
382
383
if (llvm::hasSingleElement (forOp.getBody ()->getOperations ()))
383
- return success () ;
384
+ return UnrolledLoopInfo{forOp, std::nullopt} ;
384
385
385
386
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
386
387
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -402,7 +403,7 @@ LogicalResult mlir::loopUnrollByFactor(
402
403
if (*constTripCount == 1 &&
403
404
failed (forOp.promoteIfSingleIteration (rewriter)))
404
405
return failure ();
405
- return success () ;
406
+ return UnrolledLoopInfo{forOp, std::nullopt} ;
406
407
}
407
408
408
409
int64_t tripCountEvenMultiple =
@@ -450,6 +451,8 @@ LogicalResult mlir::loopUnrollByFactor(
450
451
boundsBuilder.create <arith::MulIOp>(loc, step, unrollFactorCst);
451
452
}
452
453
454
+ UnrolledLoopInfo resultLoops;
455
+
453
456
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
454
457
if (generateEpilogueLoop) {
455
458
OpBuilder epilogueBuilder (forOp->getContext ());
@@ -467,7 +470,8 @@ LogicalResult mlir::loopUnrollByFactor(
467
470
}
468
471
epilogueForOp->setOperands (epilogueForOp.getNumControlOperands (),
469
472
epilogueForOp.getInitArgs ().size (), results);
470
- (void )epilogueForOp.promoteIfSingleIteration (rewriter);
473
+ if (epilogueForOp.promoteIfSingleIteration (rewriter).failed ())
474
+ resultLoops.epilogueLoopOp = epilogueForOp;
471
475
}
472
476
473
477
// Create unrolled loop.
@@ -489,8 +493,9 @@ LogicalResult mlir::loopUnrollByFactor(
489
493
},
490
494
annotateFn, iterArgs, yieldedValues);
491
495
// Promote the loop body up if this has turned into a single iteration loop.
492
- (void )forOp.promoteIfSingleIteration (rewriter);
493
- return success ();
496
+ if (forOp.promoteIfSingleIteration (rewriter).failed ())
497
+ resultLoops.mainLoopOp = forOp;
498
+ return resultLoops;
494
499
}
495
500
496
501
// / Check if bounds of all inner loops are defined outside of `forOp`
0 commit comments