@@ -1011,40 +1011,51 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1011
1011
return success ();
1012
1012
}
1013
1013
1014
- // / Replace the region arguments of the parallel op (which correspond to private
1015
- // / variables) with the actual private variables they correspond to. This
1016
- // / prepares the parallel op so that it matches what is expected by the
1017
- // / OMPIRBuilder. Instead of editing the original op in-place, this function
1018
- // / does the required changes to a cloned version which should then be erased by
1019
- // / the caller.
1020
- static omp::ParallelOp
1021
- prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1022
- Region ®ion = opInst.getRegion ();
1023
- auto privateVars = opInst.getPrivateVars ();
1024
-
1025
- auto privateVarsIt = privateVars.begin ();
1026
- // Reduction precede private arguments, so skip them first.
1027
- unsigned privateArgBeginIdx = opInst.getNumReductionVars ();
1028
- unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size ();
1029
-
1030
- mlir::IRMapping mapping;
1031
- for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032
- ++argIdx, ++privateVarsIt)
1033
- mapping.map (region.getArgument (argIdx), *privateVarsIt);
1034
-
1035
- mlir::OpBuilder cloneBuilder (opInst);
1036
- omp::ParallelOp opInstClone =
1037
- llvm::cast<omp::ParallelOp>(cloneBuilder.clone (*opInst, mapping));
1038
-
1039
- return opInstClone;
1040
- }
1014
+ // / A RAII class that on construction replaces the region arguments of the
1015
+ // / parallel op (which correspond to private variables) with the actual private
1016
+ // / variables they correspond to. This prepares the parallel op so that it
1017
+ // / matches what is expected by the OMPIRBuilder. Instead of editing the
1018
+ // / original op in-place, this function does the required changes to a cloned
1019
+ // / version which should then be erased by the caller.
1020
+ // /
1021
+ // / On desctruction, it restores the original state of the operation so that on
1022
+ // / the MLIR side, the op is not affected by conversion to LLVM IR.
1023
+ class OmpParallelOpConversionManager {
1024
+ public:
1025
+ OmpParallelOpConversionManager (omp::ParallelOp opInst)
1026
+ : region(opInst.getRegion()), privateVars(opInst.getPrivateVars()),
1027
+ privateArgBeginIdx (opInst.getNumReductionVars()),
1028
+ privateArgEndIdx(privateArgBeginIdx + privateVars.size()) {
1029
+ auto privateVarsIt = privateVars.begin ();
1030
+
1031
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1032
+ ++argIdx, ++privateVarsIt)
1033
+ mlir::replaceAllUsesInRegionWith (region.getArgument (argIdx),
1034
+ *privateVarsIt, region);
1035
+ }
1036
+
1037
+ ~OmpParallelOpConversionManager () {
1038
+ auto privateVarsIt = privateVars.begin ();
1039
+
1040
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1041
+ ++argIdx, ++privateVarsIt)
1042
+ mlir::replaceAllUsesInRegionWith (*privateVarsIt,
1043
+ region.getArgument (argIdx), region);
1044
+ }
1045
+
1046
+ private:
1047
+ Region ®ion;
1048
+ OperandRange privateVars;
1049
+ unsigned privateArgBeginIdx;
1050
+ unsigned privateArgEndIdx;
1051
+ };
1041
1052
1042
1053
// / Converts the OpenMP parallel operation to LLVM IR.
1043
1054
static LogicalResult
1044
1055
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1045
1056
LLVM::ModuleTranslation &moduleTranslation) {
1046
1057
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1047
- omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization (opInst);
1058
+ OmpParallelOpConversionManager raii (opInst);
1048
1059
1049
1060
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1050
1061
// relying on captured variables.
@@ -1054,12 +1065,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1054
1065
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1055
1066
// Collect reduction declarations
1056
1067
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1057
- collectReductionDecls (opInstClone , reductionDecls);
1068
+ collectReductionDecls (opInst , reductionDecls);
1058
1069
1059
1070
// Allocate reduction vars
1060
1071
SmallVector<llvm::Value *> privateReductionVariables;
1061
1072
DenseMap<Value, llvm::Value *> reductionVariableMap;
1062
- allocReductionVars (opInstClone , builder, moduleTranslation, allocaIP,
1073
+ allocReductionVars (opInst , builder, moduleTranslation, allocaIP,
1063
1074
reductionDecls, privateReductionVariables,
1064
1075
reductionVariableMap);
1065
1076
@@ -1071,7 +1082,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1071
1082
1072
1083
// Initialize reduction vars
1073
1084
builder.restoreIP (allocaIP);
1074
- for (unsigned i = 0 ; i < opInstClone .getNumReductionVars (); ++i) {
1085
+ for (unsigned i = 0 ; i < opInst .getNumReductionVars (); ++i) {
1075
1086
SmallVector<llvm::Value *> phis;
1076
1087
if (failed (inlineConvertOmpRegions (
1077
1088
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
@@ -1092,19 +1103,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1092
1103
// ParallelOp has only one region associated with it.
1093
1104
builder.restoreIP (codeGenIP);
1094
1105
auto regionBlock =
1095
- convertOmpOpRegions (opInstClone .getRegion (), " omp.par.region" , builder,
1106
+ convertOmpOpRegions (opInst .getRegion (), " omp.par.region" , builder,
1096
1107
moduleTranslation, bodyGenStatus);
1097
1108
1098
1109
// Process the reductions if required.
1099
- if (opInstClone .getNumReductionVars () > 0 ) {
1110
+ if (opInst .getNumReductionVars () > 0 ) {
1100
1111
// Collect reduction info
1101
1112
SmallVector<OwningReductionGen> owningReductionGens;
1102
1113
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1103
1114
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1104
- collectReductionInfo (opInstClone, builder, moduleTranslation,
1105
- reductionDecls, owningReductionGens,
1106
- owningAtomicReductionGens, privateReductionVariables,
1107
- reductionInfos);
1115
+ collectReductionInfo (opInst, builder, moduleTranslation, reductionDecls,
1116
+ owningReductionGens, owningAtomicReductionGens,
1117
+ privateReductionVariables, reductionInfos);
1108
1118
1109
1119
// Move to region cont block
1110
1120
builder.SetInsertPoint (regionBlock->getTerminator ());
@@ -1117,8 +1127,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1117
1127
ompBuilder->createReductions (builder.saveIP (), allocaIP,
1118
1128
reductionInfos, false );
1119
1129
if (!contInsertPoint.getBlock ()) {
1120
- bodyGenStatus = opInstClone->emitOpError ()
1121
- << " failed to convert reductions" ;
1130
+ bodyGenStatus = opInst->emitOpError () << " failed to convert reductions" ;
1122
1131
return ;
1123
1132
}
1124
1133
@@ -1140,9 +1149,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1140
1149
// returned.
1141
1150
auto [privVar, privatizerClone] =
1142
1151
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1143
- if (!opInstClone .getPrivateVars ().empty ()) {
1144
- auto privVars = opInstClone .getPrivateVars ();
1145
- auto privatizers = opInstClone .getPrivatizers ();
1152
+ if (!opInst .getPrivateVars ().empty ()) {
1153
+ auto privVars = opInst .getPrivateVars ();
1154
+ auto privatizers = opInst .getPrivatizers ();
1146
1155
1147
1156
for (auto [privVar, privatizerAttr] :
1148
1157
llvm::zip_equal (privVars, *privatizers)) {
@@ -1155,7 +1164,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1155
1164
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1156
1165
omp::PrivateClauseOp privatizer =
1157
1166
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1158
- opInstClone , privSym);
1167
+ opInst , privSym);
1159
1168
1160
1169
// Clone the privatizer in case it is used by more than one parallel
1161
1170
// region. The privatizer is processed in-place (see below) before it
@@ -1192,9 +1201,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1192
1201
SmallVector<llvm::Value *, 1 > yieldedValues;
1193
1202
if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1194
1203
moduleTranslation, &yieldedValues))) {
1195
- opInstClone.emitError (
1196
- " failed to inline `alloc` region of an `omp.private` "
1197
- " op in the parallel region" );
1204
+ opInst.emitError (" failed to inline `alloc` region of an `omp.private` "
1205
+ " op in the parallel region" );
1198
1206
bodyGenStatus = failure ();
1199
1207
} else {
1200
1208
assert (yieldedValues.size () == 1 );
@@ -1213,13 +1221,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1213
1221
auto finiCB = [&](InsertPointTy codeGenIP) {};
1214
1222
1215
1223
llvm::Value *ifCond = nullptr ;
1216
- if (auto ifExprVar = opInstClone .getIfExprVar ())
1224
+ if (auto ifExprVar = opInst .getIfExprVar ())
1217
1225
ifCond = moduleTranslation.lookupValue (ifExprVar);
1218
1226
llvm::Value *numThreads = nullptr ;
1219
- if (auto numThreadsVar = opInstClone .getNumThreadsVar ())
1227
+ if (auto numThreadsVar = opInst .getNumThreadsVar ())
1220
1228
numThreads = moduleTranslation.lookupValue (numThreadsVar);
1221
1229
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1222
- if (auto bind = opInstClone .getProcBindVal ())
1230
+ if (auto bind = opInst .getProcBindVal ())
1223
1231
pbKind = getProcBindKind (*bind);
1224
1232
// TODO: Is the Parallel construct cancellable?
1225
1233
bool isCancellable = false ;
@@ -1232,7 +1240,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1232
1240
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1233
1241
ifCond, numThreads, pbKind, isCancellable));
1234
1242
1235
- opInstClone.erase ();
1236
1243
return bodyGenStatus;
1237
1244
}
1238
1245
0 commit comments