Skip to content

Commit a7b8205

Browse files
committed
[WIP] Delayed privatization.
This is a PoC for delayed privatization in OpenMP. Instead of directly emitting privatization code in the frontend, we add a new op to outline the privatization logic for a symbol and call-like mapping that maps from the host symbol to an outlined function-like privatizer op. Later, we would inline the delayed privatizer function-like op in the OpenMP region to basically get the same code generated directly by the fronend at the moment.
1 parent 0716d31 commit a7b8205

File tree

8 files changed

+326
-30
lines changed

8 files changed

+326
-30
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Common/Fortran.h"
1717
#include "flang/Lower/LoweringOptions.h"
1818
#include "flang/Lower/PFTDefs.h"
19+
#include "flang/Lower/SymbolMap.h"
1920
#include "flang/Optimizer/Builder/BoxValue.h"
2021
#include "flang/Semantics/symbol.h"
2122
#include "mlir/IR/Builders.h"
@@ -296,6 +297,9 @@ class AbstractConverter {
296297
return loweringOptions;
297298
}
298299

300+
virtual Fortran::lower::SymbolBox
301+
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
302+
299303
private:
300304
/// Options controlling lowering behavior.
301305
const Fortran::lower::LoweringOptions &loweringOptions;

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
10701070
/// Find the symbol in one level up of symbol map such as for host-association
10711071
/// in OpenMP code or return null.
10721072
Fortran::lower::SymbolBox
1073-
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
1073+
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
10741074
if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
10751075
return v;
10761076
return {};

flang/lib/Lower/OpenMP.cpp

Lines changed: 108 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class DataSharingProcessor {
161161
const Fortran::parser::OmpClauseList &opClauseList;
162162
Fortran::lower::pft::Evaluation &eval;
163163

164+
bool useDelayedPrivatizationWhenPossible;
165+
Fortran::lower::SymMap *symTable;
166+
llvm::SetVector<mlir::SymbolRefAttr> privateInitializers;
167+
llvm::SetVector<mlir::Value> privateSymHostAddrsses;
168+
164169
bool needBarrier();
165170
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
166171
void collectOmpObjectListSymbol(
@@ -182,10 +187,14 @@ class DataSharingProcessor {
182187
public:
183188
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
184189
const Fortran::parser::OmpClauseList &opClauseList,
185-
Fortran::lower::pft::Evaluation &eval)
190+
Fortran::lower::pft::Evaluation &eval,
191+
bool useDelayedPrivatizationWhenPossible = false,
192+
Fortran::lower::SymMap *symTable = nullptr)
186193
: hasLastPrivateOp(false), converter(converter),
187194
firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
188-
eval(eval) {}
195+
eval(eval), useDelayedPrivatizationWhenPossible(
196+
useDelayedPrivatizationWhenPossible),
197+
symTable(symTable) {}
189198
// Privatisation is split into two steps.
190199
// Step1 performs cloning of all privatisation clauses and copying for
191200
// firstprivates. Step1 is performed at the place where process/processStep1
@@ -204,6 +213,14 @@ class DataSharingProcessor {
204213
assert(!loopIV && "Loop iteration variable already set");
205214
loopIV = iv;
206215
}
216+
217+
const llvm::SetVector<mlir::SymbolRefAttr> &getPrivateInitializers() const {
218+
return privateInitializers;
219+
};
220+
221+
const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
222+
return privateSymHostAddrsses;
223+
}
207224
};
208225

