@@ -574,8 +574,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
574
574
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
575
575
576
576
ClauseProcessor cp (converter, semaCtx, clauseList);
577
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
578
- ifClauseOperand);
577
+ cp.processIf (clause::If::DirectiveNameModifier::Parallel, ifClauseOperand);
579
578
cp.processNumThreads (stmtCtx, numThreadsClauseOperand);
580
579
cp.processProcBind (procBindKindAttr);
581
580
cp.processDefault ();
@@ -751,8 +750,7 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
751
750
dependOperands;
752
751
753
752
ClauseProcessor cp (converter, semaCtx, clauseList);
754
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Task,
755
- ifClauseOperand);
753
+ cp.processIf (clause::If::DirectiveNameModifier::Task, ifClauseOperand);
756
754
cp.processAllocate (allocatorOperands, allocateOperands);
757
755
cp.processDefault ();
758
756
cp.processFinal (stmtCtx, finalClauseOperand);
@@ -865,8 +863,7 @@ genDataOp(Fortran::lower::AbstractConverter &converter,
865
863
llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSymbols;
866
864
867
865
ClauseProcessor cp (converter, semaCtx, clauseList);
868
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetData,
869
- ifClauseOperand);
866
+ cp.processIf (clause::If::DirectiveNameModifier::TargetData, ifClauseOperand);
870
867
cp.processDevice (stmtCtx, deviceOperand);
871
868
cp.processUseDevicePtr (devicePtrOperands, useDeviceTypes, useDeviceLocs,
872
869
useDeviceSymbols);
@@ -911,20 +908,17 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
911
908
llvm::SmallVector<mlir::Value> mapOperands, dependOperands;
912
909
llvm::SmallVector<mlir::Attribute> dependTypeOperands;
913
910
914
- Fortran::parser::OmpIfClause ::DirectiveNameModifier directiveName;
911
+ clause::If ::DirectiveNameModifier directiveName;
915
912
// GCC 9.3.0 emits a (probably) bogus warning about an unused variable.
916
913
[[maybe_unused]] llvm::omp::Directive directive;
917
914
if constexpr (std::is_same_v<OpTy, mlir::omp::EnterDataOp>) {
918
- directiveName =
919
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetEnterData;
915
+ directiveName = clause::If::DirectiveNameModifier::TargetEnterData;
920
916
directive = llvm::omp::Directive::OMPD_target_enter_data;
921
917
} else if constexpr (std::is_same_v<OpTy, mlir::omp::ExitDataOp>) {
922
- directiveName =
923
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetExitData;
918
+ directiveName = clause::If::DirectiveNameModifier::TargetExitData;
924
919
directive = llvm::omp::Directive::OMPD_target_exit_data;
925
920
} else if constexpr (std::is_same_v<OpTy, mlir::omp::UpdateDataOp>) {
926
- directiveName =
927
- Fortran::parser::OmpIfClause::DirectiveNameModifier::TargetUpdate;
921
+ directiveName = clause::If::DirectiveNameModifier::TargetUpdate;
928
922
directive = llvm::omp::Directive::OMPD_target_update;
929
923
} else {
930
924
return nullptr ;
@@ -1126,8 +1120,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
1126
1120
llvm::SmallVector<const Fortran::semantics::Symbol *> mapSymbols;
1127
1121
1128
1122
ClauseProcessor cp (converter, semaCtx, clauseList);
1129
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Target,
1130
- ifClauseOperand);
1123
+ cp.processIf (clause::If::DirectiveNameModifier::Target, ifClauseOperand);
1131
1124
cp.processDevice (stmtCtx, deviceOperand);
1132
1125
cp.processThreadLimit (stmtCtx, threadLimitOperand);
1133
1126
cp.processDepend (dependTypeOperands, dependOperands);
@@ -1258,8 +1251,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
1258
1251
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
1259
1252
1260
1253
ClauseProcessor cp (converter, semaCtx, clauseList);
1261
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Teams,
1262
- ifClauseOperand);
1254
+ cp.processIf (clause::If::DirectiveNameModifier::Teams, ifClauseOperand);
1263
1255
cp.processAllocate (allocatorOperands, allocateOperands);
1264
1256
cp.processDefault ();
1265
1257
cp.processNumTeams (stmtCtx, numTeamsClauseOperand);
@@ -1298,8 +1290,9 @@ static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
1298
1290
1299
1291
if (const auto *objectList{
1300
1292
Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u )}) {
1293
+ ObjectList objects{makeList (*objectList, semaCtx)};
1301
1294
// Case: declare target(func, var1, var2)
1302
- gatherFuncAndVarSyms (*objectList , mlir::omp::DeclareTargetCaptureClause::to,
1295
+ gatherFuncAndVarSyms (objects , mlir::omp::DeclareTargetCaptureClause::to,
1303
1296
symbolAndClause);
1304
1297
} else if (const auto *clauseList{
1305
1298
Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
@@ -1438,7 +1431,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
1438
1431
if (const auto &ompObjectList =
1439
1432
std::get<std::optional<Fortran::parser::OmpObjectList>>(
1440
1433
flushConstruct.t ))
1441
- genObjectList (*ompObjectList, converter, operandRange);
1434
+ genObjectList2 (*ompObjectList, converter, operandRange);
1442
1435
const auto &memOrderClause =
1443
1436
std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>(
1444
1437
flushConstruct.t );
@@ -1600,8 +1593,7 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
1600
1593
loopVarTypeSize);
1601
1594
cp.processScheduleChunk (stmtCtx, scheduleChunkClauseOperand);
1602
1595
cp.processReduction (loc, reductionVars, reductionDeclSymbols);
1603
- cp.processIf (Fortran::parser::OmpIfClause::DirectiveNameModifier::Simd,
1604
- ifClauseOperand);
1596
+ cp.processIf (clause::If::DirectiveNameModifier::Simd, ifClauseOperand);
1605
1597
cp.processSimdlen (simdlenClauseOperand);
1606
1598
cp.processSafelen (safelenClauseOperand);
1607
1599
cp.processTODO <Fortran::parser::OmpClause::Aligned,
@@ -2419,106 +2411,100 @@ void Fortran::lower::genOpenMPReduction(
2419
2411
const Fortran::parser::OmpClauseList &clauseList) {
2420
2412
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
2421
2413
2422
- for (const Fortran::parser::OmpClause &clause : clauseList.v ) {
2414
+ List<Clause> clauses{makeList (clauseList, semaCtx)};
2415
+
2416
+ for (const Clause &clause : clauses) {
2423
2417
if (const auto &reductionClause =
2424
- std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u )) {
2425
- const auto &redOperator{std::get<Fortran::parser::OmpReductionOperator>(
2426
- reductionClause->v .t )};
2427
- const auto &objectList{
2428
- std::get<Fortran::parser::OmpObjectList>(reductionClause->v .t )};
2418
+ std::get_if<clause::Reduction>(&clause.u )) {
2419
+ const auto &redOperator{
2420
+ std::get<clause::ReductionOperator>(reductionClause->t )};
2421
+ const auto &objects{std::get<ObjectList>(reductionClause->t )};
2429
2422
if (const auto *reductionOp =
2430
- std::get_if<Fortran::parser ::DefinedOperator>(&redOperator.u )) {
2423
+ std::get_if<clause ::DefinedOperator>(&redOperator.u )) {
2431
2424
const auto &intrinsicOp{
2432
- std::get<Fortran::parser ::DefinedOperator::IntrinsicOperator>(
2425
+ std::get<clause ::DefinedOperator::IntrinsicOperator>(
2433
2426
reductionOp->u )};
2434
2427
2435
2428
switch (intrinsicOp) {
2436
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Add:
2437
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::Multiply:
2438
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::AND:
2439
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::EQV:
2440
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::OR:
2441
- case Fortran::parser ::DefinedOperator::IntrinsicOperator::NEQV:
2429
+ case clause ::DefinedOperator::IntrinsicOperator::Add:
2430
+ case clause ::DefinedOperator::IntrinsicOperator::Multiply:
2431
+ case clause ::DefinedOperator::IntrinsicOperator::AND:
2432
+ case clause ::DefinedOperator::IntrinsicOperator::EQV:
2433
+ case clause ::DefinedOperator::IntrinsicOperator::OR:
2434
+ case clause ::DefinedOperator::IntrinsicOperator::NEQV:
2442
2435
break ;
2443
2436
default :
2444
2437
continue ;
2445
2438
}
2446
- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2447
- if (const auto *name{
2448
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2449
- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2450
- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2451
- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2452
- reductionVal = declOp.getBase ();
2453
- mlir::Type reductionType =
2454
- reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2455
- if (!reductionType.isa <fir::LogicalType>()) {
2456
- if (!reductionType.isIntOrIndexOrFloat ())
2457
- continue ;
2458
- }
2459
- for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2460
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2461
- reductionValUse.getOwner ())) {
2462
- mlir::Value loadVal = loadOp.getRes ();
2463
- if (reductionType.isa <fir::LogicalType>()) {
2464
- mlir::Operation *reductionOp = findReductionChain (loadVal);
2465
- fir::ConvertOp convertOp =
2466
- getConvertFromReductionOp (reductionOp, loadVal);
2467
- updateReduction (reductionOp, firOpBuilder, loadVal,
2468
- reductionVal, &convertOp);
2469
- removeStoreOp (reductionOp, reductionVal);
2470
- } else if (mlir::Operation *reductionOp =
2471
- findReductionChain (loadVal, &reductionVal)) {
2472
- updateReduction (reductionOp, firOpBuilder, loadVal,
2473
- reductionVal);
2474
- }
2439
+ for (const Object &object : objects) {
2440
+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2441
+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2442
+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2443
+ reductionVal = declOp.getBase ();
2444
+ mlir::Type reductionType =
2445
+ reductionVal.getType ().cast <fir::ReferenceType>().getEleTy ();
2446
+ if (!reductionType.isa <fir::LogicalType>()) {
2447
+ if (!reductionType.isIntOrIndexOrFloat ())
2448
+ continue ;
2449
+ }
2450
+ for (mlir::OpOperand &reductionValUse : reductionVal.getUses ()) {
2451
+ if (auto loadOp =
2452
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2453
+ mlir::Value loadVal = loadOp.getRes ();
2454
+ if (reductionType.isa <fir::LogicalType>()) {
2455
+ mlir::Operation *reductionOp = findReductionChain (loadVal);
2456
+ fir::ConvertOp convertOp =
2457
+ getConvertFromReductionOp (reductionOp, loadVal);
2458
+ updateReduction (reductionOp, firOpBuilder, loadVal,
2459
+ reductionVal, &convertOp);
2460
+ removeStoreOp (reductionOp, reductionVal);
2461
+ } else if (mlir::Operation *reductionOp =
2462
+ findReductionChain (loadVal, &reductionVal)) {
2463
+ updateReduction (reductionOp, firOpBuilder, loadVal,
2464
+ reductionVal);
2475
2465
}
2476
2466
}
2477
2467
}
2478
2468
}
2479
2469
}
2480
2470
} else if (const auto *reductionIntrinsic =
2481
- std::get_if<Fortran::parser::ProcedureDesignator>(
2482
- &redOperator.u )) {
2471
+ std::get_if<clause::ProcedureDesignator>(&redOperator.u )) {
2483
2472
if (!ReductionProcessor::supportedIntrinsicProcReduction (
2484
2473
*reductionIntrinsic))
2485
2474
continue ;
2486
2475
ReductionProcessor::ReductionIdentifier redId =
2487
2476
ReductionProcessor::getReductionType (*reductionIntrinsic);
2488
- for (const Fortran::parser::OmpObject &ompObject : objectList.v ) {
2489
- if (const auto *name{
2490
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
2491
- if (const Fortran::semantics::Symbol * symbol{name->symbol }) {
2492
- mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2493
- if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2494
- reductionVal = declOp.getBase ();
2495
- for (const mlir::OpOperand &reductionValUse :
2496
- reductionVal.getUses ()) {
2497
- if (auto loadOp = mlir::dyn_cast<fir::LoadOp>(
2498
- reductionValUse.getOwner ())) {
2499
- mlir::Value loadVal = loadOp.getRes ();
2500
- // Max is lowered as a compare -> select.
2501
- // Match the pattern here.
2502
- mlir::Operation *reductionOp =
2503
- findReductionChain (loadVal, &reductionVal);
2504
- if (reductionOp == nullptr )
2505
- continue ;
2506
-
2507
- if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2508
- redId == ReductionProcessor::ReductionIdentifier::MIN) {
2509
- assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2510
- " Selection Op not found in reduction intrinsic" );
2511
- mlir::Operation *compareOp =
2512
- getCompareFromReductionOp (reductionOp, loadVal);
2513
- updateReduction (compareOp, firOpBuilder, loadVal,
2514
- reductionVal);
2515
- }
2516
- if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2517
- redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2518
- redId == ReductionProcessor::ReductionIdentifier::IAND) {
2519
- updateReduction (reductionOp, firOpBuilder, loadVal,
2520
- reductionVal);
2521
- }
2477
+ for (const Object &object : objects) {
2478
+ if (const Fortran::semantics::Symbol *symbol = object.id ()) {
2479
+ mlir::Value reductionVal = converter.getSymbolAddress (*symbol);
2480
+ if (auto declOp = reductionVal.getDefiningOp <hlfir::DeclareOp>())
2481
+ reductionVal = declOp.getBase ();
2482
+ for (const mlir::OpOperand &reductionValUse :
2483
+ reductionVal.getUses ()) {
2484
+ if (auto loadOp =
2485
+ mlir::dyn_cast<fir::LoadOp>(reductionValUse.getOwner ())) {
2486
+ mlir::Value loadVal = loadOp.getRes ();
2487
+ // Max is lowered as a compare -> select.
2488
+ // Match the pattern here.
2489
+ mlir::Operation *reductionOp =
2490
+ findReductionChain (loadVal, &reductionVal);
2491
+ if (reductionOp == nullptr )
2492
+ continue ;
2493
+
2494
+ if (redId == ReductionProcessor::ReductionIdentifier::MAX ||
2495
+ redId == ReductionProcessor::ReductionIdentifier::MIN) {
2496
+ assert (mlir::isa<mlir::arith::SelectOp>(reductionOp) &&
2497
+ " Selection Op not found in reduction intrinsic" );
2498
+ mlir::Operation *compareOp =
2499
+ getCompareFromReductionOp (reductionOp, loadVal);
2500
+ updateReduction (compareOp, firOpBuilder, loadVal,
2501
+ reductionVal);
2502
+ }
2503
+ if (redId == ReductionProcessor::ReductionIdentifier::IOR ||
2504
+ redId == ReductionProcessor::ReductionIdentifier::IEOR ||
2505
+ redId == ReductionProcessor::ReductionIdentifier::IAND) {
2506
+ updateReduction (reductionOp, firOpBuilder, loadVal,
2507
+ reductionVal);
2522
2508
}
2523
2509
}
2524
2510
}
0 commit comments