@@ -89,7 +89,7 @@ struct LoopPipelinerInternal {
89
89
bool initializeLoopInfo (ForOp op, const triton::PipeliningOption &options);
90
90
// / Emits the prologue, this creates `maxStage - 1` part which will contain
91
91
// / operations from stages [0; i], where i is the part index.
92
- void emitPrologue (RewriterBase &rewriter);
92
+ LogicalResult emitPrologue (RewriterBase &rewriter);
93
93
// / Gather liverange information for Values that are used in a different stage
94
94
// / than its definition.
95
95
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues ();
@@ -275,7 +275,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
275
275
return clone;
276
276
}
277
277
278
- void LoopPipelinerInternal::emitPrologue (RewriterBase &rewriter) {
278
+ LogicalResult LoopPipelinerInternal::emitPrologue (RewriterBase &rewriter) {
279
279
// Initialize the iteration argument to the loop initiale values.
280
280
for (auto [arg, operand] :
281
281
llvm::zip (forOp.getRegionIterArgs (), forOp.getInitsMutable ())) {
@@ -323,7 +323,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
323
323
if (predicates[predicateIdx]) {
324
324
OpBuilder::InsertionGuard insertGuard (rewriter);
325
325
newOp = predicateFn (rewriter, newOp, predicates[predicateIdx]);
326
- assert (newOp && " failed to predicate op." );
326
+ if (newOp == nullptr )
327
+ return failure ();
327
328
}
328
329
if (annotateFn)
329
330
annotateFn (newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
@@ -341,6 +342,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
341
342
}
342
343
}
343
344
}
345
+ return success ();
344
346
}
345
347
346
348
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
@@ -777,7 +779,8 @@ mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
777
779
*modifiedIR = true ;
778
780
779
781
// 1. Emit prologue.
780
- pipeliner.emitPrologue (rewriter);
782
+ if (failed (pipeliner.emitPrologue (rewriter)))
783
+ return failure ();
781
784
782
785
// 2. Track values used across stages. When a value cross stages it will
783
786
// need to be passed as loop iteration arguments.
0 commit comments