Skip to content

Commit 16e86b9

Browse files
committed
[MLIR][OpenMP] Support basic materialization for omp.private ops
Adds basic support for materializing delayed privatization. So far, the restrictions on the implementation are: - Only `private` clauses are supported (`firstprivate` support will be added in a later PR). - Only single-block `omp.private -> alloc` regions are supported (multi-block ones will be supported in a later PR).
1 parent d9f9775 commit 16e86b9

File tree

2 files changed

+227
-19
lines changed

2 files changed

+227
-19
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 136 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,11 +1000,39 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
10001000
return success();
10011001
}
10021002

1003+
/// Replace the region arguments of the parallel op (which correspond to private
1004+
/// variables) with the actual private varibles they correspond to. This
1005+
/// prepares the parallel op so that it matches what is expected by the
1006+
/// OMPIRBuilder. Instead of editing the original op in-place, this function
1007+
/// does the required changes to a cloned version which should then be erased by
1008+
/// the caller.
1009+
static omp::ParallelOp
1010+
prepareOmpParallelForPrivatization(omp::ParallelOp opInst) {
1011+
mlir::OpBuilder cloneBuilder(opInst);
1012+
omp::ParallelOp opInstClone =
1013+
llvm::cast<omp::ParallelOp>(cloneBuilder.clone(*opInst));
1014+
1015+
Region &region = opInstClone.getRegion();
1016+
auto privateVars = opInstClone.getPrivateVars();
1017+
1018+
auto privateVarsIt = privateVars.begin();
1019+
// Reduction precede private arguments, so skip them first.
1020+
unsigned privateArgBeginIdx = opInstClone.getNumReductionVars();
1021+
unsigned privateArgEndIdx = privateArgBeginIdx + privateVars.size();
1022+
for (size_t argIdx = privateArgBeginIdx; argIdx < privateArgEndIdx;
1023+
++argIdx, ++privateVarsIt)
1024+
replaceAllUsesInRegionWith(region.getArgument(argIdx), *privateVarsIt,
1025+
region);
1026+
return opInstClone;
1027+
}
1028+
10031029
/// Converts the OpenMP parallel operation to LLVM IR.
10041030
static LogicalResult
10051031
convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10061032
LLVM::ModuleTranslation &moduleTranslation) {
10071033
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1034+
omp::ParallelOp opInstClone = prepareOmpParallelForPrivatization(opInst);
1035+
10081036
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
10091037
// relying on captured variables.
10101038
LogicalResult bodyGenStatus = success();
@@ -1013,12 +1041,12 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10131041
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
10141042
// Collect reduction declarations
10151043
SmallVector<omp::ReductionDeclareOp> reductionDecls;
1016-
collectReductionDecls(opInst, reductionDecls);
1044+
collectReductionDecls(opInstClone, reductionDecls);
10171045

10181046
// Allocate reduction vars
10191047
SmallVector<llvm::Value *> privateReductionVariables;
10201048
DenseMap<Value, llvm::Value *> reductionVariableMap;
1021-
allocReductionVars(opInst, builder, moduleTranslation, allocaIP,
1049+
allocReductionVars(opInstClone, builder, moduleTranslation, allocaIP,
10221050
reductionDecls, privateReductionVariables,
10231051
reductionVariableMap);
10241052

@@ -1030,7 +1058,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10301058

10311059
// Initialize reduction vars
10321060
builder.restoreIP(allocaIP);
1033-
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
1061+
for (unsigned i = 0; i < opInstClone.getNumReductionVars(); ++i) {
10341062
SmallVector<llvm::Value *> phis;
10351063
if (failed(inlineConvertOmpRegions(
10361064
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
@@ -1051,18 +1079,19 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10511079
// ParallelOp has only one region associated with it.
10521080
builder.restoreIP(codeGenIP);
10531081
auto regionBlock =
1054-
convertOmpOpRegions(opInst.getRegion(), "omp.par.region", builder,
1082+
convertOmpOpRegions(opInstClone.getRegion(), "omp.par.region", builder,
10551083
moduleTranslation, bodyGenStatus);
10561084

10571085
// Process the reductions if required.
1058-
if (opInst.getNumReductionVars() > 0) {
1086+
if (opInstClone.getNumReductionVars() > 0) {
10591087
// Collect reduction info
10601088
SmallVector<OwningReductionGen> owningReductionGens;
10611089
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
10621090
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1063-
collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
1064-
owningReductionGens, owningAtomicReductionGens,
1065-
privateReductionVariables, reductionInfos);
1091+
collectReductionInfo(opInstClone, builder, moduleTranslation,
1092+
reductionDecls, owningReductionGens,
1093+
owningAtomicReductionGens, privateReductionVariables,
1094+
reductionInfos);
10661095

10671096
// Move to region cont block
10681097
builder.SetInsertPoint(regionBlock->getTerminator());
@@ -1075,7 +1104,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10751104
ompBuilder->createReductions(builder.saveIP(), allocaIP,
10761105
reductionInfos, false);
10771106
if (!contInsertPoint.getBlock()) {
1078-
bodyGenStatus = opInst->emitOpError() << "failed to convert reductions";
1107+
bodyGenStatus = opInstClone->emitOpError()
1108+
<< "failed to convert reductions";
10791109
return;
10801110
}
10811111

@@ -1086,12 +1116,97 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
10861116

10871117
// TODO: Perform appropriate actions according to the data-sharing
10881118
// attribute (shared, private, firstprivate, ...) of variables.
1089-
// Currently defaults to shared.
1119+
// Currently shared and private are supported.
10901120
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
10911121
llvm::Value &, llvm::Value &vPtr,
10921122
llvm::Value *&replacementValue) -> InsertPointTy {
10931123
replacementValue = &vPtr;
10941124

1125+
// If this is a private value, this lambda will return the corresponding
1126+
// mlir value and its `PrivateClauseOp`. Otherwise, empty values are
1127+
// returned.
1128+
auto [privVar, privatizerClone] =
1129+
[&]() -> std::pair<mlir::Value, omp::PrivateClauseOp> {
1130+
if (!opInstClone.getPrivateVars().empty()) {
1131+
auto privVars = opInstClone.getPrivateVars();
1132+
auto privatizers = opInstClone.getPrivatizers();
1133+
1134+
for (auto [privVar, privatizerAttr] :
1135+
llvm::zip_equal(privVars, *privatizers)) {
1136+
// Find the MLIR private variable corresponding to the LLVM value
1137+
// being privatized.
1138+
llvm::Value *llvmPrivVar = moduleTranslation.lookupValue(privVar);
1139+
if (llvmPrivVar != &vPtr)
1140+
continue;
1141+
1142+
SymbolRefAttr privSym = llvm::cast<SymbolRefAttr>(privatizerAttr);
1143+
omp::PrivateClauseOp privatizer =
1144+
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
1145+
opInstClone, privSym);
1146+
1147+
// Clone the privatizer in case it used by more than one parallel
1148+
// region. The privatizer is processed in-place (see below) before it
1149+
// gets inlined in the parallel region and therefore processing the
1150+
// original op is dangerous.
1151+
return {privVar, privatizer.clone()};
1152+
}
1153+
}
1154+
1155+
return {mlir::Value(), omp::PrivateClauseOp()};
1156+
}();
1157+
1158+
if (privVar) {
1159+
if (privatizerClone.getDataSharingType() ==
1160+
omp::DataSharingClauseType::FirstPrivate) {
1161+
privatizerClone.emitOpError(
1162+
"TODO: delayed privatization is not "
1163+
"supported for `firstprivate` clauses yet.");
1164+
bodyGenStatus = failure();
1165+
return codeGenIP;
1166+
}
1167+
1168+
Region &allocRegion = privatizerClone.getAllocRegion();
1169+
1170+
if (!allocRegion.hasOneBlock()) {
1171+
privatizerClone.emitOpError(
1172+
"TODO: multi-block alloc regions are not supported yet.");
1173+
bodyGenStatus = failure();
1174+
return codeGenIP;
1175+
}
1176+
1177+
// Replace the privatizer block argument with mlir value being privatized.
1178+
// This way, the body of the privatizer will be changed from using the
1179+
// region/block argument to the value being privatized.
1180+
auto allocRegionArg = allocRegion.getArgument(0);
1181+
replaceAllUsesInRegionWith(allocRegionArg, privVar, allocRegion);
1182+
1183+
auto oldIP = builder.saveIP();
1184+
builder.restoreIP(allocaIP);
1185+
1186+
// Temporarily unlink the terminator from its parent since
1187+
// `inlineConvertOmpRegions` expects the insertion block to **not**
1188+
// contain a terminator.
1189+
llvm::Instruction &allocaTerminator = builder.GetInsertBlock()->back();
1190+
assert(allocaTerminator.isTerminator());
1191+
allocaTerminator.removeFromParent();
1192+
1193+
SmallVector<llvm::Value *, 1> yieldedValues;
1194+
if (failed(inlineConvertOmpRegions(allocRegion, "omp.privatizer", builder,
1195+
moduleTranslation, &yieldedValues))) {
1196+
opInstClone.emitError(
1197+
"failed to inline `alloc` region of an `omp.private` "
1198+
"op in the parallel region");
1199+
bodyGenStatus = failure();
1200+
} else {
1201+
assert(yieldedValues.size() == 1);
1202+
replacementValue = yieldedValues.front();
1203+
}
1204+
1205+
allocaTerminator.insertAfter(&builder.GetInsertBlock()->back());
1206+
privatizerClone.erase();
1207+
builder.restoreIP(oldIP);
1208+
}
1209+
10951210
return codeGenIP;
10961211
};
10971212

@@ -1100,13 +1215,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11001215
auto finiCB = [&](InsertPointTy codeGenIP) {};
11011216

11021217
llvm::Value *ifCond = nullptr;
1103-
if (auto ifExprVar = opInst.getIfExprVar())
1218+
if (auto ifExprVar = opInstClone.getIfExprVar())
11041219
ifCond = moduleTranslation.lookupValue(ifExprVar);
11051220
llvm::Value *numThreads = nullptr;
1106-
if (auto numThreadsVar = opInst.getNumThreadsVar())
1221+
if (auto numThreadsVar = opInstClone.getNumThreadsVar())
11071222
numThreads = moduleTranslation.lookupValue(numThreadsVar);
11081223
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
1109-
if (auto bind = opInst.getProcBindVal())
1224+
if (auto bind = opInstClone.getProcBindVal())
11101225
pbKind = getProcBindKind(*bind);
11111226
// TODO: Is the Parallel construct cancellable?
11121227
bool isCancellable = false;
@@ -1119,6 +1234,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
11191234
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
11201235
ifCond, numThreads, pbKind, isCancellable));
11211236

1237+
opInstClone.erase();
11221238
return bodyGenStatus;
11231239
}
11241240

@@ -3009,12 +3125,13 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
30093125
.Case([&](omp::TargetOp) {
30103126
return convertOmpTarget(*op, builder, moduleTranslation);
30113127
})
3012-
.Case<omp::MapInfoOp, omp::DataBoundsOp>([&](auto op) {
3013-
// No-op, should be handled by relevant owning operations e.g.
3014-
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3015-
// discarded
3016-
return success();
3017-
})
3128+
.Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
3129+
[&](auto op) {
3130+
// No-op, should be handled by relevant owning operations e.g.
3131+
// TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
3132+
// discarded
3133+
return success();
3134+
})
30183135
.Default([&](Operation *inst) {
30193136
return inst->emitError("unsupported OpenMP operation: ")
30203137
<< inst->getName();
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Test code-gen for `omp.parallel` ops with delayed privatizers (i.e. using
2+
// `omp.private` ops).
3+
4+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
5+
6+
llvm.func @parallel_op_1_private(%arg0: !llvm.ptr) {
7+
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr) {
8+
%0 = llvm.load %arg2 : !llvm.ptr -> f32
9+
omp.terminator
10+
}
11+
llvm.return
12+
}
13+
14+
// CHECK-LABEL: @parallel_op_1_private
15+
// CHECK-SAME: (ptr %[[ORIG:.*]]) {
16+
// CHECK: %[[OMP_PAR_ARG:.*]] = alloca { ptr }, align 8
17+
// CHECK: %[[ORIG_GEP:.*]] = getelementptr { ptr }, ptr %[[OMP_PAR_ARG]], i32 0, i32 0
18+
// CHECK: store ptr %[[ORIG]], ptr %[[ORIG_GEP]], align 8
19+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @parallel_op_1_private..omp_par, ptr %[[OMP_PAR_ARG]])
20+
// CHECK: }
21+
22+
// CHECK-LABEL: void @parallel_op_1_private..omp_par
23+
// CHECK-SAME: (ptr noalias %{{.*}}, ptr noalias %{{.*}}, ptr %[[ARG:.*]])
24+
// CHECK: %[[ORIG_PTR_PTR:.*]] = getelementptr { ptr }, ptr %[[ARG]], i32 0, i32 0
25+
// CHECK: %[[ORIG_PTR:.*]] = load ptr, ptr %[[ORIG_PTR_PTR]], align 8
26+
27+
// Check that the privatizer alloc region was inlined properly.
28+
// CHECK: %[[PRIV_ALLOC:.*]] = alloca float, align 4
29+
// CHECK: %[[ORIG_VAL:.*]] = load float, ptr %[[ORIG_PTR]], align 4
30+
// CHECK: store float %[[ORIG_VAL]], ptr %[[PRIV_ALLOC]], align 4
31+
// CHECK-NEXT: br
32+
33+
// Check that the privatized value is used (rather than the original one).
34+
// CHECK: load float, ptr %[[PRIV_ALLOC]], align 4
35+
// CHECK: }
36+
37+
llvm.func @parallel_op_2_privates(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
38+
omp.parallel private(@x.privatizer %arg0 -> %arg2 : !llvm.ptr, @y.privatizer %arg1 -> %arg3 : !llvm.ptr) {
39+
%0 = llvm.load %arg2 : !llvm.ptr -> f32
40+
%1 = llvm.load %arg3 : !llvm.ptr -> i32
41+
omp.terminator
42+
}
43+
llvm.return
44+
}
45+
46+
// CHECK-LABEL: @parallel_op_2_privates
47+
// CHECK-SAME: (ptr %[[ORIG1:.*]], ptr %[[ORIG2:.*]]) {
48+
// CHECK: %[[OMP_PAR_ARG:.*]] = alloca { ptr, ptr }, align 8
49+
// CHECK: %[[ORIG1_GEP:.*]] = getelementptr { ptr, ptr }, ptr %[[OMP_PAR_ARG]], i32 0, i32 0
50+
// CHECK: store ptr %[[ORIG1]], ptr %[[ORIG1_GEP]], align 8
51+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @parallel_op_2_privates..omp_par, ptr %[[OMP_PAR_ARG]])
52+
// CHECK: }
53+
54+
// CHECK-LABEL: void @parallel_op_2_privates..omp_par
55+
// CHECK-SAME: (ptr noalias %{{.*}}, ptr noalias %{{.*}}, ptr %[[ARG:.*]])
56+
// CHECK: %[[ORIG1_PTR_PTR:.*]] = getelementptr { ptr, ptr }, ptr %[[ARG]], i32 0, i32 0
57+
// CHECK: %[[ORIG1_PTR:.*]] = load ptr, ptr %[[ORIG1_PTR_PTR]], align 8
58+
// CHECK: %[[ORIG2_PTR_PTR:.*]] = getelementptr { ptr, ptr }, ptr %[[ARG]], i32 0, i32 1
59+
// CHECK: %[[ORIG2_PTR:.*]] = load ptr, ptr %[[ORIG2_PTR_PTR]], align 8
60+
61+
// Check that the privatizer alloc region was inlined properly.
62+
// CHECK: %[[PRIV1_ALLOC:.*]] = alloca float, align 4
63+
// CHECK: %[[ORIG1_VAL:.*]] = load float, ptr %[[ORIG1_PTR]], align 4
64+
// CHECK: store float %[[ORIG1_VAL]], ptr %[[PRIV1_ALLOC]], align 4
65+
// CHECK: %[[PRIV2_ALLOC:.*]] = alloca i32, align 4
66+
// CHECK: %[[ORIG2_VAL:.*]] = load i32, ptr %[[ORIG2_PTR]], align 4
67+
// CHECK: store i32 %[[ORIG2_VAL]], ptr %[[PRIV2_ALLOC]], align 4
68+
// CHECK-NEXT: br
69+
70+
// Check that the privatized value is used (rather than the original one).
71+
// CHECK: load float, ptr %[[PRIV1_ALLOC]], align 4
72+
// CHECK: load i32, ptr %[[PRIV2_ALLOC]], align 4
73+
// CHECK: }
74+
75+
omp.private {type = private} @x.privatizer : !llvm.ptr alloc {
76+
^bb0(%arg0: !llvm.ptr):
77+
%c1 = llvm.mlir.constant(1 : i32) : i32
78+
%0 = llvm.alloca %c1 x f32 : (i32) -> !llvm.ptr
79+
%1 = llvm.load %arg0 : !llvm.ptr -> f32
80+
llvm.store %1, %0 : f32, !llvm.ptr
81+
omp.yield(%0 : !llvm.ptr)
82+
}
83+
84+
omp.private {type = private} @y.privatizer : !llvm.ptr alloc {
85+
^bb0(%arg0: !llvm.ptr):
86+
%c1 = llvm.mlir.constant(1 : i32) : i32
87+
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
88+
%1 = llvm.load %arg0 : !llvm.ptr -> i32
89+
llvm.store %1, %0 : i32, !llvm.ptr
90+
omp.yield(%0 : !llvm.ptr)
91+
}

0 commit comments

Comments
 (0)