@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396
396
397
397
// / Translates the blocks contained in the given region and appends them to at
398
398
// / the current insertion point of `builder`. The operations of the entry block
399
- // / are appended to the current insertion block, which is not expected to have a
400
- // / terminator. If set, `continuationBlockArgs` is populated with translated
401
- // / values that correspond to the values omp.yield'ed from the region.
399
+ // / are appended to the current insertion block. If set, `continuationBlockArgs`
400
+ // / is populated with translated values that correspond to the values
401
+ // / omp.yield'ed from the region.
402
402
static LogicalResult inlineConvertOmpRegions (
403
403
Region ®ion, StringRef blockName, llvm::IRBuilderBase &builder,
404
404
LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409
409
// Special case for single-block regions that don't create additional blocks:
410
410
// insert operations without creating additional blocks.
411
411
if (llvm::hasSingleElement (region)) {
412
+ llvm::Instruction *potentialTerminator =
413
+ builder.GetInsertBlock ()->empty () ? nullptr
414
+ : &builder.GetInsertBlock ()->back ();
415
+
416
+ if (potentialTerminator && potentialTerminator->isTerminator ())
417
+ potentialTerminator->removeFromParent ();
412
418
moduleTranslation.mapBlock (®ion.front (), builder.GetInsertBlock ());
419
+
413
420
if (failed (moduleTranslation.convertBlock (
414
421
region.front (), /* ignoreArguments=*/ true , builder)))
415
422
return failure ();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423
430
// Drop the mapping that is no longer necessary so that the same region can
424
431
// be processed multiple times.
425
432
moduleTranslation.forgetMapping (region);
433
+
434
+ if (potentialTerminator && potentialTerminator->isTerminator ())
435
+ potentialTerminator->insertAfter (&builder.GetInsertBlock ()->back ());
436
+
426
437
return success ();
427
438
}
428
439
@@ -1000,11 +1011,39 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1000
1011
return success ();
1001
1012
}
1002
1013
1014
+ // / Replace the region arguments of the parallel op (which correspond to private
1015
+ // / variables) with the actual private variables they correspond to. This
1016
+ // / prepares the parallel op so that it matches what is expected by the
1017
+ // / OMPIRBuilder. Instead of editing the original op in-place, this function
1018
+ // / does the required changes to a cloned version which should then be erased by
1019
+ // / the caller.
1020
+ static omp::ParallelOp
1021
+ prepareOmpParallelForPrivatization (omp::ParallelOp opInst) {
1022
+ mlir::OpBuilder cloneBuilder (opInst);
1023
+ omp::ParallelOp opInstClone =
1024
+ llvm::cast<omp::ParallelOp>(cloneBuilder.clone (*opInst));
1025
+
1026
+ Region ®ion = opInstClone.getRegion ();
1027
+ auto privateVars = opInstClone.getPrivateVars ();
1028
+
1029
+ auto privateVarsIt = privateVars.begin ();
1030
+ // Reduction precede private arguments, so skip them first.
1031
+ unsigned privateArgBeginIdx = opInstClone.getNumReductionVars ();
1032
+ unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size ();
1033
+ for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1034
+ ++argIdx, ++privateVarsIt)
1035
+ replaceAllUsesInRegionWith (region.getArgument (argIdx), *privateVarsIt,
1036
+ region);
1037
+ return opInstClone;
1038
+ }
1039
+
1003
1040
// / Converts the OpenMP parallel operation to LLVM IR.
1004
1041
static LogicalResult
1005
1042
convertOmpParallel (omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1006
1043
LLVM::ModuleTranslation &moduleTranslation) {
1007
1044
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1045
+ omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization (opInst);
1046
+
1008
1047
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
1009
1048
// relying on captured variables.
1010
1049
LogicalResult bodyGenStatus = success ();
@@ -1013,12 +1052,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1013
1052
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1014
1053
// Collect reduction declarations
1015
1054
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1016
- collectReductionDecls (opInst , reductionDecls);
1055
+ collectReductionDecls (opInstClone , reductionDecls);
1017
1056
1018
1057
// Allocate reduction vars
1019
1058
SmallVector<llvm::Value *> privateReductionVariables;
1020
1059
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021
- allocReductionVars (opInst , builder, moduleTranslation, allocaIP,
1060
+ allocReductionVars (opInstClone , builder, moduleTranslation, allocaIP,
1022
1061
reductionDecls, privateReductionVariables,
1023
1062
reductionVariableMap);
1024
1063
@@ -1030,7 +1069,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1030
1069
1031
1070
// Initialize reduction vars
1032
1071
builder.restoreIP (allocaIP);
1033
- for (unsigned i = 0 ; i < opInst .getNumReductionVars (); ++i) {
1072
+ for (unsigned i = 0 ; i < opInstClone .getNumReductionVars (); ++i) {
1034
1073
SmallVector<llvm::Value *> phis;
1035
1074
if (failed (inlineConvertOmpRegions (
1036
1075
reductionDecls[i].getInitializerRegion (), " omp.reduction.neutral" ,
@@ -1051,18 +1090,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1051
1090
// ParallelOp has only one region associated with it.
1052
1091
builder.restoreIP (codeGenIP);
1053
1092
auto regionBlock =
1054
- convertOmpOpRegions (opInst .getRegion (), " omp.par.region" , builder,
1093
+ convertOmpOpRegions (opInstClone .getRegion (), " omp.par.region" , builder,
1055
1094
moduleTranslation, bodyGenStatus);
1056
1095
1057
1096
// Process the reductions if required.
1058
- if (opInst .getNumReductionVars () > 0 ) {
1097
+ if (opInstClone .getNumReductionVars () > 0 ) {
1059
1098
// Collect reduction info
1060
1099
SmallVector<OwningReductionGen> owningReductionGens;
1061
1100
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1062
1101
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1063
- collectReductionInfo (opInst, builder, moduleTranslation, reductionDecls,
1064
- owningReductionGens, owningAtomicReductionGens,
1065
- privateReductionVariables, reductionInfos);
1102
+ collectReductionInfo (opInstClone, builder, moduleTranslation,
1103
+ reductionDecls, owningReductionGens,
1104
+ owningAtomicReductionGens, privateReductionVariables,
1105
+ reductionInfos);
1066
1106
1067
1107
// Move to region cont block
1068
1108
builder.SetInsertPoint (regionBlock->getTerminator ());
@@ -1075,7 +1115,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1075
1115
ompBuilder->createReductions (builder.saveIP (), allocaIP,
1076
1116
reductionInfos, false );
1077
1117
if (!contInsertPoint.getBlock ()) {
1078
- bodyGenStatus = opInst->emitOpError () << " failed to convert reductions" ;
1118
+ bodyGenStatus = opInstClone->emitOpError ()
1119
+ << " failed to convert reductions" ;
1079
1120
return ;
1080
1121
}
1081
1122
@@ -1086,12 +1127,82 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1086
1127
1087
1128
// TODO: Perform appropriate actions according to the data-sharing
1088
1129
// attribute (shared, private, firstprivate, ...) of variables.
1089
- // Currently defaults to shared .
1130
+ // Currently shared and private are supported .
1090
1131
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
1091
1132
llvm::Value &, llvm::Value &vPtr,
1092
1133
llvm::Value *&replacementValue) -> InsertPointTy {
1093
1134
replacementValue = &vPtr;
1094
1135
1136
+ // If this is a private value, this lambda will return the corresponding
1137
+ // mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1138
+ // returned.
1139
+ auto [privVar, privatizerClone] =
1140
+ [&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1141
+ if (!opInstClone.getPrivateVars ().empty ()) {
1142
+ auto privVars = opInstClone.getPrivateVars ();
1143
+ auto privatizers = opInstClone.getPrivatizers ();
1144
+
1145
+ for (auto [privVar, privatizerAttr] :
1146
+ llvm::zip_equal (privVars, *privatizers)) {
1147
+ // Find the MLIR private variable corresponding to the LLVM value
1148
+ // being privatized.
1149
+ llvm::Value *llvmPrivVar = moduleTranslation.lookupValue (privVar);
1150
+ if (llvmPrivVar != &vPtr)
1151
+ continue ;
1152
+
1153
+ SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1154
+ omp::PrivateClauseOp privatizer =
1155
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1156
+ opInstClone, privSym);
1157
+
1158
+ // Clone the privatizer in case it is used by more than one parallel
1159
+ // region. The privatizer is processed in-place (see below) before it
1160
+ // gets inlined in the parallel region and therefore processing the
1161
+ // original op is dangerous.
1162
+ return {privVar, privatizer.clone ()};
1163
+ }
1164
+ }
1165
+
1166
+ return {mlir::Value (), omp::PrivateClauseOp ()};
1167
+ }();
1168
+
1169
+ if (privVar) {
1170
+ if (privatizerClone.getDataSharingType () ==
1171
+ omp::DataSharingClauseType::FirstPrivate) {
1172
+ privatizerClone.emitOpError (
1173
+ " TODO: delayed privatization is not "
1174
+ " supported for `firstprivate` clauses yet." );
1175
+ bodyGenStatus = failure ();
1176
+ return codeGenIP;
1177
+ }
1178
+
1179
+ Region &allocRegion = privatizerClone.getAllocRegion ();
1180
+
1181
+ // Replace the privatizer block argument with mlir value being privatized.
1182
+ // This way, the body of the privatizer will be changed from using the
1183
+ // region/block argument to the value being privatized.
1184
+ auto allocRegionArg = allocRegion.getArgument (0 );
1185
+ replaceAllUsesInRegionWith (allocRegionArg, privVar, allocRegion);
1186
+
1187
+ auto oldIP = builder.saveIP ();
1188
+ builder.restoreIP (allocaIP);
1189
+
1190
+ SmallVector<llvm::Value *, 1 > yieldedValues;
1191
+ if (failed (inlineConvertOmpRegions (allocRegion, " omp.privatizer" , builder,
1192
+ moduleTranslation, &yieldedValues))) {
1193
+ opInstClone.emitError (
1194
+ " failed to inline `alloc` region of an `omp.private` "
1195
+ " op in the parallel region" );
1196
+ bodyGenStatus = failure ();
1197
+ } else {
1198
+ assert (yieldedValues.size () == 1 );
1199
+ replacementValue = yieldedValues.front ();
1200
+ }
1201
+
1202
+ privatizerClone.erase ();
1203
+ builder.restoreIP (oldIP);
1204
+ }
1205
+
1095
1206
return codeGenIP;
1096
1207
};
1097
1208
@@ -1100,13 +1211,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1100
1211
auto finiCB = [&](InsertPointTy codeGenIP) {};
1101
1212
1102
1213
llvm::Value *ifCond = nullptr ;
1103
- if (auto ifExprVar = opInst .getIfExprVar ())
1214
+ if (auto ifExprVar = opInstClone .getIfExprVar ())
1104
1215
ifCond = moduleTranslation.lookupValue (ifExprVar);
1105
1216
llvm::Value *numThreads = nullptr ;
1106
- if (auto numThreadsVar = opInst .getNumThreadsVar ())
1217
+ if (auto numThreadsVar = opInstClone .getNumThreadsVar ())
1107
1218
numThreads = moduleTranslation.lookupValue (numThreadsVar);
1108
1219
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1109
- if (auto bind = opInst .getProcBindVal ())
1220
+ if (auto bind = opInstClone .getProcBindVal ())
1110
1221
pbKind = getProcBindKind (*bind);
1111
1222
// TODO: Is the Parallel construct cancellable?
1112
1223
bool isCancellable = false ;
@@ -1119,6 +1230,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
1119
1230
ompBuilder->createParallel (ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
1120
1231
ifCond, numThreads, pbKind, isCancellable));
1121
1232
1233
+ opInstClone.erase ();
1122
1234
return bodyGenStatus;
1123
1235
}
1124
1236
@@ -1635,7 +1747,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
1635
1747
// A small helper structure to contain data gathered
1636
1748
// for map lowering and coalese it into one area and
1637
1749
// avoiding extra computations such as searches in the
1638
- // llvm module for lowered mapped varibles or checking
1750
+ // llvm module for lowered mapped variables or checking
1639
1751
// if something is declare target (and retrieving the
1640
1752
// value) more than neccessary.
1641
1753
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -3009,12 +3121,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
3009
3121
.Case ([&](omp::TargetOp) {
3010
3122
return convertOmpTarget (*op, builder, moduleTranslation);
3011
3123
})
3012
- .Case <omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
3013
- // No-op, should be handled by relevant owning operations e.g.
3014
- // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3015
- // discarded
3016
- return success ();
3017
- })
3124
+ .Case <omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3125
+ [&](auto op) {
3126
+ // No-op, should be handled by relevant owning operations e.g.
3127
+ // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3128
+ // discarded
3129
+ return success ();
3130
+ })
3018
3131
.Default ([&](Operation *inst) {
3019
3132
return inst->emitError (" unsupported OpenMP operation: " )
3020
3133
<< inst->getName ();
0 commit comments