Skip to content

Commit aa0403a

Browse files
committed
[Flang][OpenMP] Prevent re-composition of composite constructs
After decomposition of OpenMP compound constructs and assignment of applicable clauses to each leaf construct, composite constructs are then combined again into a single element in the construct queue. This helped later lowering stages easily identify composite constructs. However, as a result of the re-composition stage, the same list of clauses is used to produce all MLIR operations corresponding to each leaf of the original composite construct. This undoes existing logic introducing implicit clauses and deciding to which leaf construct(s) each clause applies. This patch removes construct re-composition logic and updates Flang lowering to be able to identify composite constructs from a list of leaf constructs. As a result, the right set of clauses is produced for each operation representing a leaf of a composite construct.
1 parent ac12b48 commit aa0403a

File tree

7 files changed

+103
-488
lines changed

7 files changed

+103
-488
lines changed

flang/lib/Lower/OpenMP/Decomposer.cpp

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "llvm/ADT/STLExtras.h"
2323
#include "llvm/ADT/SmallVector.h"
2424
#include "llvm/Frontend/OpenMP/ClauseT.h"
25-
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
2625
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
2726
#include "llvm/Frontend/OpenMP/OMP.h"
2827
#include "llvm/Support/raw_ostream.h"
@@ -68,12 +67,6 @@ struct ConstructDecomposition {
6867
};
6968
} // namespace
7069

71-
static UnitConstruct mergeConstructs(uint32_t version,
72-
llvm::ArrayRef<UnitConstruct> units) {
73-
tomp::ConstructCompositionT compose(version, units);
74-
return compose.merged;
75-
}
76-
7770
namespace Fortran::lower::omp {
7871
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
7972
const UnitConstruct &uc) {
@@ -90,38 +83,33 @@ ConstructQueue buildConstructQueue(
9083
Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
9184
llvm::omp::Directive compound, const List<Clause> &clauses) {
9285

93-
List<UnitConstruct> constructs;
94-
9586
ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
9687
assert(!decompose.output.empty() && "Construct decomposition failed");
9788

98-
llvm::SmallVector<llvm::omp::Directive> loweringUnits;
99-
std::ignore =
100-
llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
101-
uint32_t version = getOpenMPVersionAttribute(modOp);
102-
103-
int leafIndex = 0;
104-
for (llvm::omp::Directive dir_id : loweringUnits) {
105-
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
106-
llvm::omp::getLeafConstructsOrSelf(dir_id);
107-
size_t numLeafs = leafsOrSelf.size();
108-
109-
llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
110-
numLeafs};
111-
auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));
112-
113-
if (!transferLocations(clauses, uc.clauses)) {
114-
// If some clauses are left without source information, use the
115-
// directive's source.
116-
for (auto &clause : uc.clauses) {
117-
if (clause.source.empty())
118-
clause.source = source;
119-
}
120-
}
121-
leafIndex += numLeafs;
89+
for (UnitConstruct &uc : decompose.output) {
90+
assert(getLeafConstructs(uc.id).empty() && "unexpected compound directive");
91+
// If some clauses are left without source information, use the directive's
92+
// source.
93+
for (auto &clause : uc.clauses)
94+
if (clause.source.empty())
95+
clause.source = source;
12296
}
12397

124-
return constructs;
98+
return decompose.output;
99+
}
100+
101+
bool matchLeafSequence(ConstructQueue::const_iterator item,
102+
const ConstructQueue &queue,
103+
llvm::ArrayRef<llvm::omp::Directive> directives) {
104+
for (auto [dir, leaf] :
105+
llvm::zip_longest(directives, llvm::make_range(item, queue.end()))) {
106+
if (!dir || !leaf)
107+
return false;
108+
109+
if (dir.value() != leaf.value().id)
110+
return false;
111+
}
112+
return true;
125113
}
126114

127115
bool isLastItemInQueue(ConstructQueue::const_iterator item,

flang/lib/Lower/OpenMP/Decomposer.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "Clauses.h"
1212
#include "mlir/IR/BuiltinOps.h"
13-
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
1413
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
1514
#include "llvm/Frontend/OpenMP/OMP.h"
1615
#include "llvm/Support/Compiler.h"
@@ -49,6 +48,12 @@ ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
4948

5049
bool isLastItemInQueue(ConstructQueue::const_iterator item,
5150
const ConstructQueue &queue);
51+
52+
/// Try to match a sequence of \c directives to the range of leaf constructs
53+
/// starting from \c item to the end of the \c queue.
54+
bool matchLeafSequence(ConstructQueue::const_iterator item,
55+
const ConstructQueue &queue,
56+
llvm::ArrayRef<llvm::omp::Directive> directives);
5257
} // namespace Fortran::lower::omp
5358

5459
#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2042,6 +2042,7 @@ static void genCompositeDistributeParallelDoSimd(
20422042
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
20432043
mlir::Location loc, const ConstructQueue &queue,
20442044
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
2045+
assert(std::distance(item, queue.end()) == 4 && "Invalid leaf constructs");
20452046
TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD");
20462047
}
20472048

