Skip to content

[flang][openacc] Allow acc routine at the top level #69936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 24, 2023

Conversation

clementval
Copy link
Contributor

Some compilers allow the $acc routine(<name>) to be placed at the program unit level. To be compatible, this patch enables the use of acc routine at this level. These acc routine directives must have a name.

Some compilers allow the !$acc routine to be placed at the program unit level.
To be comptabile, this patch enables the use of acc routine at this level.
These acc routine directives must have a name.
@llvmbot
Copy link
Member

llvmbot commented Oct 23, 2023

@llvm/pr-subscribers-flang-semantics
@llvm/pr-subscribers-flang-parser
@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-flang-openmp

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Some compilers allow the $acc routine(&lt;name&gt;) to be placed at the program unit level. To be compatible, this patch enables the use of acc routine at this level. These acc routine directives must have a name.


Full diff: https://github.com/llvm/llvm-project/pull/69936.diff

15 Files Affected:

  • (modified) flang/docs/OpenACC.md (+1)
  • (modified) flang/include/flang/Lower/OpenACC.h (+6)
  • (modified) flang/include/flang/Lower/PFTBuilder.h (+14-2)
  • (modified) flang/include/flang/Parser/parse-tree.h (+3-1)
  • (modified) flang/lib/Lower/Bridge.cpp (+13)
  • (modified) flang/lib/Lower/OpenACC.cpp (+11-11)
  • (modified) flang/lib/Lower/PFTBuilder.cpp (+24)
  • (modified) flang/lib/Parser/program-parsers.cpp (+6-1)
  • (modified) flang/lib/Semantics/program-tree.cpp (+4)
  • (modified) flang/lib/Semantics/program-tree.h (+1)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+19-10)
  • (modified) flang/lib/Semantics/resolve-directives.h (+3-2)
  • (modified) flang/lib/Semantics/resolve-names.cpp (+7-1)
  • (added) flang/test/Lower/OpenACC/acc-routine02.f90 (+21)
  • (added) flang/test/Semantics/OpenACC/acc-routine-validity02.f90 (+17)
diff --git a/flang/docs/OpenACC.md b/flang/docs/OpenACC.md
index e29c5f89f9b2482..41c974f837421b9 100644
--- a/flang/docs/OpenACC.md
+++ b/flang/docs/OpenACC.md
@@ -23,3 +23,4 @@ local:
   warning instead of an error as other compiler accepts it.
 * The `if` clause accepts scalar integer expression in addition to scalar
   logical expression.
