@@ -519,6 +519,11 @@ struct OpWithBodyGenInfo {
519
519
return *this ;
520
520
}
521
521
522
+ OpWithBodyGenInfo &setGenSkeletonOnly (bool value) {
523
+ genSkeletonOnly = value;
524
+ return *this ;
525
+ }
526
+
522
527
// / [inout] converter to use for the clauses.
523
528
lower::AbstractConverter &converter;
524
529
// / [in] Symbol table
@@ -538,6 +543,9 @@ struct OpWithBodyGenInfo {
538
543
// / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
539
544
// / is created in the region.
540
545
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
546
+ // / [in] if set to `true`, skip generating nested evaluations and dispatching
547
+ // / any further leaf constructs.
548
+ bool genSkeletonOnly = false ;
541
549
};
542
550
543
551
// / Create the body (block) for an OpenMP Operation.
@@ -600,20 +608,22 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
600
608
}
601
609
}
602
610
603
- if (ConstructQueue::const_iterator next = std::next (item);
604
- next != queue.end ()) {
605
- genOMPDispatch (info.converter , info.symTable , info.semaCtx , info.eval ,
606
- info.loc , queue, next);
607
- } else {
608
- // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
609
- // a lot of complications for our approach if the terminator generation
610
- // is delayed past this point. Insert a temporary terminator here, then
611
- // delete it.
612
- firOpBuilder.setInsertionPointToEnd (&op.getRegion (0 ).back ());
613
- auto *temp = lower::genOpenMPTerminator (firOpBuilder, &op, info.loc );
614
- firOpBuilder.setInsertionPointAfter (marker);
615
- genNestedEvaluations (info.converter , info.eval );
616
- temp->erase ();
611
+ if (!info.genSkeletonOnly ) {
612
+ if (ConstructQueue::const_iterator next = std::next (item);
613
+ next != queue.end ()) {
614
+ genOMPDispatch (info.converter , info.symTable , info.semaCtx , info.eval ,
615
+ info.loc , queue, next);
616
+ } else {
617
+ // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
618
+ // a lot of complications for our approach if the terminator generation
619
+ // is delayed past this point. Insert a temporary terminator here, then
620
+ // delete it.
621
+ firOpBuilder.setInsertionPointToEnd (&op.getRegion (0 ).back ());
622
+ auto *temp = lower::genOpenMPTerminator (firOpBuilder, &op, info.loc );
623
+ firOpBuilder.setInsertionPointAfter (marker);
624
+ genNestedEvaluations (info.converter , info.eval );
625
+ temp->erase ();
626
+ }
617
627
}
618
628
619
629
// Get or create a unique exiting block from the given region, or
@@ -1445,7 +1455,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1445
1455
const ConstructQueue &queue, ConstructQueue::const_iterator item,
1446
1456
mlir::omp::ParallelOperands &clauseOps,
1447
1457
llvm::ArrayRef<const semantics::Symbol *> reductionSyms,
1448
- llvm::ArrayRef<mlir::Type> reductionTypes) {
1458
+ llvm::ArrayRef<mlir::Type> reductionTypes,
1459
+ DataSharingProcessor *dsp, bool isComposite = false ) {
1449
1460
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1450
1461
1451
1462
auto reductionCallback = [&](mlir::Operation *op) {
@@ -1457,17 +1468,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1457
1468
OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
1458
1469
llvm::omp::Directive::OMPD_parallel)
1459
1470
.setClauses (&item->clauses )
1460
- .setGenRegionEntryCb (reductionCallback);
1461
-
1462
- if (!enableDelayedPrivatization)
1463
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
1464
- clauseOps);
1465
-
1466
- DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1467
- lower::omp::isLastItemInQueue (item, queue),
1468
- /* useDelayedPrivatization=*/ true , &symTable);
1469
- dsp.processStep1 (&clauseOps);
1471
+ .setGenRegionEntryCb (reductionCallback)
1472
+ .setGenSkeletonOnly (isComposite);
1473
+
1474
+ if (!enableDelayedPrivatization) {
1475
+ auto parallelOp =
1476
+ genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1477
+ parallelOp.setComposite (isComposite);
1478
+ return parallelOp;
1479
+ }
1470
1480
1481
+ assert (dsp && " expected valid DataSharingProcessor" );
1471
1482
auto genRegionEntryCB = [&](mlir::Operation *op) {
1472
1483
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
1473
1484
@@ -1491,8 +1502,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1491
1502
allRegionArgLocs);
1492
1503
1493
1504
llvm::SmallVector<const semantics::Symbol *> allSymbols (reductionSyms);
1494
- allSymbols.append (dsp. getDelayedPrivSymbols ().begin (),
1495
- dsp. getDelayedPrivSymbols ().end ());
1505
+ allSymbols.append (dsp-> getDelayedPrivSymbols ().begin (),
1506
+ dsp-> getDelayedPrivSymbols ().end ());
1496
1507
1497
1508
unsigned argIdx = 0 ;
1498
1509
for (const semantics::Symbol *arg : allSymbols) {
@@ -1519,8 +1530,11 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1519
1530
return allSymbols;
1520
1531
};
1521
1532
1522
- genInfo.setGenRegionEntryCb (genRegionEntryCB).setDataSharingProcessor (&dsp);
1523
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1533
+ genInfo.setGenRegionEntryCb (genRegionEntryCB).setDataSharingProcessor (dsp);
1534
+ auto parallelOp =
1535
+ genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1536
+ parallelOp.setComposite (isComposite);
1537
+ return parallelOp;
1524
1538
}
1525
1539
1526
1540
// / This breaks the normal prototype of the gen*Op functions: adding the
@@ -2005,8 +2019,16 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
2005
2019
genParallelClauses (converter, semaCtx, stmtCtx, item->clauses , loc, clauseOps,
2006
2020
reductionTypes, reductionSyms);
2007
2021
2022
+ std::optional<DataSharingProcessor> dsp;
2023
+ if (enableDelayedPrivatization) {
2024
+ dsp.emplace (converter, semaCtx, item->clauses , eval,
2025
+ lower::omp::isLastItemInQueue (item, queue),
2026
+ /* useDelayedPrivatization=*/ true , &symTable);
2027
+ dsp->processStep1 (&clauseOps);
2028
+ }
2008
2029
genParallelOp (converter, symTable, semaCtx, eval, loc, queue, item, clauseOps,
2009
- reductionSyms, reductionTypes);
2030
+ reductionSyms, reductionTypes,
2031
+ enableDelayedPrivatization ? &dsp.value () : nullptr );
2010
2032
}
2011
2033
2012
2034
static void genStandaloneSimd (lower::AbstractConverter &converter,
@@ -2058,8 +2080,69 @@ static void genCompositeDistributeParallelDo(
2058
2080
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
2059
2081
mlir::Location loc, const ConstructQueue &queue,
2060
2082
ConstructQueue::const_iterator item) {
2083
+ lower::StatementContext stmtCtx;
2084
+
2061
2085
assert (std::distance (item, queue.end ()) == 3 && " Invalid leaf constructs" );
2062
- TODO (loc, " Composite DISTRIBUTE PARALLEL DO" );
2086
+ ConstructQueue::const_iterator distributeItem = item;
2087
+ ConstructQueue::const_iterator parallelItem = std::next (distributeItem);
2088
+ ConstructQueue::const_iterator doItem = std::next (parallelItem);
2089
+
2090
+ // Create parent omp.parallel first.
2091
+ mlir::omp::ParallelOperands parallelClauseOps;
2092
+ llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
2093
+ llvm::SmallVector<mlir::Type> parallelReductionTypes;
2094
+ genParallelClauses (converter, semaCtx, stmtCtx, parallelItem->clauses , loc,
2095
+ parallelClauseOps, parallelReductionTypes,
2096
+ parallelReductionSyms);
2097
+
2098
+ DataSharingProcessor dsp (converter, semaCtx, doItem->clauses , eval,
2099
+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2100
+ /* useDelayedPrivatization=*/ true , &symTable);
2101
+ dsp.processStep1 (¶llelClauseOps);
2102
+
2103
+ genParallelOp (converter, symTable, semaCtx, eval, loc, queue, parallelItem,
2104
+ parallelClauseOps, parallelReductionSyms,
2105
+ parallelReductionTypes, &dsp, /* isComposite=*/ true );
2106
+
2107
+ // Clause processing.
2108
+ mlir::omp::DistributeOperands distributeClauseOps;
2109
+ genDistributeClauses (converter, semaCtx, stmtCtx, distributeItem->clauses ,
2110
+ loc, distributeClauseOps);
2111
+
2112
+ mlir::omp::WsloopOperands wsloopClauseOps;
2113
+ llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
2114
+ llvm::SmallVector<mlir::Type> wsloopReductionTypes;
2115
+ genWsloopClauses (converter, semaCtx, stmtCtx, doItem->clauses , loc,
2116
+ wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
2117
+
2118
+ mlir::omp::LoopNestOperands loopNestClauseOps;
2119
+ llvm::SmallVector<const semantics::Symbol *> iv;
2120
+ genLoopNestClauses (converter, semaCtx, eval, doItem->clauses , loc,
2121
+ loopNestClauseOps, iv);
2122
+
2123
+ // Operation creation.
2124
+ // TODO: Populate entry block arguments with private variables.
2125
+ auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
2126
+ converter, loc, distributeClauseOps, /* blockArgTypes=*/ {});
2127
+ distributeOp.setComposite (/* val=*/ true );
2128
+
2129
+ // TODO: Add private variables to entry block arguments.
2130
+ auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
2131
+ converter, loc, wsloopClauseOps, wsloopReductionTypes);
2132
+ wsloopOp.setComposite (/* val=*/ true );
2133
+
2134
+ // Construct wrapper entry block list and associated symbols. It is important
2135
+ // that the symbol order and the block argument order match, so that the
2136
+ // symbol-value bindings created are correct.
2137
+ auto &wrapperSyms = wsloopReductionSyms;
2138
+
2139
+ auto wrapperArgs = llvm::to_vector (
2140
+ llvm::concat<mlir::BlockArgument>(distributeOp.getRegion ().getArguments (),
2141
+ wsloopOp.getRegion ().getArguments ()));
2142
+
2143
+ genLoopNestOp (converter, symTable, semaCtx, eval, loc, queue, doItem,
2144
+ loopNestClauseOps, iv, wrapperSyms, wrapperArgs,
2145
+ llvm::omp::Directive::OMPD_distribute_parallel_do, dsp);
2063
2146
}
2064
2147
2065
2148
static void genCompositeDistributeParallelDoSimd (
0 commit comments