@@ -572,8 +572,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
572
572
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
573
573
574
574
ClauseProcessor cp (converter, semaCtx, clauseList);
575
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
576
- ifClauseOperand);
575
+ cp.processIf (clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
577
576
cp.processNumThreads (stmtCtx, numThreadsClauseOperand);
578
577
cp.processProcBind (procBindKindAttr);
579
578
cp.processDefault ();
@@ -676,8 +675,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
676
675
dependOperands;
677
676
678
677
ClauseProcessor cp (converter, semaCtx, clauseList);
679
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
680
- ifClauseOperand);
678
+ cp.processIf (clause::If::DirectiveNameModifier::Task, ifClauseOperand);
681
679
cp.processAllocate (allocatorOperands, allocateOperands);
682
680
cp.processDefault ();
683
681
cp.processFinal (stmtCtx, finalClauseOperand);
@@ -738,7 +736,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
738
736
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
739
737
740
738
ClauseProcessor cp (converter, semaCtx, clauseList);
741
- cp.processIf (Fortran::parser::OmpIfClause ::DirectiveNameModifier::TargetData,
739
+ cp.processIf (clause::If ::DirectiveNameModifier::TargetData,
742
740
ifClauseOperand);
743
741
cp.processDevice (stmtCtx, deviceOperand);
744
742
cp.processUseDevicePtr (devicePtrOperands, useDeviceTypes, useDeviceLocs,
@@ -770,19 +768,16 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
770
768
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
771
769
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
772
770
773
- Fortran::parser::OmpIfClause ::DirectiveNameModifier directiveName;
771
+ clause::If ::DirectiveNameModifier directiveName;
774
772
llvm::omp::Directive directive;
775
773
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
776
- directiveName =
777
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
774
+ directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
778
775
directive = llvm::omp::Directive::OMPD_target_enter_data;
779
776
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
780
- directiveName =
781
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
777
+ directiveName = clause::If::DirectiveNameModifier::TargetExitData;
782
778
directive = llvm::omp::Directive::OMPD_target_exit_data;
783
779
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
784
- directiveName =
785
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
780
+ directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
786
781
directive = llvm::omp::Directive::OMPD_target_update;
787
782
} else {
788
783
return nullptr ;
@@ -984,8 +979,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
984
979
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
985
980
986
981
ClauseProcessor cp (converter, semaCtx, clauseList);
987
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
988
- ifClauseOperand);
982
+ cp.processIf (clause::If::DirectiveNameModifier::Target, ifClauseOperand);
989
983
cp.processDevice (stmtCtx, deviceOperand);
990
984
cp.processThreadLimit (stmtCtx, threadLimitOperand);
991
985
cp.processDepend (dependTypeOperands, dependOperands);
@@ -1102,8 +1096,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
1102
1096
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
1103
1097
1104
1098
ClauseProcessor cp (converter, semaCtx, clauseList);
1105
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
1106
- ifClauseOperand);
1099
+ cp.processIf (clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
1107
1100
cp.processAllocate (allocatorOperands, allocateOperands);
1108
1101
cp.processDefault ();
1109
1102
cp.processNumTeams (stmtCtx, numTeamsClauseOperand);
@@ -1142,8 +1135,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
1142
1135
1143
1136
if (const auto *objectList{
1144
1137
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u )}) {
1138
+ ObjectList objects{makeList (*objectList, semaCtx)};
1145
1139
// Case: declare target(func, var1, var2)
1146
- gatherFuncAndVarSyms (*objectList , mlir::omp::DeclareTargetCaptureClause::to,
1140
+ gatherFuncAndVarSyms (objects , mlir::omp::DeclareTargetCaptureClause::to,
1147
1141
symbolAndClause);
1148
1142
} else if (const auto *clauseList{
1149
1143
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1257,7 +1251,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
1257
1251
if (const auto &ompObjectList =
1258
1252
std::get<std::optional<Fortran::parser::OmpObjectList>>(
1259
1253
flushConstruct.t ))
1260
- genObjectList (*ompObjectList, converter, operandRange);
1254
+ genObjectList2 (*ompObjectList, converter, operandRange);
1261
1255
const auto &memOrderClause =
1262
1256
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
1263
1257
flushConstruct.t );
@@ -1419,8 +1413,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
1419
1413
loopVarTypeSize);
1420
1414
cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
1421
1415
cp.processReduction (loc, reductionVars, reductionDeclSymbols);
1422
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
1423
- ifClauseOperand);
1416
+ cp.processIf (clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
1424
1417
cp.processSimdlen (simdlenClauseOperand);
1425
1418
cp.processSafelen (safelenClauseOperand);
1426
1419
cp.processTODO <Fortran::parser::OmpClause::Aligned,
@@ -2223,106 +2216,99 @@ void Fortran::lower::genOpenMPReduction(
2223
2216
const Fortran::parser::OmpClauseList &clauseList) {
2224
2217
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2225
2218
2226
- for (const Fortran::parser::OmpClause &clause : clauseList.v ) {
2219
+ List<Clause> clauses{makeList (clauseList, semaCtx)};
2220
+
2221
+ for (const Clause &clause : clauses) {
2227
2222
if (const auto &reductionClause =
2228
- std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u )) {
2229
- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
2230
- reductionClause->v .t )};
2231
- const auto &objectList{
2232
- std::get<Fortran::parser::OmpObjectList>(reductionClause->v .t )};
2223
+ std::get_if<clause::Reduction>(&clause.u )) {
2224
+ const auto &redOperator{
2225
+ std::get<clause::ReductionOperator>(reductionClause->t )};
2226
+ const auto &objects{std::get<ObjectList>(reductionClause->t )};
2233
2227
if (const auto *reductionOp =
2234
- std::get_if<Fortran::parser ::DefinedOperator>(&redOperator.u )) {
2228
+ std::get_if<clause ::DefinedOperator>(&redOperator.u )) {
2235
2229
const auto &intrinsicOp{
2236
- std::get<Fortran::parser ::DefinedOperator::IntrinsicOperator>(
2230
+ std::get<clause ::DefinedOperator::IntrinsicOperator>(
2237
2231
reductionOp->u )};
2238
2232
2239
2233
switch (intrinsicOp) {
2240
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Add:
2241
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Multiply:
2242
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::AND:
2243
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::EQV:
2244
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::OR:
2245
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::NEQV:
2234
+ case clause ::DefinedOperator::IntrinsicOperator::Add:
2235
+ case clause ::DefinedOperator::IntrinsicOperator::Multiply:
2236
+ case clause ::DefinedOperator::IntrinsicOperator::AND:
2237
+ case clause ::DefinedOperator::IntrinsicOperator::EQV:
2238
+ case clause ::DefinedOperator::IntrinsicOperator::OR:
2239
+ case clause ::DefinedOperator::IntrinsicOperator::NEQV:
2246
2240
break ;
2247
2241
default :
2248
2242
continue ;
2249
2243
}
2250
- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2251
- if (const auto *name{
2252
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2253
- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2254
- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2255
- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2256
- reductionVal = declOp.getBase ();
2257
- mlir::Type reductionType =
2258
- reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2259
- if (!reductionType.isa <fir::LogicalType>()) {
2260
- if (!reductionType.isIntOrIndexOrFloat ())
2261
- continue ;
2262
- }
2263
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2264
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2265
- reductionValUse.getOwner ())) {
2266
- mlir::Value loadVal = loadOp.getRes ();
2267
- if (reductionType.isa <fir::LogicalType>()) {
2268
- mlir::Operation *reductionOp = findReductionChain (loadVal);
2269
- fir::ConvertOp convertOp =
2270
- getConvertFromReductionOp (reductionOp, loadVal);
2271
- updateReduction (reductionOp, firOpBuilder, loadVal,
2272
- reductionVal, &convertOp);
2273
- removeStoreOp (reductionOp, reductionVal);
2274
- } else if (mlir::Operation *reductionOp =
2275
- findReductionChain (loadVal, &reductionVal)) {
2276
- updateReduction (reductionOp, firOpBuilder, loadVal,
2277
- reductionVal);
2278
- }
2244
+ for (const Object &object : objects) {
2245
+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2246
+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2247
+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2248
+ reductionVal = declOp.getBase ();
2249
+ mlir::Type reductionType =
2250
+ reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2251
+ if (!reductionType.isa <fir::LogicalType>()) {
2252
+ if (!reductionType.isIntOrIndexOrFloat ())
2253
+ continue ;
2254
+ }
2255
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2256
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2257
+ mlir::Value loadVal = loadOp.getRes ();
2258
+ if (reductionType.isa <fir::LogicalType>()) {
2259
+ mlir::Operation *reductionOp = findReductionChain (loadVal);
2260
+ fir::ConvertOp convertOp =
2261
+ getConvertFromReductionOp (reductionOp, loadVal);
2262
+ updateReduction (reductionOp, firOpBuilder, loadVal,
2263
+ reductionVal, &convertOp);
2264
+ removeStoreOp (reductionOp, reductionVal);
2265
+ } else if (mlir::Operation *reductionOp =
2266
+ findReductionChain (loadVal, &reductionVal)) {
2267
+ updateReduction (reductionOp, firOpBuilder, loadVal,
2268
+ reductionVal);
2279
2269
}
2280
2270
}
2281
2271
}
2282
2272
}
2283
2273
}
2284
2274
} else if (const auto *reductionIntrinsic =
2285
- std::get_if<Fortran::parser ::ProcedureDesignator>(
2275
+ std::get_if<clause ::ProcedureDesignator>(
2286
2276
&redOperator.u )) {
2287
2277
if (!ReductionProcessor::supportedIntrinsicProcReduction (
2288
2278
*reductionIntrinsic))
2289
2279
continue ;
2290
2280
ReductionProcessor::ReductionIdentifier redId =
2291
2281
ReductionProcessor::getReductionType (*reductionIntrinsic);
2292
- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2293
- if (const auto *name{
2294
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2295
- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2296
- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2297
- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2298
- reductionVal = declOp.getBase ();
2299
- for (const mlir::OpOperand &reductionValUse :
2300
- reductionVal.getUses ()) {
2301
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2302
- reductionValUse.getOwner ())) {
2303
- mlir::Value loadVal = loadOp.getRes ();
2304
- // Max is lowered as a compare -> select.
2305
- // Match the pattern here.
2306
- mlir::Operation *reductionOp =
2307
- findReductionChain (loadVal, &reductionVal);
2308
- if (reductionOp == nullptr )
2309
- continue ;
2310
-
2311
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2312
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
2313
- assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2314
- " Selection Op not found in reduction intrinsic" );
2315
- mlir::Operation *compareOp =
2316
- getCompareFromReductionOp (reductionOp, loadVal);
2317
- updateReduction (compareOp, firOpBuilder, loadVal,
2318
- reductionVal);
2319
- }
2320
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2321
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2322
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
2323
- updateReduction (reductionOp, firOpBuilder, loadVal,
2324
- reductionVal);
2325
- }
2282
+ for (const Object &object : objects) {
2283
+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2284
+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2285
+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2286
+ reductionVal = declOp.getBase ();
2287
+ for (const mlir::OpOperand &reductionValUse :
2288
+ reductionVal.getUses ()) {
2289
+ if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2290
+ mlir::Value loadVal = loadOp.getRes ();
2291
+ // Max is lowered as a compare -> select.
2292
+ // Match the pattern here.
2293
+ mlir::Operation *reductionOp =
2294
+ findReductionChain (loadVal, &reductionVal);
2295
+ if (reductionOp == nullptr )
2296
+ continue ;
2297
+
2298
+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2299
+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
2300
+ assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2301
+ " Selection Op not found in reduction intrinsic" );
2302
+ mlir::Operation *compareOp =
2303
+ getCompareFromReductionOp (reductionOp, loadVal);
2304
+ updateReduction (compareOp, firOpBuilder, loadVal,
2305
+ reductionVal);
2306
+ }
2307
+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2308
+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2309
+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
2310
+ updateReduction (reductionOp, firOpBuilder, loadVal,
2311
+ reductionVal);
2326
2312
}
2327
2313
}
2328
2314
}
0 commit comments