209226
void DataSharingProcessor::processStep1() {
@@ -496,8 +513,46 @@ void DataSharingProcessor::privatize() {
496513
copyFirstPrivateSymbol(&*mem);
497514
}
498515
} else {
499-
cloneSymbol(sym);
500-
copyFirstPrivateSymbol(sym);
516+
if (useDelayedPrivatizationWhenPossible) {
517+
auto ip = firOpBuilder.saveInsertionPoint();
518+
519+
auto moduleOp = firOpBuilder.getInsertionBlock()
520+
->getParentOp()
521+
->getParentOfType<mlir::ModuleOp>();
522+
523+
firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
524+
moduleOp.getBodyRegion().front().end());
525+
526+
Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
527+
assert(hsb && "Host symbol box not found");
528+
529+
auto symType = hsb.getAddr().getType();
530+
auto symLoc = hsb.getAddr().getLoc();
531+
auto privatizerOp = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
532+
symLoc, symType, sym->name().ToString());
533+
firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());
534+
535+
symTable->pushScope();
536+
symTable->addSymbol(*sym, privatizerOp.getArgument(0));
537+
symTable->pushScope();
538+
539+
cloneSymbol(sym);
540+
copyFirstPrivateSymbol(sym);
541+
542+
firOpBuilder.create<mlir::omp::YieldOp>(
543+
hsb.getAddr().getLoc(),
544+
symTable->shallowLookupSymbol(*sym).getAddr());
545+
546+
symTable->popScope();
547+
symTable->popScope();
548+
firOpBuilder.restoreInsertionPoint(ip);
549+
550+
privateInitializers.insert(mlir::SymbolRefAttr::get(privatizerOp));
551+
privateSymHostAddrsses.insert(hsb.getAddr());
552+
} else {
553+
cloneSymbol(sym);
554+
copyFirstPrivateSymbol(sym);
555+
}
501556
}
502557
}
503558
}
@@ -2480,12 +2535,12 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
24802535
Fortran::lower::pft::Evaluation &eval, bool genNested,
24812536
mlir::Location currentLocation, bool outerCombined,
24822537
const Fortran::parser::OmpClauseList *clauseList,
2483-
Args &&...args) {
2538+
DataSharingProcessor *dsp, Args &&...args) {
24842539
auto op = converter.getFirOpBuilder().create<OpTy>(
24852540
currentLocation, std::forward<Args>(args)...);
24862541
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
24872542
clauseList,
2488-
/*args=*/{}, outerCombined);
2543+
/*args=*/{}, outerCombined, dsp);
24892544
return op;
24902545
}
24912546

@@ -2497,21 +2552,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
24972552
currentLocation,
24982553
/*outerCombined=*/false,
24992554
/*clauseList=*/nullptr,
2555+
/*dsp=*/nullptr,
25002556
/*resultTypes=*/mlir::TypeRange());
25012557
}
25022558

25032559
static mlir::omp::OrderedRegionOp
25042560
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
25052561
Fortran::lower::pft::Evaluation &eval, bool genNested,
25062562
mlir::Location currentLocation) {
2507-
return genOpWithBody<mlir::omp::OrderedRegionOp>(
2508-
converter, eval, genNested, currentLocation,
2509-
/*outerCombined=*/false,
2510-
/*clauseList=*/nullptr, /*simd=*/false);
2563+
return genOpWithBody<mlir::omp::OrderedRegionOp>(converter, eval, genNested,
2564+
currentLocation,
2565+
/*outerCombined=*/false,
2566+
/*clauseList=*/nullptr,
2567+
/*dsp=*/nullptr,
2568+
/*simd=*/false);
25112569
}
25122570

