Skip to content

Commit af31311

Browse files
committed
[MLIR][OpenMP] Support basic materialization for omp.private ops
Adds basic support for materializing delayed privatization. So far, the restrictions on the implementation are: - Only `private` clauses are supported (`firstprivate` support will be added in a later PR). - Only single-block `omp.private -> alloc` regions are supported (multi-block ones will be supported in a later PR).
1 parent 03203b7 commit af31311

File tree

3 files changed

+282
-24
lines changed

3 files changed

+282
-24
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1957,7 +1957,10 @@ LogicalResult PrivateClauseOp::verify() {
19571957
Type symType = getType();
19581958

19591959
auto verifyTerminator = [&](Operation *terminator) -> LogicalResult {
1960-
if (!terminator->hasSuccessors() && !llvm::isa<YieldOp>(terminator))
1960+
if (!terminator->getBlock()->getSuccessors().empty())
1961+
return success();
1962+
1963+
if (!llvm::isa<YieldOp>(terminator))
19611964
return mlir::emitError(terminator->getLoc())
19621965
<< "expected exit block terminator to be an `omp.yield` op.";
19631966

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 136 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ collectReductionDecls(T loop,
396396

397397
/// Translates the blocks contained in the given region and appends them to at
398398
/// the current insertion point of `builder`. The operations of the entry block
399-
/// are appended to the current insertion block, which is not expected to have a
400-
/// terminator. If set, `continuationBlockArgs` is populated with translated
401-
/// values that correspond to the values omp.yield'ed from the region.
399+
/// are appended to the current insertion block. If set, `continuationBlockArgs`
400+
/// is populated with translated values that correspond to the values
401+
/// omp.yield'ed from the region.
402402
static LogicalResult inlineConvertOmpRegions(
403403
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
404404
LLVM::ModuleTranslation &moduleTranslation,
@@ -409,7 +409,14 @@ static LogicalResult inlineConvertOmpRegions(
409409
// Special case for single-block regions that don't create additional blocks:
410410
// insert operations without creating additional blocks.
411411
if (llvm::hasSingleElement(region)) {
412+
llvm::Instruction *potentialTerminator =
413+
builder.GetInsertBlock()->empty() ? nullptr
414+
: &builder.GetInsertBlock()->back();
415+
416+
if (potentialTerminator && potentialTerminator->isTerminator())
417+
potentialTerminator->removeFromParent();
412418
moduleTranslation.mapBlock(&region.front(), builder.GetInsertBlock());
419+
413420
if (failed(moduleTranslation.convertBlock(
414421
region.front(), /*ignoreArguments=*/true, builder)))
415422
return failure();
@@ -423,6 +430,10 @@ static LogicalResult inlineConvertOmpRegions(
423430
// Drop the mapping that is no longer necessary so that the same region can
424431
// be processed multiple times.
425432
moduleTranslation.forgetMapping(region);
433+
434+
if (potentialTerminator && potentialTerminator->isTerminator())
435+
potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
436+
426437
return success();
427438
}
428439

@@ -1000,11 +1011,39 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10001011
return success();
10011012
}
10021013

1014+
/// Replace the region arguments of the parallel op (which correspond to private
1015+
/// variables) with the actual private variables they correspond to. This
1016+
/// prepares the parallel op so that it matches what is expected by the
1017+
/// OMPIRBuilder. Instead of editing the original op in-place, this function
1018+
/// does the required changes to a cloned version which should then be erased by
1019+
/// the caller.
1020+
static omp::ParallelOp
1021+
prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1022+
mlir::OpBuilder cloneBuilder(opInst);
1023+
omp::ParallelOp opInstClone =
1024+
llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst));
1025+
1026+
Region &region = opInstClone.getRegion();
1027+
auto privateVars = opInstClone.getPrivateVars();
1028+
1029+
auto privateVarsIt = privateVars.begin();
1030+
// Reduction precede private arguments, so skip them first.
1031+
unsigned privateArgBeginIdx = opInstClone.getNumReductionVars();
1032+
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
1033+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1034+
++argIdx, ++privateVarsIt)
1035+
replaceAllUsesInRegionWith(region.getArgument(argIdx), *privateVarsIt,
1036+
region);
1037+
return opInstClone;
1038+
}
1039+
10031040
/// Converts the OpenMP parallel operation to LLVM IR.
10041041
static LogicalResult
10051042
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10061043
LLVM::ModuleTranslation &moduleTranslation) {
10071044
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1045+
omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
1046+
10081047
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10091048
// relying on captured variables.
10101049
LogicalResult bodyGenStatus = success();
@@ -1013,12 +1052,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10131052
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
10141053
// Collect reduction declarations
10151054
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1016-
collectReductionDecls(opInst, reductionDecls);
1055+
collectReductionDecls(opInstClone, reductionDecls);
10171056

10181057
// Allocate reduction vars
10191058
SmallVector<llvm::Value *> privateReductionVariables;
10201059
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1060+
allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
10221061
reductionDecls, privateReductionVariables,
10231062
reductionVariableMap);
10241063

