Skip to content

Commit d5635d6

Browse files
committed
Split the directive decomposer into its own file
1 parent 60201bb commit d5635d6

File tree

4 files changed

+179
-99
lines changed

4 files changed

+179
-99
lines changed

flang/lib/Lower/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_flang_library(FortranLower
2727
OpenMP/ClauseProcessor.cpp
2828
OpenMP/Clauses.cpp
2929
OpenMP/DataSharingProcessor.cpp
30+
OpenMP/Decomposer.cpp
3031
OpenMP/OpenMP.cpp
3132
OpenMP/ReductionProcessor.cpp
3233
OpenMP/Utils.cpp

flang/lib/Lower/OpenMP/Decomposer.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//===-- Decomposer.cpp -- Compound directive decomposition ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "Decomposer.h"
14+
15+
#include "Clauses.h"
16+
#include "Utils.h"
17+
#include "flang/Lower/PFTBuilder.h"
18+
#include "flang/Semantics/semantics.h"
19+
#include "mlir/IR/BuiltinOps.h"
20+
#include "llvm/ADT/ArrayRef.h"
21+
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/Frontend/OpenMP/ClauseT.h"
24+
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
25+
#include "llvm/Frontend/OpenMP/OMP.h"
26+
#include "llvm/Support/raw_ostream.h"
27+
28+
#include <optional>
29+
#include <utility>
30+
#include <variant>
31+
32+
using namespace Fortran;
33+
34+
namespace {
35+
using namespace Fortran::lower::omp;
36+
37+
struct ConstructDecomposition {
38+
ConstructDecomposition(mlir::ModuleOp modOp,
39+
semantics::SemanticsContext &semaCtx,
40+
lower::pft::Evaluation &ev,
41+
llvm::omp::Directive compound,
42+
const List<Clause> &clauses)
43+
: semaCtx(semaCtx), mod(modOp), eval(ev) {
44+
tomp::ConstructDecompositionT decompose(getOpenMPVersion(modOp), *this,
45+
compound, llvm::ArrayRef(clauses));
46+
output = std::move(decompose.output);
47+
}
48+
49+
// Given an object, return its base object if one exists.
50+
std::optional<Object> getBaseObject(const Object &object) {
51+
return lower::omp::getBaseObject(object, semaCtx);
52+
}
53+
54+
// Return the iteration variable of the associated loop if any.
55+
std::optional<Object> getLoopIterVar() {
56+
if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
57+
return Object{symbol, /*designator=*/{}};
58+
return std::nullopt;
59+
}
60+
61+
semantics::SemanticsContext &semaCtx;
62+
mlir::ModuleOp mod;
63+
lower::pft::Evaluation &eval;
64+
List<UnitConstruct> output;
65+
};
66+
} // namespace
67+
68+
namespace Fortran::lower::omp {
69+
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
70+
const UnitConstruct &uc) {
71+
os << llvm::omp::getOpenMPDirectiveName(uc.id);
72+
for (auto [index, clause] : llvm::enumerate(uc.clauses)) {
73+
os << (index == 0 ? '\t' : ' ');
74+
os << llvm::omp::getOpenMPClauseName(clause.id);
75+
}
76+
return os;
77+
}
78+
79+
ConstructQueue buildConstructQueue(
80+
mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
81+
Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive compound,
82+
const List<Clause> &clauses) {
83+
84+
List<UnitConstruct> constructs;
85+
86+
ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
87+
assert(!decompose.output.empty());
88+
89+
llvm::SmallVector<llvm::omp::Directive> loweringUnits;
90+
std::ignore =
91+
llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
92+
93+
int leafIndex = 0;
94+
for (llvm::omp::Directive dir_id : loweringUnits) {
95+
constructs.push_back(UnitConstruct{dir_id});
96+
UnitConstruct &uc = constructs.back();
97+
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
98+
llvm::omp::getLeafConstructsOrSelf(dir_id);
99+
for (int i = 0, e = leafsOrSelf.size(); i != e; ++i) {
100+
uc.clauses.append(decompose.output[leafIndex].clauses);
101+
++leafIndex;
102+
}
103+
}
104+
105+
return constructs;
106+
}
107+
} // namespace Fortran::lower::omp

