Skip to content

RFC: WIP: add support for compiler directives which apply to functions #75352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions flang/include/flang/Lower/PFTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,12 @@ VariableList getDependentVariableList(const Fortran::semantics::Symbol &);

void dump(VariableList &, std::string s = {}); // `s` is an optional dump label

/// Things that can be nested inside of a module or function
/// TODO: add the rest
struct FunctionLikeUnit;
struct CompilerDirectiveUnit;
using NestedUnit = std::variant<FunctionLikeUnit, CompilerDirectiveUnit>;

/// Function-like units may contain evaluations (executable statements) and
/// nested function-like units (internal procedures and function statements).
struct FunctionLikeUnit : public ProgramUnit {
Expand Down Expand Up @@ -695,7 +701,7 @@ struct FunctionLikeUnit : public ProgramUnit {
EvaluationList evaluationList;
LabelEvalMap labelEvaluationMap;
SymbolLabelMap assignSymbolLabelMap;
std::list<FunctionLikeUnit> nestedFunctions;
std::list<NestedUnit> nestedUnits;
/// <Symbol, Evaluation> pairs for each entry point. The pair at index 0
/// is the primary entry point; remaining pairs are alternate entry points.
/// The primary entry point symbol is Null for an anonymous program.
Expand Down Expand Up @@ -741,7 +747,7 @@ struct ModuleLikeUnit : public ProgramUnit {

ModuleStatement beginStmt;
ModuleStatement endStmt;
std::list<FunctionLikeUnit> nestedFunctions;
std::list<NestedUnit> nestedUnits;
EvaluationList evaluationList;
};

Expand Down
3 changes: 2 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -2894,7 +2894,8 @@ struct ModuleSubprogram {
UNION_CLASS_BOILERPLATE(ModuleSubprogram);
std::variant<common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>,
common::Indirection<SeparateModuleSubprogram>>
common::Indirection<SeparateModuleSubprogram>,
common::Indirection<CompilerDirective>>
u;
};

Expand Down
103 changes: 89 additions & 14 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "flang/Semantics/runtime-type-info.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
Expand Down Expand Up @@ -303,9 +304,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
},
[&](Fortran::lower::pft::ModuleLikeUnit &m) {
lowerModuleDeclScope(m);
for (Fortran::lower::pft::FunctionLikeUnit &f :
m.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::NestedUnit &unit :
m.nestedUnits) {
if (auto *f = std::get_if<
Fortran::lower::pft::FunctionLikeUnit>(&unit))
declareFunction(*f);
}
},
[&](Fortran::lower::pft::BlockDataUnit &b) {
if (!globalOmpRequiresSymbol)
Expand All @@ -322,13 +326,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&]() { createIntrinsicModuleDefinitions(pft); });

// Primary translation pass.
for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
std::list<Fortran::lower::pft::Program::Units> &units = pft.getUnits();
for (auto it = units.begin(); it != units.end(); it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
processSubprogramDirective(it, units.end(), d);
},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
Expand All @@ -338,7 +345,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder = nullptr;
},
},
u);
*it);
}

// Once all the code has been translated, create global runtime type info
Expand Down Expand Up @@ -387,13 +394,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {

// Compute the set of host associated entities from the nested functions.
llvm::SetVector<const Fortran::semantics::Symbol *> escapeHost;
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
collectHostAssociatedVariables(f, escapeHost);
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
collectHostAssociatedVariables(*f, escapeHost);
}
funit.setHostAssociatedSymbols(escapeHost);

// Declare internal procedures
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
declareFunction(f);
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
declareFunction(*f);
}
}

/// Get the scope that is defining or using \p sym. The returned scope is not
Expand Down Expand Up @@ -4667,8 +4678,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
endNewFunction(funit);
}
funit.setActiveEntry(0);
for (Fortran::lower::pft::FunctionLikeUnit &f : funit.nestedFunctions)
lowerFunc(f); // internal procedure
for (Fortran::lower::pft::NestedUnit &nested : funit.nestedUnits) {
if (auto *f = std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&nested))
lowerFunc(*f); // internal procedure
}
}

/// Lower module variable definitions to fir::globalOp and OpenMP/OpenACC
Expand All @@ -4692,8 +4705,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {

/// Lower functions contained in a module.
void lowerMod(Fortran::lower::pft::ModuleLikeUnit &mod) {
for (Fortran::lower::pft::FunctionLikeUnit &f : mod.nestedFunctions)
lowerFunc(f);
for (auto it = mod.nestedUnits.begin(); it != mod.nestedUnits.end();
it = std::next(it)) {
std::visit(
Fortran::common::visitors{
[&](Fortran::lower::pft::FunctionLikeUnit &f) { lowerFunc(f); },
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {
processSubprogramDirective(it, mod.nestedUnits.end(), d);
}},
*it);
}
}

void setCurrentPosition(const Fortran::parser::CharBlock &position) {
Expand Down Expand Up @@ -5001,6 +5022,60 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol);
}

