Skip to content

Commit 64adf93

Browse files
committed
Address review comments
1 parent 9cf7001 commit 64adf93

File tree

1 file changed

+49
-37
lines changed

1 file changed

+49
-37
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,15 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174174
if (op.getHint())
175175
op.emitWarning("hint clause discarded");
176176
};
177+
auto checkHostEval = [](auto op, LogicalResult &result) {
178+
// Host evaluated clauses are supported, except for loop bounds.
179+
for (BlockArgument arg :
180+
cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
181+
for (Operation *user : arg.getUsers())
182+
if (isa<omp::LoopNestOp>(user))
183+
result = op.emitError("not yet implemented: host evaluation of loop "
184+
"bounds in omp.target operation");
185+
};
177186
auto checkIf = [&todo](auto op, LogicalResult &result) {
178187
if (op.getIfExpr())
179188
result = todo("if");
@@ -212,8 +221,24 @@ static LogicalResult checkImplementationStatus(Operation &op) {
212221
result = todo("priority");
213222
};
214223
auto checkPrivate = [&todo](auto op, LogicalResult &result) {
215-
if (!op.getPrivateVars().empty() || op.getPrivateSyms())
216-
result = todo("privatization");
224+
if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
225+
// Privatization clauses are supported, except on some situations, so we
226+
// need to check here whether any of these unsupported cases are being
227+
// translated.
228+
if (std::optional<ArrayAttr> privateSyms = op.getPrivateSyms()) {
229+
for (Attribute privatizerNameAttr : *privateSyms) {
230+
omp::PrivateClauseOp privatizer = findPrivatizer(
231+
op.getOperation(), cast<SymbolRefAttr>(privatizerNameAttr));
232+
233+
if (privatizer.getDataSharingType() ==
234+
omp::DataSharingClauseType::FirstPrivate)
235+
result = todo("firstprivate");
236+
}
237+
}
238+
} else {
239+
if (!op.getPrivateVars().empty() || op.getPrivateSyms())
240+
result = todo("privatization");
241+
}
217242
};
218243
auto checkReduction = [&todo](auto op, LogicalResult &result) {
219244
if (!op.getReductionVars().empty() || op.getReductionByref() ||
@@ -281,32 +306,11 @@ static LogicalResult checkImplementationStatus(Operation &op) {
281306
checkBare(op, result);
282307
checkDevice(op, result);
283308
checkHasDeviceAddr(op, result);
284-
285-
// Host evaluated clauses are supported, except for target SPMD loop
286-
// bounds.
287-
for (BlockArgument arg :
288-
cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
289-
for (Operation *user : arg.getUsers())
290-
if (isa<omp::LoopNestOp>(user))
291-
result = op.emitError("not yet implemented: host evaluation of "
292-
"loop bounds in omp.target operation");
293-
309+
checkHostEval(op, result);
294310
checkIf(op, result);
295311
checkInReduction(op, result);
296312
checkIsDevicePtr(op, result);
297-
// Privatization clauses are supported, except on some situations, so we
298-
// need to check here whether any of these unsupported cases are being
299-
// translated.
300-
if (std::optional<ArrayAttr> privateSyms = op.getPrivateSyms()) {
301-
for (Attribute privatizerNameAttr : *privateSyms) {
302-
omp::PrivateClauseOp privatizer = findPrivatizer(
303-
op.getOperation(), cast<SymbolRefAttr>(privatizerNameAttr));
304-
305-
if (privatizer.getDataSharingType() ==
306-
omp::DataSharingClauseType::FirstPrivate)
307-
result = todo("firstprivate");
308-
}
309-
}
313+
checkPrivate(op, result);
310314
})
311315
.Default([](Operation &) {
312316
// Assume all clauses for an operation can be translated unless they are
@@ -3923,7 +3927,11 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
39233927
llvm_unreachable("unsupported host_eval use");
39243928
})
39253929
.Case([&](omp::LoopNestOp loopOp) {
3926-
// TODO: Extract bounds and step values.
3930+
// TODO: Extract bounds and step values. Currently, this cannot be
3931+
// reached because translation would have been stopped earlier as a
3932+
// result of `checkImplementationStatus` detecting and reporting
3933+
// this situation.
3934+
llvm_unreachable("unsupported host_eval use");
39273935
})
39283936
.Default([](Operation *) {
39293937
llvm_unreachable("unsupported host_eval use");
@@ -3953,6 +3961,20 @@ static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
39533961
return op->getParentOfType<OpTy>();
39543962
}
39553963

3964+
/// If the given \p value is defined by an \c llvm.mlir.constant operation and
3965+
/// it is of an integer type, return its value.
3966+
static std::optional<int64_t> extractConstInteger(Value value) {
3967+
if (!value)
3968+
return std::nullopt;
3969+
3970+
if (auto constOp =
3971+
dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
3972+
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3973+
return constAttr.getInt();
3974+
3975+
return std::nullopt;
3976+
}
3977+
39563978
/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
39573979
/// values as stated by the corresponding clauses, if constant.
39583980
///
@@ -3984,15 +4006,6 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
39844006
numThreads = parallelOp.getNumThreads();
39854007
}
39864008

3987-
auto extractConstInteger = [](Value value) -> std::optional<int64_t> {
3988-
if (auto constOp =
3989-
dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
3990-
if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
3991-
return constAttr.getInt();
3992-
3993-
return std::nullopt;
3994-
};
3995-
39964009
// Handle clauses impacting the number of teams.
39974010

39984011
int32_t minTeamsVal = 1, maxTeamsVal = -1;
@@ -4016,8 +4029,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
40164029

40174030
// Handle clauses impacting the number of threads.
40184031

4019-
auto setMaxValueFromClause = [&extractConstInteger](Value clauseValue,
4020-
int32_t &result) {
4032+
auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
40214033
if (!clauseValue)
40224034
return;
40234035

0 commit comments

Comments
 (0)