Skip to content

Commit 63e70c0

Browse files
authored
[flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… (#81623)
…essor Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work. [Clause representation 3/6]
1 parent 03bad4b commit 63e70c0

File tree

10 files changed

+358
-366
lines changed

10 files changed

+358
-366
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,28 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
430430
}
431431
}
432432

433+
struct ExtractSubstringHelper {
434+
template <typename T> static std::optional<Substring> visit(T &&) {
435+
return std::nullopt;
436+
}
437+
438+
static std::optional<Substring> visit(const Substring &e) { return e; }
439+
440+
template <typename T>
441+
static std::optional<Substring> visit(const Designator<T> &e) {
442+
return std::visit([](auto &&s) { return visit(s); }, e.u);
443+
}
444+
445+
template <typename T>
446+
static std::optional<Substring> visit(const Expr<T> &e) {
447+
return std::visit([](auto &&s) { return visit(s); }, e.u);
448+
}
449+
};
450+
451+
template <typename A> std::optional<Substring> ExtractSubstring(const A &x) {
452+
return ExtractSubstringHelper::visit(x);
453+
}
454+
433455
// If an expression is simply a whole symbol data designator,
434456
// extract and return that symbol, else null.
435457
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 96 additions & 122 deletions
Large diffs are not rendered by default.

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ class ClauseProcessor {
105105
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
106106
bool
107107
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
108-
bool
109-
processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
110-
mlir::Value &result) const;
108+
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
109+
mlir::Value &result) const;
111110
bool
112111
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
113112