+* `!$acc routine` directive can be placed at the top level. 
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index c73af0a6eb0f874..409956f0ecb309f 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -37,6 +37,7 @@ namespace Fortran {
 namespace parser {
 struct OpenACCConstruct;
 struct OpenACCDeclarativeConstruct;
+struct OpenACCRoutineConstruct;
 } // namespace parser
 
 namespace semantics {
@@ -71,6 +72,11 @@ void genOpenACCDeclarativeConstruct(AbstractConverter &,
                                     StatementContext &,
                                     const parser::OpenACCDeclarativeConstruct &,
                                     AccRoutineInfoMappingList &);
+void genOpenACCRoutineConstruct(AbstractConverter &,
+                                Fortran::semantics::SemanticsContext &,
+                                mlir::ModuleOp &,
+                                const parser::OpenACCRoutineConstruct &,
+                                AccRoutineInfoMappingList &);
 
 void finalizeOpenACCRoutineAttachment(mlir::ModuleOp &,
                                       AccRoutineInfoMappingList &);
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index 5927fc1915ae34d..6f68dc7c9f525f1 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -135,6 +135,7 @@ using Constructs =
 
 using Directives =
     std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
+               parser::OpenACCRoutineConstruct,
                parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
                parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
 
@@ -360,7 +361,8 @@ using ProgramVariant =
     ReferenceVariant<parser::MainProgram, parser::FunctionSubprogram,
                      parser::SubroutineSubprogram, parser::Module,
                      parser::Submodule, parser::SeparateModuleSubprogram,
-                     parser::BlockData, parser::CompilerDirective>;
+                     parser::BlockData, parser::CompilerDirective,
+                     parser::OpenACCRoutineConstruct>;
 /// A program is a list of program units.
 /// These units can be function like, module like, or block data.
 struct ProgramUnit : ProgramVariant {
@@ -763,10 +765,20 @@ struct CompilerDirectiveUnit : public ProgramUnit {
   CompilerDirectiveUnit(const CompilerDirectiveUnit &) = delete;
 };
 
+// Top level OpenACC routine directives
+struct OpenACCDirectiveUnit : public ProgramUnit {
+  OpenACCDirectiveUnit(const parser::OpenACCRoutineConstruct &directive,
+                       const PftNode &parent)
+      : ProgramUnit{directive, parent}, routine{directive} {};
+  OpenACCDirectiveUnit(OpenACCDirectiveUnit &&) = default;
+  OpenACCDirectiveUnit(const OpenACCDirectiveUnit &) = delete;
+  const parser::OpenACCRoutineConstruct &routine;
+};
+
 /// A Program is the top-level root of the PFT.
 struct Program {
   using Units = std::variant<FunctionLikeUnit, ModuleLikeUnit, BlockDataUnit,
-                             CompilerDirectiveUnit>;
+                             CompilerDirectiveUnit, OpenACCDirectiveUnit>;
 
   Program(semantics::CommonBlockList &&commonBlocks)
       : commonBlocks{std::move(commonBlocks)} {}
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index a51921f2562b8a9..83c8db936934a03 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -262,6 +262,7 @@ struct PauseStmt;
 struct OpenACCConstruct;
 struct AccEndCombinedDirective;
 struct OpenACCDeclarativeConstruct;
+struct OpenACCRoutineConstruct;
 struct OpenMPConstruct;
 struct OpenMPDeclarativeConstruct;
 struct OmpEndLoopDirective;
@@ -558,7 +559,8 @@ struct ProgramUnit {
       common::Indirection<FunctionSubprogram>,
       common::Indirection<SubroutineSubprogram>, common::Indirection<Module>,
       common::Indirection<Submodule>, common::Indirection<BlockData>,
-      common::Indirection<CompilerDirective>>
+      common::Indirection<CompilerDirective>,
+      common::Indirection<OpenACCRoutineConstruct>>
       u;
 };
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index f26a1aaf0236fa5..ccd54e349786e51 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -316,6 +316,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
                          globalOmpRequiresSymbol = b.symTab.symbol();
                      },
                      [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+                     [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {},
                  },
                  u);
     }
@@ -328,6 +329,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
               [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerMod(m); },
               [&](Fortran::lower::pft::BlockDataUnit &b) {},
               [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
+              [&](Fortran::lower::pft::OpenACCDirectiveUnit &d) {
+                builder = new fir::FirOpBuilder(bridge.getModule(),
+                                                bridge.getKindMap());
+                Fortran::lower::genOpenACCRoutineConstruct(
+                    *this, bridge.getSemanticsContext(), bridge.getModule(),
+                    d.routine, accRoutineInfos);
+                builder = nullptr;
+              },
           },
           u);
     }
@@ -2320,6 +2329,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
       genFIR(e);
   }
 
+  void genFIR(const Fortran::parser::OpenACCRoutineConstruct &acc) {
+    // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
+  }
+
   void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4fafcebc30d116a..bfef5e2226ed4be 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3143,29 +3143,26 @@ static void attachRoutineInfo(mlir::func::FuncOp func,
       mlir::acc::RoutineInfoAttr::get(func.getContext(), routines));
 }
 
-static void
-genACC(Fortran::lower::AbstractConverter &converter,
-       Fortran::semantics::SemanticsContext &semanticsContext,
-       const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
-       Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
+void Fortran::lower::genOpenACCRoutineConstruct(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::semantics::SemanticsContext &semanticsContext, mlir::ModuleOp &mod,
+    const Fortran::parser::OpenACCRoutineConstruct &routineConstruct,
+    Fortran::lower::AccRoutineInfoMappingList &accRoutineInfos) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
   mlir::Location loc = converter.genLocation(routineConstruct.source);
   std::optional<Fortran::parser::Name> name =
       std::get<std::optional<Fortran::parser::Name>>(routineConstruct.t);
   const auto &clauses =
       std::get<Fortran::parser::AccClauseList>(routineConstruct.t);
-
-  mlir::ModuleOp mod = builder.getModule();
   mlir::func::FuncOp funcOp;
   std::string funcName;
   if (name) {
     funcName = converter.mangleName(*name->symbol);
-    funcOp = builder.getNamedFunction(funcName);
+    funcOp = builder.getNamedFunction(mod, funcName);
   } else {
     funcOp = builder.getFunction();
     funcName = funcOp.getName();
   }
-
   bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
        hasNohost = false;
   std::optional<std::string> bindName = std::nullopt;
@@ -3381,8 +3378,11 @@ void Fortran::lower::genOpenACCDeclarativeConstruct(
           },
           [&](const Fortran::parser::OpenACCRoutineConstruct
                   &routineConstruct) {
-            genACC(converter, semanticsContext, routineConstruct,
-                   accRoutineInfos);
+            fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+            mlir::ModuleOp mod = builder.getModule();
+            Fortran::lower::genOpenACCRoutineConstruct(
+                converter, semanticsContext, mod, routineConstruct,
+                accRoutineInfos);
           },
       },
       accDeclConstruct.u);
