Skip to content

Commit d9f36d8

Browse files
committed
Add region args.
1 parent 243465d commit d9f36d8

File tree

3 files changed

+164
-70
lines changed

3 files changed

+164
-70
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 132 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
147147
//===----------------------------------------------------------------------===//
148148

149149
class DataSharingProcessor {
150+
public:
151+
struct DelayedPrivatizationInfo {
152+
llvm::SetVector<mlir::SymbolRefAttr> privatizers;
153+
llvm::SetVector<mlir::Value> hostAddresses;
154+
llvm::SetVector<const Fortran::semantics::Symbol *> hostSymbols;
155+
};
156+
157+
private:
150158
bool hasLastPrivateOp;
151159
mlir::OpBuilder::InsertPoint lastPrivIP;
152160
mlir::OpBuilder::InsertPoint insPt;
@@ -163,8 +171,8 @@ class DataSharingProcessor {
163171

164172
bool useDelayedPrivatizationWhenPossible;
165173
Fortran::lower::SymMap *symTable;
166-
llvm::SetVector<mlir::SymbolRefAttr> privatizers;
167-
llvm::SetVector<mlir::Value> privateSymHostAddrsses;
174+
175+
DelayedPrivatizationInfo delayedPrivatizationInfo;
168176

169177
bool needBarrier();
170178
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
@@ -214,12 +222,8 @@ class DataSharingProcessor {
214222
loopIV = iv;
215223
}
216224

217-
const llvm::SetVector<mlir::SymbolRefAttr> &getPrivatizers() const {
218-
return privatizers;
219-
};
220-
221-
const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
222-
return privateSymHostAddrsses;
225+
const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
226+
return delayedPrivatizationInfo;
223227
}
224228
};
225229

@@ -547,8 +551,10 @@ void DataSharingProcessor::privatize() {
547551
symTable->popScope();
548552
firOpBuilder.restoreInsertionPoint(ip);
549553

550-
privatizers.insert(mlir::SymbolRefAttr::get(privatizerOp));
551-
privateSymHostAddrsses.insert(hsb.getAddr());
554+
delayedPrivatizationInfo.privatizers.insert(
555+
mlir::SymbolRefAttr::get(privatizerOp));
556+
delayedPrivatizationInfo.hostAddresses.insert(hsb.getAddr());
557+
delayedPrivatizationInfo.hostSymbols.insert(sym);
552558
} else {
553559
cloneSymbol(sym);
554560
copyFirstPrivateSymbol(sym);
@@ -2322,7 +2328,9 @@ static void createBodyOfOp(
23222328
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
23232329
Fortran::lower::pft::Evaluation &eval, bool genNested,
23242330
const Fortran::parser::OmpClauseList *clauses = nullptr,
2325-
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2331+
std::function<llvm::SmallVector<const Fortran::semantics::Symbol *>(
2332+
mlir::Operation *)>
2333+
genRegionEntryCB = nullptr,
23262334
bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
23272335
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
23282336

@@ -2336,27 +2344,15 @@ static void createBodyOfOp(
23362344
// argument. Also update the symbol's address with the mlir argument value.
23372345
// e.g. For loops the argument is the induction variable. And all further
23382346
// uses of the induction variable should use this mlir value.
2339-
if (args.size()) {
2340-
std::size_t loopVarTypeSize = 0;
2341-
for (const Fortran::semantics::Symbol *arg : args)
2342-
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
2343-
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2344-
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
2345-
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
2346-
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
2347-
// The argument is not currently in memory, so make a temporary for the
2348-
// argument, and store it there, then bind that location to the argument.
2349-
mlir::Operation *storeOp = nullptr;
2350-
for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
2351-
mlir::Value indexVal =
2352-
fir::getBase(op.getRegion().front().getArgument(argIndex));
2353-
storeOp =
2354-
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
2347+
auto regionArgs =
2348+
[&]() -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2349+
if (genRegionEntryCB != nullptr) {
2350+
return genRegionEntryCB(op);
23552351
}
2356-
firOpBuilder.setInsertionPointAfter(storeOp);
2357-
} else {
2352+
23582353
firOpBuilder.createBlock(&op.getRegion());
2359-
}
2354+
return {};
2355+
}();
23602356

23612357
// Mark the earliest insertion point.
23622358
mlir::Operation *marker = insertMarker(firOpBuilder);
@@ -2454,8 +2450,8 @@ static void createBodyOfOp(
24542450
assert(tempDsp.has_value());
24552451
tempDsp->processStep2(op, isLoop);
24562452
} else {
2457-
if (isLoop && args.size() > 0)
2458-
dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
2453+
if (isLoop && regionArgs.size() > 0)
2454+
dsp->setLoopIV(converter.getSymbolAddress(*regionArgs[0]));
24592455
dsp->processStep2(op, isLoop);
24602456
}
24612457
}
@@ -2531,41 +2527,44 @@ static void genBodyOfTargetDataOp(
25312527
}
25322528

25332529
template <typename OpTy, typename... Args>
2534-
static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
2535-
Fortran::lower::pft::Evaluation &eval, bool genNested,
2536-
mlir::Location currentLocation, bool outerCombined,
2537-
const Fortran::parser::OmpClauseList *clauseList,
2538-
DataSharingProcessor *dsp, Args &&...args) {
2530+
static OpTy genOpWithBody(
2531+
Fortran::lower::AbstractConverter &converter,
2532+
Fortran::lower::pft::Evaluation &eval, bool genNested,
2533+
mlir::Location currentLocation, bool outerCombined,
2534+
const Fortran::parser::OmpClauseList *clauseList,
2535+
std::function<llvm::SmallVector<const Fortran::semantics::Symbol *>(
2536+
mlir::Operation *)>
2537+
genRegionEntryCB,
2538+
DataSharingProcessor *dsp, Args &&...args) {
25392539
auto op = converter.getFirOpBuilder().create<OpTy>(
25402540
currentLocation, std::forward<Args>(args)...);
25412541
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
2542-
clauseList,
2543-
/*args=*/{}, outerCombined, dsp);
2542+
clauseList, genRegionEntryCB, outerCombined, dsp);
25442543
return op;
25452544
}
25462545

