@@ -519,6 +519,11 @@ struct OpWithBodyGenInfo {
519
519
return *this ;
520
520
}
521
521
522
+ OpWithBodyGenInfo &setGenSkeletonOnly (bool value) {
523
+ genSkeletonOnly = value;
524
+ return *this ;
525
+ }
526
+
522
527
// / [inout] converter to use for the clauses.
523
528
lower::AbstractConverter &converter;
524
529
// / [in] Symbol table
@@ -538,6 +543,9 @@ struct OpWithBodyGenInfo {
538
543
// / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
539
544
// / is created in the region.
540
545
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
546
+ // / [in] if set to `true`, skip generating nested evaluations and dispatching
547
+ // / any further leaf constructs.
548
+ bool genSkeletonOnly = false ;
541
549
};
542
550
543
551
// / Create the body (block) for an OpenMP Operation.
@@ -600,20 +608,22 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
600
608
}
601
609
}
602
610
603
- if (ConstructQueue::const_iterator next = std::next (item);
604
- next != queue.end ()) {
605
- genOMPDispatch (info.converter , info.symTable , info.semaCtx , info.eval ,
606
- info.loc , queue, next);
607
- } else {
608
- // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
609
- // a lot of complications for our approach if the terminator generation
610
- // is delayed past this point. Insert a temporary terminator here, then
611
- // delete it.
612
- firOpBuilder.setInsertionPointToEnd (&op.getRegion (0 ).back ());
613
- auto *temp = lower::genOpenMPTerminator (firOpBuilder, &op, info.loc );
614
- firOpBuilder.setInsertionPointAfter (marker);
615
- genNestedEvaluations (info.converter , info.eval );
616
- temp->erase ();
611
+ if (!info.genSkeletonOnly ) {
612
+ if (ConstructQueue::const_iterator next = std::next (item);
613
+ next != queue.end ()) {
614
+ genOMPDispatch (info.converter , info.symTable , info.semaCtx , info.eval ,
615
+ info.loc , queue, next);
616
+ } else {
617
+ // genFIR(Evaluation&) tries to patch up unterminated blocks, causing
618
+ // a lot of complications for our approach if the terminator generation
619
+ // is delayed past this point. Insert a temporary terminator here, then
620
+ // delete it.
621
+ firOpBuilder.setInsertionPointToEnd (&op.getRegion (0 ).back ());
622
+ auto *temp = lower::genOpenMPTerminator (firOpBuilder, &op, info.loc );
623
+ firOpBuilder.setInsertionPointAfter (marker);
624
+ genNestedEvaluations (info.converter , info.eval );
625
+ temp->erase ();
626
+ }
617
627
}
618
628
619
629
// Get or create a unique exiting block from the given region, or
@@ -1445,7 +1455,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1445
1455
const ConstructQueue &queue, ConstructQueue::const_iterator item,
1446
1456
mlir::omp::ParallelOperands &clauseOps,
1447
1457
llvm::ArrayRef<const semantics::Symbol *> reductionSyms,
1448
- llvm::ArrayRef<mlir::Type> reductionTypes) {
1458
+ llvm::ArrayRef<mlir::Type> reductionTypes,
1459
+ DataSharingProcessor *dsp, bool isComposite = false ) {
1449
1460
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1450
1461
1451
1462
auto reductionCallback = [&](mlir::Operation *op) {
@@ -1457,17 +1468,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1457
1468
OpWithBodyGenInfo (converter, symTable, semaCtx, loc, eval,
1458
1469
llvm::omp::Directive::OMPD_parallel)
1459
1470
.setClauses (&item->clauses )
1460
- .setGenRegionEntryCb (reductionCallback);
1461
-
1462
- if (!enableDelayedPrivatization)
1463
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
1464
- clauseOps);
1465
-
1466
- DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1467
- lower::omp::isLastItemInQueue (item, queue),
1468
- /* useDelayedPrivatization=*/ true , &symTable);
1469
- dsp.processStep1 (&clauseOps);
1471
+ .setGenRegionEntryCb (reductionCallback)
1472
+ .setGenSkeletonOnly (isComposite);
1473
+
1474
+ if (!enableDelayedPrivatization) {
1475
+ auto parallelOp =
1476
+ genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1477
+ parallelOp.setComposite (isComposite);
1478
+ return parallelOp;
1479
+ }
1470
1480
1481
+ assert (dsp && " expected valid DataSharingProcessor" );
1471
1482
auto genRegionEntryCB = [&](mlir::Operation *op) {
1472
1483
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
1473
1484
@@ -1491,8 +1502,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1491
1502
allRegionArgLocs);
1492
1503
1493
1504
llvm::SmallVector<const semantics::Symbol *> allSymbols (reductionSyms);
1494
- allSymbols.append (dsp. getAllSymbolsToPrivatize ().begin (),
1495
- dsp. getAllSymbolsToPrivatize ().end ());
1505
+ allSymbols.append (dsp-> getAllSymbolsToPrivatize ().begin (),
1506
+ dsp-> getAllSymbolsToPrivatize ().end ());
1496
1507
1497
1508
unsigned argIdx = 0 ;
1498
1509
for (const semantics::Symbol *arg : allSymbols) {
@@ -1519,8 +1530,11 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1519
1530
return allSymbols;
1520
1531
};
1521
1532
1522
- genInfo.setGenRegionEntryCb (genRegionEntryCB).setDataSharingProcessor (&dsp);
1523
- return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1533
+ genInfo.setGenRegionEntryCb (genRegionEntryCB).setDataSharingProcessor (dsp);
1534
+ auto parallelOp =
1535
+ genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1536
+ parallelOp.setComposite (isComposite);
1537
+ return parallelOp;
1524
1538
}
1525
1539
1526
1540
// / This breaks the normal prototype of the gen*Op functions: adding the
@@ -1999,8 +2013,16 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
1999
2013
genParallelClauses (converter, semaCtx, stmtCtx, item->clauses , loc, clauseOps,
2000
2014
reductionTypes, reductionSyms);
2001
2015
2016
+ std::optional<DataSharingProcessor> dsp;
2017
+ if (enableDelayedPrivatization) {
2018
+ dsp.emplace (converter, semaCtx, item->clauses , eval,
2019
+ lower::omp::isLastItemInQueue (item, queue),
2020
+ /* useDelayedPrivatization=*/ true , &symTable);
2021
+ dsp->processStep1 (&clauseOps);
2022
+ }
2002
2023
genParallelOp (converter, symTable, semaCtx, eval, loc, queue, item, clauseOps,
2003
- reductionSyms, reductionTypes);
2024
+ reductionSyms, reductionTypes,
2025
+ enableDelayedPrivatization ? &dsp.value () : nullptr );
2004
2026
}
2005
2027
2006
2028
static void genStandaloneSimd (lower::AbstractConverter &converter,
@@ -2052,8 +2074,69 @@ static void genCompositeDistributeParallelDo(
2052
2074
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
2053
2075
mlir::Location loc, const ConstructQueue &queue,
2054
2076
ConstructQueue::const_iterator item) {
2077
+ lower::StatementContext stmtCtx;
2078
+
2055
2079
assert (std::distance (item, queue.end ()) == 3 && " Invalid leaf constructs" );
2056
- TODO (loc, " Composite DISTRIBUTE PARALLEL DO" );
2080
+ ConstructQueue::const_iterator distributeItem = item;
2081
+ ConstructQueue::const_iterator parallelItem = std::next (distributeItem);
2082
+ ConstructQueue::const_iterator doItem = std::next (parallelItem);
2083
+
2084
+ // Create parent omp.parallel first.
2085
+ mlir::omp::ParallelOperands parallelClauseOps;
2086
+ llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
2087
+ llvm::SmallVector<mlir::Type> parallelReductionTypes;
2088
+ genParallelClauses (converter, semaCtx, stmtCtx, parallelItem->clauses , loc,
2089
+ parallelClauseOps, parallelReductionTypes,
2090
+ parallelReductionSyms);
2091
+
2092
+ DataSharingProcessor dsp (converter, semaCtx, doItem->clauses , eval,
2093
+ /* shouldCollectPreDeterminedSymbols=*/ true ,
2094
+ /* useDelayedPrivatization=*/ true , &symTable);
2095
+ dsp.processStep1 (¶llelClauseOps);
2096
+
2097
+ genParallelOp (converter, symTable, semaCtx, eval, loc, queue, parallelItem,
2098
+ parallelClauseOps, parallelReductionSyms,
2099
+ parallelReductionTypes, &dsp, /* isComposite=*/ true );
2100
+
2101
+ // Clause processing.
2102
+ mlir::omp::DistributeOperands distributeClauseOps;
2103
+ genDistributeClauses (converter, semaCtx, stmtCtx, distributeItem->clauses ,
2104
+ loc, distributeClauseOps);
2105
+
2106
+ mlir::omp::WsloopOperands wsloopClauseOps;
2107
+ llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
2108
+ llvm::SmallVector<mlir::Type> wsloopReductionTypes;
2109
+ genWsloopClauses (converter, semaCtx, stmtCtx, doItem->clauses , loc,
2110
+ wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
2111
+
2112
+ mlir::omp::LoopNestOperands loopNestClauseOps;
2113
+ llvm::SmallVector<const semantics::Symbol *> iv;
2114
+ genLoopNestClauses (converter, semaCtx, eval, doItem->clauses , loc,
2115
+ loopNestClauseOps, iv);
2116
+
2117
+ // Operation creation.
2118
+ // TODO: Populate entry block arguments with private variables.
2119
+ auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
2120
+ converter, loc, distributeClauseOps, /* blockArgTypes=*/ {});
2121
+ distributeOp.setComposite (/* val=*/ true );
2122
+
2123
+ // TODO: Add private variables to entry block arguments.
2124
+ auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
2125
+ converter, loc, wsloopClauseOps, wsloopReductionTypes);
2126
+ wsloopOp.setComposite (/* val=*/ true );
2127
+
2128
+ // Construct wrapper entry block list and associated symbols. It is important
2129
+ // that the symbol order and the block argument order match, so that the
2130
+ // symbol-value bindings created are correct.
2131
+ auto &wrapperSyms = wsloopReductionSyms;
2132
+
2133
+ auto wrapperArgs = llvm::to_vector (
2134
+ llvm::concat<mlir::BlockArgument>(distributeOp.getRegion ().getArguments (),
2135
+ wsloopOp.getRegion ().getArguments ()));
2136
+
2137
+ genLoopNestOp (converter, symTable, semaCtx, eval, loc, queue, doItem,
2138
+ loopNestClauseOps, iv, wrapperSyms, wrapperArgs,
2139
+ llvm::omp::Directive::OMPD_distribute_parallel_do, dsp);
2057
2140
}
2058
2141
2059
2142
static void genCompositeDistributeParallelDoSimd (
0 commit comments