@@ -1091,14 +1091,16 @@ static void genParallelClauses(
1091
1091
cp.processReduction (loc, clauseOps, &reductionTypes, &reductionSyms);
1092
1092
}
1093
1093
1094
- static void genSectionsClauses (lower::AbstractConverter &converter,
1095
- semantics::SemanticsContext &semaCtx,
1096
- const List<Clause> &clauses, mlir::Location loc,
1097
- mlir::omp::SectionsClauseOps &clauseOps) {
1094
+ static void genSectionsClauses (
1095
+ lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
1096
+ const List<Clause> &clauses, mlir::Location loc,
1097
+ mlir::omp::SectionsClauseOps &clauseOps,
1098
+ llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
1099
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
1098
1100
ClauseProcessor cp (converter, semaCtx, clauses);
1099
1101
cp.processAllocate (clauseOps);
1100
- cp.processSectionsReduction (loc, clauseOps);
1101
1102
cp.processNowait (clauseOps);
1103
+ cp.processReduction (loc, clauseOps, &reductionTypes, &reductionSyms);
1102
1104
// TODO Support delayed privatization.
1103
1105
}
1104
1106
@@ -1481,27 +1483,20 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1481
1483
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1482
1484
}
1483
1485
1484
- static mlir::omp::SectionOp
1485
- genSectionOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
1486
- semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
1487
- mlir::Location loc, const ConstructQueue &queue,
1488
- ConstructQueue::iterator item) {
1489
- // Currently only private/firstprivate clause is handled, and
1490
- // all privatization is done within `omp.section` operations.
1491
- return genOpWithBody<mlir::omp::SectionOp>(
1492
- OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
1493
- llvm::omp::Directive::OMPD_section)
1494
- .setClauses (&item->clauses ),
1495
- queue, item);
1496
- }
1497
-
1486
+ // / This breaks the normal prototype of the gen*Op functions: adding the
1487
+ // / sectionBlocks argument so that the enclosed section constructs can be
1488
+ // / lowered here with correct reduction symbol remapping.
1498
1489
static mlir::omp::SectionsOp
1499
1490
genSectionsOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
1500
1491
semantics::SemanticsContext &semaCtx,
1501
1492
lower::pft::Evaluation &eval, mlir::Location loc,
1502
- const ConstructQueue &queue, ConstructQueue::iterator item) {
1493
+ const ConstructQueue &queue, ConstructQueue::iterator item,
1494
+ const parser::OmpSectionBlocks §ionBlocks) {
1495
+ llvm::SmallVector<mlir::Type> reductionTypes;
1496
+ llvm::SmallVector<const semantics::Symbol *> reductionSyms;
1503
1497
mlir::omp::SectionsClauseOps clauseOps;
1504
- genSectionsClauses (converter, semaCtx, item->clauses , loc, clauseOps);
1498
+ genSectionsClauses (converter, semaCtx, item->clauses , loc, clauseOps,
1499
+ reductionTypes, reductionSyms);
1505
1500
1506
1501
auto &builder = converter.getFirOpBuilder ();
1507
1502
@@ -1530,11 +1525,46 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1530
1525
}
1531
1526
1532
1527
// SECTIONS construct.
1533
- mlir::omp::SectionsOp sectionsOp = genOpWithBody<mlir::omp::SectionsOp>(
1534
- OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
1535
- llvm::omp::Directive::OMPD_sections)
1536
- .setClauses (&nonDsaClauses),
1537
- queue, item, clauseOps);
1528
+ auto sectionsOp = builder.create <mlir::omp::SectionsOp>(loc, clauseOps);
1529
+
1530
+ // create entry block with reduction variables as arguments
1531
+ llvm::SmallVector<mlir::Location> blockArgLocs (reductionSyms.size (), loc);
1532
+ builder.createBlock (§ionsOp->getRegion (0 ), {}, reductionTypes,
1533
+ blockArgLocs);
1534
+ mlir::Operation *terminator =
1535
+ lower::genOpenMPTerminator (builder, sectionsOp, loc);
1536
+
1537
+ auto reductionCallback = [&](mlir::Operation *op) {
1538
+ genReductionVars (op, converter, loc, reductionSyms, reductionTypes);
1539
+ return reductionSyms;
1540
+ };
1541
+
1542
+ // Generate nested SECTION constructs.
1543
+ // This is done here rather than in genOMP([...], OpenMPSectionConstruct )
1544
+ // because we need to run genReductionVars on each omp.section so that the
1545
+ // reduction variable gets mapped to the private version
1546
+ for (auto [construct, nestedEval] :
1547
+ llvm::zip (sectionBlocks.v , eval.getNestedEvaluations ())) {
1548
+ const auto *sectionConstruct =
1549
+ std::get_if<parser::OpenMPSectionConstruct>(&construct.u );
1550
+ if (!sectionConstruct) {
1551
+ assert (false &&
1552
+ " unexpected construct nested inside of SECTIONS construct" );
1553
+ continue ;
1554
+ }
1555
+
1556
+ ConstructQueue sectionQueue{buildConstructQueue (
1557
+ converter.getFirOpBuilder ().getModule (), semaCtx, nestedEval,
1558
+ sectionConstruct->source , llvm::omp::Directive::OMPD_section, {})};
1559
+
1560
+ builder.setInsertionPoint (terminator);
1561
+ genOpWithBody<mlir::omp::SectionOp>(
1562
+ OpWithBodyGenInfo (converter, symTable, semaCtx, loc, nestedEval,
1563
+ llvm::omp::Directive::OMPD_section)
1564
+ .setClauses (§ionQueue.begin ()->clauses )
1565
+ .setGenRegionEntryCb (reductionCallback),
1566
+ sectionQueue, sectionQueue.begin ());
1567
+ }
1538
1568
1539
1569
if (!lastprivates.empty ()) {
1540
1570
mlir::Region §ionsBody = sectionsOp.getRegion ();
@@ -2120,10 +2150,14 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
2120
2150
genStandaloneParallel (converter, symTable, semaCtx, eval, loc, queue, item);
2121
2151
break ;
2122
2152
case llvm::omp::Directive::OMPD_section:
2123
- genSectionOp (converter, symTable, semaCtx, eval, loc, queue, item);
2153
+ llvm_unreachable (" genOMPDispatch: OMPD_section" );
2154
+ // Lowered in the enclosing genSectionsOp.
2124
2155
break ;
2125
2156
case llvm::omp::Directive::OMPD_sections:
2126
- genSectionsOp (converter, symTable, semaCtx, eval, loc, queue, item);
2157
+ // Called directly from genOMP([...], OpenMPSectionsConstruct) because it
2158
+ // has a different prototype.
2159
+ // This code path is still taken when iterating through the construct queue
2160
+ // in genBodyOfOp
2127
2161
break ;
2128
2162
case llvm::omp::Directive::OMPD_simd:
2129
2163
genStandaloneSimd (converter, symTable, semaCtx, eval, loc, queue, item,
@@ -2536,11 +2570,7 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
2536
2570
semantics::SemanticsContext &semaCtx,
2537
2571
lower::pft::Evaluation &eval,
2538
2572
const parser::OpenMPSectionConstruct §ionConstruct) {
2539
- mlir::Location loc = converter.getCurrentLocation ();
2540
- ConstructQueue queue{buildConstructQueue (
2541
- converter.getFirOpBuilder ().getModule (), semaCtx, eval,
2542
- sectionConstruct.source , llvm::omp::Directive::OMPD_section, {})};
2543
- genOMPDispatch (converter, symTable, semaCtx, eval, loc, queue, queue.begin ());
2573
+ // Do nothing here. SECTION is lowered inside of the lowering for Sections
2544
2574
}
2545
2575
2546
2576
static void genOMP (lower::AbstractConverter &converter, lower::SymMap &symTable,
@@ -2553,6 +2583,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
2553
2583
std::get<parser::OmpClauseList>(beginSectionsDirective.t ), semaCtx);
2554
2584
const auto &endSectionsDirective =
2555
2585
std::get<parser::OmpEndSectionsDirective>(sectionsConstruct.t );
2586
+ const auto §ionBlocks =
2587
+ std::get<parser::OmpSectionBlocks>(sectionsConstruct.t );
2556
2588
clauses.append (makeClauses (
2557
2589
std::get<parser::OmpClauseList>(endSectionsDirective.t ), semaCtx));
2558
2590
mlir::Location currentLocation = converter.getCurrentLocation ();
@@ -2564,8 +2596,22 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
2564
2596
ConstructQueue queue{
2565
2597
buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2566
2598
eval, source, directive, clauses)};
2567
- genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
2568
- queue.begin ());
2599
+ ConstructQueue::iterator next = queue.begin ();
2600
+ // Generate constructs that come first e.g. Parallel
2601
+ while (next != queue.end () &&
2602
+ next->id != llvm::omp::Directive::OMPD_sections) {
2603
+ genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
2604
+ next);
2605
+ next = std::next (next);
2606
+ }
2607
+
2608
+ // call genSectionsOp directly (not via genOMPDispatch) so that we can add the
2609
+ // sectionBlocks argument
2610
+ assert (next != queue.end ());
2611
+ assert (next->id == llvm::omp::Directive::OMPD_sections);
2612
+ genSectionsOp (converter, symTable, semaCtx, eval, currentLocation, queue,
2613
+ next, sectionBlocks);
2614
+ assert (std::next (next) == queue.end ());
2569
2615
}
2570
2616
2571
2617
static void genOMP (lower::AbstractConverter &converter, lower::SymMap &symTable,
0 commit comments