-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] Add support for lowering directives at the CONTAINS level #95123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There is currently support for lowering directives that appear outside of a module or procedure, or inside the body of a module or procedure. Extend this to support directives at the CONTAINS level of a module or procedure, such as directives 3, 5, 7 9, and 10 in: !dir$ some directive 1 module m !dir$ some directive 2 contains !dir$ some directive 3 subroutine p !dir$ some directive 4 contains !dir$ some directive 5 subroutine s1 !dir$ some directive 6 end subroutine s1 !dir$ some directive 7 subroutine s2 !dir$ some directive 8 end subroutine s2 !dir$ some directive 9 end subroutine p !dir$ some directive 10 end module m !dir$ some directive 11 This is done by looking for CONTAINS statements at the module or procedure level, while ignoring CONTAINS statements at the derived type level.
@llvm/pr-subscribers-flang-fir-hlfir Author: None (vdonaldson) ChangesThere is currently support for lowering directives that appear outside of a module or procedure, or inside the body of a module or procedure. Extend this to support directives at the CONTAINS level of a module or procedure, such as directives 3, 5, 7 9, and 10 in:
This is done by looking for CONTAINS statements at the module or procedure level, while ignoring CONTAINS statements at the derived type level. Patch is 22.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95123.diff 4 Files Affected:
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 9913f584133fa..83200eb6351a8 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -31,11 +31,14 @@
namespace Fortran::lower::pft {
+struct CompilerDirectiveUnit;
struct Evaluation;
-struct Program;
-struct ModuleLikeUnit;
struct FunctionLikeUnit;
+struct ModuleLikeUnit;
+struct Program;
+using ContainedUnit = std::variant<CompilerDirectiveUnit, FunctionLikeUnit>;
+using ContainedUnitList = std::list<ContainedUnit>;
using EvaluationList = std::list<Evaluation>;
/// Provide a variant like container that can hold references. It can hold
@@ -594,8 +597,8 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);
void dump(VariableList &, std::string s = {}); // `s` is an optional dump label
-/// Function-like units may contain evaluations (executable statements) and
-/// nested function-like units (internal procedures and function statements).
+/// Function-like units may contain evaluations (executable statements),
+/// directives, and internal (nested) function-like units.
struct FunctionLikeUnit : public ProgramUnit {
// wrapper statements for function-like syntactic structures
using FunctionStatement =
@@ -697,10 +700,10 @@ struct FunctionLikeUnit : public ProgramUnit {
std::optional<FunctionStatement> beginStmt;
FunctionStatement endStmt;
const semantics::Scope *scope;
- EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
- std::list<FunctionLikeUnit> nestedFunctions;
+ ContainedUnitList containedUnitList;
+ EvaluationList evaluationList;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
@@ -746,7 +749,7 @@ struct ModuleLikeUnit : public ProgramUnit {
ModuleStatement beginStmt;
ModuleStatement endStmt;
- std::list<FunctionLikeUnit> nestedFunctions;
+ ContainedUnitList containedUnitList;
EvaluationList evaluationList;
};
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 202efa57d4a36..9ecbbc73dce07 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -302,28 +302,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
bool hasMainProgram = false;
const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
- std::visit(Fortran::common::visitors{
- [&](Fortran::lower::pft::FunctionLikeUnit &f) {
- if (f.isMainProgram())
- hasMainProgram = true;
- declareFunction(f);
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = f.getScope().symbol();
- },
- [&](Fortran::lower::pft::ModuleLikeUnit &m) {
- lowerModuleDeclScope(m);
- for (Fortran::lower::pft::FunctionLikeUnit &f :
- m.nestedFunctions)
- declareFunction(f);
- },
- [&](Fortran::lower::pft::BlockDataUnit &b) {
- if (!globalOmpRequiresSymbol)
- globalOmpRequiresSymbol = b.symTab.symbol();
- },
- [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
- [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
- },
- u);
+ std::visit(
+ Fortran::common::visitors{
+ [&](Fortran::lower::pft::FunctionLikeUnit &f) {
+ if (f.isMainProgram())
+ hasMainProgram = true;
+ declareFunction(f);
+ if (!globalOmpRequiresSymbol)
+ globalOmpRequiresSymbol = f.getScope().symbol();
+ },
+ [&](Fortran::lower::pft::ModuleLikeUnit &m) {
+ lowerModuleDeclScope(m);
+ for (Fortran::lower::pft::ContainedUnit &unit :
+ m.containedUnitList)
+ if (auto *f =
+ std::get_if<Fortran::lower::pft::FunctionLikeUnit>(
+ &unit))
+ declareFunction(*f);
+ },
+ [&](Fortran::lower::pft::BlockDataUnit &b) {
+ if (!globalOmpRequiresSymbol)
+ globalOmpRequiresSymbol = b.symTab.symbol();
+ },
+ [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+ [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
+ },
+ u);
}
// Create definitions of intrinsic module constants.
@@ -387,13 +391,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- collectHostAssociatedVariables(f, escapeHost);
+ for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ collectHostAssociatedVariables(*f, escapeHost);
funit.setHostAssociatedSymbols(escapeHost);
// Declare internal procedures
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- declareFunction(f);
+ for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ declareFunction(*f);
}
/// Get the scope that is defining or using \p sym. The returned scope is not
@@ -5356,8 +5362,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
- for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
- lowerFunc(f); // internal procedure
+ for (Fortran::lower::pft::ContainedUnit &unit : funit.containedUnitList)
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ lowerFunc(*f); // internal procedure
}
/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
@@ -5381,8 +5388,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
- for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
- lowerFunc(f);
+ for (Fortran::lower::pft::ContainedUnit &unit : mod.containedUnitList)
+ if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&unit))
+ lowerFunc(*f);
}
void setCurrentPosition(const Fortran::parser::CharBlock &position) {
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index f196b9c5a0cbc..df2c31381a0e7 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -209,6 +209,20 @@ class PFTBuilder {
}
}
+ bool Pre(const parser::SpecificationPart &) {
+ ++specificationPartLevel;
+ return true;
+ }
+ void Post(const parser::SpecificationPart &) { --specificationPartLevel; }
+
+ bool Pre(const parser::ContainsStmt &) {
+ if (!specificationPartLevel) {
+ assert(containsStmtStack.size() && "empty contains stack");
+ containsStmtStack.back() = true;
+ }
+ return false;
+ }
+
// Module like
bool Pre(const parser::Module &node) { return enterModule(node); }
bool Pre(const parser::Submodule &node) { return enterModule(node); }
@@ -249,15 +263,21 @@ class PFTBuilder {
whereBody.u);
}
- // CompilerDirective have special handling in case they are top level
- // directives (i.e. they do not belong to a ProgramUnit).
+ // A CompilerDirective may appear outside any program unit, after a module
+ // or function contains statement, or inside a module or function.
bool Pre(const parser::CompilerDirective &directive) {
- assert(pftParentStack.size() > 0 &&
- "At least the Program must be a parent");
- if (pftParentStack.back().isA<lower::pft::Program>()) {
- addUnit(
- lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
+ assert(pftParentStack.size() > 0 && "no program");
+ lower::pft::PftNode &node = pftParentStack.back();
+ if (node.isA<lower::pft::Program>()) {
+ addUnit(lower::pft::CompilerDirectiveUnit(directive, node));
return false;
+ } else if ((node.isA<lower::pft::ModuleLikeUnit>() ||
+ node.isA<lower::pft::FunctionLikeUnit>())) {
+ assert(containsStmtStack.size() && "empty contains stack");
+ if (containsStmtStack.back()) {
+ addContainedUnit(lower::pft::CompilerDirectiveUnit{directive, node});
+ return false;
+ }
}
return enterConstructOrDirective(directive);
}
@@ -277,9 +297,10 @@ class PFTBuilder {
/// Initialize a new module-like unit and make it the builder's focus.
template <typename A>
bool enterModule(const A &mod) {
- Fortran::lower::pft::ModuleLikeUnit &unit =
+ lower::pft::ModuleLikeUnit &unit =
addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
- functionList = &unit.nestedFunctions;
+ containsStmtStack.push_back(false);
+ containedUnitList = &unit.containedUnitList;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -287,6 +308,7 @@ class PFTBuilder {
}
void exitModule() {
+ containsStmtStack.pop_back();
if (!evaluationListStack.empty())
popEvaluationList();
pftParentStack.pop_back();
@@ -344,12 +366,13 @@ class PFTBuilder {
const semantics::SemanticsContext &semanticsContext) {
cleanModuleEvaluationList();
endFunctionBody(); // enclosing host subprogram body, if any
- Fortran::lower::pft::FunctionLikeUnit &unit =
- addFunction(lower::pft::FunctionLikeUnit{func, pftParentStack.back(),
- semanticsContext});
+ lower::pft::FunctionLikeUnit &unit =
+ addContainedUnit(lower::pft::FunctionLikeUnit{
+ func, pftParentStack.back(), semanticsContext});
labelEvaluationMap = &unit.labelEvaluationMap;
assignSymbolLabelMap = &unit.assignSymbolLabelMap;
- functionList = &unit.nestedFunctions;
+ containsStmtStack.push_back(false);
+ containedUnitList = &unit.containedUnitList;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
@@ -361,6 +384,7 @@ class PFTBuilder {
endFunctionBody();
analyzeBranches(nullptr, *evaluationListStack.back()); // add branch links
processEntryPoints();
+ containsStmtStack.pop_back();
popEvaluationList();
labelEvaluationMap = nullptr;
assignSymbolLabelMap = nullptr;
@@ -371,7 +395,7 @@ class PFTBuilder {
/// Initialize a new construct or directive and make it the builder's focus.
template <typename A>
bool enterConstructOrDirective(const A &constructOrDirective) {
- Fortran::lower::pft::Evaluation &eval = addEvaluation(
+ lower::pft::Evaluation &eval = addEvaluation(
lower::pft::Evaluation{constructOrDirective, pftParentStack.back()});
eval.evaluationList.reset(new lower::pft::EvaluationList);
pushEvaluationList(eval.evaluationList.get());
@@ -381,7 +405,7 @@ class PFTBuilder {
}
void exitConstructOrDirective() {
- auto isOpenMPLoopConstruct = [](Fortran::lower::pft::Evaluation *eval) {
+ auto isOpenMPLoopConstruct = [](lower::pft::Evaluation *eval) {
if (const auto *ompConstruct = eval->getIf<parser::OpenMPConstruct>())
if (std::holds_alternative<parser::OpenMPLoopConstruct>(
ompConstruct->u))
@@ -396,8 +420,7 @@ class PFTBuilder {
// construct region must have an exit target inside the region.
// This is not applicable to the OpenMP loop construct since the
// end of the loop is an available target inside the region.
- Fortran::lower::pft::EvaluationList &evaluationList =
- *eval->evaluationList;
+ lower::pft::EvaluationList &evaluationList = *eval->evaluationList;
if (!evaluationList.empty() && evaluationList.back().isConstruct()) {
static const parser::ContinueStmt exitTarget{};
addEvaluation(
@@ -413,15 +436,15 @@ class PFTBuilder {
void resetFunctionState() {
if (!pftParentStack.empty()) {
pftParentStack.back().visit(common::visitors{
+ [&](lower::pft::ModuleLikeUnit &p) {
+ containedUnitList = &p.containedUnitList;
+ },
[&](lower::pft::FunctionLikeUnit &p) {
- functionList = &p.nestedFunctions;
+ containedUnitList = &p.containedUnitList;
labelEvaluationMap = &p.labelEvaluationMap;
assignSymbolLabelMap = &p.assignSymbolLabelMap;
},
- [&](lower::pft::ModuleLikeUnit &p) {
- functionList = &p.nestedFunctions;
- },
- [&](auto &) { functionList = nullptr; },
+ [&](auto &) { containedUnitList = nullptr; },
});
}
}
@@ -433,12 +456,11 @@ class PFTBuilder {
}
template <typename A>
- A &addFunction(A &&func) {
- if (functionList) {
- functionList->emplace_back(std::move(func));
- return functionList->back();
- }
- return addUnit(std::move(func));
+ A &addContainedUnit(A &&unit) {
+ if (!containedUnitList)
+ return addUnit(std::move(unit));
+ containedUnitList->emplace_back(std::move(unit));
+ return std::get<A>(containedUnitList->back());
}
// ActionStmt has a couple of non-conforming cases, explicitly handled here.
@@ -459,7 +481,6 @@ class PFTBuilder {
/// Append an Evaluation to the end of the current list.
lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
- assert(functionList && "not in a function");
assert(!evaluationListStack.empty() && "empty evaluation list stack");
if (!constructAndDirectiveStack.empty())
eval.parentConstruct = constructAndDirectiveStack.back();
@@ -499,15 +520,15 @@ class PFTBuilder {
/// push a new list on the stack of Evaluation lists
void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
- assert(functionList && "not in a function");
assert(evaluationList && evaluationList->empty() &&
- "evaluation list isn't correct");
+ "invalid evaluation list");
evaluationListStack.emplace_back(evaluationList);
}
/// pop the current list and return to the last Evaluation list
void popEvaluationList() {
- assert(functionList && "not in a function");
+ assert(!evaluationListStack.empty() &&
+ "trying to pop an empty evaluationListStack");
evaluationListStack.pop_back();
}
@@ -1089,9 +1110,8 @@ class PFTBuilder {
std::vector<lower::pft::PftNode> pftParentStack;
const semantics::SemanticsContext &semanticsContext;
- /// functionList points to the internal or module procedure function list
- /// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
- std::list<lower::pft::FunctionLikeUnit> *functionList{};
+ llvm::SmallVector<bool> containsStmtStack{};
+ lower::pft::ContainedUnitList *containedUnitList{};
std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
std::vector<lower::pft::Evaluation *> doConstructStack{};
/// evaluationListStack is the current nested construct evaluationList state.
@@ -1099,6 +1119,7 @@ class PFTBuilder {
llvm::DenseMap<parser::Label, lower::pft::Evaluation *> *labelEvaluationMap{};
lower::pft::SymbolLabelMap *assignSymbolLabelMap{};
std::map<std::string, lower::pft::Evaluation *> constructNameMap{};
+ int specificationPartLevel{};
lower::pft::Evaluation *lastLexicalEvaluation{};
};
@@ -1201,11 +1222,15 @@ class PFTDumper {
outputStream << " -> " << eval.controlSuccessor->printIndex;
else if (eval.isA<parser::EntryStmt>() && eval.lexicalSuccessor)
outputStream << " -> " << eval.lexicalSuccessor->printIndex;
+ bool extraNewline = false;
if (!eval.position.empty())
outputStream << ": " << eval.position.ToString();
- else if (auto *dir = eval.getIf<Fortran::parser::CompilerDirective>())
+ else if (auto *dir = eval.getIf<parser::CompilerDirective>()) {
+ extraNewline = dir->source.ToString().back() == '\n';
outputStream << ": !" << dir->source.ToString();
- outputStream << '\n';
+ }
+ if (!extraNewline)
+ outputStream << '\n';
if (eval.hasNestedEvaluations()) {
dumpEvaluationList(outputStream, *eval.evaluationList, indent + 1);
outputStream << indentString << "<<End " << name << bang << ">>\n";
@@ -1265,13 +1290,7 @@ class PFTDumper {
outputStream << ": " << header;
outputStream << '\n';
dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
- if (!functionLikeUnit.nestedFunctions.empty()) {
- outputStream << "\nContains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- functionLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
- outputStream << "End Contains\n";
- }
+ dumpContainedUnitList(outputStream, functionLikeUnit.containedUnitList);
outputStream << "End " << unitKind << ' ' << name << "\n\n";
}
@@ -1298,11 +1317,8 @@ class PFTDumper {
});
outputStream << unitKind << ' ' << name << ": " << header << '\n';
dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
- outputStream << "Contains\n";
- for (const lower::pft::FunctionLikeUnit &func :
- moduleLikeUnit.nestedFunctions)
- dumpFunctionLikeUnit(outputStream, func);
- outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
+ dumpContainedUnitList(outputStream, moduleLikeUnit.containedUnitList);
+ outputStream << "End " << unitKind << ' ' << name << "\n\n";
}
// Top level directives
@@ -1311,9 +1327,34 @@ class PFTDumper {
const lower::pft::CompilerDirectiveUnit &directive) {
outputStream << getNodeIndex(directive) << " ";
outputStream << "CompilerDirective: !";
- outputStream << directive.get<Fortran::parser::CompilerDirective>()
- .source.ToString();
- outputStream << "\nEnd CompilerDirective\n\n";
+ bool extraNewline =
+ directive.get<parser::CompilerDirective>().source.ToString().back() ==
+ 'n';
+ outputStream
+ << directive.get<parser::CompilerDirective>().source.ToString();
+ if (!extraNewline)
+ outputStream << "\n";
+ outputStream << "\n";
+ }
+
+ void dumpContainedUnitList(
+ llvm::raw_ostream &outputStream,
+ const lower::pft::ContainedUnitList &containedUnitList) {
+ if (containedUnitList.empty())
+ return;
+ outputStream << "\nContains\n";
+ for (const lower::pft::ContainedUnit &unit : containedUnitList)
+ if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&unit)) {
+ dumpFunctionLikeUnit(outputStream, *func);
+ } else if (const auto *dir =
+ std::get_if<lower::pft::CompilerDirectiveUnit>(&unit)) {
+ outputStream << getNodeIndex(*dir) << " ";
+ dumpEvaluation(outputStream,
+ lower::pft::Evaluation{
+ dir->get<parser::CompilerDirective>(), dir->parent});
+ outputStream << "\n";
+ }
+ outputStream << "End Contains\n";
}
void
@@ -1321,8 +1362,8 @@ class PFTDumper {
const lower::pft::OpenACCDirectiveUnit &directive) {
outputStream << getNodeIndex(directive) << " ";
outputStream << "OpenACCDirective: !$acc ";
- outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
- .source.ToString();
+ outputStream
+ << directive.get<parser::OpenACCRoutineConstruct>().source.ToString();
outpu...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, LGTM otherwise, thanks for adding support for this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All builds and tests correctly and looks good.
There is currently support for lowering directives that appear outside of a module or procedure, or inside the body of a module or procedure. Extend this to support directives at the CONTAINS level of a module or procedure, such as directives 3, 5, 7 9, and 10 in:
This is done by looking for CONTAINS statements at the module or procedure level, while ignoring CONTAINS statements at the derived type level.