diff --git a/flang/lib/Lower/PFTBuilder.cpp b/flang/lib/Lower/PFTBuilder.cpp
index 97afdaf49b672a9..0946a85dcaddd88 100644
--- a/flang/lib/Lower/PFTBuilder.cpp
+++ b/flang/lib/Lower/PFTBuilder.cpp
@@ -241,6 +241,17 @@ class PFTBuilder {
     return enterConstructOrDirective(directive);
   }
 
+  bool Pre(const parser::OpenACCRoutineConstruct &directive) {
+    assert(pftParentStack.size() > 0 &&
+           "At least the Program must be a parent");
+    if (pftParentStack.back().isA<lower::pft::Program>()) {
+      addUnit(
+          lower::pft::OpenACCDirectiveUnit(directive, pftParentStack.back()));
+      return false;
+    }
+    return enterConstructOrDirective(directive);
+  }
+
 private:
   /// Initialize a new module-like unit and make it the builder's focus.
   template <typename A>
@@ -1133,6 +1144,9 @@ class PFTDumper {
                      [&](const lower::pft::CompilerDirectiveUnit &unit) {
                        dumpCompilerDirectiveUnit(outputStream, unit);
                      },
+                     [&](const lower::pft::OpenACCDirectiveUnit &unit) {
+                       dumpOpenACCDirectiveUnit(outputStream, unit);
+                     },
                  },
                  unit);
     }
@@ -1280,6 +1294,16 @@ class PFTDumper {
     outputStream << "\nEnd CompilerDirective\n\n";
   }
 
+  void
+  dumpOpenACCDirectiveUnit(llvm::raw_ostream &outputStream,
+                           const lower::pft::OpenACCDirectiveUnit &directive) {
+    outputStream << getNodeIndex(directive) << " ";
+    outputStream << "OpenACCDirective: !$acc ";
+    outputStream << directive.get<Fortran::parser::OpenACCRoutineConstruct>()
+                        .source.ToString();
+    outputStream << "\nEnd OpenACCDirective\n\n";
+  }
+
   template <typename T>
   std::size_t getNodeIndex(const T &node) {
     auto addr = static_cast<const void *>(&node);
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index 521ae43097adc6a..e24559bf14f7c92 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -46,6 +46,10 @@ static constexpr auto normalProgramUnit{StartNewSubprogram{} >> programUnit /
 static constexpr auto globalCompilerDirective{
     construct<ProgramUnit>(indirect(compilerDirective))};
 
+static constexpr auto globalOpenACCCompilerDirective{
+    construct<ProgramUnit>(indirect(skipStuffBeforeStatement >>
+        "!$ACC "_sptok >> Parser<OpenACCRoutineConstruct>{}))};
+
 // R501 program -> program-unit [program-unit]...
 // This is the top-level production for the Fortran language.
 // F'2018 6.3.1 defines a program unit as a sequence of one or more lines,
@@ -58,7 +62,8 @@ TYPE_PARSER(
                            "nonstandard usage: empty source file"_port_en_US,
                            skipStuffBeforeStatement >> !nextCh >>
                                pure<std::list<ProgramUnit>>()) ||
-        some(globalCompilerDirective || normalProgramUnit) /
+        some(globalCompilerDirective || globalOpenACCCompilerDirective ||
+            normalProgramUnit) /
             skipStuffBeforeStatement))
 
 // R504 specification-part ->
diff --git a/flang/lib/Semantics/program-tree.cpp b/flang/lib/Semantics/program-tree.cpp
index cd631da93d698be..bf773f3810c847b 100644
--- a/flang/lib/Semantics/program-tree.cpp
+++ b/flang/lib/Semantics/program-tree.cpp
@@ -200,6 +200,10 @@ ProgramTree ProgramTree::Build(const parser::CompilerDirective &) {
   DIE("ProgramTree::Build() called for CompilerDirective");
 }
 
