Skip to content

Commit 5b654d9

Browse files
committed
[Flang][OpenMP] DISTRIBUTE PARALLEL DO lowering
This patch adds PFT to MLIR lowering support for `distribute parallel do` composite constructs.
1 parent 23983c7 commit 5b654d9

File tree

4 files changed

+493
-35
lines changed

4 files changed

+493
-35
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 114 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,11 @@ struct OpWithBodyGenInfo {
519519
return *this;
520520
}
521521

522+
OpWithBodyGenInfo &setGenSkeletonOnly(bool value) {
523+
genSkeletonOnly = value;
524+
return *this;
525+
}
526+
522527
/// [inout] converter to use for the clauses.
523528
lower::AbstractConverter &converter;
524529
/// [in] Symbol table
@@ -538,6 +543,9 @@ struct OpWithBodyGenInfo {
538543
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
539544
/// is created in the region.
540545
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
546+
/// [in] if set to `true`, skip generating nested evaluations and dispatching
547+
/// any further leaf constructs.
548+
bool genSkeletonOnly = false;
541549
};
542550

543551
/// Create the body (block) for an OpenMP Operation.
@@ -600,20 +608,22 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
600608
}
601609
}
602610

603-
if (ConstructQueue::const_iterator next = std::next(item);
604-
next != queue.end()) {
605-
genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
606-
info.loc, queue, next);
607-
} else {
608-
// genFIR(Evaluation&) tries to patch up unterminated blocks, causing
609-
// a lot of complications for our approach if the terminator generation
610-
// is delayed past this point. Insert a temporary terminator here, then
611-
// delete it.
612-
firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
613-
auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
614-
firOpBuilder.setInsertionPointAfter(marker);
615-
genNestedEvaluations(info.converter, info.eval);
616-
temp->erase();
611+
if (!info.genSkeletonOnly) {
612+
if (ConstructQueue::const_iterator next = std::next(item);
613+
next != queue.end()) {
614+
genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
615+
info.loc, queue, next);
616+
} else {
617+
// genFIR(Evaluation&) tries to patch up unterminated blocks, causing
618+
// a lot of complications for our approach if the terminator generation
619+
// is delayed past this point. Insert a temporary terminator here, then
620+
// delete it.
621+
firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
622+
auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
623+
firOpBuilder.setInsertionPointAfter(marker);
624+
genNestedEvaluations(info.converter, info.eval);
625+
temp->erase();
626+
}
617627
}
618628

619629
// Get or create a unique exiting block from the given region, or
@@ -1445,7 +1455,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14451455
const ConstructQueue &queue, ConstructQueue::const_iterator item,
14461456
mlir::omp::ParallelOperands &clauseOps,
14471457
llvm::ArrayRef<const semantics::Symbol *> reductionSyms,
1448-
llvm::ArrayRef<mlir::Type> reductionTypes) {
1458+
llvm::ArrayRef<mlir::Type> reductionTypes,
1459+
DataSharingProcessor *dsp, bool isComposite = false) {
14491460
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
14501461

14511462
auto reductionCallback = [&](mlir::Operation *op) {
@@ -1457,17 +1468,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14571468
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
14581469
llvm::omp::Directive::OMPD_parallel)
14591470
.setClauses(&item->clauses)
1460-
.setGenRegionEntryCb(reductionCallback);
1461-
1462-
if (!enableDelayedPrivatization)
1463-
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
1464-
clauseOps);
1465-
1466-
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
1467-
lower::omp::isLastItemInQueue(item, queue),
1468-
/*useDelayedPrivatization=*/true, &symTable);
1469-
dsp.processStep1(&clauseOps);
1471+
.setGenRegionEntryCb(reductionCallback)
1472+
.setGenSkeletonOnly(isComposite);
1473+
1474+
if (!enableDelayedPrivatization) {
1475+
auto parallelOp =
1476+
genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1477+
parallelOp.setComposite(isComposite);
1478+
return parallelOp;
1479+
}
14701480

1481+
assert(dsp && "expected valid DataSharingProcessor");
14711482
auto genRegionEntryCB = [&](mlir::Operation *op) {
14721483
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
14731484

@@ -1491,8 +1502,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
14911502
allRegionArgLocs);
14921503