flang/lib/Lower/OpenMP/Decomposer.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//===-- Decomposer.h -- Compound directive decomposition ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef FORTRAN_LOWER_OPENMP_DECOMPOSER_H
9+
#define FORTRAN_LOWER_OPENMP_DECOMPOSER_H
10+
11+
#include "Clauses.h"
12+
#include "mlir/IR/BuiltinOps.h"
13+
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
14+
#include "llvm/Frontend/OpenMP/OMP.h"
15+
#include "llvm/Support/Compiler.h"
16+
17+
namespace llvm {
18+
class raw_ostream;
19+
}
20+
21+
namespace Fortran {
22+
namespace semantics {
23+
class SemanticsContext;
24+
}
25+
namespace lower::pft {
26+
struct Evaluation;
27+
}
28+
} // namespace Fortran
29+
30+
namespace Fortran::lower::omp {
31+
using UnitConstruct = tomp::DirectiveWithClauses<lower::omp::Clause>;
32+
using ConstructQueue = List<UnitConstruct>;
33+
34+
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
35+
const UnitConstruct &uc);
36+
37+
ConstructQueue
38+
buildConstructQueue(mlir::ModuleOp modOp,
39+
Fortran::semantics::SemanticsContext &semaCtx,
40+
Fortran::lower::pft::Evaluation &eval,
41+
llvm::omp::Directive compound, const List<Clause> &clauses);
42+
} // namespace Fortran::lower::omp
43+
44+
#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 27 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "ClauseProcessor.h"
1616
#include "Clauses.h"
1717
#include "DataSharingProcessor.h"
18+
#include "Decomposer.h"
1819
#include "DirectivesCommon.h"
1920
#include "ReductionProcessor.h"
2021
#include "Utils.h"
@@ -36,7 +37,6 @@
3637
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3738
#include "mlir/Transforms/RegionUtils.h"
3839
#include "llvm/ADT/STLExtras.h"
39-
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
4040
#include "llvm/Frontend/OpenMP/OMPConstants.h"
4141

4242
using namespace Fortran::lower::omp;
@@ -45,6 +45,13 @@ using namespace Fortran::lower::omp;
4545
// Code generation helper functions
4646
//===----------------------------------------------------------------------===//
4747

48+
static void genOMPDispatch(Fortran::lower::AbstractConverter &converter,
49+
Fortran::lower::SymMap &symTable,
50+
Fortran::semantics::SemanticsContext &semaCtx,
51+
Fortran::lower::pft::Evaluation &eval,
52+
mlir::Location loc, const ConstructQueue &queue,
53+
ConstructQueue::iterator item);
54+
4855
static Fortran::lower::pft::Evaluation *
4956
getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
5057
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -73,89 +80,6 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
7380
converter.genEval(e);
7481
}
7582

76-
//===----------------------------------------------------------------------===//
77-
// Directive decomposition
78-
//===----------------------------------------------------------------------===//
79-
80-
namespace {
81-
using DirectiveWithClauses = tomp::DirectiveWithClauses<lower::omp::Clause>;
82-
using ConstructQueue = List<DirectiveWithClauses>;
83-
} // namespace
84-
85-
static void genOMPDispatch(Fortran::lower::AbstractConverter &converter,
86-
Fortran::lower::SymMap &symTable,
87-
Fortran::semantics::SemanticsContext &semaCtx,
88-
Fortran::lower::pft::Evaluation &eval,
89-
mlir::Location loc, const ConstructQueue &queue,
90-
ConstructQueue::iterator item);
91-
92-
namespace {
93-
struct ConstructDecomposition {
94-
ConstructDecomposition(mlir::ModuleOp modOp,
95-
semantics::SemanticsContext &semaCtx,
96-
lower::pft::Evaluation &ev,
97-
llvm::omp::Directive construct,
98-
const List<Clause> &clauses)
99-
: semaCtx(semaCtx), mod(modOp), eval(ev) {
100-
tomp::ConstructDecompositionT decompose(getOpenMPVersion(modOp), *this,
101-
construct, llvm::ArrayRef(clauses));
102-
output = std::move(decompose.output);
103-
}
104-
105-
// Given an object, return its base object if one exists.
106-
std::optional<Object> getBaseObject(const Object &object) {
107-
return lower::omp::getBaseObject(object, semaCtx);
108-
}
109-
110-
// Return the iteration variable of the associated loop if any.
111-
std::optional<Object> getLoopIterVar() {
112-
if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
113-
return Object{symbol, /*designator=*/{}};
114-
return std::nullopt;
115-
}
116-
117-
semantics::SemanticsContext &semaCtx;
118-
mlir::ModuleOp mod;
119-
lower::pft::Evaluation &eval;
120-
List<DirectiveWithClauses> output;
121-
};
122-
} // namespace
123-
124-
LLVM_DUMP_METHOD static llvm::raw_ostream &
125-
operator<<(llvm::raw_ostream &os, const DirectiveWithClauses &dwc) {
126-
os << llvm::omp::getOpenMPDirectiveName(dwc.id);
127-
for (auto [index, clause] : llvm::enumerate(dwc.clauses)) {
128-
os << (index == 0 ? '\t' : ' ');
129-
os << llvm::omp::getOpenMPClauseName(clause.id);
130-
}
131-
return os;
132-
}
133-
134-
static void splitCompoundConstruct(
135-
mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
136-
Fortran::lower::pft::Evaluation &eval, llvm::omp::Directive construct,
137-
const List<Clause> &clauses, List<DirectiveWithClauses> &directives) {
138-
139-
ConstructDecomposition decompose(modOp, semaCtx, eval, construct, clauses);
140-
assert(!decompose.output.empty());
141-
142-
llvm::SmallVector<llvm::omp::Directive> loweringUnits;
143-
std::ignore =
144-
llvm::omp::getLeafOrCompositeConstructs(construct, loweringUnits);
145-
146-
int leafIndex = 0;
147-
for (llvm::omp::Directive dir_id : loweringUnits) {
148-
directives.push_back(DirectiveWithClauses{dir_id});
149-
DirectiveWithClauses &dwc = directives.back();
150-
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
151-
llvm::omp::getLeafConstructsOrSelf(dir_id);
152-
for (int i = 0, e = leafsOrSelf.size(); i != e; ++i) {
153-
dwc.clauses.append(decompose.output[leafIndex].clauses);
154-
++leafIndex;
155-
}
156-
}
157-
}
158-
15983
static fir::GlobalOp globalInitialization(
16084
Fortran::lower::AbstractConverter &converter,
16185
fir::FirOpBuilder &firOpBuilder, const Fortran::semantics::Symbol &sym,
@@ -2170,7 +2094,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
21702094
semaCtx);
21712095
mlir::Location currentLocation = converter.genLocation(directive.source);
21722096

