Skip to content

Commit be7c9e3

Browse files
committed
[flang][OpenMP] Decompose compound constructs, do recursive lowering (#90098)
A compound construct with a list of clauses is broken up into individual leaf/composite constructs. Each such construct has the list of clauses that apply to it based on the OpenMP spec. Each lowering function (i.e. a function that generates MLIR ops) is now responsible for generating its body as described below. Functions that receive AST nodes extract the construct, and the clauses from the node. They then create a work queue consisting of individual constructs, and invoke a common dispatch function to process (lower) the queue. The dispatch function examines the current position in the queue, and invokes the appropriate lowering function. Each lowering function receives the queue as well, and once it needs to generate its body, it either invokes the dispatch function on the rest of the queue (if any), or processes nested evaluations if the work queue is at the end. Re-application of ca1bd59 with fixes for compilation errors.
1 parent 1f6f5bf commit be7c9e3

File tree

16 files changed

+3234
-453
lines changed

16 files changed

+3234
-453
lines changed

flang/lib/Lower/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_flang_library(FortranLower
2727
OpenMP/ClauseProcessor.cpp
2828
OpenMP/Clauses.cpp
2929
OpenMP/DataSharingProcessor.cpp
30+
OpenMP/Decomposer.cpp
3031
OpenMP/OpenMP.cpp
3132
OpenMP/ReductionProcessor.cpp
3233
OpenMP/Utils.cpp

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,4 +1227,27 @@ List<Clause> makeClauses(const parser::OmpClauseList &clauses,
12271227
return makeClause(s, semaCtx);
12281228
});
12291229
}
1230+
1231+
bool transferLocations(const List<Clause> &from, List<Clause> &to) {
1232+
bool allDone = true;
1233+
1234+
for (Clause &clause : to) {
1235+
if (!clause.source.empty())
1236+
continue;
1237+
auto found =
1238+
llvm::find_if(from, [&](const Clause &c) { return c.id == clause.id; });
1239+
// This is not completely accurate, but should be good enough for now.
1240+
// It can be improved in the future if necessary, but in cases of
1241+
// synthesized clauses getting accurate location may be impossible.
1242+
if (found != from.end()) {
1243+
clause.source = found->source;
1244+
} else {
1245+
// Found a clause that won't have "source".
1246+
allDone = false;
1247+
}
1248+
}
1249+
1250+
return allDone;
1251+
}
1252+
12301253
} // namespace Fortran::lower::omp

flang/lib/Lower/OpenMP/Clauses.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@
2323

2424
namespace Fortran::lower::omp {
2525
using namespace Fortran;
26-
using SomeType = evaluate::SomeType;
2726
using SomeExpr = semantics::SomeExpr;
2827
using MaybeExpr = semantics::MaybeExpr;
2928

30-
using TypeTy = SomeType;
29+
// evaluate::SomeType doesn't provide == operation. It's not really used in
30+
// flang's clauses so far, so a trivial implementation is sufficient.
31+
struct TypeTy : public evaluate::SomeType {
32+
bool operator==(const TypeTy &t) const { return true; }
33+
};
34+
3135
using IdTy = semantics::Symbol *;
3236
using ExprTy = SomeExpr;
3337

@@ -222,6 +226,8 @@ using When = tomp::clause::WhenT<TypeTy, IdTy, ExprTy>;
222226
using Write = tomp::clause::WriteT<TypeTy, IdTy, ExprTy>;
223227
} // namespace clause
224228

229+
using tomp::type::operator==;
230+
225231
struct CancellationConstructType {
226232
using EmptyTrait = std::true_type;
227233
};
@@ -244,20 +250,25 @@ using ClauseBase = tomp::ClauseT<TypeTy, IdTy, ExprTy,
244250
MemoryOrder, Threadprivate>;
245251

246252
struct Clause : public ClauseBase {
253+
Clause(ClauseBase &&base, const parser::CharBlock source = {})
254+
: ClauseBase(std::move(base)), source(source) {}
255+
// "source" will be ignored by tomp::type::operator==.
247256
parser::CharBlock source;
248257
};
249258

