Skip to content

[flang][openacc] Allow acc routine at the top level #69936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flang/docs/OpenACC.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ local:
warning instead of an error as other compiler accepts it.
* The `if` clause accepts scalar integer expression in addition to scalar
logical expression.
* `!$acc routine` directive can be placed at the top level.
6 changes: 6 additions & 0 deletions flang/include/flang/Lower/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace Fortran {
namespace parser {
struct OpenACCConstruct;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
} // namespace parser

namespace semantics {
Expand Down Expand Up @@ -71,6 +72,11 @@ void genOpenACCDeclarativeConstruct(AbstractConverter &,
StatementContext &,
const parser::OpenACCDeclarativeConstruct &,
AccRoutineInfoMappingList &);
void genOpenACCRoutineConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
mlir::ModuleOp &,
const parser::OpenACCRoutineConstruct &,
AccRoutineInfoMappingList &);

void finalizeOpenACCRoutineAttachment(mlir::ModuleOp &,
AccRoutineInfoMappingList &);
Expand Down
16 changes: 14 additions & 2 deletions flang/include/flang/Lower/PFTBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ using Constructs =

using Directives =
std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
parser::OpenACCRoutineConstruct,
parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;

Expand Down Expand Up @@ -360,7 +361,8 @@ using ProgramVariant =
ReferenceVariant<parser::MainProgram, parser::FunctionSubprogram,
parser::SubroutineSubprogram, parser::Module,
parser::Submodule, parser::SeparateModuleSubprogram,
parser::BlockData, parser::CompilerDirective>;
parser::BlockData, parser::CompilerDirective,
parser::OpenACCRoutineConstruct>;
/// A program is a list of program units.
/// These units can be function like, module like, or block data.
struct ProgramUnit : ProgramVariant {
Expand Down Expand Up @@ -763,10 +765,20 @@ struct CompilerDirectiveUnit : public ProgramUnit {
CompilerDirectiveUnit(const CompilerDirectiveUnit &) = delete;
};

// Top level OpenACC routine directives
struct OpenACCDirectiveUnit : public ProgramUnit {
OpenACCDirectiveUnit(const parser::OpenACCRoutineConstruct &directive,
const PftNode &parent)
: ProgramUnit{directive, parent}, routine{directive} {};
OpenACCDirectiveUnit(OpenACCDirectiveUnit &&) = default;
OpenACCDirectiveUnit(const OpenACCDirectiveUnit &) = delete;
const parser::OpenACCRoutineConstruct &routine;
};

/// A Program is the top-level root of the PFT.
struct Program {
using Units = std::variant<FunctionLikeUnit, ModuleLikeUnit, BlockDataUnit,
CompilerDirectiveUnit>;
CompilerDirectiveUnit, OpenACCDirectiveUnit>;

Program(semantics::CommonBlockList &&commonBlocks)
: commonBlocks{std::move(commonBlocks)} {}
Expand Down
4 changes: 3 additions & 1 deletion flang/include/flang/Parser/parse-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ struct PauseStmt;
struct OpenACCConstruct;
struct AccEndCombinedDirective;
struct OpenACCDeclarativeConstruct;
struct OpenACCRoutineConstruct;
struct OpenMPConstruct;
struct OpenMPDeclarativeConstruct;
struct OmpEndLoopDirective;
Expand Down Expand Up @@ -558,7 +559,8 @@ struct ProgramUnit {
common::Indirection<FunctionSubprogram>,
common::Indirection<SubroutineSubprogram>, common::Indirection<Module>,
common::Indirection<Submodule>, common::Indirection<BlockData>,
common::Indirection<CompilerDirective>>
common::Indirection<CompilerDirective>,
common::Indirection<OpenACCRoutineConstruct>>
u;
};

Expand Down
13 changes: 13 additions & 0 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
globalOmpRequiresSymbol = b.symTab.symbol();
},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
},
u);
}
Expand All @@ -328,6 +329,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
[&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
[&](Fortran::lower::pft::BlockDataUnit &b) {},
[&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
[&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
builder = new fir::FirOpBuilder(bridge.getModule(),
bridge.getKindMap());
Fortran::lower::genOpenACCRoutineConstruct(
*this, bridge.getSemanticsContext(), bridge.getModule(),
d.routine, accRoutineInfos);
builder = nullptr;
},
},
u);
}
Expand Down Expand Up @@ -2320,6 +2329,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genFIR(e);
}

void genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
// Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
}

void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
Expand Down
22 changes: 11 additions & 11 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3143,29 +3143,26 @@ static void attachRoutineInfo(mlir::func::FuncOp func,
mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
}

