15
15
#include " ClauseProcessor.h"
16
16
#include " Clauses.h"
17
17
#include " DataSharingProcessor.h"
18
+ #include " Decomposer.h"
18
19
#include " DirectivesCommon.h"
19
20
#include " ReductionProcessor.h"
20
21
#include " Utils.h"
36
37
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
37
38
#include " mlir/Transforms/RegionUtils.h"
38
39
#include " llvm/ADT/STLExtras.h"
39
- #include " llvm/Frontend/OpenMP/ConstructDecompositionT.h"
40
40
#include " llvm/Frontend/OpenMP/OMPConstants.h"
41
41
42
42
using namespace Fortran ::lower::omp;
@@ -45,6 +45,13 @@ using namespace Fortran::lower::omp;
45
45
// Code generation helper functions
46
46
// ===----------------------------------------------------------------------===//
47
47
48
+ static void genOMPDispatch (Fortran::lower::AbstractConverter &converter,
49
+ Fortran::lower::SymMap &symTable,
50
+ Fortran::semantics::SemanticsContext &semaCtx,
51
+ Fortran::lower::pft::Evaluation &eval,
52
+ mlir::Location loc, const ConstructQueue &queue,
53
+ ConstructQueue::iterator item);
54
+
48
55
static Fortran::lower::pft::Evaluation *
49
56
getCollapsedLoopEval (Fortran::lower::pft::Evaluation &eval, int collapseValue) {
50
57
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -73,89 +80,6 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
73
80
converter.genEval (e);
74
81
}
75
82
76
- // ===----------------------------------------------------------------------===//
77
- // Directive decomposition
78
- // ===----------------------------------------------------------------------===//
79
-
80
- namespace {
81
- using DirectiveWithClauses = tomp::DirectiveWithClauses<lower::omp::Clause>;
82
- using ConstructQueue = List<DirectiveWithClauses>;
83
- } // namespace
84
-
85
- static void genOMPDispatch (Fortran::lower::AbstractConverter &converter,
86
- Fortran::lower::SymMap &symTable,
87
- Fortran::semantics::SemanticsContext &semaCtx,
88
- Fortran::lower::pft::Evaluation &eval,
89
- mlir::Location loc, const ConstructQueue &queue,
90
- ConstructQueue::iterator item);
91
-
92
- namespace {
93
- struct ConstructDecomposition {
94
- ConstructDecomposition (mlir::ModuleOp modOp,
95
- semantics::SemanticsContext &semaCtx,
96
- lower::pft::Evaluation &ev,
97
- llvm::omp::Directive construct,
98
- const List<Clause> &clauses)
99
- : semaCtx(semaCtx), mod(modOp), eval(ev) {
100
- tomp::ConstructDecompositionT decompose (getOpenMPVersion (modOp), *this ,
101
- construct, llvm::ArrayRef (clauses));
102
- output = std::move (decompose.output );
103
- }
104
-
105
- // Given an object, return its base object if one exists.
106
- std::optional<Object> getBaseObject (const Object &object) {
107
- return lower::omp::getBaseObject (object, semaCtx);
108
- }
109
-
110
- // Return the iteration variable of the associated loop if any.
111
- std::optional<Object> getLoopIterVar () {
112
- if (semantics::Symbol *symbol = getIterationVariableSymbol (eval))
113
- return Object{symbol, /* designator=*/ {}};
114
- return std::nullopt;
115
- }
116
-
117
- semantics::SemanticsContext &semaCtx;
118
- mlir::ModuleOp mod;
119
- lower::pft::Evaluation &eval;
120
- List<DirectiveWithClauses> output;
121
- };
122
- } // namespace
123
-
124
- LLVM_DUMP_METHOD static llvm::raw_ostream &
125
- operator <<(llvm::raw_ostream &os, const DirectiveWithClauses &dwc) {
126
- os << llvm::omp::getOpenMPDirectiveName (dwc.id );
127
- for (auto [index, clause] : llvm::enumerate (dwc.clauses )) {
128
- os << (index == 0 ? ' \t ' : ' ' );
129
- os << llvm::omp::getOpenMPClauseName (clause.id );
130
- }
131
- return os;
132
- }
133
-
134
- static void splitCompoundConstruct (
135
- mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
136
- Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive construct,
137
- const List<Clause> &clauses, List<DirectiveWithClauses> &directives) {
138
-
139
- ConstructDecomposition decompose (modOp, semaCtx, eval, construct, clauses);
140
- assert (!decompose.output .empty ());
141
-
142
- llvm::SmallVector<llvm::omp::Directive> loweringUnits;
143
- std::ignore =
144
- llvm::omp::getLeafOrCompositeConstructs (construct, loweringUnits);
145
-
146
- int leafIndex = 0 ;
147
- for (llvm::omp::Directive dir_id : loweringUnits) {
148
- directives.push_back (DirectiveWithClauses{dir_id});
149
- DirectiveWithClauses &dwc = directives.back ();
150
- llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
151
- llvm::omp::getLeafConstructsOrSelf (dir_id);
152
- for (int i = 0 , e = leafsOrSelf.size (); i != e; ++i) {
153
- dwc.clauses .append (decompose.output [leafIndex].clauses );
154
- ++leafIndex;
155
- }
156
- }
157
- }
158
-
159
83
static fir::GlobalOp globalInitialization (
160
84
Fortran::lower::AbstractConverter &converter,
161
85
fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym,
@@ -2170,7 +2094,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
2170
2094
semaCtx);
2171
2095
mlir::Location currentLocation = converter.genLocation (directive.source );
2172
2096
2173
- ConstructQueue queue{{DirectiveWithClauses{directive.v , clauses}}};
2097
+ ConstructQueue queue{
2098
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2099
+ eval, directive.v , clauses)};
2174
2100
2175
2101
switch (directive.v ) {
2176
2102
default :
@@ -2234,7 +2160,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2234
2160
mlir::Location currentLocation = converter.genLocation (verbatim.source );
2235
2161
2236
2162
ConstructQueue queue{
2237
- DirectiveWithClauses{llvm::omp::Directive::OMPD_flush, clauses}};
2163
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2164
+ eval, llvm::omp::Directive::OMPD_flush, clauses)};
2238
2165
genFlushOp (converter, symTable, semaCtx, eval, currentLocation, objects,
2239
2166
clauses, queue, queue.begin ());
2240
2167
}
@@ -2381,9 +2308,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2381
2308
2382
2309
llvm::omp::Directive directive =
2383
2310
std::get<parser::OmpBlockDirective>(beginBlockDirective.t ).v ;
2384
- ConstructQueue queue;
2385
- splitCompoundConstruct (converter.getFirOpBuilder ().getModule (), semaCtx, eval ,
2386
- directive, clauses, queue) ;
2311
+ ConstructQueue queue{
2312
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2313
+ eval, directive, clauses)} ;
2387
2314
genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
2388
2315
queue.begin ());
2389
2316
}
@@ -2399,9 +2326,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2399
2326
List<Clause> clauses =
2400
2327
makeClauses (std::get<Fortran::parser::OmpClauseList>(cd.t ), semaCtx);
2401
2328
2402
- ConstructQueue queue;
2403
- splitCompoundConstruct (converter.getFirOpBuilder ().getModule (), semaCtx, eval ,
2404
- llvm::omp::Directive::OMPD_critical, clauses, queue) ;
2329
+ ConstructQueue queue{
2330
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2331
+ eval, llvm::omp::Directive::OMPD_critical, clauses)} ;
2405
2332
2406
2333
const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t );
2407
2334
mlir::Location currentLocation = converter.getCurrentLocation ();
@@ -2440,9 +2367,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
2440
2367
2441
2368
llvm::omp::Directive directive =
2442
2369
std::get<parser::OmpLoopDirective>(beginLoopDirective.t ).v ;
2443
- ConstructQueue queue;
2444
- splitCompoundConstruct (converter.getFirOpBuilder ().getModule (), semaCtx, eval ,
2445
- directive, clauses, queue) ;
2370
+ ConstructQueue queue{
2371
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2372
+ eval, directive, clauses)} ;
2446
2373
genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
2447
2374
queue.begin ());
2448
2375
}
@@ -2455,7 +2382,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2455
2382
const Fortran::parser::OpenMPSectionConstruct §ionConstruct) {
2456
2383
mlir::Location loc = converter.getCurrentLocation ();
2457
2384
ConstructQueue queue{
2458
- DirectiveWithClauses{llvm::omp::Directive::OMPD_section}};
2385
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2386
+ eval, llvm::omp::Directive::OMPD_section, {})};
2459
2387
genSectionOp (converter, symTable, semaCtx, eval, loc,
2460
2388
/* clauses=*/ {}, queue, queue.begin ());
2461
2389
}
@@ -2480,9 +2408,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
2480
2408
2481
2409
llvm::omp::Directive directive =
2482
2410
std::get<parser::OmpSectionsDirective>(beginSectionsDirective.t ).v ;
2483
- ConstructQueue queue;
2484
- splitCompoundConstruct (converter.getFirOpBuilder ().getModule (), semaCtx, eval ,
2485
- directive, clauses, queue) ;
2411
+ ConstructQueue queue{
2412
+ buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
2413
+ eval, directive, clauses)} ;
2486
2414
genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
2487
2415
queue.begin ());
2488
2416
}
0 commit comments