@@ -1030,7 +1069,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10301069

10311070
// Initialize reduction vars
10321071
builder.restoreIP(allocaIP);
1033-
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1072+
for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
10341073
SmallVector<llvm::Value *> phis;
10351074
if (failed(inlineConvertOmpRegions(
10361075
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1051,18 +1090,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10511090
// ParallelOp has only one region associated with it.
10521091
builder.restoreIP(codeGenIP);
10531092
auto regionBlock =
1054-
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
1093+
convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
10551094
moduleTranslation, bodyGenStatus);
10561095

10571096
// Process the reductions if required.
1058-
if (opInst.getNumReductionVars() > 0) {
1097+
if (opInstClone.getNumReductionVars() > 0) {
10591098
// Collect reduction info
10601099
SmallVector<OwningReductionGen> owningReductionGens;
10611100
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
10621101
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1063-
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1064-
owningReductionGens, owningAtomicReductionGens,
1065-
privateReductionVariables, reductionInfos);
1102+
collectReductionInfo(opInstClone, builder, moduleTranslation,
1103+
reductionDecls, owningReductionGens,
1104+
owningAtomicReductionGens, privateReductionVariables,
1105+
reductionInfos);
10661106

10671107
// Move to region cont block
10681108
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1075,7 +1115,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10751115
ompBuilder->createReductions(builder.saveIP(), allocaIP,
10761116
reductionInfos, false);
10771117
if (!contInsertPoint.getBlock()) {
1078-
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
1118+
bodyGenStatus = opInstClone->emitOpError()
1119+
<< "failed to convert reductions";
10791120
return;
10801121
}
10811122

@@ -1086,12 +1127,82 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10861127

10871128
// TODO: Perform appropriate actions according to the data-sharing
10881129
// attribute (shared, private, firstprivate, ...) of variables.
1089-
// Currently defaults to shared.
1130+
// Currently shared and private are supported.
10901131
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
10911132
llvm::Value &, llvm::Value &vPtr,
10921133
llvm::Value *&replacementValue) -> InsertPointTy {
10931134
replacementValue = &vPtr;
10941135

1136+
// If this is a private value, this lambda will return the corresponding
1137+
// mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1138+
// returned.
1139+
auto [privVar, privatizerClone] =
1140+
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1141+
if (!opInstClone.getPrivateVars().empty()) {
1142+
auto privVars = opInstClone.getPrivateVars();
1143+
auto privatizers = opInstClone.getPrivatizers();
1144+
1145+
for (auto [privVar, privatizerAttr] :
1146+
llvm::zip_equal(privVars, *privatizers)) {
1147+
// Find the MLIR private variable corresponding to the LLVM value
1148+
// being privatized.
1149+
llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
1150+
if (llvmPrivVar != &vPtr)
1151+
continue;
1152+
1153+
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1154+
omp::PrivateClauseOp privatizer =
1155+
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1156+
opInstClone, privSym);
1157+
1158+
// Clone the privatizer in case it is used by more than one parallel
1159+
// region. The privatizer is processed in-place (see below) before it
1160+
// gets inlined in the parallel region and therefore processing the
1161+
// original op is dangerous.
1162+
return {privVar, privatizer.clone()};
1163+
}
1164+
}
1165+
1166+
return {mlir::Value(), omp::PrivateClauseOp()};
1167+
}();
1168+
1169+
if (privVar) {
1170+
if (privatizerClone.getDataSharingType() ==
1171+
omp::DataSharingClauseType::FirstPrivate) {
1172+
privatizerClone.emitOpError(
1173+
"TODO: delayed privatization is not "
1174+
"supported for `firstprivate` clauses yet.");
1175+
bodyGenStatus = failure();
1176+
return codeGenIP;
1177+
}
1178+
1179+
Region &allocRegion = privatizerClone.getAllocRegion();
1180+
1181+
// Replace the privatizer block argument with mlir value being privatized.
1182+
// This way, the body of the privatizer will be changed from using the
1183+
// region/block argument to the value being privatized.
1184+
auto allocRegionArg = allocRegion.getArgument(0);
1185+
replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
1186+
1187+
auto oldIP = builder.saveIP();
1188+
builder.restoreIP(allocaIP);
1189+
1190+
SmallVector<llvm::Value *, 1> yieldedValues;
1191+
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
1192+
moduleTranslation, &yieldedValues))) {
1193+
opInstClone.emitError(
1194+
"failed to inline `alloc` region of an `omp.private` "
1195+
"op in the parallel region");
1196+
bodyGenStatus = failure();
1197+
} else {
1198+
assert(yieldedValues.size() == 1);
1199+
replacementValue = yieldedValues.front();
1200+
}
1201+
1202+
privatizerClone.erase();
1203+
builder.restoreIP(oldIP);
1204+
}
1205+
10951206
return codeGenIP;
10961207
};
10971208

