Skip to content

Commit 8e9349b

Browse files
committed
Handle review comments
1 parent 3fda4f3 commit 8e9349b

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,26 +1003,36 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10031003
/// Replace the region arguments of the parallel op (which correspond to private
10041004
/// variables) with the actual private varibles they correspond to. This
10051005
/// prepares the parallel op so that it matches what is expected by the
1006-
/// OMPIRBuilder.
1007-
static void prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1008-
Region &region = opInst.getRegion();
1009-
auto privateVars = opInst.getPrivateVars();
1006+
/// OMPIRBuilder. Instead of editing the original op in-place, this function
1007+
/// does the required changes to a cloned version which should then be erased by
1008+
/// the caller.
1009+
static omp::ParallelOp
1010+
prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1011+
mlir::OpBuilder cloneBuilder(opInst);
1012+
omp::ParallelOp opInstClone =
1013+
llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst));
1014+
1015+
Region &region = opInstClone.getRegion();
1016+
auto privateVars = opInstClone.getPrivateVars();
10101017

10111018
auto privateVarsIt = privateVars.begin();
10121019
// Reduction precede private arguments, so skip them first.
1013-
unsigned privateArgBeginIdx = opInst.getNumReductionVars();
1020+
unsigned privateArgBeginIdx = opInstClone.getNumReductionVars();
10141021
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
10151022
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
10161023
++argIdx, ++privateVarsIt)
10171024
replaceAllUsesInRegionWith(region.getArgument(argIdx), *privateVarsIt,
10181025
region);
1026+
return opInstClone;
10191027
}
10201028

10211029
/// Converts the OpenMP parallel operation to LLVM IR.
10221030
static LogicalResult
10231031
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10241032
LLVM::ModuleTranslation &moduleTranslation) {
10251033
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1034+
omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
1035+
10261036
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10271037
// relying on captured variables.
10281038
LogicalResult bodyGenStatus = success();
@@ -1031,12 +1041,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10311041
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
10321042
// Collect reduction declarations
10331043
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1034-
collectReductionDecls(opInst, reductionDecls);
1044+
collectReductionDecls(opInstClone, reductionDecls);
10351045

10361046
// Allocate reduction vars
10371047
SmallVector<llvm::Value *> privateReductionVariables;
10381048
DenseMap<Value, llvm::Value *> reductionVariableMap;
1039-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1049+
allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
10401050
reductionDecls, privateReductionVariables,
10411051
reductionVariableMap);
10421052

@@ -1048,7 +1058,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10481058

10491059
// Initialize reduction vars
10501060
builder.restoreIP(allocaIP);
1051-
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1061+
for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
10521062
SmallVector<llvm::Value *> phis;
10531063
if (failed(inlineConvertOmpRegions(
10541064
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1061,8 +1071,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10611071
builder.CreateStore(phis[0], privateReductionVariables[i]);
10621072
}
10631073

1064-
prepareOmpParallelForPrivatization(opInst);
1065-
10661074
// Save the alloca insertion point on ModuleTranslation stack for use in
10671075
// nested regions.
10681076
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
@@ -1071,18 +1079,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10711079
// ParallelOp has only one region associated with it.
10721080
builder.restoreIP(codeGenIP);
10731081
auto regionBlock =
1074-
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
1082+
convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
10751083
moduleTranslation, bodyGenStatus);
10761084

10771085
// Process the reductions if required.
1078-
if (opInst.getNumReductionVars() > 0) {
1086+
if (opInstClone.getNumReductionVars() > 0) {
10791087
// Collect reduction info
10801088
SmallVector<OwningReductionGen> owningReductionGens;
10811089
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
10821090
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1083-
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1084-
owningReductionGens, owningAtomicReductionGens,
1085-
privateReductionVariables, reductionInfos);
1091+
collectReductionInfo(opInstClone, builder, moduleTranslation,
1092+
reductionDecls, owningReductionGens,
1093+
owningAtomicReductionGens, privateReductionVariables,
1094+
reductionInfos);
10861095

10871096
// Move to region cont block
10881097
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1095,7 +1104,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10951104
ompBuilder->createReductions(builder.saveIP(), allocaIP,
10961105
reductionInfos, false);
10971106
if (!contInsertPoint.getBlock()) {
1098-
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
1107+
bodyGenStatus = opInstClone->emitOpError()
1108+
<< "failed to convert reductions";
10991109
return;
11001110
}
11011111

@@ -1117,9 +1127,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11171127
// returned.
11181128
auto [privVar, privatizerClone] =
11191129
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1120-
if (!opInst.getPrivateVars().empty()) {
1121-
auto privVars = opInst.getPrivateVars();
1122-
auto privatizers = opInst.getPrivatizers();
1130+
if (!opInstClone.getPrivateVars().empty()) {
1131+
auto privVars = opInstClone.getPrivateVars();
1132+
auto privatizers = opInstClone.getPrivatizers();
11231133

11241134
for (auto [privVar, privatizerAttr] :
11251135
llvm::zip_equal(privVars, *privatizers)) {
@@ -1132,7 +1142,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11321142
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
11331143
omp::PrivateClauseOp privatizer =
11341144
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1135-
opInst, privSym);
1145+
opInstClone, privSym);
11361146

11371147
// Clone the privatizer in case it used by more than one parallel
11381148
// region. The privatizer is processed in-place (see below) before it
@@ -1159,9 +1169,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11591169

11601170
if (!allocRegion.hasOneBlock()) {
11611171
privatizerClone.emitOpError(
1162-
"TODO: multi-block alloc regions are not supported yet. Seems "
1163-
"like there is a difference in `inlineConvertOmpRegions`'s "
1164-
"pre-conditions for single- and multi-block regions.");
1172+
"TODO: multi-block alloc regions are not supported yet.");
11651173
bodyGenStatus = failure();
11661174
return codeGenIP;
11671175
}
@@ -1185,8 +1193,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11851193
SmallVector<llvm::Value *, 1> yieldedValues;
11861194
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
11871195
moduleTranslation, &yieldedValues))) {
1188-
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
1189-
"op in the parallel region");
1196+
opInstClone.emitError(
1197+
"failed to inline `alloc` region of an `omp.private` "
1198+
"op in the parallel region");
11901199
bodyGenStatus = failure();
11911200
} else {
11921201
assert(yieldedValues.size() == 1);
@@ -1206,13 +1215,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12061215
auto finiCB = [&](InsertPointTy codeGenIP) {};
12071216

12081217
llvm::Value *ifCond = nullptr;
1209-
if (auto ifExprVar = opInst.getIfExprVar())
1218+
if (auto ifExprVar = opInstClone.getIfExprVar())
12101219
ifCond = moduleTranslation.lookupValue(ifExprVar);
12111220
llvm::Value *numThreads = nullptr;
1212-
if (auto numThreadsVar = opInst.getNumThreadsVar())
1221+
if (auto numThreadsVar = opInstClone.getNumThreadsVar())
12131222
numThreads = moduleTranslation.lookupValue(numThreadsVar);
12141223
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1215-
if (auto bind = opInst.getProcBindVal())
1224+
if (auto bind = opInstClone.getProcBindVal())
12161225
pbKind = getProcBindKind(*bind);
12171226
// TODO: Is the Parallel construct cancellable?
12181227
bool isCancellable = false;
@@ -1225,6 +1234,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12251234
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
12261235
ifCond, numThreads, pbKind, isCancellable));
12271236

1237+
opInstClone.erase();
12281238
return bodyGenStatus;
12291239
}
12301240

0 commit comments

Comments
 (0)