Skip to content

Commit fa57c7a

Browse files
authored
[mlir] Extend SCF loopUnrollByFactor to return the result loops (#114573)
There is a need of accessing the resulted epilog loop from the SC loop unroller. It'd clean and convenient to get that directly from the loop unroller instead of rescanning the whole function, as discussed in triton-lang/triton#5027 . I'm changing the result type of `loopUnrollByFactor` for that.
1 parent 6127724 commit fa57c7a

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,18 @@ LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
111111
void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
112112
ArrayRef<std::vector<unsigned>> combinedDimensions);
113113

114-
/// Unrolls this for operation by the specified unroll factor. Returns failure
115-
/// if the loop cannot be unrolled either due to restrictions or due to invalid
116-
/// unroll factors. Requires positive loop bounds and step. If specified,
117-
/// annotates the Ops in each unrolled iteration by applying `annotateFn`.
118-
LogicalResult loopUnrollByFactor(
114+
struct UnrolledLoopInfo {
115+
std::optional<scf::ForOp> mainLoopOp = std::nullopt;
116+
std::optional<scf::ForOp> epilogueLoopOp = std::nullopt;
117+
};
118+
119+
/// Unrolls this for operation by the specified unroll factor. Returns the
120+
/// unrolled main loop and the eplilog loop, if the loop is unrolled. Otherwise
121+
/// returns failure if the loop cannot be unrolled either due to restrictions or
122+
/// due to invalid unroll factors. Requires positive loop bounds and step. If
123+
/// specified, annotates the Ops in each unrolled iteration by applying
124+
/// `annotateFn`.
125+
FailureOr<UnrolledLoopInfo> loopUnrollByFactor(
119126
scf::ForOp forOp, uint64_t unrollFactor,
120127
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr);
121128

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -372,15 +372,16 @@ static void generateUnrolledLoop(
372372
loopBodyBlock->getTerminator()->setOperands(lastYielded);
373373
}
374374

375-
/// Unrolls 'forOp' by 'unrollFactor', returns success if the loop is unrolled.
376-
LogicalResult mlir::loopUnrollByFactor(
375+
/// Unrolls 'forOp' by 'unrollFactor', returns the unrolled main loop and the
376+
/// eplilog loop, if the loop is unrolled.
377+
FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor(
377378
scf::ForOp forOp, uint64_t unrollFactor,
378379
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn) {
379380
assert(unrollFactor > 0 && "expected positive unroll factor");
380381

381382
// Return if the loop body is empty.
382383
if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
383-
return success();
384+
return UnrolledLoopInfo{forOp, std::nullopt};
384385

385386
// Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate
386387
// 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases.
@@ -402,7 +403,7 @@ LogicalResult mlir::loopUnrollByFactor(
402403
if (*constTripCount == 1 &&
403404
failed(forOp.promoteIfSingleIteration(rewriter)))
404405
return failure();
405-
return success();
406+
return UnrolledLoopInfo{forOp, std::nullopt};
406407
}
407408

408409
int64_t tripCountEvenMultiple =
@@ -450,6 +451,8 @@ LogicalResult mlir::loopUnrollByFactor(
450451
boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst);
451452
}
452453

454+
UnrolledLoopInfo resultLoops;
455+
453456
// Create epilogue clean up loop starting at 'upperBoundUnrolled'.
454457
if (generateEpilogueLoop) {
455458
OpBuilder epilogueBuilder(forOp->getContext());
@@ -467,7 +470,8 @@ LogicalResult mlir::loopUnrollByFactor(
467470
}
468471
epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(),
469472
epilogueForOp.getInitArgs().size(), results);
470-
(void)epilogueForOp.promoteIfSingleIteration(rewriter);
473+
if (epilogueForOp.promoteIfSingleIteration(rewriter).failed())
474+
resultLoops.epilogueLoopOp = epilogueForOp;
471475
}
472476

473477
// Create unrolled loop.
@@ -489,8 +493,9 @@ LogicalResult mlir::loopUnrollByFactor(
489493
},
490494
annotateFn, iterArgs, yieldedValues);
491495
// Promote the loop body up if this has turned into a single iteration loop.
492-
(void)forOp.promoteIfSingleIteration(rewriter);
493-
return success();
496+
if (forOp.promoteIfSingleIteration(rewriter).failed())
497+
resultLoops.mainLoopOp = forOp;
498+
return resultLoops;
494499
}
495500

496501
/// Check if bounds of all inner loops are defined outside of `forOp`

0 commit comments

Comments
 (0)