250259
template <typename Specific>
251260
Clause makeClause(llvm::omp::Clause id, Specific &&specific,
252261
parser::CharBlock source = {}) {
253-
return Clause{{id, specific}, source};
262+
return Clause(typename Clause::BaseT{id, specific}, source);
254263
}
255264

256265
Clause makeClause(const Fortran::parser::OmpClause &cls,
257266
semantics::SemanticsContext &semaCtx);
258267

259268
List<Clause> makeClauses(const parser::OmpClauseList &clauses,
260269
semantics::SemanticsContext &semaCtx);
270+
271+
bool transferLocations(const List<Clause> &from, List<Clause> &to);
261272
} // namespace Fortran::lower::omp
262273

263274
#endif // FORTRAN_LOWER_OPENMP_CLAUSES_H

flang/lib/Lower/OpenMP/Decomposer.cpp

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
//===-- Decomposer.cpp -- Compound directive decomposition ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "Decomposer.h"
14+
15+
#include "Clauses.h"
16+
#include "Utils.h"
17+
#include "flang/Lower/PFTBuilder.h"
18+
#include "flang/Semantics/semantics.h"
19+
#include "flang/Tools/CrossToolHelpers.h"
20+
#include "mlir/IR/BuiltinOps.h"
21+
#include "llvm/ADT/ArrayRef.h"
22+
#include "llvm/ADT/STLExtras.h"
23+
#include "llvm/ADT/SmallVector.h"
24+
#include "llvm/Frontend/OpenMP/ClauseT.h"
25+
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
26+
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
27+
#include "llvm/Frontend/OpenMP/OMP.h"
28+
#include "llvm/Support/raw_ostream.h"
29+
30+
#include <optional>
31+
#include <utility>
32+
#include <variant>
33+
34+
using namespace Fortran;
35+
36+
namespace {
37+
using namespace Fortran::lower::omp;
38+
39+
struct ConstructDecomposition {
40+
ConstructDecomposition(mlir::ModuleOp modOp,
41+
semantics::SemanticsContext &semaCtx,
42+
lower::pft::Evaluation &ev,
43+
llvm::omp::Directive compound,
44+
const List<Clause> &clauses)
45+
: semaCtx(semaCtx), mod(modOp), eval(ev) {
46+
tomp::ConstructDecompositionT decompose(getOpenMPVersionAttribute(modOp),
47+
*this, compound,
48+
llvm::ArrayRef(clauses));
49+
output = std::move(decompose.output);
50+
}
51+
52+
// Given an object, return its base object if one exists.
53+
std::optional<Object> getBaseObject(const Object &object) {
54+
return lower::omp::getBaseObject(object, semaCtx);
55+
}
56+
57+
// Return the iteration variable of the associated loop if any.
58+
std::optional<Object> getLoopIterVar() {
59+
if (semantics::Symbol *symbol = getIterationVariableSymbol(eval))
60+
return Object{symbol, /*designator=*/{}};
61+
return std::nullopt;
62+
}
63+
64+
semantics::SemanticsContext &semaCtx;
65+
mlir::ModuleOp mod;
66+
lower::pft::Evaluation &eval;
67+
List<UnitConstruct> output;
68+
};
69+
} // namespace
70+
71+
static UnitConstruct mergeConstructs(uint32_t version,
72+
llvm::ArrayRef<UnitConstruct> units) {
73+
tomp::ConstructCompositionT compose(version, units);
74+
return compose.merged;
75+
}
76+
77+
namespace Fortran::lower::omp {
78+
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
79+
const UnitConstruct &uc) {
80+
os << llvm::omp::getOpenMPDirectiveName(uc.id);
81+
for (auto [index, clause] : llvm::enumerate(uc.clauses)) {
82+
os << (index == 0 ? '\t' : ' ');
83+
os << llvm::omp::getOpenMPClauseName(clause.id);
84+
}
85+
return os;
86+
}
87+
88+
ConstructQueue buildConstructQueue(
89+
mlir::ModuleOp modOp, Fortran::semantics::SemanticsContext &semaCtx,
90+
Fortran::lower::pft::Evaluation &eval, const parser::CharBlock &source,
91+
llvm::omp::Directive compound, const List<Clause> &clauses) {
92+
93+
List<UnitConstruct> constructs;
94+
95+
ConstructDecomposition decompose(modOp, semaCtx, eval, compound, clauses);
96+
assert(!decompose.output.empty() && "Construct decomposition failed");
97+
98+
llvm::SmallVector<llvm::omp::Directive> loweringUnits;
99+
std::ignore =
100+
llvm::omp::getLeafOrCompositeConstructs(compound, loweringUnits);
101+
uint32_t version = getOpenMPVersionAttribute(modOp);
102+
103+
int leafIndex = 0;
104+
for (llvm::omp::Directive dir_id : loweringUnits) {
105+
llvm::ArrayRef<llvm::omp::Directive> leafsOrSelf =
106+
llvm::omp::getLeafConstructsOrSelf(dir_id);
107+
size_t numLeafs = leafsOrSelf.size();
108+
109+
llvm::ArrayRef<UnitConstruct> toMerge{&decompose.output[leafIndex],
110+
numLeafs};
111+
auto &uc = constructs.emplace_back(mergeConstructs(version, toMerge));
112+
113+
if (!transferLocations(clauses, uc.clauses)) {
114+
// If some clauses are left without source information, use the
115+
// directive's source.
116+
for (auto &clause : uc.clauses) {
117+
if (clause.source.empty())
118+
clause.source = source;
119+
}
120+
}
121+
leafIndex += numLeafs;
122+
}
123+
124+
return constructs;
125+
}
126+
} // namespace Fortran::lower::omp

