@@ -1000,11 +1000,39 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1000
1000
return success ();
1001
1001
}
1002
1002
1003
+ // / Replace the region arguments of the parallel op (which correspond to private
1004
+ // / variables) with the actual private varibles they correspond to. This
1005
+ // / prepares the parallel op so that it matches what is expected by the
1006
+ // / OMPIRBuilder. Instead of editing the original op in-place, this function
1007
+ // / does the required changes to a cloned version which should then be erased by
1008
+ // / the caller.
1009
+ static omp::ParallelOp
1010
+ prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1011
+ mlir::OpBuilder cloneBuilder (opInst);
1012
+ omp::ParallelOp opInstClone =
1013
+ llvm::cast<omp::ParallelOp>(cloneBuilder.clone (*opInst));
1014
+
1015
+ Region ®ion = opInstClone.getRegion ();
1016
+ auto privateVars = opInstClone.getPrivateVars ();
1017
+
1018
+ auto privateVarsIt = privateVars.begin ();
1019
+ // Reduction precede private arguments, so skip them first.
1020
+ unsigned privateArgBeginIdx = opInstClone.getNumReductionVars ();
1021
+ unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size ();
1022
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1023
+ ++argIdx, ++privateVarsIt)
1024
+ replaceAllUsesInRegionWith (region.getArgument (argIdx), *privateVarsIt,
1025
+ region);
1026
+ return opInstClone;
1027
+ }
1028
+
1003
1029
// / Converts the OpenMP parallel operation to LLVM IR.
1004
1030
static LogicalResult
1005
1031
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1006
1032
LLVM::ModuleTranslation &moduleTranslation) {
1007
1033
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1034
+ omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization (opInst);
1035
+
1008
1036
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1009
1037
// relying on captured variables.
1010
1038
LogicalResult bodyGenStatus = success ();
@@ -1013,12 +1041,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1013
1041
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1014
1042
// Collect reduction declarations
1015
1043
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1016
- collectReductionDecls (opInst , reductionDecls);
1044
+ collectReductionDecls (opInstClone , reductionDecls);
1017
1045
1018
1046
// Allocate reduction vars
1019
1047
SmallVector<llvm::Value *> privateReductionVariables;
1020
1048
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021
- allocReductionVars (opInst , builder, moduleTranslation, allocaIP,
1049
+ allocReductionVars (opInstClone , builder, moduleTranslation, allocaIP,
1022
1050
reductionDecls, privateReductionVariables,
1023
1051
reductionVariableMap);
1024
1052
@@ -1030,7 +1058,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1030
1058
1031
1059
// Initialize reduction vars
1032
1060
builder.restoreIP (allocaIP);
1033
- for (unsigned i = 0 ; i < opInst .getNumReductionVars (); ++i) {
1061
+ for (unsigned i = 0 ; i < opInstClone .getNumReductionVars (); ++i) {
1034
1062
SmallVector<llvm::Value *> phis;
1035
1063
if (failed (inlineConvertOmpRegions (
1036
1064
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
@@ -1051,18 +1079,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1051
1079
// ParallelOp has only one region associated with it.
1052
1080
builder.restoreIP (codeGenIP);
1053
1081
auto regionBlock =
1054
- convertOmpOpRegions (opInst .getRegion (), " omp.par.region" , builder,
1082
+ convertOmpOpRegions (opInstClone .getRegion (), " omp.par.region" , builder,
1055
1083
moduleTranslation, bodyGenStatus);
1056
1084
1057
1085
// Process the reductions if required.
1058
- if (opInst .getNumReductionVars () > 0 ) {
1086
+ if (opInstClone .getNumReductionVars () > 0 ) {
1059
1087
// Collect reduction info
1060
1088
SmallVector<OwningReductionGen> owningReductionGens;
1061
1089
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1062
1090
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1063
- collectReductionInfo (opInst, builder, moduleTranslation, reductionDecls,
1064
- owningReductionGens, owningAtomicReductionGens,
1065
- privateReductionVariables, reductionInfos);
1091
+ collectReductionInfo (opInstClone, builder, moduleTranslation,
1092
+ reductionDecls, owningReductionGens,
1093
+ owningAtomicReductionGens, privateReductionVariables,
1094
+ reductionInfos);
1066
1095
1067
1096
// Move to region cont block
1068
1097
builder.SetInsertPoint (regionBlock->getTerminator ());
@@ -1075,7 +1104,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1075
1104
ompBuilder->createReductions (builder.saveIP (), allocaIP,
1076
1105
reductionInfos, false );
1077
1106
if (!contInsertPoint.getBlock ()) {
1078
- bodyGenStatus = opInst->emitOpError () << " failed to convert reductions" ;
1107
+ bodyGenStatus = opInstClone->emitOpError ()
1108
+ << " failed to convert reductions" ;
1079
1109
return ;
1080
1110
}
1081
1111
@@ -1086,12 +1116,97 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1086
1116
1087
1117
// TODO: Perform appropriate actions according to the data-sharing
1088
1118
// attribute (shared, private, firstprivate, ...) of variables.
1089
- // Currently defaults to shared .
1119
+ // Currently shared and private are supported .
1090
1120
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1091
1121
llvm::Value &, llvm::Value &vPtr,
1092
1122
llvm::Value *&replacementValue) -> InsertPointTy {
1093
1123
replacementValue = &vPtr;
1094
1124
1125
+ // If this is a private value, this lambda will return the corresponding
1126
+ // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1127
+ // returned.
1128
+ auto [privVar, privatizerClone] =
1129
+ [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1130
+ if (!opInstClone.getPrivateVars ().empty ()) {
1131
+ auto privVars = opInstClone.getPrivateVars ();
1132
+ auto privatizers = opInstClone.getPrivatizers ();
1133
+
1134
+ for (auto [privVar, privatizerAttr] :
1135
+ llvm::zip_equal (privVars, *privatizers)) {
1136
+ // Find the MLIR private variable corresponding to the LLVM value
1137
+ // being privatized.
1138
+ llvm::Value *llvmPrivVar = moduleTranslation.lookupValue (privVar);
1139
+ if (llvmPrivVar != &vPtr)
1140
+ continue ;
1141
+
1142
+ SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1143
+ omp::PrivateClauseOp privatizer =
1144
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1145
+ opInstClone, privSym);
1146
+
1147
+ // Clone the privatizer in case it used by more than one parallel
1148
+ // region. The privatizer is processed in-place (see below) before it
1149
+ // gets inlined in the parallel region and therefore processing the
1150
+ // original op is dangerous.
1151
+ return {privVar, privatizer.clone ()};
1152
+ }
1153
+ }
1154
+
1155
+ return {mlir::Value (), omp::PrivateClauseOp ()};
1156
+ }();
1157
+
1158
+ if (privVar) {
1159
+ if (privatizerClone.getDataSharingType () ==
1160
+ omp::DataSharingClauseType::FirstPrivate) {
1161
+ privatizerClone.emitOpError (
1162
+ " TODO: delayed privatization is not "
1163
+ " supported for `firstprivate` clauses yet." );
1164
+ bodyGenStatus = failure ();
1165
+ return codeGenIP;
1166
+ }
1167
+
1168
+ Region &allocRegion = privatizerClone.getAllocRegion ();
1169
+
1170
+ if (!allocRegion.hasOneBlock ()) {
1171
+ privatizerClone.emitOpError (
1172
+ " TODO: multi-block alloc regions are not supported yet." );
1173
+ bodyGenStatus = failure ();
1174
+ return codeGenIP;
1175
+ }
1176
+
1177
+ // Replace the privatizer block argument with mlir value being privatized.
1178
+ // This way, the body of the privatizer will be changed from using the
1179
+ // region/block argument to the value being privatized.
1180
+ auto allocRegionArg = allocRegion.getArgument (0 );
1181
+ replaceAllUsesInRegionWith (allocRegionArg, privVar, allocRegion);
1182
+
1183
+ auto oldIP = builder.saveIP ();
1184
+ builder.restoreIP (allocaIP);
1185
+
1186
+ // Temporarily unlink the terminator from its parent since
1187
+ // `inlineConvertOmpRegions` expects the insertion block to **not**
1188
+ // contain a terminator.
1189
+ llvm::Instruction &allocaTerminator = builder.GetInsertBlock ()->back ();
1190
+ assert (allocaTerminator.isTerminator ());
1191
+ allocaTerminator.removeFromParent ();
1192
+
1193
+ SmallVector<llvm::Value *, 1 > yieldedValues;
1194
+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1195
+ moduleTranslation, &yieldedValues))) {
1196
+ opInstClone.emitError (
1197
+ " failed to inline `alloc` region of an `omp.private` "
1198
+ " op in the parallel region" );
1199
+ bodyGenStatus = failure ();
1200
+ } else {
1201
+ assert (yieldedValues.size () == 1 );
1202
+ replacementValue = yieldedValues.front ();
1203
+ }
1204
+
1205
+ allocaTerminator.insertAfter (&builder.GetInsertBlock ()->back ());
1206
+ privatizerClone.erase ();
1207
+ builder.restoreIP (oldIP);
1208
+ }
1209
+
1095
1210
return codeGenIP;
1096
1211
};
1097
1212
@@ -1100,13 +1215,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1100
1215
auto finiCB = [&](InsertPointTy codeGenIP) {};
1101
1216
1102
1217
llvm::Value *ifCond = nullptr ;
1103
- if (auto ifExprVar = opInst .getIfExprVar ())
1218
+ if (auto ifExprVar = opInstClone .getIfExprVar ())
1104
1219
ifCond = moduleTranslation.lookupValue (ifExprVar);
1105
1220
llvm::Value *numThreads = nullptr ;
1106
- if (auto numThreadsVar = opInst .getNumThreadsVar ())
1221
+ if (auto numThreadsVar = opInstClone .getNumThreadsVar ())
1107
1222
numThreads = moduleTranslation.lookupValue (numThreadsVar);
1108
1223
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1109
- if (auto bind = opInst .getProcBindVal ())
1224
+ if (auto bind = opInstClone .getProcBindVal ())
1110
1225
pbKind = getProcBindKind (*bind);
1111
1226
// TODO: Is the Parallel construct cancellable?
1112
1227
bool isCancellable = false ;
@@ -1119,6 +1234,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1119
1234
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1120
1235
ifCond, numThreads, pbKind, isCancellable));
1121
1236
1237
+ opInstClone.erase ();
1122
1238
return bodyGenStatus;
1123
1239
}
1124
1240
@@ -3009,12 +3125,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
3009
3125
.Case ([&](omp::TargetOp) {
3010
3126
return convertOmpTarget (*op, builder, moduleTranslation);
3011
3127
})
3012
- .Case <omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
3013
- // No-op, should be handled by relevant owning operations e.g.
3014
- // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3015
- // discarded
3016
- return success ();
3017
- })
3128
+ .Case <omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3129
+ [&](auto op) {
3130
+ // No-op, should be handled by relevant owning operations e.g.
3131
+ // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3132
+ // discarded
3133
+ return success ();
3134
+ })
3018
3135
.Default ([&](Operation *inst) {
3019
3136
return inst->emitError (" unsupported OpenMP operation: " )
3020
3137
<< inst->getName ();
0 commit comments