@@ -485,6 +485,33 @@ mergePrivateVarsInfo(OMPOp op, llvm::ArrayRef<InfoTy> currentList,
485
485
infoAccessor);
486
486
}
487
487
488
+ static void
489
+ bindSymbolsToRegionArgs (lower::AbstractConverter &converter, mlir::Location loc,
490
+ llvm::ArrayRef<const semantics::Symbol *> symbols,
491
+ mlir::Region ®ion, unsigned regionBeginArgIdx) {
492
+ assert (regionBeginArgIdx + symbols.size () <= region.getNumArguments ());
493
+ for (const semantics::Symbol *arg : symbols) {
494
+ auto bind = [&](const semantics::Symbol *sym) {
495
+ mlir::BlockArgument blockArg = region.getArgument (regionBeginArgIdx);
496
+ ++regionBeginArgIdx;
497
+ converter.bindSymbol (
498
+ *sym,
499
+ hlfir::translateToExtendedValue (
500
+ loc, converter.getFirOpBuilder (), hlfir::Entity{blockArg},
501
+ /* contiguousHint=*/
502
+ evaluate::IsSimplyContiguous (*sym, converter.getFoldingContext ()))
503
+ .first );
504
+ };
505
+
506
+ if (const auto *commonDet =
507
+ arg->detailsIf <semantics::CommonBlockDetails>()) {
508
+ for (const auto &mem : commonDet->objects ())
509
+ bind (&*mem);
510
+ } else
511
+ bind (arg);
512
+ }
513
+ }
514
+
488
515
// ===----------------------------------------------------------------------===//
489
516
// Op body generation helper structures and functions
490
517
// ===----------------------------------------------------------------------===//
@@ -1493,28 +1520,7 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1493
1520
llvm::SmallVector<const semantics::Symbol *> allSymbols (reductionSyms);
1494
1521
allSymbols.append (dsp->getDelayedPrivSymbols ().begin (),
1495
1522
dsp->getDelayedPrivSymbols ().end ());
1496
-
1497
- unsigned argIdx = 0 ;
1498
- for (const semantics::Symbol *arg : allSymbols) {
1499
- auto bind = [&](const semantics::Symbol *sym) {
1500
- mlir::BlockArgument blockArg = region.getArgument (argIdx);
1501
- ++argIdx;
1502
- converter.bindSymbol (*sym,
1503
- hlfir::translateToExtendedValue (
1504
- loc, firOpBuilder, hlfir::Entity{blockArg},
1505
- /* contiguousHint=*/
1506
- evaluate::IsSimplyContiguous (
1507
- *sym, converter.getFoldingContext ()))
1508
- .first );
1509
- };
1510
-
1511
- if (const auto *commonDet =
1512
- arg->detailsIf <semantics::CommonBlockDetails>()) {
1513
- for (const auto &mem : commonDet->objects ())
1514
- bind (&*mem);
1515
- } else
1516
- bind (arg);
1517
- }
1523
+ bindSymbolsToRegionArgs (converter, loc, allSymbols, region, 0 );
1518
1524
1519
1525
return allSymbols;
1520
1526
};
@@ -1681,7 +1687,6 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
1681
1687
mapTypes, deviceAddrSyms, deviceAddrLocs, deviceAddrTypes,
1682
1688
devicePtrSyms, devicePtrLocs, devicePtrTypes);
1683
1689
1684
- llvm::SmallVector<const semantics::Symbol *> privateSyms;
1685
1690
DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1686
1691
/* shouldCollectPreDeterminedSymbols=*/
1687
1692
lower::omp::isLastItemInQueue (item, queue),
@@ -1932,22 +1937,49 @@ static void genStandaloneDistribute(lower::AbstractConverter &converter,
1932
1937
ConstructQueue::const_iterator item) {
1933
1938
lower::StatementContext stmtCtx;
1934
1939
1940
+ auto teamsOp = mlir::cast<mlir::omp::TeamsOp>(
1941
+ converter.getFirOpBuilder ().getInsertionBlock ()->getParentOp ());
1935
1942
mlir::omp::DistributeOperands distributeClauseOps;
1936
1943
genDistributeClauses (converter, semaCtx, stmtCtx, item->clauses , loc,
1937
1944
distributeClauseOps);
1938
1945
1939
- // TODO: Support delayed privatization.
1946
+ // Privatization for a `distribute` directive is done in the `teams` region to
1947
+ // which the directive binds. Therefore, all privatization logic (delayed as
1948
+ // well as early) happens **before** the `distribute` op is generated (i.e.
1949
+ // inside the parent `teams` op).
1940
1950
DataSharingProcessor dsp (converter, semaCtx, item->clauses , eval,
1941
1951
/* shouldCollectPreDeterminedSymbols=*/ true ,
1942
- /* useDelayedPrivatization=*/ false , &symTable);
1943
- dsp.processStep1 ();
1952
+ enableDelayedPrivatizationStaging, &symTable);
1953
+ mlir::omp::PrivateClauseOps privateClauseOps;
1954
+ dsp.processStep1 (&privateClauseOps);
1955
+
1956
+ if (enableDelayedPrivatizationStaging) {
1957
+ mlir::Region &teamsRegion = teamsOp.getRegion ();
1958
+ unsigned privateVarsArgIdx = teamsRegion.getNumArguments ();
1959
+ llvm::SmallVector<mlir::Type> privateVarTypes;
1960
+ llvm::SmallVector<mlir::Location> privateVarLocs;
1961
+
1962
+ for (mlir::Value privateVar : privateClauseOps.privateVars ) {
1963
+ privateVarTypes.push_back (privateVar.getType ());
1964
+ privateVarLocs.push_back (privateVar.getLoc ());
1965
+ teamsOp.getPrivateVarsMutable ().append (privateVar);
1966
+ }
1967
+
1968
+ teamsOp.setPrivateSymsAttr (
1969
+ converter.getFirOpBuilder ().getArrayAttr (privateClauseOps.privateSyms ));
1970
+ teamsRegion.addArguments (privateVarTypes, privateVarLocs);
1971
+
1972
+ llvm::ArrayRef<const semantics::Symbol *> delayedPrivSyms =
1973
+ dsp.getDelayedPrivSymbols ();
1974
+ bindSymbolsToRegionArgs (converter, loc, delayedPrivSyms, teamsRegion,
1975
+ privateVarsArgIdx);
1976
+ }
1944
1977
1945
1978
mlir::omp::LoopNestOperands loopNestClauseOps;
1946
1979
llvm::SmallVector<const semantics::Symbol *> iv;
1947
1980
genLoopNestClauses (converter, semaCtx, eval, item->clauses , loc,
1948
1981
loopNestClauseOps, iv);
1949
1982
1950
- // TODO: Populate entry block arguments with private variables.
1951
1983
auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
1952
1984
converter, loc, distributeClauseOps, /* blockArgTypes=*/ {});
1953
1985
0 commit comments