Skip to content

Commit d9a97e6

Browse files
committed
Add RAII object to manage mapping of the op's arguments.
1 parent 60aa135 commit d9a97e6

File tree

1 file changed

+57
-50
lines changed

1 file changed

+57
-50
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 57 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,40 +1011,51 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10111011
return success();
10121012
}
10131013

1014-
/// Replace the region arguments of the parallel op (which correspond to private
1015-
/// variables) with the actual private variables they correspond to. This
1016-
/// prepares the parallel op so that it matches what is expected by the
1017-
/// OMPIRBuilder. Instead of editing the original op in-place, this function
1018-
/// does the required changes to a cloned version which should then be erased by
1019-
/// the caller.
1020-
static omp::ParallelOp
1021-
prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1022-
Region &region = opInst.getRegion();
1023-
auto privateVars = opInst.getPrivateVars();
1024-
1025-
auto privateVarsIt = privateVars.begin();
1026-
// Reduction precede private arguments, so skip them first.
1027-
unsigned privateArgBeginIdx = opInst.getNumReductionVars();
1028-
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
1029-
1030-
mlir::IRMapping mapping;
1031-
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032-
++argIdx, ++privateVarsIt)
1033-
mapping.map(region.getArgument(argIdx), *privateVarsIt);
1034-
1035-
mlir::OpBuilder cloneBuilder(opInst);
1036-
omp::ParallelOp opInstClone =
1037-
llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst, mapping));
1038-
1039-
return opInstClone;
1040-
}
1014+
/// A RAII class that on construction replaces the region arguments of the
1015+
/// parallel op (which correspond to private variables) with the actual private
1016+
/// variables they correspond to. This prepares the parallel op so that it
1017+
/// matches what is expected by the OMPIRBuilder. Instead of editing the
1018+
/// original op in-place, this function does the required changes to a cloned
1019+
/// version which should then be erased by the caller.
1020+
///
1021+
/// On desctruction, it restores the original state of the operation so that on
1022+
/// the MLIR side, the op is not affected by conversion to LLVM IR.
1023+
class OmpParallelOpConversionManager {
1024+
public:
1025+
OmpParallelOpConversionManager(omp::ParallelOp opInst)
1026+
: region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1027+
privateArgBeginIdx(opInst.getNumReductionVars()),
1028+
privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1029+
auto privateVarsIt = privateVars.begin();
1030+
1031+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032+
++argIdx, ++privateVarsIt)
1033+
mlir::replaceAllUsesInRegionWith(region.getArgument(argIdx),
1034+
*privateVarsIt, region);
1035+
}
1036+
1037+
~OmpParallelOpConversionManager() {
1038+
auto privateVarsIt = privateVars.begin();
1039+
1040+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1041+
++argIdx, ++privateVarsIt)
1042+
mlir::replaceAllUsesInRegionWith(*privateVarsIt,
1043+
region.getArgument(argIdx), region);
1044+
}
1045+
1046+
private:
1047+
Region &region;
1048+
OperandRange privateVars;
1049+
unsigned privateArgBeginIdx;
1050+
unsigned privateArgEndIdx;
1051+
};
10411052

10421053
/// Converts the OpenMP parallel operation to LLVM IR.
10431054
static LogicalResult
10441055
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10451056
LLVM::ModuleTranslation &moduleTranslation) {
10461057
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1047-
omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
1058+
OmpParallelOpConversionManager raii(opInst);
10481059

10491060
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10501061
// relying on captured variables.
@@ -1054,12 +1065,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10541065
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
10551066
// Collect reduction declarations
10561067
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1057-
collectReductionDecls(opInstClone, reductionDecls);
1068+
collectReductionDecls(opInst, reductionDecls);
10581069

10591070
// Allocate reduction vars
10601071
SmallVector<llvm::Value *> privateReductionVariables;
10611072
DenseMap<Value, llvm::Value *> reductionVariableMap;
1062-
allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
1073+
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
10631074
reductionDecls, privateReductionVariables,
10641075
reductionVariableMap);
10651076

@@ -1071,7 +1082,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10711082