+ProgramTree ProgramTree::Build(const parser::OpenACCRoutineConstruct &) {
+  DIE("ProgramTree::Build() called for OpenACCRoutineConstruct");
+}
+
 const parser::ParentIdentifier &ProgramTree::GetParentId() const {
   const auto *stmt{
       std::get<const parser::Statement<parser::SubmoduleStmt> *>(stmt_)};
diff --git a/flang/lib/Semantics/program-tree.h b/flang/lib/Semantics/program-tree.h
index 4bf6567c6adfe8b..d49b0405d8b122d 100644
--- a/flang/lib/Semantics/program-tree.h
+++ b/flang/lib/Semantics/program-tree.h
@@ -43,6 +43,7 @@ class ProgramTree {
   static ProgramTree Build(const parser::Submodule &);
   static ProgramTree Build(const parser::BlockData &);
   static ProgramTree Build(const parser::CompilerDirective &);
+  static ProgramTree Build(const parser::OpenACCRoutineConstruct &);
 
   ENUM_CLASS(Kind, // kind of node
       Program, Function, Subroutine, MpSubprogram, Module, Submodule, BlockData)
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index f7720fcf43e5768..2a6523f579f2773 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -61,6 +61,9 @@ template <typename T> class DirectiveAttributeVisitor {
         ? std::nullopt
         : std::make_optional<DirContext>(dirContext_.back());
   }
+  void PushContext(const parser::CharBlock &source, T dir, Scope &scope) {
+    dirContext_.emplace_back(source, dir, scope);
+  }
   void PushContext(const parser::CharBlock &source, T dir) {
     dirContext_.emplace_back(source, dir, context_.FindScope(source));
   }
@@ -115,8 +118,8 @@ template <typename T> class DirectiveAttributeVisitor {
 
 class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
 public:
-  explicit AccAttributeVisitor(SemanticsContext &context)
-      : DirectiveAttributeVisitor(context) {}
+  explicit AccAttributeVisitor(SemanticsContext &context, Scope *topScope)
+      : DirectiveAttributeVisitor(context), topScope_(topScope) {}
 
   template <typename A> void Walk(const A &x) { parser::Walk(x, *this); }
   template <typename A> bool Pre(const A &) { return true; }
@@ -281,6 +284,7 @@ class AccAttributeVisitor : DirectiveAttributeVisitor<llvm::acc::Directive> {
       const llvm::acc::Clause clause, const parser::AccObjectList &objectList);
   void AddRoutineInfoToSymbol(
       Symbol &, const parser::OpenACCRoutineConstruct &);
+  Scope *topScope_;
 };
 
 // Data-sharing and Data-mapping attributes for data-refs in OpenMP construct
@@ -802,10 +806,6 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCDeclarativeConstruct &x) {
     const auto &declDir{
         std::get<parser::AccDeclarativeDirective>(declConstruct->t)};
     PushContext(declDir.source, llvm::acc::Directive::ACCD_declare);
-  } else if (const auto *routineConstruct{
-                 std::get_if<parser::OpenACCRoutineConstruct>(&x.u)}) {
-    const auto &verbatim{std::get<parser::Verbatim>(routineConstruct->t)};
-    PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
   }
   ClearDataSharingAttributeObjects();
   return true;
@@ -994,6 +994,13 @@ void AccAttributeVisitor::AddRoutineInfoToSymbol(
 }
 
 bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
+  const auto &verbatim{std::get<parser::Verbatim>(x.t)};
+  if (topScope_) {
+    PushContext(
+        verbatim.source, llvm::acc::Directive::ACCD_routine, *topScope_);
+  } else {
+    PushContext(verbatim.source, llvm::acc::Directive::ACCD_routine);
+  }
   const auto &optName{std::get<std::optional<parser::Name>>(x.t)};
   if (optName) {
     if (Symbol *sym = ResolveFctName(*optName)) {
@@ -1005,7 +1012,9 @@ bool AccAttributeVisitor::Pre(const parser::OpenACCRoutineConstruct &x) {
           (*optName).source);
     }
   } else {
-    AddRoutineInfoToSymbol(*currScope().symbol(), x);
+    if (currScope().symbol()) {
+      AddRoutineInfoToSymbol(*currScope().symbol(), x);
+    }
   }
   return true;
 }