25132571
static mlir::omp::ParallelOp
25142572
genParallelOp(Fortran::lower::AbstractConverter &converter,
2573+
Fortran::lower::SymMap &symTable,
25152574
Fortran::lower::pft::Evaluation &eval, bool genNested,
25162575
mlir::Location currentLocation,
25172576
const Fortran::parser::OmpClauseList &clauseList,
@@ -2533,16 +2592,37 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25332592
if (!outerCombined)
25342593
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
25352594

2595+
bool privatize = !outerCombined;
2596+
DataSharingProcessor dsp(converter, clauseList, eval,
2597+
/*useDelayedPrivatizationWhenPossible=*/true,
2598+
&symTable);
2599+
2600+
if (privatize) {
2601+
dsp.processStep1();
2602+
}
2603+
2604+
llvm::SmallVector<mlir::Attribute> privateInits(
2605+
dsp.getPrivateInitializers().begin(), dsp.getPrivateInitializers().end());
2606+
2607+
llvm::SmallVector<mlir::Value> privateSymAddresses(
2608+
dsp.getPrivateSymHostAddrsses().begin(),
2609+
dsp.getPrivateSymHostAddrsses().end());
2610+
25362611
return genOpWithBody<mlir::omp::ParallelOp>(
25372612
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2613+
&dsp,
25382614
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
25392615
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2540-
reductionVars,
2616+
reductionVars, privateSymAddresses,
25412617
reductionDeclSymbols.empty()
25422618
? nullptr
25432619
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
25442620
reductionDeclSymbols),
2545-
procBindKindAttr);
2621+
procBindKindAttr,
2622+
privateInits.empty()
2623+
? nullptr
2624+
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2625+
privateInits));
25462626
}
25472627

25482628
static mlir::omp::SectionOp
@@ -2554,7 +2634,8 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
25542634
// all privatization is done within `omp.section` operations.
25552635
return genOpWithBody<mlir::omp::SectionOp>(
25562636
converter, eval, genNested, currentLocation,
2557-
/*outerCombined=*/false, &sectionsClauseList);
2637+
/*outerCombined=*/false, &sectionsClauseList,
2638+
/*dsp=*/nullptr);
25582639
}
25592640

25602641
static mlir::omp::SingleOp
@@ -2575,8 +2656,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
25752656

25762657
return genOpWithBody<mlir::omp::SingleOp>(
25772658
converter, eval, genNested, currentLocation,
2578-
/*outerCombined=*/false, &beginClauseList, allocateOperands,
2579-
allocatorOperands, nowaitAttr);
2659+
/*outerCombined=*/false, &beginClauseList, /*dsp=*/nullptr,
2660+
allocateOperands, allocatorOperands, nowaitAttr);
25802661
}
25812662

25822663
static mlir::omp::TaskOp
@@ -2608,8 +2689,8 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
26082689

26092690
return genOpWithBody<mlir::omp::TaskOp>(
26102691
converter, eval, genNested, currentLocation,
2611-
/*outerCombined=*/false, &clauseList, ifClauseOperand, finalClauseOperand,
2612-
untiedAttr, mergeableAttr,
2692+
/*outerCombined=*/false, &clauseList, /*dsp=*/nullptr, ifClauseOperand,
2693+
finalClauseOperand, untiedAttr, mergeableAttr,
26132694
/*in_reduction_vars=*/mlir::ValueRange(),
26142695
/*in_reductions=*/nullptr, priorityClauseOperand,
26152696
dependTypeOperands.empty()
@@ -2632,6 +2713,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
26322713
return genOpWithBody<mlir::omp::TaskGroupOp>(
26332714
converter, eval, genNested, currentLocation,
26342715
/*outerCombined=*/false, &clauseList,
2716+
/*dsp=*/nullptr,
26352717
/*task_reduction_vars=*/mlir::ValueRange(),
26362718
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
26372719
}
@@ -3015,6 +3097,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
30153097

30163098
return genOpWithBody<mlir::omp::TeamsOp>(
30173099
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
3100+
/*dsp=*/nullptr,
30183101
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
30193102
threadLimitClauseOperand, allocateOperands, allocatorOperands,
30203103
reductionVars,
@@ -3413,8 +3496,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
34133496
if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
34143497
.test(ompDirective)) {
34153498
validDirective = true;
3416-
genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
3417-
loopOpClauseList,
3499+
genParallelOp(converter, symTable, eval, /*genNested=*/false,
3500+
currentLocation, loopOpClauseList,
34183501
/*outerCombined=*/true);
34193502
}
34203503
}
@@ -3502,8 +3585,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
35023585
genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation);
35033586
break;
35043587
case llvm::omp::Directive::OMPD_parallel:
3505-
genParallelOp(converter, eval, /*genNested=*/true, currentLocation,
3506-
beginClauseList);
3588+
genParallelOp(converter, symTable, eval, /*genNested=*/true,
3589+
currentLocation, beginClauseList);
35073590
break;
35083591
case llvm::omp::Directive::OMPD_single:
35093592
genSingleOp(converter, eval, /*genNested=*/true, currentLocation,
@@ -3562,8 +3645,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
35623645
.test(directive.v)) {
35633646
bool outerCombined =
35643647
directive.v != llvm::omp::Directive::OMPD_target_parallel;
3565-
genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
3566-
beginClauseList, outerCombined);
3648+
genParallelOp(converter, symTable, eval, /*genNested=*/false,
3649+
currentLocation, beginClauseList, outerCombined);
35673650
combinedDirective = true;
35683651
}
35693652
if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
@@ -3646,7 +3729,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
36463729

