Skip to content

Commit 65cb0ea

Browse files
authored
[Flang][OpenMP] Add Semantics support for Nested OpenMPLoopConstructs (#145917)
In OpenMP Version 5.1, the tile and unroll directives were added. When using these directives, it is possible to nest them within other OpenMP Loop Constructs. This patch enables the semantics to allow for this behaviour on these specific directives. Any nested loops will be stored within the initial Loop Construct until reaching the DoConstruct itself. Relevant tests have been added, and previous behaviour has been retained with no changes. See also, #110008
1 parent 372e332 commit 65cb0ea

File tree

11 files changed

+520
-96
lines changed

11 files changed

+520
-96
lines changed

flang/include/flang/Parser/parse-tree.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ struct AccEndCombinedDirective;
267267
struct OpenACCDeclarativeConstruct;
268268
struct OpenACCRoutineConstruct;
269269
struct OpenMPConstruct;
270+
struct OpenMPLoopConstruct;
270271
struct OpenMPDeclarativeConstruct;
271272
struct OmpEndLoopDirective;
272273
struct OmpMemoryOrderClause;
@@ -5021,11 +5022,13 @@ struct OpenMPBlockConstruct {
50215022
};
50225023

50235024
// OpenMP directives enclosing do loop
5025+
using NestedConstruct =
5026+
std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
50245027
struct OpenMPLoopConstruct {
50255028
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
50265029
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
50275030
: t({std::move(a), std::nullopt, std::nullopt}) {}
5028-
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
5031+
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
50295032
std::optional<OmpEndLoopDirective>>
50305033
t;
50315034
};

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3769,6 +3769,16 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
37693769
mlir::Location currentLocation =
37703770
converter.genLocation(beginLoopDirective.source);
37713771

3772+
auto &optLoopCons =
3773+
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
3774+
if (optLoopCons.has_value()) {
3775+
if (auto *ompNestedLoopCons{
3776+
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
3777+
&*optLoopCons)}) {
3778+
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
3779+
}
3780+
}
3781+
37723782
llvm::omp::Directive directive =
37733783
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
37743784
const parser::CharBlock &source =

flang/lib/Parser/unparse.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2926,7 +2926,8 @@ class UnparseVisitor {
29262926
Walk(std::get<OmpBeginLoopDirective>(x.t));
29272927
Put("\n");
29282928
EndOpenMP();
2929-
Walk(std::get<std::optional<DoConstruct>>(x.t));
2929+
Walk(std::get<std::optional<std::variant<DoConstruct,
2930+
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
29302931
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
29312932
}
29322933
void Unparse(const BasedPointer &x) {

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "canonicalize-omp.h"
1010
#include "flang/Parser/parse-tree-visitor.h"
11+
#include "flang/Parser/parse-tree.h"
1112

1213
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1314
// Constructs more structured which provide explicit scopes for later
@@ -125,6 +126,16 @@ class CanonicalizationOfOmp {
125126
parser::Block::iterator nextIt;
126127
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
127128
auto &dir{std::get<parser::OmpLoopDirective>(beginDir.t)};
129+
auto missingDoConstruct = [](auto &dir, auto &messages) {
130+
messages.Say(dir.source,
131+
"A DO loop must follow the %s directive"_err_en_US,
132+
parser::ToUpperCaseLetters(dir.source.ToString()));
133+
};
134+
auto tileUnrollError = [](auto &dir, auto &messages) {
135+
messages.Say(dir.source,
136+
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
137+
parser::ToUpperCaseLetters(dir.source.ToString()));
138+
};
128139

129140
nextIt = it;
130141
while (++nextIt != block.end()) {
@@ -135,31 +146,95 @@ class CanonicalizationOfOmp {
135146
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
136147
if (doCons->GetLoopControl()) {
137148
// move DoConstruct
138-
std::get<std::optional<parser::DoConstruct>>(x.t) =
149+
std::get<std::optional<std::variant<parser::DoConstruct,
150+
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
139151
std::move(*doCons);
140152
nextIt = block.erase(nextIt);
141153
// try to match OmpEndLoopDirective
142-
if (nextIt != block.end()) {
143-
if (auto *endDir{
144-
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
145-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
146-
std::move(*endDir);
147-
block.erase(nextIt);
148-
}
154+
if (auto *endDir{
155+
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
156+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
157+
std::move(*endDir);
158+
nextIt = block.erase(nextIt);
149159
}
150160
} else {
151161
messages_.Say(dir.source,
152162
"DO loop after the %s directive must have loop control"_err_en_US,
153163
parser::ToUpperCaseLetters(dir.source.ToString()));
154164
}
165+
} else if (auto *ompLoopCons{
166+
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
167+
// We should allow UNROLL and TILE constructs to be inserted between an
168+
// OpenMP Loop Construct and the DO loop itself
169+
auto &nestedBeginDirective =
170+
std::get<parser::OmpBeginLoopDirective>(ompLoopCons->t);
171+
auto &nestedBeginLoopDirective =
172+
std::get<parser::OmpLoopDirective>(nestedBeginDirective.t);
173+
if ((nestedBeginLoopDirective.v == llvm::omp::Directive::OMPD_unroll ||
174+
nestedBeginLoopDirective.v ==
175+
llvm::omp::Directive::OMPD_tile) &&
176+
!(nestedBeginLoopDirective.v == llvm::omp::Directive::OMPD_unroll &&
177+
dir.v == llvm::omp::Directive::OMPD_tile)) {
178+
// iterate through the remaining block items to find the end directive
179+
// for the unroll/tile directive.
180+
parser::Block::iterator endIt;
181+
endIt = nextIt;
182+
while (endIt != block.end()) {
183+
if (auto *endDir{
184+
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
185+
auto &endLoopDirective =
186+
std::get<parser::OmpLoopDirective>(endDir->t);
187+
if (endLoopDirective.v == dir.v) {
188+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
189+
std::move(*endDir);
190+
endIt = block.erase(endIt);
191+
continue;
192+
}
193+
}
194+
++endIt;
195+
}
196+
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
197+
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
198+
ompLoop =
199+
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
200+
common::Indirection{std::move(*ompLoopCons)}}};
201+
nextIt = block.erase(nextIt);
202+
} else if (nestedBeginLoopDirective.v ==
203+
llvm::omp::Directive::OMPD_unroll &&
204+
dir.v == llvm::omp::Directive::OMPD_tile) {
205+
// if a loop has been unrolled, the user can not then tile that loop
206+
// as it has been unrolled
207+
parser::OmpClauseList &unrollClauseList{
208+
std::get<parser::OmpClauseList>(nestedBeginDirective.t)};
209+
if (unrollClauseList.v.empty()) {
210+
// if the clause list is empty for an unroll construct, we assume
211+
// the loop is being fully unrolled
212+
tileUnrollError(dir, messages_);
213+
} else {
214+
// parse the clauses for the unroll directive to find the full
215+
// clause
216+
for (auto clause{unrollClauseList.v.begin()};
217+
clause != unrollClauseList.v.end(); ++clause) {
218+
if (clause->Id() == llvm::omp::OMPC_full) {
219+
tileUnrollError(dir, messages_);
220+
}
221+
}
222+
}
223+
} else {
224+
messages_.Say(nestedBeginLoopDirective.source,
225+
"Only Loop Transformation Constructs or Loop Nests can be nested within Loop Constructs"_err_en_US,
226+
parser::ToUpperCaseLetters(
227+
nestedBeginLoopDirective.source.ToString()));
228+
}
155229
} else {
156-
messages_.Say(dir.source,
157-
"A DO loop must follow the %s directive"_err_en_US,
158-
parser::ToUpperCaseLetters(dir.source.ToString()));
230+
missingDoConstruct(dir, messages_);
159231
}
160232
// If we get here, we either found a loop, or issued an error message.
161233
return;
162234
}
235+
if (nextIt == block.end()) {
236+
missingDoConstruct(dir, messages_);
237+
}
163238
}
164239

165240
void RewriteOmpAllocations(parser::ExecutionPart &body) {

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -762,10 +762,13 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
762762
}
763763
SetLoopInfo(x);
764764

765-
if (const auto &doConstruct{
766-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
767-
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
768-
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
765+
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
766+
if (optLoopCons.has_value()) {
767+
if (const auto &doConstruct{
768+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
769+
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
770+
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
771+
}
769772
}
770773
CheckLoopItrVariableIsInt(x);
771774
CheckAssociatedLoopConstraints(x);
@@ -786,12 +789,15 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
786789
return std::get<Bounds>(x->GetLoopControl()->u).name.thing;
787790
}
788791
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
789-
if (const auto &loopConstruct{
790-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
791-
const parser::DoConstruct *loop{&*loopConstruct};
792-
if (loop && loop->IsDoNormal()) {
793-
const parser::Name &itrVal{GetLoopIndex(loop)};
794-
SetLoopIv(itrVal.symbol);
792+
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
793+
if (optLoopCons.has_value()) {
794+
if (const auto &loopConstruct{
795+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
796+
const parser::DoConstruct *loop{&*loopConstruct};
797+
if (loop && loop->IsDoNormal()) {
798+
const parser::Name &itrVal{GetLoopIndex(loop)};
799+
SetLoopIv(itrVal.symbol);
800+
}
795801
}
796802
}
797803
}
@@ -857,27 +863,30 @@ void OmpStructureChecker::CheckIteratorModifier(const parser::OmpIterator &x) {
857863

858864
void OmpStructureChecker::CheckLoopItrVariableIsInt(
859865
const parser::OpenMPLoopConstruct &x) {
860-
if (const auto &loopConstruct{
861-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
866+
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
867+
if (optLoopCons.has_value()) {
868+
if (const auto &loopConstruct{
869+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
862870

863-
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
864-
if (loop->IsDoNormal()) {
865-
const parser::Name &itrVal{GetLoopIndex(loop)};
866-
if (itrVal.symbol) {
867-
const auto *type{itrVal.symbol->GetType()};
868-
if (!type->IsNumeric(TypeCategory::Integer)) {
869-
context_.Say(itrVal.source,
870-
"The DO loop iteration"
871-
" variable must be of the type integer."_err_en_US,
872-
itrVal.ToString());
871+
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
872+
if (loop->IsDoNormal()) {
873+
const parser::Name &itrVal{GetLoopIndex(loop)};
874+
if (itrVal.symbol) {
875+
const auto *type{itrVal.symbol->GetType()};
876+
if (!type->IsNumeric(TypeCategory::Integer)) {
877+
context_.Say(itrVal.source,
878+
"The DO loop iteration"
879+
" variable must be of the type integer."_err_en_US,
880+
itrVal.ToString());
881+
}
873882
}
874883
}
884+
// Get the next DoConstruct if block is not empty.
885+
const auto &block{std::get<parser::Block>(loop->t)};
886+
const auto it{block.begin()};
887+
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
888+
: nullptr;
875889
}
876-
// Get the next DoConstruct if block is not empty.
877-
const auto &block{std::get<parser::Block>(loop->t)};
878-
const auto it{block.begin()};
879-
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
880-
: nullptr;
881890
}
882891
}
883892
}
@@ -1077,25 +1086,28 @@ void OmpStructureChecker::CheckDistLinear(
10771086

10781087
// Match the loop index variables with the collected symbols from linear
10791088
// clauses.
1080-
if (const auto &loopConstruct{
1081-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
1082-
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
1083-
if (loop->IsDoNormal()) {
1084-
const parser::Name &itrVal{GetLoopIndex(loop)};
1085-
if (itrVal.symbol) {
1086-
// Remove the symbol from the collected set
1087-
indexVars.erase(&itrVal.symbol->GetUltimate());
1088-
}
1089-
collapseVal--;
1090-
if (collapseVal == 0) {
1091-
break;
1089+
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
1090+
if (optLoopCons.has_value()) {
1091+
if (const auto &loopConstruct{
1092+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
1093+
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
1094+
if (loop->IsDoNormal()) {
1095+
const parser::Name &itrVal{GetLoopIndex(loop)};
1096+
if (itrVal.symbol) {
1097+
// Remove the symbol from the collected set
1098+
indexVars.erase(&itrVal.symbol->GetUltimate());
1099+
}
1100+
collapseVal--;
1101+
if (collapseVal == 0) {
1102+
break;
1103+
}
10921104
}
1105+
// Get the next DoConstruct if block is not empty.
1106+
const auto &block{std::get<parser::Block>(loop->t)};
1107+
const auto it{block.begin()};
1108+
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
1109+
: nullptr;
10931110
}
1094-
// Get the next DoConstruct if block is not empty.
1095-
const auto &block{std::get<parser::Block>(loop->t)};
1096-
const auto it{block.begin()};
1097-
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
1098-
: nullptr;
10991111
}
11001112
}
11011113

0 commit comments

Comments
 (0)