@@ -1100,13 +1211,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11001211
auto finiCB = [&](InsertPointTy codeGenIP) {};
11011212

11021213
llvm::Value *ifCond = nullptr;
1103-
if (auto ifExprVar = opInst.getIfExprVar())
1214+
if (auto ifExprVar = opInstClone.getIfExprVar())
11041215
ifCond = moduleTranslation.lookupValue(ifExprVar);
11051216
llvm::Value *numThreads = nullptr;
1106-
if (auto numThreadsVar = opInst.getNumThreadsVar())
1217+
if (auto numThreadsVar = opInstClone.getNumThreadsVar())
11071218
numThreads = moduleTranslation.lookupValue(numThreadsVar);
11081219
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1109-
if (auto bind = opInst.getProcBindVal())
1220+
if (auto bind = opInstClone.getProcBindVal())
11101221
pbKind = getProcBindKind(*bind);
11111222
// TODO: Is the Parallel construct cancellable?
11121223
bool isCancellable = false;
@@ -1119,6 +1230,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11191230
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
11201231
ifCond, numThreads, pbKind, isCancellable));
11211232

1233+
opInstClone.erase();
11221234
return bodyGenStatus;
11231235
}
11241236

@@ -1635,7 +1747,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
16351747
// A small helper structure to contain data gathered
16361748
// for map lowering and coalese it into one area and
16371749
// avoiding extra computations such as searches in the
1638-
// llvm module for lowered mapped varibles or checking
1750+
// llvm module for lowered mapped variables or checking
16391751
// if something is declare target (and retrieving the
16401752
// value) more than neccessary.
16411753
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
@@ -3009,12 +3121,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
30093121
.Case([&](omp::TargetOp) {
30103122
return convertOmpTarget(*op, builder, moduleTranslation);
30113123
})
3012-
.Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
3013-
// No-op, should be handled by relevant owning operations e.g.
3014-
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3015-
// discarded
3016-
return success();
3017-
})
3124+
.Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3125+
[&](auto op) {
3126+
// No-op, should be handled by relevant owning operations e.g.
3127+
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3128+
// discarded
3129+
return success();
3130+
})
30183131
.Default([&](Operation *inst) {
30193132
return inst->emitError("unsupported OpenMP operation: ")
30203133
<< inst->getName();

0 commit comments

Comments
 (0)