10721083
// Initialize reduction vars
10731084
builder.restoreIP(allocaIP);
1074-
for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
1085+
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
10751086
SmallVector<llvm::Value *> phis;
10761087
if (failed(inlineConvertOmpRegions(
10771088
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1092,19 +1103,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10921103
// ParallelOp has only one region associated with it.
10931104
builder.restoreIP(codeGenIP);
10941105
auto regionBlock =
1095-
convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
1106+
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
10961107
moduleTranslation, bodyGenStatus);
10971108

10981109
// Process the reductions if required.
1099-
if (opInstClone.getNumReductionVars() > 0) {
1110+
if (opInst.getNumReductionVars() > 0) {
11001111
// Collect reduction info
11011112
SmallVector<OwningReductionGen> owningReductionGens;
11021113
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
11031114
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1104-
collectReductionInfo(opInstClone, builder, moduleTranslation,
1105-
reductionDecls, owningReductionGens,
1106-
owningAtomicReductionGens, privateReductionVariables,
1107-
reductionInfos);
1115+
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1116+
owningReductionGens, owningAtomicReductionGens,
1117+
privateReductionVariables, reductionInfos);
11081118

11091119
// Move to region cont block
11101120
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1117,8 +1127,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11171127
ompBuilder->createReductions(builder.saveIP(), allocaIP,
11181128
reductionInfos, false);
11191129
if (!contInsertPoint.getBlock()) {
1120-
bodyGenStatus = opInstClone->emitOpError()
1121-
<< "failed to convert reductions";
1130+
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
11221131
return;
11231132
}
11241133

@@ -1140,9 +1149,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11401149
// returned.
11411150
auto [privVar, privatizerClone] =
11421151
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1143-
if (!opInstClone.getPrivateVars().empty()) {
1144-
auto privVars = opInstClone.getPrivateVars();
1145-
auto privatizers = opInstClone.getPrivatizers();
1152+
if (!opInst.getPrivateVars().empty()) {
1153+
auto privVars = opInst.getPrivateVars();
1154+
auto privatizers = opInst.getPrivatizers();
11461155

11471156
for (auto [privVar, privatizerAttr] :
11481157
llvm::zip_equal(privVars, *privatizers)) {
@@ -1155,7 +1164,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11551164
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
11561165
omp::PrivateClauseOp privatizer =
11571166
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1158-
opInstClone, privSym);
1167+
opInst, privSym);
11591168

11601169
// Clone the privatizer in case it is used by more than one parallel
11611170
// region. The privatizer is processed in-place (see below) before it
@@ -1192,9 +1201,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11921201
SmallVector<llvm::Value *, 1> yieldedValues;
11931202
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
11941203
moduleTranslation, &yieldedValues))) {
1195-
opInstClone.emitError(
1196-
"failed to inline `alloc` region of an `omp.private` "
1197-
"op in the parallel region");
1204+
opInst.emitError("failed to inline `alloc` region of an `omp.private` "
1205+
"op in the parallel region");
11981206
bodyGenStatus = failure();
11991207
} else {
12001208
assert(yieldedValues.size() == 1);
@@ -1213,13 +1221,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12131221
auto finiCB = [&](InsertPointTy codeGenIP) {};
12141222

12151223
llvm::Value *ifCond = nullptr;
1216-
if (auto ifExprVar = opInstClone.getIfExprVar())
1224+
if (auto ifExprVar = opInst.getIfExprVar())
12171225
ifCond = moduleTranslation.lookupValue(ifExprVar);
12181226
llvm::Value *numThreads = nullptr;
1219-
if (auto numThreadsVar = opInstClone.getNumThreadsVar())
1227+
if (auto numThreadsVar = opInst.getNumThreadsVar())
12201228
numThreads = moduleTranslation.lookupValue(numThreadsVar);
12211229
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1222-
if (auto bind = opInstClone.getProcBindVal())
1230+
if (auto bind = opInst.getProcBindVal())
12231231
pbKind = getProcBindKind(*bind);
12241232
// TODO: Is the Parallel construct cancellable?
12251233
bool isCancellable = false;
@@ -1232,7 +1240,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
12321240
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
12331241
ifCond, numThreads, pbKind, isCancellable));
12341242

1235-
opInstClone.erase();
12361243
return bodyGenStatus;
12371244
}
12381245

0 commit comments

Comments
 (0)