14931504
llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
1494-
allSymbols.append(dsp.getAllSymbolsToPrivatize().begin(),
1495-
dsp.getAllSymbolsToPrivatize().end());
1505+
allSymbols.append(dsp->getAllSymbolsToPrivatize().begin(),
1506+
dsp->getAllSymbolsToPrivatize().end());
14961507

14971508
unsigned argIdx = 0;
14981509
for (const semantics::Symbol *arg : allSymbols) {
@@ -1519,8 +1530,11 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
15191530
return allSymbols;
15201531
};
15211532

1522-
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
1523-
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1533+
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(dsp);
1534+
auto parallelOp =
1535+
genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
1536+
parallelOp.setComposite(isComposite);
1537+
return parallelOp;
15241538
}
15251539

15261540
/// This breaks the normal prototype of the gen*Op functions: adding the
@@ -1999,8 +2013,16 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
19992013
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
20002014
reductionTypes, reductionSyms);
20012015

2016+
std::optional<DataSharingProcessor> dsp;
2017+
if (enableDelayedPrivatization) {
2018+
dsp.emplace(converter, semaCtx, item->clauses, eval,
2019+
lower::omp::isLastItemInQueue(item, queue),
2020+
/*useDelayedPrivatization=*/true, &symTable);
2021+
dsp->processStep1(&clauseOps);
2022+
}
20022023
genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, clauseOps,
2003-
reductionSyms, reductionTypes);
2024+
reductionSyms, reductionTypes,
2025+
enableDelayedPrivatization ? &dsp.value() : nullptr);
20042026
}
20052027

20062028
static void genStandaloneSimd(lower::AbstractConverter &converter,
@@ -2052,8 +2074,69 @@ static void genCompositeDistributeParallelDo(
20522074
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
20532075
mlir::Location loc, const ConstructQueue &queue,
20542076
ConstructQueue::const_iterator item) {
2077+
lower::StatementContext stmtCtx;
2078+
20552079
assert(std::distance(item, queue.end()) == 3 && "Invalid leaf constructs");
2056-
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
2080+
ConstructQueue::const_iterator distributeItem = item;
2081+
ConstructQueue::const_iterator parallelItem = std::next(distributeItem);
2082+
ConstructQueue::const_iterator doItem = std::next(parallelItem);
2083+
2084+
// Create parent omp.parallel first.
2085+
mlir::omp::ParallelOperands parallelClauseOps;
2086+
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
2087+
llvm::SmallVector<mlir::Type> parallelReductionTypes;
2088+
genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc,
2089+
parallelClauseOps, parallelReductionTypes,
2090+
parallelReductionSyms);
2091+
2092+
DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval,
2093+
/*shouldCollectPreDeterminedSymbols=*/true,
2094+
/*useDelayedPrivatization=*/true, &symTable);
2095+
dsp.processStep1(&parallelClauseOps);
2096+
2097+
genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem,
2098+
parallelClauseOps, parallelReductionSyms,
2099+
parallelReductionTypes, &dsp, /*isComposite=*/true);
2100+
2101+
// Clause processing.
2102+
mlir::omp::DistributeOperands distributeClauseOps;
2103+
genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
2104+
loc, distributeClauseOps);
2105+
2106+
mlir::omp::WsloopOperands wsloopClauseOps;
2107+
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
2108+
llvm::SmallVector<mlir::Type> wsloopReductionTypes;
2109+
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
2110+
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
2111+
2112+
mlir::omp::LoopNestOperands loopNestClauseOps;
2113+
llvm::SmallVector<const semantics::Symbol *> iv;
2114+
genLoopNestClauses(converter, semaCtx, eval, doItem->clauses, loc,
2115+
loopNestClauseOps, iv);
2116+
2117+
// Operation creation.
2118+
// TODO: Populate entry block arguments with private variables.
2119+
auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
2120+
converter, loc, distributeClauseOps, /*blockArgTypes=*/{});
2121+
distributeOp.setComposite(/*val=*/true);
2122+
2123+
// TODO: Add private variables to entry block arguments.
2124+
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
2125+
converter, loc, wsloopClauseOps, wsloopReductionTypes);
2126+
wsloopOp.setComposite(/*val=*/true);
2127+
2128+
// Construct wrapper entry block list and associated symbols. It is important
2129+
// that the symbol order and the block argument order match, so that the
2130+
// symbol-value bindings created are correct.
2131+
auto &wrapperSyms = wsloopReductionSyms;
2132+
2133+
auto wrapperArgs = llvm::to_vector(
2134+
llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(),
2135+
wsloopOp.getRegion().getArguments()));
2136+
2137+
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, doItem,
2138+
loopNestClauseOps, iv, wrapperSyms, wrapperArgs,
2139+
llvm::omp::Directive::OMPD_distribute_parallel_do, dsp);
20572140
}
20582141