36473730
// Parallel wrapper of PARALLEL SECTIONS construct
36483731
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
3649-
genParallelOp(converter, eval,
3732+
genParallelOp(converter, symTable, eval,
36503733
/*genNested=*/false, currentLocation, sectionsClauseList,
36513734
/*outerCombined=*/true);
36523735
} else {
@@ -3663,6 +3746,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
36633746
/*genNested=*/false, currentLocation,
36643747
/*outerCombined=*/false,
36653748
/*clauseList=*/nullptr,
3749+
/*dsp=*/nullptr,
36663750
/*reduction_vars=*/mlir::ValueRange(),
36673751
/*reductions=*/nullptr, allocateOperands,
36683752
allocatorOperands, nowaitClauseOperand);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
subroutine delayed_privatization()
2+
integer :: var1
3+
integer :: var2
4+
5+
!$OMP PARALLEL FIRSTPRIVATE(var1, var2)
6+
var1 = var1 + var2 + 2
7+
!$OMP END PARALLEL
8+
9+
end subroutine
10+
11+
! This is what flang emits with the PoC:
12+
! --------------------------------------
13+
!
14+
!func.func @_QPdelayed_privatization() {
15+
! %0 = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatizationEvar1"}
16+
! %1 = fir.alloca i32 {bindc_name = "var2", uniq_name = "_QFdelayed_privatizationEvar2"}
17+
! omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
18+
! %2 = fir.load %0 : !fir.ref<i32>
19+
! %3 = fir.load %1 : !fir.ref<i32>
20+
! %4 = arith.addi %2, %3 : i32
21+
! %c2_i32 = arith.constant 2 : i32
22+
! %5 = arith.addi %4, %c2_i32 : i32
23+
! fir.store %5 to %0 : !fir.ref<i32>
24+
! omp.terminator
25+
! }
26+
! return
27+
!}
28+
!
29+
!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var1.privatizer"}> ({
30+
!^bb0(%arg0: !fir.ref<i32>):
31+
! %0 = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatizationEvar1"}
32+
! %1 = fir.load %arg0 : !fir.ref<i32>
33+
! fir.store %1 to %0 : !fir.ref<i32>
34+
! omp.yield(%0 : !fir.ref<i32>)
35+
!}) : () -> ()
36+
!
37+
!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var2.privatizer"}> ({
38+
!^bb0(%arg0: !fir.ref<i32>):
39+
! %0 = fir.alloca i32 {bindc_name = "var2", pinned, uniq_name = "_QFdelayed_privatizationEvar2"}
40+
! %1 = fir.load %arg0 : !fir.ref<i32>
41+
! fir.store %1 to %0 : !fir.ref<i32>
42+
! omp.yield(%0 : !fir.ref<i32>)
43+
!}) : () -> ()

0 commit comments

Comments
 (0)