@@ -1003,26 +1003,36 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1003
1003
// / Replace the region arguments of the parallel op (which correspond to private
1004
1004
// / variables) with the actual private varibles they correspond to. This
1005
1005
// / prepares the parallel op so that it matches what is expected by the
1006
- // / OMPIRBuilder.
1007
- static void prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1008
- Region ®ion = opInst.getRegion ();
1009
- auto privateVars = opInst.getPrivateVars ();
1006
+ // / OMPIRBuilder. Instead of editing the original op in-place, this function
1007
+ // / does the required changes to a cloned version which should then be erased by
1008
+ // / the caller.
1009
+ static omp::ParallelOp
1010
+ prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1011
+ mlir::OpBuilder cloneBuilder (opInst);
1012
+ omp::ParallelOp opInstClone =
1013
+ llvm::cast<omp::ParallelOp>(cloneBuilder.clone (*opInst));
1014
+
1015
+ Region ®ion = opInstClone.getRegion ();
1016
+ auto privateVars = opInstClone.getPrivateVars ();
1010
1017
1011
1018
auto privateVarsIt = privateVars.begin ();
1012
1019
// Reduction precede private arguments, so skip them first.
1013
- unsigned privateArgBeginIdx = opInst .getNumReductionVars ();
1020
+ unsigned privateArgBeginIdx = opInstClone .getNumReductionVars ();
1014
1021
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size ();
1015
1022
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1016
1023
++argIdx, ++privateVarsIt)
1017
1024
replaceAllUsesInRegionWith (region.getArgument (argIdx), *privateVarsIt,
1018
1025
region);
1026
+ return opInstClone;
1019
1027
}
1020
1028
1021
1029
// / Converts the OpenMP parallel operation to LLVM IR.
1022
1030
static LogicalResult
1023
1031
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1024
1032
LLVM::ModuleTranslation &moduleTranslation) {
1025
1033
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1034
+ omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization (opInst);
1035
+
1026
1036
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1027
1037
// relying on captured variables.
1028
1038
LogicalResult bodyGenStatus = success ();
@@ -1031,12 +1041,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1031
1041
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1032
1042
// Collect reduction declarations
1033
1043
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1034
- collectReductionDecls (opInst , reductionDecls);
1044
+ collectReductionDecls (opInstClone , reductionDecls);
1035
1045
1036
1046
// Allocate reduction vars
1037
1047
SmallVector<llvm::Value *> privateReductionVariables;
1038
1048
DenseMap<Value, llvm::Value *> reductionVariableMap;
1039
- allocReductionVars (opInst , builder, moduleTranslation, allocaIP,
1049
+ allocReductionVars (opInstClone , builder, moduleTranslation, allocaIP,
1040
1050
reductionDecls, privateReductionVariables,
1041
1051
reductionVariableMap);
1042
1052
@@ -1048,7 +1058,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1048
1058
1049
1059
// Initialize reduction vars
1050
1060
builder.restoreIP (allocaIP);
1051
- for (unsigned i = 0 ; i < opInst .getNumReductionVars (); ++i) {
1061
+ for (unsigned i = 0 ; i < opInstClone .getNumReductionVars (); ++i) {
1052
1062
SmallVector<llvm::Value *> phis;
1053
1063
if (failed (inlineConvertOmpRegions (
1054
1064
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
@@ -1061,8 +1071,6 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1061
1071
builder.CreateStore (phis[0 ], privateReductionVariables[i]);
1062
1072
}
1063
1073
1064
- prepareOmpParallelForPrivatization (opInst);
1065
-
1066
1074
// Save the alloca insertion point on ModuleTranslation stack for use in
1067
1075
// nested regions.
1068
1076
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame (
@@ -1071,18 +1079,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1071
1079
// ParallelOp has only one region associated with it.
1072
1080
builder.restoreIP (codeGenIP);
1073
1081
auto regionBlock =
1074
- convertOmpOpRegions (opInst .getRegion (), " omp.par.region" , builder,
1082
+ convertOmpOpRegions (opInstClone .getRegion (), " omp.par.region" , builder,
1075
1083
moduleTranslation, bodyGenStatus);
1076
1084
1077
1085
// Process the reductions if required.
1078
- if (opInst .getNumReductionVars () > 0 ) {
1086
+ if (opInstClone .getNumReductionVars () > 0 ) {
1079
1087
// Collect reduction info
1080
1088
SmallVector<OwningReductionGen> owningReductionGens;
1081
1089
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1082
1090
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1083
- collectReductionInfo (opInst, builder, moduleTranslation, reductionDecls,
1084
- owningReductionGens, owningAtomicReductionGens,
1085
- privateReductionVariables, reductionInfos);
1091
+ collectReductionInfo (opInstClone, builder, moduleTranslation,
1092
+ reductionDecls, owningReductionGens,
1093
+ owningAtomicReductionGens, privateReductionVariables,
1094
+ reductionInfos);
1086
1095
1087
1096
// Move to region cont block
1088
1097
builder.SetInsertPoint (regionBlock->getTerminator ());
@@ -1095,7 +1104,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1095
1104
ompBuilder->createReductions (builder.saveIP (), allocaIP,
1096
1105
reductionInfos, false );
1097
1106
if (!contInsertPoint.getBlock ()) {
1098
- bodyGenStatus = opInst->emitOpError () << " failed to convert reductions" ;
1107
+ bodyGenStatus = opInstClone->emitOpError ()
1108
+ << " failed to convert reductions" ;
1099
1109
return ;
1100
1110
}
1101
1111
@@ -1117,9 +1127,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1117
1127
// returned.
1118
1128
auto [privVar, privatizerClone] =
1119
1129
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1120
- if (!opInst .getPrivateVars ().empty ()) {
1121
- auto privVars = opInst .getPrivateVars ();
1122
- auto privatizers = opInst .getPrivatizers ();
1130
+ if (!opInstClone .getPrivateVars ().empty ()) {
1131
+ auto privVars = opInstClone .getPrivateVars ();
1132
+ auto privatizers = opInstClone .getPrivatizers ();
1123
1133
1124
1134
for (auto [privVar, privatizerAttr] :
1125
1135
llvm::zip_equal (privVars, *privatizers)) {
@@ -1132,7 +1142,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1132
1142
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1133
1143
omp::PrivateClauseOp privatizer =
1134
1144
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1135
- opInst , privSym);
1145
+ opInstClone , privSym);
1136
1146
1137
1147
// Clone the privatizer in case it used by more than one parallel
1138
1148
// region. The privatizer is processed in-place (see below) before it
@@ -1159,9 +1169,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1159
1169
1160
1170
if (!allocRegion.hasOneBlock ()) {
1161
1171
privatizerClone.emitOpError (
1162
- " TODO: multi-block alloc regions are not supported yet. Seems "
1163
- " like there is a difference in `inlineConvertOmpRegions`'s "
1164
- " pre-conditions for single- and multi-block regions." );
1172
+ " TODO: multi-block alloc regions are not supported yet." );
1165
1173
bodyGenStatus = failure ();
1166
1174
return codeGenIP;
1167
1175
}
@@ -1185,8 +1193,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1185
1193
SmallVector<llvm::Value *, 1 > yieldedValues;
1186
1194
if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1187
1195
moduleTranslation, &yieldedValues))) {
1188
- opInst.emitError (" failed to inline `alloc` region of an `omp.private` "
1189
- " op in the parallel region" );
1196
+ opInstClone.emitError (
1197
+ " failed to inline `alloc` region of an `omp.private` "
1198
+ " op in the parallel region" );
1190
1199
bodyGenStatus = failure ();
1191
1200
} else {
1192
1201
assert (yieldedValues.size () == 1 );
@@ -1206,13 +1215,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1206
1215
auto finiCB = [&](InsertPointTy codeGenIP) {};
1207
1216
1208
1217
llvm::Value *ifCond = nullptr ;
1209
- if (auto ifExprVar = opInst .getIfExprVar ())
1218
+ if (auto ifExprVar = opInstClone .getIfExprVar ())
1210
1219
ifCond = moduleTranslation.lookupValue (ifExprVar);
1211
1220
llvm::Value *numThreads = nullptr ;
1212
- if (auto numThreadsVar = opInst .getNumThreadsVar ())
1221
+ if (auto numThreadsVar = opInstClone .getNumThreadsVar ())
1213
1222
numThreads = moduleTranslation.lookupValue (numThreadsVar);
1214
1223
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1215
- if (auto bind = opInst .getProcBindVal ())
1224
+ if (auto bind = opInstClone .getProcBindVal ())
1216
1225
pbKind = getProcBindKind (*bind);
1217
1226
// TODO: Is the Parallel construct cancellable?
1218
1227
bool isCancellable = false ;
@@ -1225,6 +1234,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1225
1234
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1226
1235
ifCond, numThreads, pbKind, isCancellable));
1227
1236
1237
+ opInstClone.erase ();
1228
1238
return bodyGenStatus;
1229
1239
}
1230
1240
0 commit comments