@@ -2052,17 +2053,23 @@ static void genCompositeDistributeSimd(
20522053
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
20532054
lower::StatementContext stmtCtx;
20542055

2056+
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
2057+
ConstructQueue::const_iterator distributeItem = item;
2058+
ConstructQueue::const_iterator simdItem = std::next(distributeItem);
2059+
20552060
// Clause processing.
20562061
mlir::omp::DistributeOperands distributeClauseOps;
2057-
genDistributeClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
2058-
distributeClauseOps);
2062+
genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
2063+
loc, distributeClauseOps);
20592064

20602065
mlir::omp::SimdOperands simdClauseOps;
2061-
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
2066+
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);
20622067

2068+
// Pass the innermost leaf construct's clauses because that's where COLLAPSE
2069+
// is placed by construct decomposition.
20632070
mlir::omp::LoopNestOperands loopNestClauseOps;
20642071
llvm::SmallVector<const semantics::Symbol *> iv;
2065-
genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
2072+
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
20662073
loopNestClauseOps, iv);
20672074

20682075
// Operation creation.
@@ -2084,7 +2091,7 @@ static void genCompositeDistributeSimd(
20842091

20852092
assert(wrapperArgs.empty() &&
20862093
"Block args for omp.simd and omp.distribute currently not expected");
2087-
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
2094+
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
20882095
loopNestClauseOps, iv, /*wrapperSyms=*/{}, wrapperArgs,
20892096
llvm::omp::Directive::OMPD_distribute_simd, dsp);
20902097
}
@@ -2098,19 +2105,25 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
20982105
DataSharingProcessor &dsp) {
20992106
lower::StatementContext stmtCtx;
21002107

2108+
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
2109+
ConstructQueue::const_iterator doItem = item;
2110+
ConstructQueue::const_iterator simdItem = std::next(doItem);
2111+
21012112
// Clause processing.
21022113
mlir::omp::WsloopOperands wsloopClauseOps;
21032114
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
21042115
llvm::SmallVector<mlir::Type> wsloopReductionTypes;
2105-
genWsloopClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
2116+
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
21062117
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);
21072118

21082119
mlir::omp::SimdOperands simdClauseOps;
2109-
genSimdClauses(converter, semaCtx, item->clauses, loc, simdClauseOps);
2120+
genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps);
21102121

2122+
// Pass the innermost leaf construct's clauses because that's where COLLAPSE
2123+
// is placed by construct decomposition.
21112124
mlir::omp::LoopNestOperands loopNestClauseOps;
21122125
llvm::SmallVector<const semantics::Symbol *> iv;
2113-
genLoopNestClauses(converter, semaCtx, eval, item->clauses, loc,
2126+
genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
21142127
loopNestClauseOps, iv);
21152128

21162129
// Operation creation.
@@ -2131,7 +2144,7 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
21312144

21322145
assert(wsloopReductionSyms.size() == wrapperArgs.size() &&
21332146
"Number of symbols and wrapper block arguments must match");
2134-
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, item,
2147+
genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, simdItem,
21352148
loopNestClauseOps, iv, wsloopReductionSyms, wrapperArgs,
21362149
llvm::omp::Directive::OMPD_do_simd, dsp);
21372150
}
@@ -2141,13 +2154,50 @@ static void genCompositeTaskloopSimd(
21412154
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
21422155
mlir::Location loc, const ConstructQueue &queue,
21432156
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
2157+
assert(std::distance(item, queue.end()) == 2 && "Invalid leaf constructs");
21442158
TODO(loc, "Composite TASKLOOP SIMD");
21452159
}
21462160

21472161
//===----------------------------------------------------------------------===//
21482162
// Dispatch
21492163
//===----------------------------------------------------------------------===//
21502164

