Skip to content

Commit 655dce5

Browse files
committed
[flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProcessor
Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work.
1 parent fafbd98 commit 655dce5

File tree

10 files changed

+348
-345
lines changed

10 files changed

+348
-345
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
430430
}
431431
}
432432

433+
struct ExtractSubstringHelper {
434+
template <typename T> static std::optional<Substring> visit(T &&) {
435+
return std::nullopt;
436+
}
437+
438+
static std::optional<Substring> visit(const Substring &e) { return e; }
439+
440+
template <typename T>
441+
static std::optional<Substring> visit(const Designator<T> &e) {
442+
return std::visit([](auto &&s) { return visit(s); }, e.u);
443+
}
444+
445+
template <typename T>
446+
static std::optional<Substring> visit(const Expr<T> &e) {
447+
return std::visit([](auto &&s) { return visit(s); }, e.u);
448+
}
449+
};
450+
451+
template <typename A>
452+
std::optional<Substring> ExtractSubstring(const A &x) {
453+
return ExtractSubstringHelper::visit(x);
454+
}
455+
433456
// If an expression is simply a whole symbol data designator,
434457
// extract and return that symbol, else null.
435458
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 96 additions & 122 deletions
Large diffs are not rendered by default.

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ class ClauseProcessor {
105105
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
106106
bool
107107
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
108-
bool
109-
processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
110-
mlir::Value &result) const;
108+
bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
109+
mlir::Value &result) const;
111110
bool
112111
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
113112