static void
genACC(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
void Fortran::lower::genOpenACCRoutineConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp &mod,
const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = converter.genLocation(routineConstruct.source);
std::optional<Fortran::parser::Name> name =
std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
const auto &clauses =
std::get<Fortran::parser::AccClauseList>(routineConstruct.t);

mlir::ModuleOp mod = builder.getModule();
mlir::func::FuncOp funcOp;
std::string funcName;
if (name) {
funcName = converter.mangleName(*name->symbol);
funcOp = builder.getNamedFunction(funcName);
funcOp = builder.getNamedFunction(mod, funcName);
} else {
funcOp = builder.getFunction();
funcName = funcOp.getName();
}

bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
hasNohost = false;
std::optional<std::string> bindName = std::nullopt;
Expand Down Expand Up @@ -3381,8 +3378,11 @@ void Fortran::lower::genOpenACCDeclarativeConstruct(
},
[&](const Fortran::parser::OpenACCRoutineConstruct
&routineConstruct) {
genACC(converter, semanticsContext, routineConstruct,
accRoutineInfos);
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::ModuleOp mod = builder.getModule();
Fortran::lower::genOpenACCRoutineConstruct(
converter, semanticsContext, mod, routineConstruct,
accRoutineInfos);
},
},
accDeclConstruct.u);
Expand Down
24 changes: 24 additions & 0 deletions flang/lib/Lower/PFTBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,17 @@ class PFTBuilder {
return enterConstructOrDirective(directive);
}

bool Pre(const parser::OpenACCRoutineConstruct &directive) {
assert(pftParentStack.size() > 0 &&
"At least the Program must be a parent");
if (pftParentStack.back().isA<lower::pft::Program>()) {
addUnit(
lower::pft::OpenACCDirectiveUnit(directive, pftParentStack.back()));
return false;
}
return enterConstructOrDirective(directive);
}

private:
/// Initialize a new module-like unit and make it the builder's focus.
template <typename A>
Expand Down Expand Up @@ -1133,6 +1144,9 @@ class PFTDumper {
[&](const lower::pft::CompilerDirectiveUnit &unit) {
dumpCompilerDirectiveUnit(outputStream, unit);
},
[&](const lower::pft::OpenACCDirectiveUnit &unit) {
dumpOpenACCDirectiveUnit(outputStream, unit);
},
},
unit);
}
Expand Down Expand Up @@ -1280,6 +1294,16 @@ class PFTDumper {
outputStream << "\nEnd CompilerDirective\n\n";
}

void
dumpOpenACCDirectiveUnit(llvm::raw_ostream &outputStream,
const lower::pft::OpenACCDirectiveUnit &directive) {
outputStream << getNodeIndex(directive) << " ";
outputStream << "OpenACCDirective: !$acc ";
outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
.source.ToString();
outputStream << "\nEnd OpenACCDirective\n\n";
}

template <typename T>
std::size_t getNodeIndex(const T &node) {
auto addr = static_cast<const void *>(&node);
Expand Down
7 changes: 6 additions & 1 deletion flang/lib/Parser/program-parsers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ static constexpr auto normalProgramUnit{StartNewSubprogram{} >> programUnit /
static constexpr auto globalCompilerDirective{
construct<ProgramUnit>(indirect(compilerDirective))};

static constexpr auto globalOpenACCCompilerDirective{
construct<ProgramUnit>(indirect(skipStuffBeforeStatement >>
"!$ACC "_sptok >> Parser<OpenACCRoutineConstruct>{}))};

// R501 program -> program-unit [program-unit]...
// This is the top-level production for the Fortran language.
// F'2018 6.3.1 defines a program unit as a sequence of one or more lines,
Expand All @@ -58,7 +62,8 @@ TYPE_PARSER(
"nonstandard usage: empty source file"_port_en_US,
skipStuffBeforeStatement >> !nextCh >>
pure<std::list<ProgramUnit>>()) ||
some(globalCompilerDirective || normalProgramUnit) /
some(globalCompilerDirective || globalOpenACCCompilerDirective ||
normalProgramUnit) /
skipStuffBeforeStatement))

// R504 specification-part ->
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Semantics/program-tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ ProgramTree ProgramTree::Build(const parser::CompilerDirective &) {
DIE("ProgramTree::Build() called for CompilerDirective");
}

ProgramTree ProgramTree::Build(const parser::OpenACCRoutineConstruct &) {
DIE("ProgramTree::Build() called for OpenACCRoutineConstruct");
}