20592142
static void genCompositeDistributeParallelDoSimd(
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
! This test checks lowering of OpenMP DISTRIBUTE PARALLEL DO composite
2+
! constructs.
3+
4+
! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
5+
! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - | FileCheck %s
6+
7+
! CHECK-LABEL: func.func @_QPdistribute_parallel_do_num_threads(
8+
subroutine distribute_parallel_do_num_threads()
9+
!$omp teams
10+
11+
! CHECK: omp.parallel num_threads({{.*}}) private({{.*}}) {
12+
! CHECK: omp.distribute {
13+
! CHECK-NEXT: omp.wsloop {
14+
! CHECK-NEXT: omp.loop_nest
15+
!$omp distribute parallel do num_threads(10)
16+
do index_ = 1, 10
17+
end do
18+
!$omp end distribute parallel do
19+
20+
!$omp end teams
21+
end subroutine distribute_parallel_do_num_threads
22+
23+
! CHECK-LABEL: func.func @_QPdistribute_parallel_do_dist_schedule(
24+
subroutine distribute_parallel_do_dist_schedule()
25+
!$omp teams
26+
27+
! CHECK: omp.parallel private({{.*}}) {
28+
! CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size({{.*}}) {
29+
! CHECK-NEXT: omp.wsloop {
30+
! CHECK-NEXT: omp.loop_nest
31+
!$omp distribute parallel do dist_schedule(static, 4)
32+
do index_ = 1, 10
33+
end do
34+
!$omp end distribute parallel do
35+
36+
!$omp end teams
37+
end subroutine distribute_parallel_do_dist_schedule
38+
39+
! CHECK-LABEL: func.func @_QPdistribute_parallel_do_ordered(
40+
subroutine distribute_parallel_do_ordered()
41+
!$omp teams
42+
43+
! CHECK: omp.parallel private({{.*}}) {
44+
! CHECK: omp.distribute {
45+
! CHECK-NEXT: omp.wsloop ordered(1) {
46+
! CHECK-NEXT: omp.loop_nest
47+
!$omp distribute parallel do ordered(1)
48+
do index_ = 1, 10
49+
end do
50+
!$omp end distribute parallel do
51+
52+
!$omp end teams
53+
end subroutine distribute_parallel_do_ordered
54+
55+
! CHECK-LABEL: func.func @_QPdistribute_parallel_do_private(
56+
subroutine distribute_parallel_do_private()
57+
! CHECK: %[[INDEX_ALLOC:.*]] = fir.alloca i32
58+
! CHECK: %[[INDEX:.*]]:2 = hlfir.declare %[[INDEX_ALLOC]]
59+
! CHECK: %[[X_ALLOC:.*]] = fir.alloca i64
60+
! CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_ALLOC]]
61+
integer(8) :: x
62+
63+
! CHECK: omp.teams {
64+
!$omp teams
65+
66+
! CHECK: omp.parallel private(@{{.*}} %[[X]]#0 -> %[[X_ARG:.*]] : !fir.ref<i64>,
67+
! CHECK-SAME: @{{.*}} %[[INDEX]]#0 -> %[[INDEX_ARG:.*]] : !fir.ref<i32>) {
68+
! CHECK: %[[X_PRIV:.*]]:2 = hlfir.declare %[[X_ARG]]
69+
! CHECK: %[[INDEX_PRIV:.*]]:2 = hlfir.declare %[[INDEX_ARG]]
70+
! CHECK: omp.distribute {
71+
! CHECK-NEXT: omp.wsloop {
72+
! CHECK-NEXT: omp.loop_nest
73+
!$omp distribute parallel do private(x)
74+
do index_ = 1, 10
75+
end do
76+
!$omp end distribute parallel do
77+
78+
!$omp end teams
79+
end subroutine distribute_parallel_do_private

0 commit comments

Comments
 (0)