2173-
ConstructQueue queue{{DirectiveWithClauses{directive.v, clauses}}};
2097+
ConstructQueue queue{
2098+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2099+
eval, directive.v, clauses)};
21742100

21752101
switch (directive.v) {
21762102
default:
@@ -2234,7 +2160,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
22342160
mlir::Location currentLocation = converter.genLocation(verbatim.source);
22352161

22362162
ConstructQueue queue{
2237-
DirectiveWithClauses{llvm::omp::Directive::OMPD_flush, clauses}};
2163+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2164+
eval, llvm::omp::Directive::OMPD_flush, clauses)};
22382165
genFlushOp(converter, symTable, semaCtx, eval, currentLocation, objects,
22392166
clauses, queue, queue.begin());
22402167
}
@@ -2381,9 +2308,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
23812308

23822309
llvm::omp::Directive directive =
23832310
std::get<parser::OmpBlockDirective>(beginBlockDirective.t).v;
2384-
ConstructQueue queue;
2385-
splitCompoundConstruct(converter.getFirOpBuilder().getModule(), semaCtx, eval,
2386-
directive, clauses, queue);
2311+
ConstructQueue queue{
2312+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2313+
eval, directive, clauses)};
23872314
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
23882315
queue.begin());
23892316
}
@@ -2399,9 +2326,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
23992326
List<Clause> clauses =
24002327
makeClauses(std::get<Fortran::parser::OmpClauseList>(cd.t), semaCtx);
24012328

2402-
ConstructQueue queue;
2403-
splitCompoundConstruct(converter.getFirOpBuilder().getModule(), semaCtx, eval,
2404-
llvm::omp::Directive::OMPD_critical, clauses, queue);
2329+
ConstructQueue queue{
2330+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2331+
eval, llvm::omp::Directive::OMPD_critical, clauses)};
24052332

24062333
const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t);
24072334
mlir::Location currentLocation = converter.getCurrentLocation();
@@ -2440,9 +2367,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
24402367

24412368
llvm::omp::Directive directive =
24422369
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
2443-
ConstructQueue queue;
2444-
splitCompoundConstruct(converter.getFirOpBuilder().getModule(), semaCtx, eval,
2445-
directive, clauses, queue);
2370+
ConstructQueue queue{
2371+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2372+
eval, directive, clauses)};
24462373
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
24472374
queue.begin());
24482375
}
@@ -2455,7 +2382,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
24552382
const Fortran::parser::OpenMPSectionConstruct &sectionConstruct) {
24562383
mlir::Location loc = converter.getCurrentLocation();
24572384
ConstructQueue queue{
2458-
DirectiveWithClauses{llvm::omp::Directive::OMPD_section}};
2385+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2386+
eval, llvm::omp::Directive::OMPD_section, {})};
24592387
genSectionOp(converter, symTable, semaCtx, eval, loc,
24602388
/*clauses=*/{}, queue, queue.begin());
24612389
}
@@ -2480,9 +2408,9 @@ genOMP(Fortran::lower::AbstractConverter &converter,
24802408

24812409
llvm::omp::Directive directive =
24822410
std::get<parser::OmpSectionsDirective>(beginSectionsDirective.t).v;
2483-
ConstructQueue queue;
2484-
splitCompoundConstruct(converter.getFirOpBuilder().getModule(), semaCtx, eval,
2485-
directive, clauses, queue);
2411+
ConstructQueue queue{
2412+
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
2413+
eval, directive, clauses)};
24862414
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
24872415
queue.begin());
24882416
}

0 commit comments

Comments
 (0)