@@ -178,6 +177,10 @@ class ClauseProcessor {
178177
/// if at least one instance was found.
179178
template <typename T>
180179
bool findRepeatableClause(
180+
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
181+
callbackFn) const;
182+
template <typename T>
183+
bool findRepeatableClause2(
181184
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
182185
callbackFn) const;
183186

@@ -195,7 +198,7 @@ template <typename T>
195198
bool ClauseProcessor::processMotionClauses(
196199
Fortran::lower::StatementContext &stmtCtx,
197200
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
198-
return findRepeatableClause<T>(
201+
return findRepeatableClause2<T>(
199202
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
200203
mlir::Location clauseLocation = converter.genLocation(source);
201204
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause(
295298

296299
template <typename T>
297300
bool ClauseProcessor::findRepeatableClause(
301+
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
302+
callbackFn) const {
303+
bool found = false;
304+
ClauseIterator nextIt, endIt = clauses.end();
305+
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
306+
nextIt = findClause<T>(it, endIt);
307+
308+
if (nextIt != endIt) {
309+
callbackFn(std::get<T>(nextIt->u), nextIt->source);
310+
found = true;
311+
++nextIt;
312+
}
313+
}
314+
return found;
315+
}
316+
317+
template <typename T>
318+
bool ClauseProcessor::findRepeatableClause2(
298319
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
299320
callbackFn) const {
300321
bool found = false;

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,6 @@ namespace clause {
210210
#undef EMPTY_CLASS
211211
#undef WRAPPER_CLASS
212212

213-
using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
214-
using ProcedureDesignator =
215-
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
216-
using ReductionOperator =
217-
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
218-
219213
DefinedOperator makeDefinedOperator(const parser::DefinedOperator &inp,
220214
semantics::SemanticsContext &semaCtx) {
221215
return std::visit(

flang/lib/Lower/OpenMP/Clauses.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,12 @@ namespace clause {
124124
#undef EMPTY_CLASS
125125
#undef WRAPPER_CLASS
126126

127+
using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
128+
using ProcedureDesignator =
129+
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
130+
using ReductionOperator =
131+
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
132+
127133
// "Requires" clauses are handled early on, and the aggregated information
128134
// is stored in the Symbol details of modules, programs, and subprograms.
129135
// These clauses are still handled here to cover all alternatives in the

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 86 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
574574
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
575575

576576
ClauseProcessor cp(converter, semaCtx, clauseList);
577-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
578-
ifClauseOperand);
577+
cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
579578
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
580579
cp.processProcBind(procBindKindAttr);
581580
cp.processDefault();
@@ -751,8 +750,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
751750
dependOperands;
752751

753752
ClauseProcessor cp(converter, semaCtx, clauseList);
754-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
755-
ifClauseOperand);
753+
cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand);
756754
cp.processAllocate(allocatorOperands, allocateOperands);
757755
cp.processDefault();
758756
cp.processFinal(stmtCtx, finalClauseOperand);
@@ -865,8 +863,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
865863
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
866864

867865
ClauseProcessor cp(converter, semaCtx, clauseList);
868-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
869-
ifClauseOperand);
866+
cp.processIf(clause::If::DirectiveNameModifier::TargetData, ifClauseOperand);
870867
cp.processDevice(stmtCtx, deviceOperand);
871868
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
872869
useDeviceSymbols);
@@ -911,20 +908,17 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
911908
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
912909
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
913910

914-
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
911+
clause::If::DirectiveNameModifier directiveName;
915912
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
916913
[[maybe_unused]] llvm::omp::Directive directive;
917914
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
918-
directiveName =
919-
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
915+
directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
920916
directive = llvm::omp::Directive::OMPD_target_enter_data;
921917
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
922-
directiveName =
923-
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
918+
directiveName = clause::If::DirectiveNameModifier::TargetExitData;
924919
directive = llvm::omp::Directive::OMPD_target_exit_data;
925920
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
926-
directiveName =
927-
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
921+
directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
928922
directive = llvm::omp::Directive::OMPD_target_update;
929923
} else {
930924
return nullptr;
@@ -1126,8 +1120,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
11261120
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
11271121

11281122
ClauseProcessor cp(converter, semaCtx, clauseList);
1129-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
1130-
ifClauseOperand);
1123+
cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand);
11311124
cp.processDevice(stmtCtx, deviceOperand);
11321125
cp.processThreadLimit(stmtCtx, threadLimitOperand);
11331126
cp.processDepend(dependTypeOperands, dependOperands);
@@ -1258,8 +1251,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
12581251
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
12591252

12601253
ClauseProcessor cp(converter, semaCtx, clauseList);
1261-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
1262-
ifClauseOperand);
1254+
cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
12631255
cp.processAllocate(allocatorOperands, allocateOperands);
12641256
cp.processDefault();
12651257
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
@@ -1298,8 +1290,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
12981290

12991291
if (const auto *objectList{
13001292
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
1293+
ObjectList objects{makeList(*objectList, semaCtx)};
13011294
// Case: declare target(func, var1, var2)
1302-
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
1295+
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
13031296
symbolAndClause);
13041297
} else if (const auto *clauseList{
13051298
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1438,7 +1431,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
14381431
if (const auto &ompObjectList =
14391432
std::get<std::optional<Fortran::parser::OmpObjectList>>(
14401433
flushConstruct.t))
1441-
genObjectList(*ompObjectList, converter, operandRange);
1434+
genObjectList2(*ompObjectList, converter, operandRange);
14421435
const auto &memOrderClause =
14431436
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
14441437
flushConstruct.t);
@@ -1600,8 +1593,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
16001593
loopVarTypeSize);
16011594
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
16021595
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
1603-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
1604-
ifClauseOperand);
1596+
cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
16051597
cp.processSimdlen(simdlenClauseOperand);
16061598
cp.processSafelen(safelenClauseOperand);
16071599
cp.processTODO<Fortran::parser::OmpClause::Aligned,
@@ -2419,106 +2411,100 @@ void Fortran::lower::genOpenMPReduction(
24192411
const Fortran::parser::OmpClauseList &clauseList) {
24202412
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
24212413

2422-
for (const Fortran::parser::OmpClause &clause : clauseList.v) {
2414+
List<Clause> clauses{makeList(clauseList, semaCtx)};
2415+
2416+
for (const Clause &clause : clauses) {
24232417
if (const auto &reductionClause =
2424-
std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
2425-
const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
2426-
reductionClause->v.t)};
2427-
const auto &objectList{
2428-
std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
2418+
std::get_if<clause::Reduction>(&clause.u)) {
2419+
const auto &redOperator{
2420+
std::get<clause::ReductionOperator>(reductionClause->t)};
2421+
const auto &objects{std::get<ObjectList>(reductionClause->t)};
24292422
if (const auto *reductionOp =
2430-
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
2423+
std::get_if<clause::DefinedOperator>(&redOperator.u)) {
24312424
const auto &intrinsicOp{
2432-
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
2425+
std::get<clause::DefinedOperator::IntrinsicOperator>(
24332426
reductionOp->u)};
24342427

24352428
switch (intrinsicOp) {
2436-
case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
2437-
case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
2438-
case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
2439-
case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
2440-
case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
2441-
case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
2429+
case clause::DefinedOperator::IntrinsicOperator::Add:
2430+
case clause::DefinedOperator::IntrinsicOperator::Multiply:
2431+
case clause::DefinedOperator::IntrinsicOperator::AND:
2432+
case clause::DefinedOperator::IntrinsicOperator::EQV:
2433+
case clause::DefinedOperator::IntrinsicOperator::OR:
2434+
case clause::DefinedOperator::IntrinsicOperator::NEQV:
24422435
break;
24432436
default:
24442437
continue;
24452438
}
2446-
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
2447-
if (const auto *name{
2448-
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2449-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
2450-
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2451-
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2452-
reductionVal = declOp.getBase();
2453-
mlir::Type reductionType =
2454-
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
2455-
if (!reductionType.isa<fir::LogicalType>()) {
2456-
if (!reductionType.isIntOrIndexOrFloat())
2457-
continue;
2458-
}
2459-
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
2460-
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2461-
reductionValUse.getOwner())) {
2462-
mlir::Value loadVal = loadOp.getRes();
2463-
if (reductionType.isa<fir::LogicalType>()) {
2464-
mlir::Operation *reductionOp = findReductionChain(loadVal);
2465-
fir::ConvertOp convertOp =
2466-
getConvertFromReductionOp(reductionOp, loadVal);
2467-
updateReduction(reductionOp, firOpBuilder, loadVal,
2468-
reductionVal, &convertOp);
2469-
removeStoreOp(reductionOp, reductionVal);
2470-
} else if (mlir::Operation *reductionOp =
2471-
findReductionChain(loadVal, &reductionVal)) {
2472-
updateReduction(reductionOp, firOpBuilder, loadVal,
2473-
reductionVal);
2474-
}
2439+
for (const Object &object : objects) {
2440+
if (const Fortran::semantics::Symbol *symbol = object.id()) {
2441+
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2442+
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2443+
reductionVal = declOp.getBase();
2444+
mlir::Type reductionType =
2445+
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
2446+
if (!reductionType.isa<fir::LogicalType>()) {
2447+
if (!reductionType.isIntOrIndexOrFloat())
2448+
continue;
2449+
}
2450+
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
2451+
if (auto loadOp =
2452+
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
2453+
mlir::Value loadVal = loadOp.getRes();
2454+
if (reductionType.isa<fir::LogicalType>()) {
2455+
mlir::Operation *reductionOp = findReductionChain(loadVal);
2456+
fir::ConvertOp convertOp =
2457+
getConvertFromReductionOp(reductionOp, loadVal);
2458+
updateReduction(reductionOp, firOpBuilder, loadVal,
2459+
reductionVal, &convertOp);
2460+
removeStoreOp(reductionOp, reductionVal);
2461+
} else if (mlir::Operation *reductionOp =
2462+
findReductionChain(loadVal, &reductionVal)) {
2463+
updateReduction(reductionOp, firOpBuilder, loadVal,
2464+
reductionVal);
24752465
}
24762466
}
24772467
}
24782468
}
24792469
}
24802470
} else if (const auto *reductionIntrinsic =
2481-
std::get_if<Fortran::parser::ProcedureDesignator>(
2482-
&redOperator.u)) {
2471+
std::get_if<clause::ProcedureDesignator>(&redOperator.u)) {
24832472
if (!ReductionProcessor::supportedIntrinsicProcReduction(
24842473
*reductionIntrinsic))
24852474
continue;
24862475
ReductionProcessor::ReductionIdentifier redId =
24872476
ReductionProcessor::getReductionType(*reductionIntrinsic);
2488-
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
2489-
if (const auto *name{
2490-
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2491-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
2492-
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2493-
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2494-
reductionVal = declOp.getBase();
2495-
for (const mlir::OpOperand &reductionValUse :
2496-
reductionVal.getUses()) {
2497-
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2498-
reductionValUse.getOwner())) {
2499-
mlir::Value loadVal = loadOp.getRes();
2500-
// Max is lowered as a compare -> select.
2501-
// Match the pattern here.
2502-
mlir::Operation *reductionOp =
2503-
findReductionChain(loadVal, &reductionVal);
2504-
if (reductionOp == nullptr)
2505-
continue;
2506-
2507-
if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2508-
redId == ReductionProcessor::ReductionIdentifier::MIN) {
2509-
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2510-
"Selection Op not found in reduction intrinsic");
2511-
mlir::Operation *compareOp =
2512-
getCompareFromReductionOp(reductionOp, loadVal);
2513-
updateReduction(compareOp, firOpBuilder, loadVal,
2514-
reductionVal);
2515-
}
2516-
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2517-
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2518-
redId == ReductionProcessor::ReductionIdentifier::IAND) {
2519-
updateReduction(reductionOp, firOpBuilder, loadVal,
2520-
reductionVal);
2521-
}
2477+
for (const Object &object : objects) {
2478+
if (const Fortran::semantics::Symbol *symbol = object.id()) {
2479+
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2480+
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2481+
reductionVal = declOp.getBase();
2482+
for (const mlir::OpOperand &reductionValUse :
2483+
reductionVal.getUses()) {
2484+
if (auto loadOp =
2485+
mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
2486+
mlir::Value loadVal = loadOp.getRes();
2487+
// Max is lowered as a compare -> select.
2488+
// Match the pattern here.
2489+
mlir::Operation *reductionOp =
2490+
findReductionChain(loadVal, &reductionVal);
2491+
if (reductionOp == nullptr)
2492+
continue;
2493+
2494+
if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2495+
redId == ReductionProcessor::ReductionIdentifier::MIN) {
2496+
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2497+
"Selection Op not found in reduction intrinsic");
2498+
mlir::Operation *compareOp =
2499+
getCompareFromReductionOp(reductionOp, loadVal);
2500+
updateReduction(compareOp, firOpBuilder, loadVal,
2501+
reductionVal);
2502+
}
2503+
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2504+
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2505+
redId == ReductionProcessor::ReductionIdentifier::IAND) {
2506+
updateReduction(reductionOp, firOpBuilder, loadVal,
2507+
reductionVal);
25222508
}
25232509
}
25242510
}

0 commit comments

Comments
 (0)