/// Process compiler directives that apply to subprograms
template <typename ITERATOR>
void
processSubprogramDirective(ITERATOR it, ITERATOR endIt,
Fortran::lower::pft::CompilerDirectiveUnit &d) {
auto *parserDirective = d.getIf<Fortran::parser::CompilerDirective>();
if (!parserDirective)
return;
auto *nvList =
std::get_if<std::list<Fortran::parser::CompilerDirective::NameValue>>(
&parserDirective->u);
if (!nvList)
return;

// get the function the directive applies to (hopefully the next unit)
mlir::func::FuncOp mlirFunc;
it = std::next(it);
if (it != endIt) {
auto *pftFunction =
std::get_if<Fortran::lower::pft::FunctionLikeUnit>(&*it);
if (pftFunction) {
Fortran::lower::CalleeInterface callee{*pftFunction, *this};
mlirFunc = callee.getFuncOp();
}
}

for (const Fortran::parser::CompilerDirective::NameValue &nv : *nvList) {
std::string name = std::get<Fortran::parser::Name>(nv.t).ToString();

// arm streaming sve directives
auto streamingMode = mlir::arm_sme::ArmStreamingMode::Disabled;
if (name == "arm_streaming")
streamingMode = mlir::arm_sme::ArmStreamingMode::Streaming;
else if (name == "arm_locally_streaming")
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingLocally;
else if (name == "arm_streaming_compatible")
streamingMode = mlir::arm_sme::ArmStreamingMode::StreamingCompatible;
if (streamingMode != mlir::arm_sme::ArmStreamingMode::Disabled) {
if (!mlirFunc) {
// TODO: share diagnostic code with warnings elsewhere
// TODO: source location is printed as loc<"file.f90":line:col>
mlir::Location loc = genLocation(parserDirective->source);
llvm::errs() << loc << ": warning: ignoring directive '" << name
<< "' because it has no associated subprogram\n";
continue;
}
llvm::StringRef attrName =
mlir::arm_sme::stringifyArmStreamingMode(streamingMode);
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(mlirFunc.getContext());
mlirFunc->setAttr(attrName, unitAttr);
}
}
}

//===--------------------------------------------------------------------===//

Fortran::lower::LoweringBridge &bridge;
Expand Down
63 changes: 42 additions & 21 deletions flang/lib/Lower/PFTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,12 @@ class PFTBuilder {
lower::pft::CompilerDirectiveUnit(directive, pftParentStack.back()));
return false;
}
if (auto *mod = pftParentStack.back().getIf<lower::pft::ModuleLikeUnit>()) {
assert(nestedUnitList && "Modules have a nested units list");
lower::pft::CompilerDirectiveUnit unit{directive, pftParentStack.back()};
addNestedUnit(std::move(unit));
return false;
}
return enterConstructOrDirective(directive);
}

