Skip to content

[Flang][OpenMP] Add Semantics support for Nested OpenMPLoopConstructs #145917

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 1, 2025
5 changes: 4 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ struct AccEndCombinedDirective;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
struct OpenMPConstruct;
struct OpenMPLoopConstruct;
struct OpenMPDeclarativeConstruct;
struct OmpEndLoopDirective;
struct OmpMemoryOrderClause;
Expand Down Expand Up @@ -5021,11 +5022,13 @@ struct OpenMPBlockConstruct {
};

// OpenMP directives enclosing do loop
using NestedConstruct =
std::variant<DoConstruct, common::Indirection<OpenMPLoopConstruct>>;
struct OpenMPLoopConstruct {
TUPLE_CLASS_BOILERPLATE(OpenMPLoopConstruct);
OpenMPLoopConstruct(OmpBeginLoopDirective &&a)
: t({std::move(a), std::nullopt, std::nullopt}) {}
std::tuple<OmpBeginLoopDirective, std::optional<DoConstruct>,
std::tuple<OmpBeginLoopDirective, std::optional<NestedConstruct>,
std::optional<OmpEndLoopDirective>>
t;
};
Expand Down
10 changes: 10 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4231,6 +4231,16 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);

auto &optLoopCons =
std::get<std::optional<parser::NestedConstruct>>(loopConstruct.t);
if (optLoopCons.has_value()) {
if (auto *ompNestedLoopCons{
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>(
&*optLoopCons)}) {
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value());
}
}

llvm::omp::Directive directive =
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v;
const parser::CharBlock &source =
Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Parser/unparse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2926,7 +2926,8 @@ class UnparseVisitor {
Walk(std::get<OmpBeginLoopDirective>(x.t));
Put("\n");
EndOpenMP();
Walk(std::get<std::optional<DoConstruct>>(x.t));
Walk(std::get<std::optional<std::variant<DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t));
Walk(std::get<std::optional<OmpEndLoopDirective>>(x.t));
}
void Unparse(const BasedPointer &x) {
Expand Down
97 changes: 86 additions & 11 deletions flang/lib/Semantics/canonicalize-omp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "canonicalize-omp.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"

// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
// Constructs more structured which provide explicit scopes for later
Expand Down Expand Up @@ -125,6 +126,16 @@ class CanonicalizationOfOmp {
parser::Block::iterator nextIt;
auto &beginDir{std::get<parser::OmpBeginLoopDirective>(x.t)};
auto &dir{std::get<parser::OmpLoopDirective>(beginDir.t)};
auto missingDoConstruct = [](auto &dir, auto &messages) {
messages.Say(dir.source,
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
};
auto tileUnrollError = [](auto &dir, auto &messages) {
messages.Say(dir.source,
"If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
};

nextIt = it;
while (++nextIt != block.end()) {
Expand All @@ -135,31 +146,95 @@ class CanonicalizationOfOmp {
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
// move DoConstruct
std::get<std::optional<parser::DoConstruct>>(x.t) =
std::get<std::optional<std::variant<parser::DoConstruct,
common::Indirection<parser::OpenMPLoopConstruct>>>>(x.t) =
std::move(*doCons);
nextIt = block.erase(nextIt);
// try to match OmpEndLoopDirective
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
block.erase(nextIt);
}
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
nextIt = block.erase(nextIt);
}
} else {
messages_.Say(dir.source,
"DO loop after the %s directive must have loop control"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
}
} else if (auto *ompLoopCons{
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
// We should allow UNROLL and TILE constructs to be inserted between an
// OpenMP Loop Construct and the DO loop itself
auto &nestedBeginDirective =
std::get<parser::OmpBeginLoopDirective>(ompLoopCons->t);
auto &nestedBeginLoopDirective =
std::get<parser::OmpLoopDirective>(nestedBeginDirective.t);
if ((nestedBeginLoopDirective.v == llvm::omp::Directive::OMPD_unroll ||
nestedBeginLoopDirective.v ==
llvm::omp::Directive::OMPD_tile) &&
!(nestedBeginLoopDirective.v == llvm::omp::Directive::OMPD_unroll &&
dir.v == llvm::omp::Directive::OMPD_tile)) {
// iterate through the remaining block items to find the end directive
// for the unroll/tile directive.
parser::Block::iterator endIt;
endIt = nextIt;
while (endIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*endIt)}) {
auto &endLoopDirective =
std::get<parser::OmpLoopDirective>(endDir->t);
if (endLoopDirective.v == dir.v) {
std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
std::move(*endDir);
endIt = block.erase(endIt);
continue;
}
}
++endIt;
}
RewriteOpenMPLoopConstruct(*ompLoopCons, block, nextIt);
auto &ompLoop = std::get<std::optional<parser::NestedConstruct>>(x.t);
ompLoop =
std::optional<parser::NestedConstruct>{parser::NestedConstruct{
common::Indirection{std::move(*ompLoopCons)}}};
nextIt = block.erase(nextIt);
} else if (nestedBeginLoopDirective.v ==
llvm::omp::Directive::OMPD_unroll &&
dir.v == llvm::omp::Directive::OMPD_tile) {
// if a loop has been unrolled, the user can not then tile that loop
// as it has been unrolled
parser::OmpClauseList &unrollClauseList{
std::get<parser::OmpClauseList>(nestedBeginDirective.t)};
if (unrollClauseList.v.empty()) {
// if the clause list is empty for an unroll construct, we assume
// the loop is being fully unrolled
tileUnrollError(dir, messages_);
} else {
// parse the clauses for the unroll directive to find the full
// clause
for (auto clause{unrollClauseList.v.begin()};
clause != unrollClauseList.v.end(); ++clause) {
if (clause->Id() == llvm::omp::OMPC_full) {
tileUnrollError(dir, messages_);
}
}
}
} else {
messages_.Say(nestedBeginLoopDirective.source,
"Only Loop Transformation Constructs or Loop Nests can be nested within Loop Constructs"_err_en_US,
parser::ToUpperCaseLetters(
nestedBeginLoopDirective.source.ToString()));
}
} else {
messages_.Say(dir.source,
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dir.source.ToString()));
missingDoConstruct(dir, messages_);
}
// If we get here, we either found a loop, or issued an error message.
return;
}
if (nextIt == block.end()) {
missingDoConstruct(dir, messages_);
}
}