25472546
static mlir::omp::MasterOp
25482547
genMasterOp(Fortran::lower::AbstractConverter &converter,
25492548
Fortran::lower::pft::Evaluation &eval, bool genNested,
25502549
mlir::Location currentLocation) {
2551-
return genOpWithBody<mlir::omp::MasterOp>(converter, eval, genNested,
2552-
currentLocation,
2553-
/*outerCombined=*/false,
2554-
/*clauseList=*/nullptr,
2555-
/*dsp=*/nullptr,
2556-
/*resultTypes=*/mlir::TypeRange());
2550+
return genOpWithBody<mlir::omp::MasterOp>(
2551+
converter, eval, genNested, currentLocation,
2552+
/*outerCombined=*/false,
2553+
/*clauseList=*/nullptr, /*genRegionEntryCB=*/nullptr,
2554+
/*dsp=*/nullptr,
2555+
/*resultTypes=*/mlir::TypeRange());
25572556
}
25582557

25592558
static mlir::omp::OrderedRegionOp
25602559
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
25612560
Fortran::lower::pft::Evaluation &eval, bool genNested,
25622561
mlir::Location currentLocation) {
2563-
return genOpWithBody<mlir::omp::OrderedRegionOp>(converter, eval, genNested,
2564-
currentLocation,
2565-
/*outerCombined=*/false,
2566-
/*clauseList=*/nullptr,
2567-
/*dsp=*/nullptr,
2568-
/*simd=*/false);
2562+
return genOpWithBody<mlir::omp::OrderedRegionOp>(
2563+
converter, eval, genNested, currentLocation,
2564+
/*outerCombined=*/false,
2565+
/*clauseList=*/nullptr, /*genRegionEntryCB=*/nullptr,
2566+
/*dsp=*/nullptr,
2567+
/*simd=*/false);
25692568
}
25702569

25712570
static mlir::omp::ParallelOp
@@ -2601,16 +2600,44 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
26012600
dsp.processStep1();
26022601
}
26032602

2604-
llvm::SmallVector<mlir::Attribute> privatizers(dsp.getPrivatizers().begin(),
2605-
dsp.getPrivatizers().end());
2603+
const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
2604+
llvm::SmallVector<mlir::Attribute> privatizers(
2605+
delayedPrivatizationInfo.privatizers.begin(),
2606+
delayedPrivatizationInfo.privatizers.end());
26062607

