@@ -89,7 +89,7 @@ struct LoopPipelinerInternal {
89
89
bool initializeLoopInfo (ForOp op, const triton::PipeliningOption &options);
90
90
// / Emits the prologue, this creates `maxStage - 1` part which will contain
91
91
// / operations from stages [0; i], where i is the part index.
92
- void emitPrologue (RewriterBase &rewriter);
92
+ LogicalResult emitPrologue (RewriterBase &rewriter);
93
93
// / Gather liverange information for Values that are used in a different stage
94
94
// / than its definition.
95
95
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues ();
@@ -275,7 +275,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
275
275
return clone;
276
276
}
277
277
278
- void LoopPipelinerInternal::emitPrologue (RewriterBase &rewriter) {
278
+ LogicalResult LoopPipelinerInternal::emitPrologue (RewriterBase &rewriter) {
279
279
// Initialize the iteration argument to the loop initiale values.
280
280
for (auto [arg, operand] :
281
281
llvm::zip (forOp.getRegionIterArgs (), forOp.getInitsMutable ())) {
@@ -323,7 +323,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
323
323
if (predicates[predicateIdx]) {
324
324
OpBuilder::InsertionGuard insertGuard (rewriter);
325
325
newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
326
- assert (newOp && " failed to predicate op." );
326
+ if (newOp == nullptr )
327
+ return failure ();
327
328
}
328
329
if (annotateFn)
329
330
annotateFn (newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
@@ -351,6 +352,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
351
352
}
352
353
}
353
354
}
355
+ return success ();
354
356
}
355
357
356
358
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -787,7 +789,8 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
787
789
*modifiedIR = true ;
788
790
789
791
// 1. Emit prologue.
790
- pipeliner.emitPrologue (rewriter);
792
+ if (failed (pipeliner.emitPrologue (rewriter)))
793
+ return failure ();
791
794
792
795
// 2. Track values used across stages. When a value cross stages it will
793
796
// need to be passed as loop iteration arguments.
0 commit comments