Expand All @@ -279,7 +285,7 @@ class PFTBuilder {
bool enterModule(const A &mod) {
Fortran::lower::pft::ModuleLikeUnit &unit =
addUnit(lower::pft::ModuleLikeUnit{mod, pftParentStack.back()});
functionList = &unit.nestedFunctions;
nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
Expand Down Expand Up @@ -349,7 +355,7 @@ class PFTBuilder {
semanticsContext});
labelEvaluationMap = &unit.labelEvaluationMap;
assignSymbolLabelMap = &unit.assignSymbolLabelMap;
functionList = &unit.nestedFunctions;
nestedUnitList = &unit.nestedUnits;
pushEvaluationList(&unit.evaluationList);
pftParentStack.emplace_back(unit);
LLVM_DEBUG(dumpScope(&unit.getScope()));
Expand Down Expand Up @@ -414,14 +420,14 @@ class PFTBuilder {
if (!pftParentStack.empty()) {
pftParentStack.back().visit(common::visitors{
[&](lower::pft::FunctionLikeUnit &p) {
functionList = &p.nestedFunctions;
nestedUnitList = &p.nestedUnits;
labelEvaluationMap = &p.labelEvaluationMap;
assignSymbolLabelMap = &p.assignSymbolLabelMap;
},
[&](lower::pft::ModuleLikeUnit &p) {
functionList = &p.nestedFunctions;
nestedUnitList = &p.nestedUnits;
},
[&](auto &) { functionList = nullptr; },
[&](auto &) { nestedUnitList = nullptr; },
});
}
}
Expand All @@ -432,11 +438,16 @@ class PFTBuilder {
return std::get<A>(pgm->getUnits().back());
}

template <typename A>
void addNestedUnit(A &&source) {
nestedUnitList->emplace_back(lower::pft::NestedUnit{std::move(source)});
}

template <typename A>
A &addFunction(A &&func) {
if (functionList) {
functionList->emplace_back(std::move(func));
return functionList->back();
if (nestedUnitList) {
addNestedUnit(func);
return std::get<A>(nestedUnitList->back());
}
return addUnit(std::move(func));
}
Expand All @@ -459,7 +470,7 @@ class PFTBuilder {

/// Append an Evaluation to the end of the current list.
lower::pft::Evaluation &addEvaluation(lower::pft::Evaluation &&eval) {
assert(functionList && "not in a function");
assert(nestedUnitList && "not in a function");
assert(!evaluationListStack.empty() && "empty evaluation list stack");
if (!constructAndDirectiveStack.empty())
eval.parentConstruct = constructAndDirectiveStack.back();
Expand Down Expand Up @@ -499,15 +510,15 @@ class PFTBuilder {

/// push a new list on the stack of Evaluation lists
void pushEvaluationList(lower::pft::EvaluationList *evaluationList) {
assert(functionList && "not in a function");
assert(nestedUnitList && "not in a function");
assert(evaluationList && evaluationList->empty() &&
"evaluation list isn't correct");
evaluationListStack.emplace_back(evaluationList);
}

/// pop the current list and return to the last Evaluation list
void popEvaluationList() {
assert(functionList && "not in a function");
assert(nestedUnitList && "not in a function");
evaluationListStack.pop_back();
}

Expand Down Expand Up @@ -1088,9 +1099,9 @@ class PFTBuilder {
std::vector<lower::pft::PftNode> pftParentStack;
const semantics::SemanticsContext &semanticsContext;

/// functionList points to the internal or module procedure function list
/// of a FunctionLikeUnit or a ModuleLikeUnit. It may be null.
std::list<lower::pft::FunctionLikeUnit> *functionList{};
/// nestedUnitList points to the internal or module procedure unit list
/// of nested units (e.g. functions). It may be null.
std::list<lower::pft::NestedUnit> *nestedUnitList{};
std::vector<lower::pft::Evaluation *> constructAndDirectiveStack{};
std::vector<lower::pft::Evaluation *> doConstructStack{};
/// evaluationListStack is the current nested construct evaluationList state.
Expand Down Expand Up @@ -1264,11 +1275,17 @@ class PFTDumper {
outputStream << ": " << header;
outputStream << '\n';
dumpEvaluationList(outputStream, functionLikeUnit.evaluationList);
if (!functionLikeUnit.nestedFunctions.empty()) {
if (!functionLikeUnit.nestedUnits.empty()) {
outputStream << "\nContains\n";
for (const lower::pft::FunctionLikeUnit &func :
functionLikeUnit.nestedFunctions)
dumpFunctionLikeUnit(outputStream, func);
for (const lower::pft::NestedUnit &nested :
functionLikeUnit.nestedUnits) {
if (const auto *func =
std::get_if<lower::pft::FunctionLikeUnit>(&nested))
dumpFunctionLikeUnit(outputStream, *func);
if (const auto *directive =
std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
dumpCompilerDirectiveUnit(outputStream, *directive);
}
outputStream << "End Contains\n";
}
outputStream << "End " << unitKind << ' ' << name << "\n\n";
Expand Down Expand Up @@ -1298,9 +1315,13 @@ class PFTDumper {
outputStream << unitKind << ' ' << name << ": " << header << '\n';
dumpEvaluationList(outputStream, moduleLikeUnit.evaluationList);
outputStream << "Contains\n";
for (const lower::pft::FunctionLikeUnit &func :
moduleLikeUnit.nestedFunctions)
dumpFunctionLikeUnit(outputStream, func);
for (const lower::pft::NestedUnit &nested : moduleLikeUnit.nestedUnits) {
if (const auto *func = std::get_if<lower::pft::FunctionLikeUnit>(&nested))
dumpFunctionLikeUnit(outputStream, *func);
if (const auto *directive =
std::get_if<lower::pft::CompilerDirectiveUnit>(&nested))
dumpCompilerDirectiveUnit(outputStream, *directive);
}
outputStream << "End Contains\nEnd " << unitKind << ' ' << name << "\n\n";
}

Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Parser/program-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ TYPE_CONTEXT_PARSER("module subprogram part"_en_US,
// separate-module-subprogram
TYPE_PARSER(construct<ModuleSubprogram>(indirect(functionSubprogram)) ||
construct<ModuleSubprogram>(indirect(subroutineSubprogram)) ||
construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})))
construct<ModuleSubprogram>(indirect(Parser<SeparateModuleSubprogram>{})) ||
construct<ModuleSubprogram>(indirect(compilerDirective)))

// R1410 module-nature -> INTRINSIC | NON_INTRINSIC
constexpr auto moduleNature{
Expand Down
6 changes: 5 additions & 1 deletion flang/lib/Semantics/program-tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ static ProgramTree BuildModuleTree(const parser::Name &name, const T &x) {
for (const auto &subp :
std::get<std::list<parser::ModuleSubprogram>>(subps->t)) {
common::visit(
[&](const auto &y) { node.AddChild(ProgramTree::Build(y.value())); },
common::visitors{
[&](const common::Indirection<parser::CompilerDirective> &) {},
[&](const auto &y) {
node.AddChild(ProgramTree::Build(y.value()));
}},
subp.u);
}
}
Expand Down
Loading