26072608
llvm::SmallVector<mlir::Value> privateSymAddresses(
2608-
dsp.getPrivateSymHostAddrsses().begin(),
2609-
dsp.getPrivateSymHostAddrsses().end());
2609+
delayedPrivatizationInfo.hostAddresses.begin(),
2610+
delayedPrivatizationInfo.hostAddresses.end());
2611+
2612+
auto genRegionEntryCB = [&](mlir::Operation *op) {
2613+
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
2614+
auto privateVars = parallelOp.getPrivateVars();
2615+
auto &region = parallelOp.getRegion();
2616+
llvm::SmallVector<mlir::Type> privateVarTypes;
2617+
llvm::SmallVector<mlir::Location> privateVarLocs;
2618+
2619+
for (auto privateVar : privateVars) {
2620+
privateVarTypes.push_back(privateVar.getType());
2621+
privateVarLocs.push_back(privateVar.getLoc());
2622+
}
2623+
2624+
converter.getFirOpBuilder().createBlock(&region, {}, privateVarTypes,
2625+
privateVarLocs);
2626+
2627+
int argIdx = 0;
2628+
for (const auto *sym : delayedPrivatizationInfo.hostSymbols) {
2629+
converter.bindSymbol(*sym, region.getArgument(argIdx));
2630+
++argIdx;
2631+
}
2632+
2633+
return llvm::SmallVector<const Fortran::semantics::Symbol *>(
2634+
delayedPrivatizationInfo.hostSymbols.begin(),
2635+
delayedPrivatizationInfo.hostSymbols.end());
2636+
};
26102637

26112638
return genOpWithBody<mlir::omp::ParallelOp>(
26122639
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2613-
&dsp,
2640+
genRegionEntryCB, &dsp,
26142641
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
26152642
numThreadsClauseOperand, allocateOperands, allocatorOperands,
26162643
reductionVars,
@@ -2635,6 +2662,7 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
26352662
return genOpWithBody<mlir::omp::SectionOp>(
26362663
converter, eval, genNested, currentLocation,
26372664
/*outerCombined=*/false, &sectionsClauseList,
2665+
/*genRegionEntryCB=*/nullptr,
26382666
/*dsp=*/nullptr);
26392667
}
26402668

@@ -2656,8 +2684,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
26562684

26572685
return genOpWithBody<mlir::omp::SingleOp>(
26582686
converter, eval, genNested, currentLocation,
2659-
/*outerCombined=*/false, &beginClauseList, /*dsp=*/nullptr,
2660-
allocateOperands, allocatorOperands, nowaitAttr);
2687+
/*outerCombined=*/false, &beginClauseList, /*genRegionEntryCB=*/nullptr,
2688+
/*dsp=*/nullptr, allocateOperands, allocatorOperands, nowaitAttr);
26612689
}
26622690

26632691
static mlir::omp::TaskOp
@@ -2689,8 +2717,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
26892717

26902718
return genOpWithBody<mlir::omp::TaskOp>(
26912719
converter, eval, genNested, currentLocation,
2692-
/*outerCombined=*/false, &clauseList, /*dsp=*/nullptr, ifClauseOperand,
2693-
finalClauseOperand, untiedAttr, mergeableAttr,
2720+
/*outerCombined=*/false, &clauseList, /*genRegionEntryCB=*/nullptr,
2721+
/*dsp=*/nullptr, ifClauseOperand, finalClauseOperand, untiedAttr,
2722+
mergeableAttr,
26942723
/*in_reduction_vars=*/mlir::ValueRange(),
26952724
/*in_reductions=*/nullptr, priorityClauseOperand,
26962725
dependTypeOperands.empty()
@@ -2712,7 +2741,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
27122741
currentLocation, llvm::omp::Directive::OMPD_taskgroup);
27132742
return genOpWithBody<mlir::omp::TaskGroupOp>(
27142743
converter, eval, genNested, currentLocation,
2715-
/*outerCombined=*/false, &clauseList,
2744+
/*outerCombined=*/false, &clauseList, /*genRegionEntryCB=*/nullptr,
27162745
/*dsp=*/nullptr,
27172746
/*task_reduction_vars=*/mlir::ValueRange(),
27182747
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
@@ -3097,6 +3126,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
30973126

30983127
return genOpWithBody<mlir::omp::TeamsOp>(
30993128
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
3129+
/*genRegionEntryCB=*/nullptr,
31003130
/*dsp=*/nullptr,
31013131
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
31023132
threadLimitClauseOperand, allocateOperands, allocatorOperands,
@@ -3294,6 +3324,33 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
32943324
}
32953325
}
32963326

3327+
static llvm::SmallVector<const Fortran::semantics::Symbol *> genCodeForIterVar(
3328+
mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3329+
mlir::Location &loc,
3330+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args) {
3331+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3332+
auto &region = op->getRegion(0);
3333+
3334+
std::size_t loopVarTypeSize = 0;
3335+
for (const Fortran::semantics::Symbol *arg : args)
3336+
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
3337+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
3338+
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
3339+
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
3340+
firOpBuilder.createBlock(&region, {}, tiv, locs);
3341+
// The argument is not currently in memory, so make a temporary for the
3342+
// argument, and store it there, then bind that location to the argument.
3343+
mlir::Operation *storeOp = nullptr;
3344+
for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
3345+
mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex));
3346+
storeOp =
3347+
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
3348+
}
3349+
firOpBuilder.setInsertionPointAfter(storeOp);
3350+
3351+
return args;
3352+
}
3353+
32973354
static void
32983355
createSimdLoop(Fortran::lower::AbstractConverter &converter,
32993356
Fortran::lower::pft::Evaluation &eval,
@@ -3341,9 +3398,14 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
33413398

33423399
auto *nestedEval = getCollapsedLoopEval(
33433400
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
3401+
3402+
auto ivCallback = [&](mlir::Operation *op) {
3403+
return genCodeForIterVar(op, converter, loc, iv);
3404+
};
3405+
33443406
createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, *nestedEval,
33453407
/*genNested=*/true, &loopOpClauseList,
3346-
iv, /*outer=*/false, &dsp);
3408+
ivCallback, /*outer=*/false, &dsp);
33473409
}
33483410

33493411
static void createWsLoop(Fortran::lower::AbstractConverter &converter,
@@ -3416,8 +3478,14 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
34163478

34173479
auto *nestedEval = getCollapsedLoopEval(
34183480
eval, Fortran::lower::getCollapseValue(beginClauseList));
3481+
3482+
auto ivCallback = [&](mlir::Operation *op) {
3483+
return genCodeForIterVar(op, converter, loc, iv);
3484+
};
3485+
34193486
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, *nestedEval,
3420-
/*genNested=*/true, &beginClauseList, iv,
3487+
/*genNested=*/true, &beginClauseList,
3488+
ivCallback,
34213489
/*outer=*/false, &dsp);
34223490
}
34233491

@@ -3746,6 +3814,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
37463814
/*genNested=*/false, currentLocation,
37473815
/*outerCombined=*/false,
37483816
/*clauseList=*/nullptr,
3817+
/*genRegionEntryCB=*/nullptr,
37493818
/*dsp=*/nullptr,
37503819
/*reduction_vars=*/mlir::ValueRange(),
37513820
/*reductions=*/nullptr, allocateOperands,

flang/test/Lower/OpenMP/FIR/delayed_privatization.f90

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ subroutine delayed_privatization()
2929
! %c222_i32 = arith.constant 222 : i32
3030
! fir.store %c222_i32 to %1 : !fir.ref<i32>
3131
! omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
32-
! %2 = fir.load %0 : !fir.ref<i32>
33-
! %3 = fir.load %1 : !fir.ref<i32>
32+
! ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
33+
! %2 = fir.load %arg0 : !fir.ref<i32>
34+
! %3 = fir.load %arg1 : !fir.ref<i32>
3435
! %4 = arith.addi %2, %3 : i32
3536
! %c2_i32 = arith.constant 2 : i32
3637
! %5 = arith.addi %4, %c2_i32 : i32
37-
! fir.store %5 to %0 : !fir.ref<i32>
38+
! fir.store %5 to %arg0 : !fir.ref<i32>
3839
! omp.terminator
3940
! }
4041
! return
@@ -53,7 +54,6 @@ subroutine delayed_privatization()
5354
! fir.store %1 to %0 : !fir.ref<i32>
5455
! omp.yield(%0 : !fir.ref<i32>)
5556
! }) : () -> ()
56-
!}
5757
!
5858
! -----------------------------
5959
! ### Conversion to LLVM + OMP:
@@ -69,12 +69,13 @@ subroutine delayed_privatization()
6969
! %5 = llvm.mlir.constant(222 : i32) : i32
7070
! llvm.store %5, %3 : i32, !llvm.ptr
7171
! omp.parallel private(@var1.privatizer %1, @var2.privatizer %3 : !llvm.ptr, !llvm.ptr) {
72-
! %6 = llvm.load %1 : !llvm.ptr -> i32
73-
! %7 = llvm.load %3 : !llvm.ptr -> i32
72+
! ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
73+
! %6 = llvm.load %arg0 : !llvm.ptr -> i32
74+
! %7 = llvm.load %arg1 : !llvm.ptr -> i32
7475
! %8 = llvm.add %6, %7 : i32
7576
! %9 = llvm.mlir.constant(2 : i32) : i32
7677
! %10 = llvm.add %8, %9 : i32
77-
! llvm.store %10, %1 : i32, !llvm.ptr
78+
! llvm.store %10, %arg0 : i32, !llvm.ptr
7879
! omp.terminator
7980
! }
8081
! llvm.return

0 commit comments

Comments
 (0)