void RewriteOmpAllocations(parser::ExecutionPart &body) {
Expand Down
100 changes: 56 additions & 44 deletions flang/lib/Semantics/check-omp-structure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -762,10 +762,13 @@ void OmpStructureChecker::Enter(const parser::OpenMPLoopConstruct &x) {
}
SetLoopInfo(x);

if (const auto &doConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &doConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
const auto &doBlock{std::get<parser::Block>(doConstruct->t)};
CheckNoBranching(doBlock, beginDir.v, beginDir.source);
}
}
CheckLoopItrVariableIsInt(x);
CheckAssociatedLoopConstraints(x);
Expand All @@ -786,12 +789,15 @@ const parser::Name OmpStructureChecker::GetLoopIndex(
return std::get<Bounds>(x->GetLoopControl()->u).name.thing;
}
void OmpStructureChecker::SetLoopInfo(const parser::OpenMPLoopConstruct &x) {
if (const auto &loopConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
const parser::DoConstruct *loop{&*loopConstruct};
if (loop && loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
SetLoopIv(itrVal.symbol);
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
const parser::DoConstruct *loop{&*loopConstruct};
if (loop && loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
SetLoopIv(itrVal.symbol);
}
}
}
}
Expand Down Expand Up @@ -857,27 +863,30 @@ void OmpStructureChecker::CheckIteratorModifier(const parser::OmpIterator &x) {

void OmpStructureChecker::CheckLoopItrVariableIsInt(
const parser::OpenMPLoopConstruct &x) {
if (const auto &loopConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {

for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
const auto *type{itrVal.symbol->GetType()};
if (!type->IsNumeric(TypeCategory::Integer)) {
context_.Say(itrVal.source,
"The DO loop iteration"
" variable must be of the type integer."_err_en_US,
itrVal.ToString());
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
const auto *type{itrVal.symbol->GetType()};
if (!type->IsNumeric(TypeCategory::Integer)) {
context_.Say(itrVal.source,
"The DO loop iteration"
" variable must be of the type integer."_err_en_US,
itrVal.ToString());
}
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
}
}
Expand Down Expand Up @@ -1077,25 +1086,28 @@ void OmpStructureChecker::CheckDistLinear(

// Match the loop index variables with the collected symbols from linear
// clauses.
if (const auto &loopConstruct{
std::get<std::optional<parser::DoConstruct>>(x.t)}) {
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
// Remove the symbol from the collected set
indexVars.erase(&itrVal.symbol->GetUltimate());
}
collapseVal--;
if (collapseVal == 0) {
break;
auto &optLoopCons = std::get<std::optional<parser::NestedConstruct>>(x.t);
if (optLoopCons.has_value()) {
if (const auto &loopConstruct{
std::get_if<parser::DoConstruct>(&*optLoopCons)}) {
for (const parser::DoConstruct *loop{&*loopConstruct}; loop;) {
if (loop->IsDoNormal()) {
const parser::Name &itrVal{GetLoopIndex(loop)};
if (itrVal.symbol) {
// Remove the symbol from the collected set
indexVars.erase(&itrVal.symbol->GetUltimate());
}
collapseVal--;
if (collapseVal == 0) {
break;
}
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
// Get the next DoConstruct if block is not empty.
const auto &block{std::get<parser::Block>(loop->t)};
const auto it{block.begin()};
loop = it != block.end() ? parser::Unwrap<parser::DoConstruct>(*it)
: nullptr;
}
}

Expand Down
Loading