2165+
static bool genOMPCompositeDispatch(
2166+
lower::AbstractConverter &converter, lower::SymMap &symTable,
2167+
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
2168+
mlir::Location loc, const ConstructQueue &queue,
2169+
ConstructQueue::const_iterator item, DataSharingProcessor &dsp) {
2170+
using llvm::omp::Directive;
2171+
using llvm::omp::getLeafConstructs, lower::omp::matchLeafSequence;
2172+
2173+
if (matchLeafSequence(
2174+
item, queue,
2175+
getLeafConstructs(Directive::OMPD_distribute_parallel_do)))
2176+
genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
2177+
queue, item, dsp);
2178+
else if (matchLeafSequence(
2179+
item, queue,
2180+
getLeafConstructs(Directive::OMPD_distribute_parallel_do_simd)))
2181+
genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
2182+
loc, queue, item, dsp);
2183+
else if (matchLeafSequence(
2184+
item, queue, getLeafConstructs(Directive::OMPD_distribute_simd)))
2185+
genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
2186+
item, dsp);
2187+
else if (matchLeafSequence(item, queue,
2188+
getLeafConstructs(Directive::OMPD_do_simd)))
2189+
genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
2190+
dsp);
2191+
else if (matchLeafSequence(item, queue,
2192+
getLeafConstructs(Directive::OMPD_taskloop_simd)))
2193+
genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
2194+
item, dsp);
2195+
else
2196+
return false;
2197+
2198+
return true;
2199+
}
2200+
21512201
static void genOMPDispatch(lower::AbstractConverter &converter,
21522202
lower::SymMap &symTable,
21532203
semantics::SemanticsContext &semaCtx,
@@ -2161,10 +2211,18 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
21612211
llvm::omp::Association::Loop;
21622212
if (loopLeaf) {
21632213
symTable.pushScope();
2214+
// TODO: Use one DataSharingProcessor for each leaf of a composite
2215+
// construct.
21642216
loopDsp.emplace(converter, semaCtx, item->clauses, eval,
21652217
/*shouldCollectPreDeterminedSymbols=*/true,
21662218
/*useDelayedPrivatization=*/false, &symTable);
21672219
loopDsp->processStep1();
2220+
2221+
if (genOMPCompositeDispatch(converter, symTable, semaCtx, eval, loc, queue,
2222+
item, *loopDsp)) {
2223+
symTable.popScope();
2224+
return;
2225+
}
21682226
}
21692227

21702228
switch (llvm::omp::Directive dir = item->id) {
@@ -2263,24 +2321,13 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
22632321

22642322
// Composite constructs
22652323
case llvm::omp::Directive::OMPD_distribute_parallel_do:
2266-
genCompositeDistributeParallelDo(converter, symTable, semaCtx, eval, loc,
2267-
queue, item, *loopDsp);
2268-
break;
22692324
case llvm::omp::Directive::OMPD_distribute_parallel_do_simd:
2270-
genCompositeDistributeParallelDoSimd(converter, symTable, semaCtx, eval,
2271-
loc, queue, item, *loopDsp);
2272-
break;
22732325
case llvm::omp::Directive::OMPD_distribute_simd:
2274-
genCompositeDistributeSimd(converter, symTable, semaCtx, eval, loc, queue,
2275-
item, *loopDsp);
2276-
break;
22772326
case llvm::omp::Directive::OMPD_do_simd:
2278-
genCompositeDoSimd(converter, symTable, semaCtx, eval, loc, queue, item,
2279-
*loopDsp);
2280-
break;
22812327
case llvm::omp::Directive::OMPD_taskloop_simd:
2282-
genCompositeTaskloopSimd(converter, symTable, semaCtx, eval, loc, queue,
2283-
item, *loopDsp);
2328+
// Composite constructs should have been split into a sequence of leaf
2329+
// constructs and lowered by genOMPCompositeDispatch().
2330+
llvm_unreachable("Unexpected composite construct.");
22842331
break;
22852332
default:
22862333
break;

flang/test/Lower/OpenMP/Todo/omp-do-simd-linear.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
55
subroutine testDoSimdLinear(int_array)
66
integer :: int_array(*)
7-
!CHECK: not yet implemented: Unhandled clause LINEAR in DO construct
7+
!CHECK: not yet implemented: Unhandled clause LINEAR in SIMD construct
88
!$omp do simd linear(int_array)
99
do index_ = 1, 10
1010
end do

flang/test/Lower/OpenMP/default-clause-byref.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ subroutine nested_default_clause_tests
197197
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
198198
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_testsEz"}
199199
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
200-
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
201-
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
200+
!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
202201
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_testsEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
202+
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_testsEy"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
203203
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_testsEz"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
204204
!CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_testsEk"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
205205
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {

flang/test/Lower/OpenMP/default-clause.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ end program default_clause_lowering
134134
!CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
135135
!CHECK: %[[Z:.*]] = fir.alloca i32 {bindc_name = "z", uniq_name = "_QFnested_default_clause_test1Ez"}
136136
!CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
137-
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
138-
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
137+
!CHECK: omp.parallel private({{.*firstprivate.*}} {{.*}}#0 -> %[[PRIVATE_X:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_Z:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[PRIVATE_K:.*]] : {{.*}}) {
139138
!CHECK: %[[PRIVATE_X_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_X]] {uniq_name = "_QFnested_default_clause_test1Ex"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
139+
!CHECK: %[[PRIVATE_Y_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Y]] {uniq_name = "_QFnested_default_clause_test1Ey"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
140140
!CHECK: %[[PRIVATE_Z_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_Z]] {uniq_name = "_QFnested_default_clause_test1Ez"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
141141
!CHECK: %[[PRIVATE_K_DECL:.*]]:2 = hlfir.declare %[[PRIVATE_K]] {uniq_name = "_QFnested_default_clause_test1Ek"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
142142
!CHECK: omp.parallel private({{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_Y:.*]] : {{.*}}, {{.*}} {{.*}}#0 -> %[[INNER_PRIVATE_X:.*]] : {{.*}}) {

0 commit comments

Comments
 (0)