Skip to content

[MLIR][OpenMP] Prevent loop wrapper translation crashes #115475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 33 additions & 20 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());

llvm::BasicBlock *continuationBlock =
splitBB(builder, true, "omp.region.cont");
llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
Expand All @@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(

// Terminators (namely YieldOp) may be forwarding values to the region that
// need to be available in the continuation block. Collect the types of these
// operands in preparation of creating PHI nodes.
// operands in preparation of creating PHI nodes. This is skipped for loop
// wrapper operations, for which we know in advance they have no terminators.
SmallVector<llvm::Type *> continuationBlockPHITypes;
bool operandsProcessed = false;
unsigned numYields = 0;
for (Block &bb : region.getBlocks()) {
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
if (!operandsProcessed) {
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
continuationBlockPHITypes.push_back(
moduleTranslation.convertType(yield->getOperand(i).getType()));
}
operandsProcessed = true;
} else {
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
"mismatching number of values yielded from the region");
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
llvm::Type *operandType =
moduleTranslation.convertType(yield->getOperand(i).getType());
(void)operandType;
assert(continuationBlockPHITypes[i] == operandType &&
"values of mismatching types yielded from the region");

if (!isLoopWrapper) {
bool operandsProcessed = false;
for (Block &bb : region.getBlocks()) {
if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
if (!operandsProcessed) {
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
continuationBlockPHITypes.push_back(
moduleTranslation.convertType(yield->getOperand(i).getType()));
}
operandsProcessed = true;
} else {
assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
"mismatching number of values yielded from the region");
for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
llvm::Type *operandType =
moduleTranslation.convertType(yield->getOperand(i).getType());
(void)operandType;
assert(continuationBlockPHITypes[i] == operandType &&
"values of mismatching types yielded from the region");
}
}
numYields++;
}
numYields++;
}
}

Expand Down Expand Up @@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
return llvm::make_error<PreviouslyReportedError>();

// Create a direct branch here for loop wrappers to prevent their lack of a
// terminator from causing a crash below.
if (isLoopWrapper) {
builder.CreateBr(continuationBlock);
continue;
}

// Special handling for `omp.yield` and `omp.terminator` (we may have more
// than one): they return the control to the parent OpenMP dialect operation
// so replace them with the branch to the continuation block. We handle this
Expand Down
Loading