Skip to content

Commit 4b5cfe3

Browse files
committed
[flang][OpenMP] Try to find a proper solution for symbol scoping in DSP
1 parent 1429f12 commit 4b5cfe3

File tree

2 files changed

+63
-37
lines changed

2 files changed

+63
-37
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,30 @@
2222
namespace Fortran {
2323
namespace lower {
2424
namespace omp {
25+
bool DataSharingProcessor::OMPConstructSymbolVisitor::isSymbolDefineBy(
26+
const semantics::Symbol *symbol, lower::pft::Evaluation &eval) const {
27+
return eval.visit(
28+
common::visitors{[&](const parser::OpenMPConstruct &functionParserNode) {
29+
return symDefMap.count(symbol) &&
30+
symDefMap.at(symbol) == &functionParserNode;
31+
},
32+
[](const auto &functionParserNode) { return false; }});
33+
}
34+
35+
DataSharingProcessor::DataSharingProcessor(
36+
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
37+
const List<Clause> &clauses, lower::pft::Evaluation &eval,
38+
bool shouldCollectPreDeterminedSymbols, bool useDelayedPrivatization,
39+
lower::SymMap *symTable)
40+
: hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx),
41+
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
42+
shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols),
43+
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable),
44+
visitor() {
45+
eval.visit([&](const auto &functionParserNode) {
46+
parser::Walk(functionParserNode, visitor);
47+
});
48+
}
2549

2650
void DataSharingProcessor::processStep1(
2751
mlir::omp::PrivateClauseOps *clauseOps,
@@ -285,38 +309,9 @@ void DataSharingProcessor::collectSymbolsInNestedRegions(
285309
// Recursively look for OpenMP constructs within `nestedEval`'s region
286310
collectSymbolsInNestedRegions(nestedEval, flag, symbolsInNestedRegions);
287311
else {
288-
bool isOrderedConstruct = [&]() {
289-
if (auto *ompConstruct =
290-
nestedEval.getIf<parser::OpenMPConstruct>()) {
291-
if (auto *ompBlockConstruct =
292-
std::get_if<parser::OpenMPBlockConstruct>(
293-
&ompConstruct->u)) {
294-
const auto &beginBlockDirective =
295-
std::get<parser::OmpBeginBlockDirective>(
296-
ompBlockConstruct->t);
297-
const auto origDirective =
298-
std::get<parser::OmpBlockDirective>(beginBlockDirective.t).v;
299-
300-
return origDirective == llvm::omp::Directive::OMPD_ordered;
301-
}
302-
}
303-
304-
return false;
305-
}();
306-
307-
bool isCriticalConstruct = [&]() {
308-
if (auto *ompConstruct =
309-
nestedEval.getIf<parser::OpenMPConstruct>()) {
310-
return std::get_if<parser::OpenMPCriticalConstruct>(
311-
&ompConstruct->u) != nullptr;
312-
}
313-
return false;
314-
}();
315-
316-
if (!isOrderedConstruct && !isCriticalConstruct)
317-
converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag,
318-
/*collectSymbols=*/true,
319-
/*collectHostAssociatedSymbols=*/false);
312+
converter.collectSymbolSet(nestedEval, symbolsInNestedRegions, flag,
313+
/*collectSymbols=*/true,
314+
/*collectHostAssociatedSymbols=*/false);
320315
}
321316
}
322317
}
@@ -356,6 +351,11 @@ void DataSharingProcessor::collectSymbols(
356351

357352
llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;
358353
collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions);
354+
355+
for (auto *symbol : allSymbols)
356+
if (visitor.isSymbolDefineBy(symbol, eval))
357+
symbolsInNestedRegions.remove(symbol);
358+
359359
// Filter-out symbols that must not be privatized.
360360
bool collectImplicit = flag == semantics::Symbol::Flag::OmpImplicit;
361361
bool collectPreDetermined = flag == semantics::Symbol::Flag::OmpPreDetermined;

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,35 @@ namespace omp {
3232

3333
class DataSharingProcessor {
3434
private:
35+
struct OMPConstructSymbolVisitor {
36+
template <typename T>
37+
bool Pre(const T &) {
38+
return true;
39+
}
40+
template <typename T>
41+
void Post(const T &) {}
42+
43+
bool Pre(const parser::OpenMPConstruct &omp) {
44+
currentConstruct = &omp;
45+
return true;
46+
}
47+
48+
void Post(const parser::OpenMPConstruct &omp) {
49+
currentConstruct = nullptr;
50+
}
51+
52+
void Post(const parser::Name &name) {
53+
symDefMap.try_emplace(name.symbol, currentConstruct);
54+
}
55+
56+
const parser::OpenMPConstruct *currentConstruct = nullptr;
57+
llvm::DenseMap<semantics::Symbol *, const parser::OpenMPConstruct *>
58+
symDefMap;
59+
60+
bool isSymbolDefineBy(const semantics::Symbol *symbol,
61+
lower::pft::Evaluation &eval) const;
62+
};
63+
3564
bool hasLastPrivateOp;
3665
mlir::OpBuilder::InsertPoint lastPrivIP;
3766
mlir::OpBuilder::InsertPoint insPt;
@@ -53,6 +82,7 @@ class DataSharingProcessor {
5382
bool shouldCollectPreDeterminedSymbols;
5483
bool useDelayedPrivatization;
5584
lower::SymMap *symTable;
85+
OMPConstructSymbolVisitor visitor;
5686

5787
bool needBarrier();
5888
void collectSymbols(semantics::Symbol::Flag flag,
@@ -97,11 +127,7 @@ class DataSharingProcessor {
97127
lower::pft::Evaluation &eval,
98128
bool shouldCollectPreDeterminedSymbols,
99129
bool useDelayedPrivatization = false,
100-
lower::SymMap *symTable = nullptr)
101-
: hasLastPrivateOp(false), converter(converter), semaCtx(semaCtx),
102-
firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval),
103-
shouldCollectPreDeterminedSymbols(shouldCollectPreDeterminedSymbols),
104-
useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {}
130+
lower::SymMap *symTable = nullptr);
105131

106132
// Privatisation is split into two steps.
107133
// Step1 performs cloning of all privatisation clauses and copying for

0 commit comments

Comments
 (0)