@@ -2190,10 +2199,10 @@ void OmpAttributeVisitor::CheckMultipleAppearances(
   }
 }
 
-void ResolveAccParts(
-    SemanticsContext &context, const parser::ProgramUnit &node) {
+void ResolveAccParts(SemanticsContext &context, const parser::ProgramUnit &node,
+    Scope *topScope) {
   if (context.IsEnabled(common::LanguageFeature::OpenACC)) {
-    AccAttributeVisitor{context}.Walk(node);
+    AccAttributeVisitor{context, topScope}.Walk(node);
   }
 }
 
diff --git a/flang/lib/Semantics/resolve-directives.h b/flang/lib/Semantics/resolve-directives.h
index 839165aaf30eb81..4aef8ad6c40081a 100644
--- a/flang/lib/Semantics/resolve-directives.h
+++ b/flang/lib/Semantics/resolve-directives.h
@@ -16,11 +16,12 @@ struct ProgramUnit;
 } // namespace Fortran::parser
 
 namespace Fortran::semantics {
-
+class Scope;
 class SemanticsContext;
 
 // Name resolution for OpenACC and OpenMP directives
-void ResolveAccParts(SemanticsContext &, const parser::ProgramUnit &);
+void ResolveAccParts(
+    SemanticsContext &, const parser::ProgramUnit &, Scope *topScope = {});
 void ResolveOmpParts(SemanticsContext &, const parser::ProgramUnit &);
 void ResolveOmpTopLevelParts(SemanticsContext &, const parser::Program &);
 
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index 90c14806afbf82d..9f7a59b0b454a8a 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -8323,6 +8323,11 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
     // TODO: global directives
     return true;
   }
+  if (std::holds_alternative<
+          common::Indirection<parser::OpenACCRoutineConstruct>>(x.u)) {
+    ResolveAccParts(context(), x, &topScope_);
+    return false;
+  }
   auto root{ProgramTree::Build(x)};
   SetScope(topScope_);
   ResolveSpecificationParts(root);
@@ -8335,7 +8340,8 @@ bool ResolveNamesVisitor::Pre(const parser::ProgramUnit &x) {
 
 template <typename A> std::set<SourceName> GetUses(const A &x) {
   std::set<SourceName> uses;
-  if constexpr (!std::is_same_v<A, parser::CompilerDirective>) {
+  if constexpr (!std::is_same_v<A, parser::CompilerDirective> &&
+      !std::is_same_v<A, parser::OpenACCRoutineConstruct>) {
     const auto &spec{std::get<parser::SpecificationPart>(x.t)};
     const auto &unitUses{std::get<
         std::list<parser::Statement<common::Indirection<parser::UseStmt>>>>(
diff --git a/flang/test/Lower/OpenACC/acc-routine02.f90 b/flang/test/Lower/OpenACC/acc-routine02.f90
new file mode 100644
index 000000000000000..d93ece88235ed98
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-routine02.f90
@@ -0,0 +1,21 @@
+! This test checks lowering of OpenACC routine directive.
+
+! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(a, n)
+  integer :: n
+  real :: a(n)
+end subroutine sub1
+
+!$acc routine(sub1)
+
+program test
+  integer, parameter :: N = 10
+  real :: a(N)
+  call sub1(a, N)
+end program
+
+! CHECK-LABEL: acc.routine @acc_routine_0 func(@_QPsub1)
+
+! CHECK: func.func @_QPsub1(%ar{{.*}}: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_0]>} 
diff --git a/flang/test/Semantics/OpenACC/acc-routine-validity02.f90 b/flang/test/Semantics/OpenACC/acc-routine-validity02.f90
new file mode 100644
index 000000000000000..9410a5b17745dd4
--- /dev/null
+++ b/flang/test/Semantics/OpenACC/acc-routine-validity02.f90
@@ -0,0 +1,17 @@
+! RUN: %python %S/../test_errors.py %s %flang -fopenacc
+
+! Check acc routine in the top level.
+
+subroutine sub1(a, n)
+  integer :: n
+  real :: a(n)
+end subroutine sub1
+
+!$acc routine(sub1)
+
+!dir$ value=1
+program test
+  integer, parameter :: N = 10
+  real :: a(N)
+  call sub1(a, N)
+end program

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow! Looks like it was pretty complicated to support this case. Nice job!

@clementval clementval merged commit 8286743 into llvm:main Oct 24, 2023
@clementval clementval deleted the acc_routine_top_level branch October 24, 2023 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants