Skip to content

Commit db2b677

Browse files
committed
[Flang][OpenMP] Add Semantics support for Nested OpenMPLoopConstructs
In OpenMP Version 5.1, the tile and unroll directives were added. When using these directives, it is possible to nest them within other OpenMP Loop Constructs. This patch enables the semantics to allow for this behaviour on these specific directives. Any nested loops will be stored within the initial Loop Construct until reaching the DoConstruct itself. Relevant tests have been added, and previous behaviour has been retained with no changes. See also, #110008
1 parent 1c35fe4 commit db2b677

File tree

11 files changed

+373
-65
lines changed

11 files changed

+373
-65
lines changed

flang/include/flang/Parser/parse-tree.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5025,7 +5025,7 @@ struct OpenMPLoopConstruct {
50255025
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
50265026
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
50275027
: t({std::move(a), std::nullopt, std::nullopt}) {}
5028-
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
5028+
std::tuple<OmpBeginLoopDirective, std::optional<std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>>,
50295029
std::optional<OmpEndLoopDirective>>
50305030
t;
50315031
};

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4107,6 +4107,12 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
41074107
mlir::Location currentLocation =
41084108
converter.genLocation(beginLoopDirective.source);
41094109

4110+
auto &optLoopCons = std::get<1>(loopConstruct.t);
4111+
if(optLoopCons.has_value())
4112+
if(auto *ompNestedLoopCons{std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(&*optLoopCons)}) {
4113+
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
4114+
}
4115+
41104116
llvm::omp::Directive directive =
41114117
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
41124118
const parser::CharBlock &source =

flang/lib/Parser/unparse.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2926,7 +2926,7 @@ class UnparseVisitor {
29262926
Walk(std::get<OmpBeginLoopDirective>(x.t));
29272927
Put("\n");
29282928
EndOpenMP();
2929-
Walk(std::get<std::optional<DoConstruct>>(x.t));
2929+
Walk(std::get<std::optional<std::variant<DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
29302930
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
29312931
}
29322932
void Unparse(const BasedPointer &x) {

flang/lib/Semantics/canonicalize-omp.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "canonicalize-omp.h"
1010
#include "flang/Parser/parse-tree-visitor.h"
11+
#include "flang/Parser/parse-tree.h"
1112

1213
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
1314
// Constructs more structured which provide explicit scopes for later
@@ -106,6 +107,12 @@ class CanonicalizationOfOmp {
106107
return nullptr;
107108
}
108109

110+
void missingDoConstruct(parser::OmpLoopDirective &dir) {
111+
messages_.Say(dir.source,
112+
"A DO loop must follow the %s directive"_err_en_US,
113+
parser::ToUpperCaseLetters(dir.source.ToString()));
114+
}
115+
109116
void RewriteOpenMPLoopConstruct(parser::OpenMPLoopConstruct &x,
110117
parser::Block &block, parser::Block::iterator it) {
111118
// Check the sequence of DoConstruct and OmpEndLoopDirective
@@ -135,31 +142,62 @@ class CanonicalizationOfOmp {
135142
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
136143
if (doCons->GetLoopControl()) {
137144
// move DoConstruct
138-
std::get<std::optional<parser::DoConstruct>>(x.t) =
145+
std::get<std::optional<std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
139146
std::move(*doCons);
140147
nextIt = block.erase(nextIt);
141148
// try to match OmpEndLoopDirective
142-
if (nextIt != block.end()) {
143-
if (auto *endDir{
144-
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
145-
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
146-
std::move(*endDir);
147-
block.erase(nextIt);
148-
}
149+
if (auto *endDir{
150+
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
151+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
152+
std::move(*endDir);
153+
nextIt = block.erase(nextIt);
149154
}
150155
} else {
151156
messages_.Say(dir.source,
152157
"DO loop after the %s directive must have loop control"_err_en_US,
153158
parser::ToUpperCaseLetters(dir.source.ToString()));
154159
}
160+
} else if (auto *ompLoopCons{
161+
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
162+
// We should allow UNROLL and TILE constructs to be inserted between an OpenMP Loop Construct and the DO loop itself
163+
auto &beginDirective =
164+
std::get<parser::OmpBeginLoopDirective>(ompLoopCons->t);
165+
auto &beginLoopDirective =
166+
std::get<parser::OmpLoopDirective>(beginDirective.t);
167+
// iterate through the remaining block items to find the end directive for the unroll/tile directive.
168+
parser::Block::iterator endIt;
169+
endIt = nextIt;
170+
while(endIt != block.end()) {
171+
if (auto *endDir{
172+
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
173+
auto &endLoopDirective = std::get<parser::OmpLoopDirective>(endDir->t);
174+
if(endLoopDirective.v == dir.v) {
175+
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
176+
std::move(*endDir);
177+
endIt = block.erase(endIt);
178+
continue;
179+
}
180+
}
181+
++endIt;
182+
}
183+
if ((beginLoopDirective.v == llvm::omp::Directive::OMPD_unroll ||
184+
beginLoopDirective.v == llvm::omp::Directive::OMPD_tile)) {
185+
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
186+
auto &ompLoop = std::get<std::optional<std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t);
187+
ompLoop = std::optional<std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>>{
188+
std::variant<parser::DoConstruct, common::Indirection<parser::OpenMPLoopConstruct>>{
189+
common::Indirection{std::move(*ompLoopCons)}}};
190+
nextIt = block.erase(nextIt);
191+
}
155192
} else {
156-
messages_.Say(dir.source,
157-
"A DO loop must follow the %s directive"_err_en_US,
158-
parser::ToUpperCaseLetters(dir.source.ToString()));
193+
missingDoConstruct(dir);
159194
}
160195
// If we get here, we either found a loop, or issued an error message.
161196
return;
162197
}
198+
if (nextIt == block.end()) {
199+
missingDoConstruct(dir);
200+
}
163201
}
164202

165203
void RewriteOmpAllocations(parser::ExecutionPart &body) {

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -761,10 +761,13 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
761761
}
762762
SetLoopInfo(x);
763763

764-
if (const auto &doConstruct{
765-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
766-
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
767-
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
764+
auto &optLoopCons = std::get<1>(x.t);
765+
if(optLoopCons.has_value()) {
766+
if (const auto &doConstruct{
767+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
768+
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
769+
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
770+
}
768771
}
769772
CheckLoopItrVariableIsInt(x);
770773
CheckAssociatedLoopConstraints(x);
@@ -778,19 +781,28 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
778781
(beginDir.v == llvm::omp::Directive::OMPD_distribute_simd)) {
779782
CheckDistLinear(x);
780783
}
784+
if (beginDir.v == llvm::omp::Directive::OMPD_tile) {
785+
const auto &clauses{std::get<parser::OmpClauseList>(beginLoopDir.t)};
786+
for (auto &clause : clauses.v) {
787+
788+
}
789+
}
781790
}
782791
const parser::Name OmpStructureChecker::GetLoopIndex(
783792
const parser::DoConstruct *x) {
784793
using Bounds = parser::LoopControl::Bounds;
785794
return std::get<Bounds>(x->GetLoopControl()->u).name.thing;
786795
}
787796
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
788-
if (const auto &loopConstruct{
789-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
790-
const parser::DoConstruct *loop{&*loopConstruct};
791-
if (loop && loop->IsDoNormal()) {
792-
const parser::Name &itrVal{GetLoopIndex(loop)};
793-
SetLoopIv(itrVal.symbol);
797+
auto &optLoopCons = std::get<1>(x.t);
798+
if (optLoopCons.has_value()) {
799+
if (const auto &loopConstruct{
800+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
801+
const parser::DoConstruct *loop{&*loopConstruct};
802+
if (loop && loop->IsDoNormal()) {
803+
const parser::Name &itrVal{GetLoopIndex(loop)};
804+
SetLoopIv(itrVal.symbol);
805+
}
794806
}
795807
}
796808
}
@@ -856,8 +868,10 @@ void OmpStructureChecker::CheckIteratorModifier(const parser::OmpIterator &x) {
856868

857869
void OmpStructureChecker::CheckLoopItrVariableIsInt(
858870
const parser::OpenMPLoopConstruct &x) {
859-
if (const auto &loopConstruct{
860-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
871+
auto &optLoopCons = std::get<1>(x.t);
872+
if (optLoopCons.has_value()) {
873+
if (const auto &loopConstruct{
874+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
861875

862876
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
863877
if (loop->IsDoNormal()) {
@@ -877,6 +891,7 @@ void OmpStructureChecker::CheckLoopItrVariableIsInt(
877891
const auto it{block.begin()};
878892
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
879893
: nullptr;
894+
}
880895
}
881896
}
882897
}
@@ -1076,8 +1091,10 @@ void OmpStructureChecker::CheckDistLinear(
10761091

10771092
// Match the loop index variables with the collected symbols from linear
10781093
// clauses.
1094+
auto &optLoopCons = std::get<1>(x.t);
1095+
if (optLoopCons.has_value()) {
10791096
if (const auto &loopConstruct{
1080-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
1097+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
10811098
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
10821099
if (loop->IsDoNormal()) {
10831100
const parser::Name &itrVal{GetLoopIndex(loop)};
@@ -1095,6 +1112,7 @@ void OmpStructureChecker::CheckDistLinear(
10951112
const auto it{block.begin()};
10961113
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
10971114
: nullptr;
1115+
}
10981116
}
10991117
}
11001118

flang/lib/Semantics/resolve-directives.cpp

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,10 +1796,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPLoopConstruct &x) {
17961796
SetContextAssociatedLoopLevel(GetAssociatedLoopLevelFromClauses(clauseList));
17971797

17981798
if (beginDir.v == llvm::omp::Directive::OMPD_do) {
1799-
if (const auto &doConstruct{
1800-
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
1801-
if (doConstruct.value().IsDoWhile()) {
1802-
return true;
1799+
auto &optLoopCons = std::get<1>(x.t);
1800+
if (optLoopCons.has_value()) {
1801+
if (const auto &doConstruct{
1802+
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
1803+
if (doConstruct->IsDoWhile()) {
1804+
return true;
1805+
}
18031806
}
18041807
}
18051808
}
@@ -1962,48 +1965,64 @@ void OmpAttributeVisitor::PrivatizeAssociatedLoopIndexAndCheckLoopLevel(
19621965
bool hasCollapseClause{
19631966
clause ? (clause->Id() == llvm::omp::OMPC_collapse) : false};
19641967

1965-
const auto &outer{std::get<std::optional<parser::DoConstruct>>(x.t)};
1966-
if (outer.has_value()) {
1967-
for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) {
1968-
if (loop->IsDoConcurrent()) {
1969-
// DO CONCURRENT is explicitly allowed for the LOOP construct so long as
1970-
// there isn't a COLLAPSE clause
1971-
if (isLoopConstruct) {
1972-
if (hasCollapseClause) {
1973-
// hasCollapseClause implies clause != nullptr
1974-
context_.Say(clause->source,
1975-
"DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
1968+
auto &optLoopCons = std::get<1>(x.t);
1969+
if (optLoopCons.has_value()) {
1970+
if (const auto &outer{std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
1971+
for (const parser::DoConstruct *loop{&*outer}; loop && level > 0; --level) {
1972+
if (loop->IsDoConcurrent()) {
1973+
// DO CONCURRENT is explicitly allowed for the LOOP construct so long as
1974+
// there isn't a COLLAPSE clause
1975+
if (isLoopConstruct) {
1976+
if (hasCollapseClause) {
1977+
// hasCollapseClause implies clause != nullptr
1978+
context_.Say(clause->source,
1979+
"DO CONCURRENT loops cannot be used with the COLLAPSE clause."_err_en_US);
1980+
}
1981+
} else {
1982+
auto &stmt =
1983+
std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
1984+
context_.Say(stmt.source,
1985+
"DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
19761986
}
1977-
} else {
1978-
auto &stmt =
1979-
std::get<parser::Statement<parser::NonLabelDoStmt>>(loop->t);
1980-
context_.Say(stmt.source,
1981-
"DO CONCURRENT loops cannot form part of a loop nest."_err_en_US);
1982-
}
1983-
}
1984-
// go through all the nested do-loops and resolve index variables
1985-
const parser::Name *iv{GetLoopIndex(*loop)};
1986-
if (iv) {
1987-
if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
1988-
SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
1989-
iv->symbol = symbol; // adjust the symbol within region
1990-
AddToContextObjectWithDSA(*symbol, ivDSA);
19911987
}
1988+
// go through all the nested do-loops and resolve index variables
1989+
const parser::Name *iv{GetLoopIndex(*loop)};
1990+
if (iv) {
1991+
if (auto *symbol{ResolveOmp(*iv, ivDSA, currScope())}) {
1992+
SetSymbolDSA(*symbol, {Symbol::Flag::OmpPreDetermined, ivDSA});
1993+
iv->symbol = symbol; // adjust the symbol within region
1994+
AddToContextObjectWithDSA(*symbol, ivDSA);
1995+
}
19921996

1993-
const auto &block{std::get<parser::Block>(loop->t)};
1994-
const auto it{block.begin()};
1995-
loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
1997+
const auto &block{std::get<parser::Block>(loop->t)};
1998+
const auto it{block.begin()};
1999+
loop = it != block.end() ? GetDoConstructIf(*it) : nullptr;
2000+
}
2001+
}
2002+
CheckAssocLoopLevel(level, GetAssociatedClause());
2003+
} else if (const auto &loop{std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(&*optLoopCons)}) {
2004+
auto &beginDirective =
2005+
std::get<parser::OmpBeginLoopDirective>(loop->value().t);
2006+
auto &beginLoopDirective =
2007+
std::get<parser::OmpLoopDirective>(beginDirective.t);
2008+
if ((beginLoopDirective.v != llvm::omp::Directive::OMPD_unroll &&
2009+
beginLoopDirective.v != llvm::omp::Directive::OMPD_tile)) {
2010+
context_.Say(GetContext().directiveSource,
2011+
"Only UNROLL or TILE constructs are allowed between an OpenMP Loop Construct and a DO construct"_err_en_US,
2012+
parser::ToUpperCaseLetters(llvm::omp::getOpenMPDirectiveName(GetContext().directive, version).str()));
2013+
} else {
2014+
PrivatizeAssociatedLoopIndexAndCheckLoopLevel(loop->value());
19962015
}
2016+
} else {
2017+
context_.Say(GetContext().directiveSource,
2018+
"A DO loop must follow the %s directive"_err_en_US,
2019+
parser::ToUpperCaseLetters(
2020+
llvm::omp::getOpenMPDirectiveName(GetContext().directive, version)
2021+
.str()));
19972022
}
1998-
CheckAssocLoopLevel(level, GetAssociatedClause());
1999-
} else {
2000-
context_.Say(GetContext().directiveSource,
2001-
"A DO loop must follow the %s directive"_err_en_US,
2002-
parser::ToUpperCaseLetters(
2003-
llvm::omp::getOpenMPDirectiveName(GetContext().directive, version)
2004-
.str()));
20052023
}
20062024
}
2025+
20072026
void OmpAttributeVisitor::CheckAssocLoopLevel(
20082027
std::int64_t level, const parser::OmpClause *clause) {
20092028
if (clause && level != 0) {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
! Test to ensure TODO message is emitted for tile OpenMP 5.1 Directives when they are nested.
2+
3+
!RUN: not %flang -fopenmp -fopenmp-version=51 %s 2<&1 | FileCheck %s
4+
5+
subroutine loop_transformation_construct
6+
implicit none
7+
integer :: I = 10
8+
integer :: x
9+
integer :: y(I)
10+
11+
!$omp do
12+
!$omp tile
13+
do i = 1, I
14+
y(i) = y(i) * 5
15+
end do
16+
!$omp end tile
17+
!$omp end do
18+
end subroutine
19+
20+
!CHECK: not yet implemented: Unhandled loop directive (tile)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
! Test to ensure TODO message is emitted for unroll OpenMP 5.1 Directives when they are nested.
2+
3+
!RUN: not %flang -fopenmp -fopenmp-version=51 %s 2<&1 | FileCheck %s
4+
5+
program loop_transformation_construct
6+
implicit none
7+
integer, parameter :: I = 10
8+
integer :: x
9+
integer :: y(I)
10+
11+
!$omp do
12+
!$omp unroll
13+
do x = 1, I
14+
y(x) = y(x) * 5
15+
end do
16+
!$omp end unroll
17+
!$omp end do
18+
end program loop_transformation_construct
19+
20+
!CHECK: not yet implemented: Unhandled loop directive (unroll)

0 commit comments

Comments
 (0)