Skip to content

Commit 1892666

Browse files
authored
[MLIR][SCF] Loop pipelining fails on failed predication (no assert) (llvm#107442)
The SCFLoopPipelining allows predication on peeled or loop ops. When the predicationFn returns a nullptr this signifies the op type is unsupported and the pipeliner fails except in `emitPrologue` where it asserts. This patch fixes handling in the prologue to gracefully fail.
1 parent d219c63 commit 1892666

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ struct LoopPipelinerInternal {
7777
bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
7878
/// Emits the prologue, this creates `maxStage - 1` part which will contain
7979
/// operations from stages [0; i], where i is the part index.
80-
void emitPrologue(RewriterBase &rewriter);
80+
LogicalResult emitPrologue(RewriterBase &rewriter);
8181
/// Gather liverange information for Values that are used in a different stage
8282
/// than its definition.
8383
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
@@ -263,7 +263,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
263263
return clone;
264264
}
265265

266-
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
266+
LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
267267
// Initialize the iteration argument to the loop initial values.
268268
for (auto [arg, operand] :
269269
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
@@ -311,7 +311,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
311311
if (predicates[predicateIdx]) {
312312
OpBuilder::InsertionGuard insertGuard(rewriter);
313313
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
314-
assert(newOp && "failed to predicate op.");
314+
if (newOp == nullptr)
315+
return failure();
315316
}
316317
if (annotateFn)
317318
annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
@@ -339,6 +340,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
339340
}
340341
}
341342
}
343+
return success();
342344
}
343345

344346
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -772,7 +774,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
772774
*modifiedIR = true;
773775

774776
// 1. Emit prologue.
775-
pipeliner.emitPrologue(rewriter);
777+
if (failed(pipeliner.emitPrologue(rewriter)))
778+
return failure();
776779

777780
// 2. Track values used across stages. When a value cross stages it will
778781
// need to be passed as loop iteration arguments.

0 commit comments

Comments
 (0)