flang/lib/Lower/OpenMP/Decomposer.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===-- Decomposer.h -- Compound directive decomposition ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef FORTRAN_LOWER_OPENMP_DECOMPOSER_H
9+
#define FORTRAN_LOWER_OPENMP_DECOMPOSER_H
10+
11+
#include "Clauses.h"
12+
#include "mlir/IR/BuiltinOps.h"
13+
#include "llvm/Frontend/OpenMP/ConstructCompositionT.h"
14+
#include "llvm/Frontend/OpenMP/ConstructDecompositionT.h"
15+
#include "llvm/Frontend/OpenMP/OMP.h"
16+
#include "llvm/Support/Compiler.h"
17+
18+
namespace llvm {
19+
class raw_ostream;
20+
}
21+
22+
namespace Fortran {
23+
namespace semantics {
24+
class SemanticsContext;
25+
}
26+
namespace lower::pft {
27+
struct Evaluation;
28+
}
29+
} // namespace Fortran
30+
31+
namespace Fortran::lower::omp {
32+
using UnitConstruct = tomp::DirectiveWithClauses<lower::omp::Clause>;
33+
using ConstructQueue = List<UnitConstruct>;
34+
35+
LLVM_DUMP_METHOD llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
36+
const UnitConstruct &uc);
37+
38+
// Given a potentially compound construct with a list of clauses that
39+
// apply to it, break it up into individual sub-constructs each with
40+
// the subset of applicable clauses (plus implicit clauses, if any).
41+
// From that create a work queue where each work item corresponds to
42+
// the sub-construct with its clauses.
43+
ConstructQueue buildConstructQueue(mlir::ModuleOp modOp,
44+
semantics::SemanticsContext &semaCtx,
45+
lower::pft::Evaluation &eval,
46+
const parser::CharBlock &source,
47+
llvm::omp::Directive compound,
48+
const List<Clause> &clauses);
49+
} // namespace Fortran::lower::omp
50+
51+
#endif // FORTRAN_LOWER_OPENMP_DECOMPOSER_H

0 commit comments

Comments
 (0)