@@ -178,6 +177,10 @@ class ClauseProcessor {
178177
/// if at least one instance was found.
179178
template <typename T>
180179
bool findRepeatableClause(
180+
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
181+
callbackFn) const;
182+
template <typename T>
183+
bool findRepeatableClause2(
181184
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
182185
callbackFn) const;
183186

@@ -195,7 +198,7 @@ template <typename T>
195198
bool ClauseProcessor::processMotionClauses(
196199
Fortran::lower::StatementContext &stmtCtx,
197200
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
198-
return findRepeatableClause<T>(
201+
return findRepeatableClause2<T>(
199202
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
200203
mlir::Location clauseLocation = converter.genLocation(source);
201204
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -295,6 +298,24 @@ const T *ClauseProcessor::findUniqueClause(
295298

296299
template <typename T>
297300
bool ClauseProcessor::findRepeatableClause(
301+
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
302+
callbackFn) const {
303+
bool found = false;
304+
ClauseIterator nextIt, endIt = clauses.end();
305+
for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
306+
nextIt = findClause<T>(it, endIt);
307+
308+
if (nextIt != endIt) {
309+
callbackFn(std::get<T>(nextIt->u), nextIt->source);
310+
found = true;
311+
++nextIt;
312+
}
313+
}
314+
return found;
315+
}
316+
317+
template <typename T>
318+
bool ClauseProcessor::findRepeatableClause2(
298319
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
299320
callbackFn) const {
300321
bool found = false;

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,6 @@ namespace clause {
205205
#undef EMPTY_CLASS
206206
#undef WRAPPER_CLASS
207207

208-
using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
209-
using ProcedureDesignator =
210-
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
211-
using ReductionOperator =
212-
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
213-
214208
DefinedOperator makeDefOp(const parser::DefinedOperator &inp,
215209
semantics::SemanticsContext &semaCtx) {
216210
return DefinedOperator{

flang/lib/Lower/OpenMP/Clauses.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ getBaseObject(const Object &object,
106106
Fortran::semantics::SemanticsContext &semaCtx);
107107

108108
namespace clause {
109+
using DefinedOperator = tomp::clause::DefinedOperatorT<SymIdent, SymReference>;
110+
using ProcedureDesignator =
111+
tomp::clause::ProcedureDesignatorT<SymIdent, SymReference>;
112+
using ReductionOperator =
113+
tomp::clause::ReductionOperatorT<SymIdent, SymReference>;
114+
109115
#ifdef EMPTY_CLASS
110116
#undef EMPTY_CLASS
111117
#endif

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 84 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
572572
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
573573

574574
ClauseProcessor cp(converter, semaCtx, clauseList);
575-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
576-
ifClauseOperand);
575+
cp.processIf(clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
577576
cp.processNumThreads(stmtCtx, numThreadsClauseOperand);
578577
cp.processProcBind(procBindKindAttr);
579578
cp.processDefault();
@@ -676,8 +675,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
676675
dependOperands;
677676

678677
ClauseProcessor cp(converter, semaCtx, clauseList);
679-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
680-
ifClauseOperand);
678+
cp.processIf(clause::If::DirectiveNameModifier::Task, ifClauseOperand);
681679
cp.processAllocate(allocatorOperands, allocateOperands);
682680
cp.processDefault();
683681
cp.processFinal(stmtCtx, finalClauseOperand);
@@ -738,7 +736,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
738736
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
739737

740738
ClauseProcessor cp(converter, semaCtx, clauseList);
741-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
739+
cp.processIf(clause::If::DirectiveNameModifier::TargetData,
742740
ifClauseOperand);
743741
cp.processDevice(stmtCtx, deviceOperand);
744742
cp.processUseDevicePtr(devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -770,19 +768,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
770768
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
771769
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
772770

773-
Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName;
771+
clause::If::DirectiveNameModifier directiveName;
774772
llvm::omp::Directive directive;
775773
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
776-
directiveName =
777-
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
774+
directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
778775
directive = llvm::omp::Directive::OMPD_target_enter_data;
779776
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
780-
directiveName =
781-
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
777+
directiveName = clause::If::DirectiveNameModifier::TargetExitData;
782778
directive = llvm::omp::Directive::OMPD_target_exit_data;
783779
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
784-
directiveName =
785-
Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
780+
directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
786781
directive = llvm::omp::Directive::OMPD_target_update;
787782
} else {
788783
return nullptr;
@@ -984,8 +979,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
984979
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
985980

986981
ClauseProcessor cp(converter, semaCtx, clauseList);
987-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
988-
ifClauseOperand);
982+
cp.processIf(clause::If::DirectiveNameModifier::Target, ifClauseOperand);
989983
cp.processDevice(stmtCtx, deviceOperand);
990984
cp.processThreadLimit(stmtCtx, threadLimitOperand);
991985
cp.processDepend(dependTypeOperands, dependOperands);
@@ -1102,8 +1096,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
11021096
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
11031097

11041098
ClauseProcessor cp(converter, semaCtx, clauseList);
1105-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
1106-
ifClauseOperand);
1099+
cp.processIf(clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
11071100
cp.processAllocate(allocatorOperands, allocateOperands);
11081101
cp.processDefault();
11091102
cp.processNumTeams(stmtCtx, numTeamsClauseOperand);
@@ -1142,8 +1135,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
11421135

11431136
if (const auto *objectList{
11441137
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
1138+
ObjectList objects{makeList(*objectList, semaCtx)};
11451139
// Case: declare target(func, var1, var2)
1146-
gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
1140+
gatherFuncAndVarSyms(objects, mlir::omp::DeclareTargetCaptureClause::to,
11471141
symbolAndClause);
11481142
} else if (const auto *clauseList{
11491143
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1257,7 +1251,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
12571251
if (const auto &ompObjectList =
12581252
std::get<std::optional<Fortran::parser::OmpObjectList>>(
12591253
flushConstruct.t))
1260-
genObjectList(*ompObjectList, converter, operandRange);
1254+
genObjectList2(*ompObjectList, converter, operandRange);
12611255
const auto &memOrderClause =
12621256
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
12631257
flushConstruct.t);
@@ -1419,8 +1413,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
14191413
loopVarTypeSize);
14201414
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
14211415
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
1422-
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
1423-
ifClauseOperand);
1416+
cp.processIf(clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
14241417
cp.processSimdlen(simdlenClauseOperand);
14251418
cp.processSafelen(safelenClauseOperand);
14261419
cp.processTODO<Fortran::parser::OmpClause::Aligned,
@@ -2223,106 +2216,99 @@ void Fortran::lower::genOpenMPReduction(
22232216
const Fortran::parser::OmpClauseList &clauseList) {
22242217
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
22252218

2226-
for (const Fortran::parser::OmpClause &clause : clauseList.v) {
2219+
List<Clause> clauses{makeList(clauseList, semaCtx)};
2220+
2221+
for (const Clause &clause : clauses) {
22272222
if (const auto &reductionClause =
2228-
std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u)) {
2229-
const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
2230-
reductionClause->v.t)};
2231-
const auto &objectList{
2232-
std::get<Fortran::parser::OmpObjectList>(reductionClause->v.t)};
2223+
std::get_if<clause::Reduction>(&clause.u)) {
2224+
const auto &redOperator{
2225+
std::get<clause::ReductionOperator>(reductionClause->t)};
2226+
const auto &objects{std::get<ObjectList>(reductionClause->t)};
22332227
if (const auto *reductionOp =
2234-
std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
2228+
std::get_if<clause::DefinedOperator>(&redOperator.u)) {
22352229
const auto &intrinsicOp{
2236-
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
2230+
std::get<clause::DefinedOperator::IntrinsicOperator>(
22372231
reductionOp->u)};
22382232

22392233
switch (intrinsicOp) {
2240-
case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
2241-
case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
2242-
case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
2243-
case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
2244-
case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
2245-
case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
2234+
case clause::DefinedOperator::IntrinsicOperator::Add:
2235+
case clause::DefinedOperator::IntrinsicOperator::Multiply:
2236+
case clause::DefinedOperator::IntrinsicOperator::AND:
2237+
case clause::DefinedOperator::IntrinsicOperator::EQV:
2238+
case clause::DefinedOperator::IntrinsicOperator::OR:
2239+
case clause::DefinedOperator::IntrinsicOperator::NEQV:
22462240
break;
22472241
default:
22482242
continue;
22492243
}
2250-
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
2251-
if (const auto *name{
2252-
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2253-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
2254-
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2255-
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2256-
reductionVal = declOp.getBase();
2257-
mlir::Type reductionType =
2258-
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
2259-
if (!reductionType.isa<fir::LogicalType>()) {
2260-
if (!reductionType.isIntOrIndexOrFloat())
2261-
continue;
2262-
}
2263-
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
2264-
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2265-
reductionValUse.getOwner())) {
2266-
mlir::Value loadVal = loadOp.getRes();
2267-
if (reductionType.isa<fir::LogicalType>()) {
2268-
mlir::Operation *reductionOp = findReductionChain(loadVal);
2269-
fir::ConvertOp convertOp =
2270-
getConvertFromReductionOp(reductionOp, loadVal);
2271-
updateReduction(reductionOp, firOpBuilder, loadVal,
2272-
reductionVal, &convertOp);
2273-
removeStoreOp(reductionOp, reductionVal);
2274-
} else if (mlir::Operation *reductionOp =
2275-
findReductionChain(loadVal, &reductionVal)) {
2276-
updateReduction(reductionOp, firOpBuilder, loadVal,
2277-
reductionVal);
2278-
}
2244+
for (const Object &object : objects) {
2245+
if (const Fortran::semantics::Symbol *symbol = object.id()) {
2246+
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2247+
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2248+
reductionVal = declOp.getBase();
2249+
mlir::Type reductionType =
2250+
reductionVal.getType().cast<fir::ReferenceType>().getEleTy();
2251+
if (!reductionType.isa<fir::LogicalType>()) {
2252+
if (!reductionType.isIntOrIndexOrFloat())
2253+
continue;
2254+
}
2255+
for (mlir::OpOperand &reductionValUse : reductionVal.getUses()) {
2256+
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
2257+
mlir::Value loadVal = loadOp.getRes();
2258+
if (reductionType.isa<fir::LogicalType>()) {
2259+
mlir::Operation *reductionOp = findReductionChain(loadVal);
2260+
fir::ConvertOp convertOp =
2261+
getConvertFromReductionOp(reductionOp, loadVal);
2262+
updateReduction(reductionOp, firOpBuilder, loadVal,
2263+
reductionVal, &convertOp);
2264+
removeStoreOp(reductionOp, reductionVal);
2265+
} else if (mlir::Operation *reductionOp =
2266+
findReductionChain(loadVal, &reductionVal)) {
2267+
updateReduction(reductionOp, firOpBuilder, loadVal,
2268+
reductionVal);
22792269
}
22802270
}
22812271
}
22822272
}
22832273
}
22842274
} else if (const auto *reductionIntrinsic =
2285-
std::get_if<Fortran::parser::ProcedureDesignator>(
2275+
std::get_if<clause::ProcedureDesignator>(
22862276
&redOperator.u)) {
22872277
if (!ReductionProcessor::supportedIntrinsicProcReduction(
22882278
*reductionIntrinsic))
22892279
continue;
22902280
ReductionProcessor::ReductionIdentifier redId =
22912281
ReductionProcessor::getReductionType(*reductionIntrinsic);
2292-
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
2293-
if (const auto *name{
2294-
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2295-
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
2296-
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2297-
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2298-
reductionVal = declOp.getBase();
2299-
for (const mlir::OpOperand &reductionValUse :
2300-
reductionVal.getUses()) {
2301-
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2302-
reductionValUse.getOwner())) {
2303-
mlir::Value loadVal = loadOp.getRes();
2304-
// Max is lowered as a compare -> select.
2305-
// Match the pattern here.
2306-
mlir::Operation *reductionOp =
2307-
findReductionChain(loadVal, &reductionVal);
2308-
if (reductionOp == nullptr)
2309-
continue;
2310-
2311-
if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2312-
redId == ReductionProcessor::ReductionIdentifier::MIN) {
2313-
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2314-
"Selection Op not found in reduction intrinsic");
2315-
mlir::Operation *compareOp =
2316-
getCompareFromReductionOp(reductionOp, loadVal);
2317-
updateReduction(compareOp, firOpBuilder, loadVal,
2318-
reductionVal);
2319-
}
2320-
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2321-
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2322-
redId == ReductionProcessor::ReductionIdentifier::IAND) {
2323-
updateReduction(reductionOp, firOpBuilder, loadVal,
2324-
reductionVal);
2325-
}
2282+
for (const Object &object : objects) {
2283+
if (const Fortran::semantics::Symbol *symbol = object.id()) {
2284+
mlir::Value reductionVal = converter.getSymbolAddress(*symbol);
2285+
if (auto declOp = reductionVal.getDefiningOp<hlfir::DeclareOp>())
2286+
reductionVal = declOp.getBase();
2287+
for (const mlir::OpOperand &reductionValUse :
2288+
reductionVal.getUses()) {
2289+
if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner())) {
2290+
mlir::Value loadVal = loadOp.getRes();
2291+
// Max is lowered as a compare -> select.
2292+
// Match the pattern here.
2293+
mlir::Operation *reductionOp =
2294+
findReductionChain(loadVal, &reductionVal);
2295+
if (reductionOp == nullptr)
2296+
continue;
2297+
2298+
if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2299+
redId == ReductionProcessor::ReductionIdentifier::MIN) {
2300+
assert(mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2301+
"Selection Op not found in reduction intrinsic");
2302+
mlir::Operation *compareOp =
2303+
getCompareFromReductionOp(reductionOp, loadVal);
2304+
updateReduction(compareOp, firOpBuilder, loadVal,
2305+
reductionVal);
2306+
}
2307+
if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2308+
redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2309+
redId == ReductionProcessor::ReductionIdentifier::IAND) {
2310+
updateReduction(reductionOp, firOpBuilder, loadVal,
2311+
reductionVal);
23262312
}
23272313
}
23282314
}

0 commit comments

Comments
 (0)