const parser::ParentIdentifier &ProgramTree::GetParentId() const {
const auto *stmt{
std::get<const parser::Statement<parser::SubmoduleStmt> *>(stmt_)};
Expand Down
1 change: 1 addition & 0 deletions flang/lib/Semantics/program-tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ProgramTree {
static ProgramTree Build(const parser::Submodule &);
static ProgramTree Build(const parser::BlockData &);
static ProgramTree Build(const parser::CompilerDirective &);
static ProgramTree Build(const parser::OpenACCRoutineConstruct &);

ENUM_CLASS(Kind, // kind of node
Program, Function, Subroutine, MpSubprogram, Module, Submodule, BlockData)
Expand Down
29 changes: 19 additions & 10 deletions flang/lib/Semantics/resolve-directives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ template <typename T> class DirectiveAttributeVisitor {
? std::nullopt
: std::make_optional<DirContext>(dirContext_.back());
}
void PushContext(const parser::CharBlock &source, T dir, Scope &scope) {
dirContext_.emplace_back(source, dir, scope);
}
void PushContext(const parser::CharBlock &source, T dir) {
dirContext_.emplace_back(source, dir, context_.FindScope(source));
}
Expand Down Expand Up @@ -115,8 +118,8 @@ template <typename T> class DirectiveAttributeVisitor {

class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
public:
explicit AccAttributeVisitor(SemanticsContext &context)
: DirectiveAttributeVisitor(context) {}
explicit AccAttributeVisitor(SemanticsContext &context, Scope *topScope)
: DirectiveAttributeVisitor(context), topScope_(topScope) {}

template <typename A> void Walk(const A &x) { parser::Walk(x, *this); }
template <typename A> bool Pre(const A &) { return true; }
Expand Down Expand Up @@ -281,6 +284,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
const llvm::acc::Clause clause, const parser::AccObjectList &objectList);
void AddRoutineInfoToSymbol(
Symbol &, const parser::OpenACCRoutineConstruct &);
Scope *topScope_;
};

// Data-sharing and Data-mapping attributes for data-refs in OpenMP construct
Expand Down Expand Up @@ -802,10 +806,6 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCDeclarativeConstruct &x) {
const auto &declDir{
std::get<parser::AccDeclarativeDirective>(declConstruct->t)};
PushContext(declDir.source, llvm::acc::Directive::ACCD_declare);
} else if (const auto *routineConstruct{
std::get_if<parser::OpenACCRoutineConstruct>(&x.u)}) {
const auto &verbatim{std::get<parser::Verbatim>(routineConstruct->t)};
PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
}
ClearDataSharingAttributeObjects();
return true;
Expand Down Expand Up @@ -994,6 +994,13 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
}

bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
const auto &verbatim{std::get<parser::Verbatim>(x.t)};
if (topScope_) {
PushContext(
verbatim.source, llvm::acc::Directive::ACCD_routine, *topScope_);
} else {
PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
}
const auto &optName{std::get<std::optional<parser::Name>>(x.t)};
if (optName) {
if (Symbol *sym = ResolveFctName(*optName)) {
Expand All @@ -1005,7 +1012,9 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
(*optName).source);
}
} else {
AddRoutineInfoToSymbol(*currScope().symbol(), x);
if (currScope().symbol()) {
AddRoutineInfoToSymbol(*currScope().symbol(), x);
}
}
return true;
}
Expand Down Expand Up @@ -2190,10 +2199,10 @@ void OmpAttributeVisitor::CheckMultipleAppearances(
}
}

void ResolveAccParts(
SemanticsContext &context, const parser::ProgramUnit &node) {
void ResolveAccParts(SemanticsContext &context, const parser::ProgramUnit &node,
Scope *topScope) {
if (context.IsEnabled(common::LanguageFeature::OpenACC)) {
AccAttributeVisitor{context}.Walk(node);
AccAttributeVisitor{context, topScope}.Walk(node);
}
}

Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Semantics/resolve-directives.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ struct ProgramUnit;
} // namespace Fortran::parser

namespace Fortran::semantics {

class Scope;
class SemanticsContext;

// Name resolution for OpenACC and OpenMP directives
void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &);
void ResolveAccParts(
SemanticsContext &, const parser::ProgramUnit &, Scope *topScope = {});
void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);

Expand Down
8 changes: 7 additions & 1 deletion flang/lib/Semantics/resolve-names.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8323,6 +8323,11 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
// TODO: global directives
return true;
}
if (std::holds_alternative<
common::Indirection<parser::OpenACCRoutineConstruct>>(x.u)) {
ResolveAccParts(context(), x, &topScope_);
return false;
}
auto root{ProgramTree::Build(x)};
SetScope(topScope_);
ResolveSpecificationParts(root);
Expand All @@ -8335,7 +8340,8 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {

template <typename A> std::set<SourceName> GetUses(const A &x) {
std::set<SourceName> uses;
if constexpr (!std::is_same_v<A, parser::CompilerDirective>) {
if constexpr (!std::is_same_v<A, parser::CompilerDirective> &&
!std::is_same_v<A, parser::OpenACCRoutineConstruct>) {
const auto &spec{std::get<parser::SpecificationPart>(x.t)};
const auto &unitUses{std::get<
std::list<parser::Statement<common::Indirection<parser::UseStmt>>>>(
Expand Down
21 changes: 21 additions & 0 deletions flang/test/Lower/OpenACC/acc-routine02.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
! This test checks lowering of OpenACC routine directive.

! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s

subroutine sub1(a, n)
integer :: n
real :: a(n)
end subroutine sub1

!$acc routine(sub1)

program test
integer, parameter :: N = 10
real :: a(N)
call sub1(a, N)
end program

! CHECK-LABEL: acc.routine @acc_routine_0 func(@_QPsub1)

! CHECK: func.func @_QPsub1(%ar{{.*}}: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>}
Loading