Skip to content

[MLIR][SCF] Loop pipelining fails on failed predication (no assert) #107442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct LoopPipelinerInternal {
bool initializeLoopInfo(ForOp op, const PipeliningOption &options);
/// Emits the prologue, this creates `maxStage - 1` part which will contain
/// operations from stages [0; i], where i is the part index.
void emitPrologue(RewriterBase &rewriter);
LogicalResult emitPrologue(RewriterBase &rewriter);
/// Gather liverange information for Values that are used in a different stage
/// than its definition.
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
Expand Down Expand Up @@ -267,7 +267,7 @@ cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
return clone;
}

void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
// Initialize the iteration argument to the loop initial values.
for (auto [arg, operand] :
llvm::zip(forOp.getRegionIterArgs(), forOp.getInitsMutable())) {
Expand Down Expand Up @@ -314,7 +314,8 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
int predicateIdx = i - stages[op];
if (predicates[predicateIdx]) {
newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
assert(newOp && "failed to predicate op.");
if (newOp == nullptr)
return failure();
}
rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
Expand Down Expand Up @@ -343,6 +344,7 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
}
}
}
return success();
}

llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
Expand Down Expand Up @@ -733,7 +735,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
*modifiedIR = true;

// 1. Emit prologue.
pipeliner.emitPrologue(rewriter);
if (failed(pipeliner.emitPrologue(rewriter)))
return failure();

// 2. Track values used across stages. When a value cross stages it will
// need to be passed as loop iteration arguments.
Expand Down
Loading