Skip to content

[flang][OpenMP] Privatize vars referenced in statement functions #103390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 26, 2024

Conversation

luporl
Copy link
Contributor

@luporl luporl commented Aug 13, 2024

Variables referenced in the body of statement functions need to be
handled as if they are explicitly referenced. Otherwise, they are
skipped during implicit privatization, because statement functions
are represented as procedures in the parse tree.

To avoid missing symbols referenced only in statement functions
during implicit privatization, new symbols, associated with them,
are created and inserted into the context of the directive that
privatizes them. They are later collected and processed in
lowering. To avoid confusing these new symbols with regular ones,
they are tagged with the new OmpFromStmtFunction flag.

Fixes #74273

@luporl luporl marked this pull request as ready for review August 13, 2024 19:11
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp flang:semantics labels Aug 13, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 13, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-openmp

Author: Leandro Lupori (luporl)

Changes

Variables referenced in the body of statement functions need to be
handled as if they are explicitly referenced. Otherwise, they are
skipped during implicit privatization, because statement functions
are represented as procedures in the parse tree.

To avoid missing symbols referenced only in statement functions
during implicit privatization, new symbols, associated with them,
are created and inserted into the context of the directive that
privatizes them. They are later collected and processed in
lowering. To avoid confusing these new symbols with regular ones,
they are tagged with the new OmpFromStmtFunction flag.

Fixes #74273


Full diff: https://github.com/llvm/llvm-project/pull/103390.diff

4 Files Affected:

  • (modified) flang/include/flang/Semantics/symbol.h (+1-1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+9)
  • (modified) flang/lib/Semantics/resolve-directives.cpp (+87-25)
  • (added) flang/test/Lower/OpenMP/statement-function.f90 (+43)
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index cf0350735b5b9..b4db6689a9427 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -755,7 +755,7 @@ class Symbol {
       OmpDeclarativeAllocateDirective, OmpExecutableAllocateDirective,
       OmpDeclareSimd, OmpDeclareTarget, OmpThreadprivate, OmpDeclareReduction,
       OmpFlushed, OmpCriticalLock, OmpIfSpecified, OmpNone, OmpPreDetermined,
-      OmpImplicit);
+      OmpImplicit, OmpFromStmtFunction);
   using Flags = common::EnumSet<Flag, Flag_enumSize>;
 
   const Scope &owner() const { return *owner_; }
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e1a193edc004a..1b2f926e21bed 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -402,6 +402,15 @@ void DataSharingProcessor::collectSymbols(
                              /*collectSymbols=*/true,
                              /*collectHostAssociatedSymbols=*/true);
 
+  // Add implicitly referenced symbols from statement functions.
+  if (curScope) {
+    for (const auto &sym : curScope->GetSymbols()) {
+      if (sym->test(semantics::Symbol::Flag::OmpFromStmtFunction) &&
+          sym->test(flag))
+        allSymbols.insert(&*sym);
+    }
+  }
+
   llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;
   collectSymbolsInNestedRegions(eval, flag, symbolsInNestedRegions);
 
diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp
index d635a7b8b7874..04642aa4e0279 100644
--- a/flang/lib/Semantics/resolve-directives.cpp
+++ b/flang/lib/Semantics/resolve-directives.cpp
@@ -91,11 +91,12 @@ template <typename T> class DirectiveAttributeVisitor {
   void SetContextAssociatedLoopLevel(std::int64_t level) {
     GetContext().associatedLoopLevel = level;
   }
-  Symbol &MakeAssocSymbol(const SourceName &name, Symbol &prev, Scope &scope) {
+  Symbol &MakeAssocSymbol(
+      const SourceName &name, const Symbol &prev, Scope &scope) {
     const auto pair{scope.try_emplace(name, Attrs{}, HostAssocDetails{prev})};
     return *pair.first->second;
   }
-  Symbol &MakeAssocSymbol(const SourceName &name, Symbol &prev) {
+  Symbol &MakeAssocSymbol(const SourceName &name, const Symbol &prev) {
     return MakeAssocSymbol(name, prev, currScope());
   }
   void AddDataSharingAttributeObject(SymbolRef object) {
@@ -108,6 +109,7 @@ template <typename T> class DirectiveAttributeVisitor {
   const parser::Name *GetLoopIndex(const parser::DoConstruct &);
   const parser::DoConstruct *GetDoConstructIf(
       const parser::ExecutionPartConstruct &);
+  Symbol *DeclareNewPrivateAccessEntity(const Symbol &, Symbol::Flag, Scope &);
   Symbol *DeclarePrivateAccessEntity(
       const parser::Name &, Symbol::Flag, Scope &);
   Symbol *DeclarePrivateAccessEntity(Symbol &, Symbol::Flag, Scope &);
@@ -771,6 +773,19 @@ const parser::DoConstruct *DirectiveAttributeVisitor<T>::GetDoConstructIf(
   return parser::Unwrap<parser::DoConstruct>(x);
 }
 
+template <typename T>
+Symbol *DirectiveAttributeVisitor<T>::DeclareNewPrivateAccessEntity(
+    const Symbol &object, Symbol::Flag flag, Scope &scope) {
+  assert(object.owner() != currScope());
+  auto &symbol{MakeAssocSymbol(object.name(), object, scope)};
+  symbol.set(flag);
+  if (flag == Symbol::Flag::OmpCopyIn) {
+    // The symbol in copyin clause must be threadprivate entity.
+    symbol.set(Symbol::Flag::OmpThreadprivate);
+  }
+  return &symbol;
+}
+
 template <typename T>
 Symbol *DirectiveAttributeVisitor<T>::DeclarePrivateAccessEntity(
     const parser::Name &name, Symbol::Flag flag, Scope &scope) {
@@ -785,13 +800,7 @@ template <typename T>
 Symbol *DirectiveAttributeVisitor<T>::DeclarePrivateAccessEntity(
     Symbol &object, Symbol::Flag flag, Scope &scope) {
   if (object.owner() != currScope()) {
-    auto &symbol{MakeAssocSymbol(object.name(), object, scope)};
-    symbol.set(flag);
-    if (flag == Symbol::Flag::OmpCopyIn) {
-      // The symbol in copyin clause must be threadprivate entity.
-      symbol.set(Symbol::Flag::OmpThreadprivate);
-    }
-    return &symbol;
+    return DeclareNewPrivateAccessEntity(object, flag, scope);
   } else {
     object.set(flag);
     return &object;
@@ -2075,13 +2084,30 @@ void OmpAttributeVisitor::Post(const parser::Name &name) {
       if (found->test(semantics::Symbol::Flag::OmpThreadprivate))
         return;
     }
-    if (!IsPrivatizable(symbol)) {
+
+    std::set<const Symbol *> stmtFunctionSymbols;
+    if (auto *stmtFunction{symbol->detailsIf<semantics::SubprogramDetails>()};
+        stmtFunction && stmtFunction->stmtFunction()) {
+      // Each non-dummy argument from a statement function must be handled too,
+      // as if it was explicitly referenced.
+      semantics::UnorderedSymbolSet symbols{
+          CollectSymbols(stmtFunction->stmtFunction().value())};
+      for (const auto &sym : symbols) {
+        if (!IsStmtFunctionDummy(sym) && IsPrivatizable(&*sym) &&
+            !IsObjectWithDSA(*sym)) {
+          stmtFunctionSymbols.insert(&*sym);
+        }
+      }
+      if (stmtFunctionSymbols.empty()) {
+        return;
+      }
+    } else if (!IsPrivatizable(symbol)) {
       return;
     }
 
     // Implicitly determined DSAs
     // OMP 5.2 5.1.1 - Variables Referenced in a Construct
-    Symbol *lastDeclSymbol = nullptr;
+    std::vector<const Symbol *> lastDeclSymbols;
     std::optional<Symbol::Flag> prevDSA;
     for (int dirDepth{0}; dirDepth < (int)dirContext_.size(); ++dirDepth) {
       DirContext &dirContext = dirContext_[dirDepth];
@@ -2126,23 +2152,59 @@ void OmpAttributeVisitor::Post(const parser::Name &name) {
       // it would have the private flag set.
       // This would make x appear to be defined in p2, causing it to be
       // privatized in p2 and its privatization in p1 to be skipped.
-      auto makePrivateSymbol = [&](Symbol::Flag flag) {
-        Symbol *hostSymbol =
-            lastDeclSymbol ? lastDeclSymbol : &symbol->GetUltimate();
-        lastDeclSymbol = DeclarePrivateAccessEntity(
-            *hostSymbol, flag, context_.FindScope(dirContext.directiveSource));
-        return lastDeclSymbol;
+      // TODO Move the lambda functions below to a separate class.
+      auto hostSymbol = [&](const Symbol *sym, int index = 0) {
+        if (lastDeclSymbols.empty())
+          return &sym->GetUltimate();
+        return lastDeclSymbols[index];
+      };
+      auto declNewPrivateSymbol = [&](const Symbol *sym, Symbol::Flag flag,
+                                      bool implicit) {
+        Symbol *newSym = DeclareNewPrivateAccessEntity(
+            *sym, flag, context_.FindScope(dirContext.directiveSource));
+        if (implicit)
+          newSym->set(Symbol::Flag::OmpImplicit);
+        return newSym;
+      };
+      auto makePrivateSymbol = [&](Symbol::Flag flag, bool implicit = false) {
+        bool hasLastDeclSymbols = !lastDeclSymbols.empty();
+        auto updateLastDeclSymbols = [&](const Symbol *sym, int index = 0) {
+          if (hasLastDeclSymbols)
+            lastDeclSymbols[index] = sym;
+          else
+            lastDeclSymbols.push_back(sym);
+        };
+
+        if (stmtFunctionSymbols.empty()) {
+          const Symbol *newSym =
+              declNewPrivateSymbol(hostSymbol(symbol), flag, implicit);
+          updateLastDeclSymbols(newSym);
+          return;
+        }
+
+        int i = 0;
+        for (const auto *sym : stmtFunctionSymbols) {
+          Symbol *newSym =
+              declNewPrivateSymbol(hostSymbol(sym, i), flag, implicit);
+          newSym->set(Symbol::Flag::OmpFromStmtFunction);
+          updateLastDeclSymbols(newSym, i++);
+        }
       };
       auto makeSharedSymbol = [&]() {
-        Symbol *hostSymbol =
-            lastDeclSymbol ? lastDeclSymbol : &symbol->GetUltimate();
-        MakeAssocSymbol(symbol->name(), *hostSymbol,
-            context_.FindScope(dirContext.directiveSource));
+        if (stmtFunctionSymbols.empty()) {
+          MakeAssocSymbol(symbol->name(), *hostSymbol(symbol),
+              context_.FindScope(dirContext.directiveSource));
+        } else {
+          int i = 0;
+          for (const auto *sym : stmtFunctionSymbols) {
+            MakeAssocSymbol(sym->name(), *hostSymbol(sym, i++),
+                context_.FindScope(dirContext.directiveSource));
+          }
+        }
       };
       auto useLastDeclSymbol = [&]() {
-        if (lastDeclSymbol)
-          MakeAssocSymbol(symbol->name(), *lastDeclSymbol,
-              context_.FindScope(dirContext.directiveSource));
+        if (!lastDeclSymbols.empty())
+          makeSharedSymbol();
       };
 
       bool taskGenDir = llvm::omp::taskGeneratingSet.test(dirContext.directive);
@@ -2190,7 +2252,7 @@ void OmpAttributeVisitor::Post(const parser::Name &name) {
         } else {
           // 7) firstprivate
           dsa = Symbol::Flag::OmpFirstPrivate;
-          makePrivateSymbol(*dsa)->set(Symbol::Flag::OmpImplicit);
+          makePrivateSymbol(*dsa, /*implicit=*/true);
         }
       }
       prevDSA = dsa;
diff --git a/flang/test/Lower/OpenMP/statement-function.f90 b/flang/test/Lower/OpenMP/statement-function.f90
new file mode 100644
index 0000000000000..6cdbcb6e141c7
--- /dev/null
+++ b/flang/test/Lower/OpenMP/statement-function.f90
@@ -0,0 +1,43 @@
+! Test privatization within OpenMP constructs containing statement functions.
+! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
+
+!CHECK-LABEL: func @_QPtest_implicit_use
+!CHECK:         %[[IEXP:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_implicit_useEiexp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:         %[[IIMP:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_implicit_useEiimp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:         omp.parallel private({{.*firstprivate.*}} %[[IEXP]]#0 -> %[[PRIV_IEXP:.*]] : !fir.ref<i32>,
+!CHECK-SAME:                         {{.*firstprivate.*}} %[[IIMP]]#0 -> %[[PRIV_IIMP:.*]] : !fir.ref<i32>)
+!CHECK:           %{{.*}}:2 = hlfir.declare %[[PRIV_IEXP]] {uniq_name = "_QFtest_implicit_useEiexp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:           %{{.*}}:2 = hlfir.declare %[[PRIV_IIMP]] {uniq_name = "_QFtest_implicit_useEiimp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine test_implicit_use()
+  implicit none
+  integer :: iexp, iimp
+  integer, external :: ifun
+  integer :: sf
+
+  sf(iexp)=ifun(iimp)+iexp
+  !$omp parallel default(firstprivate)
+      iexp = sf(iexp)
+  !$omp end parallel
+end subroutine
+
+!CHECK-LABEL: func @_QPtest_implicit_use2
+!CHECK:         %[[IEXP:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_implicit_use2Eiexp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:         %[[IIMP:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_implicit_use2Eiimp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:         omp.task
+!CHECK:           %[[PRIV_IEXP:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_implicit_use2Eiexp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:           %[[TEMP0:.*]] = fir.load %[[IEXP]]#0 : !fir.ref<i32>
+!CHECK:           hlfir.assign %[[TEMP0]] to %[[PRIV_IEXP]]#0 temporary_lhs : i32, !fir.ref<i32>
+!CHECK:           %[[PRIV_IIMP:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFtest_implicit_use2Eiimp"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK:           %[[TEMP1:.*]] = fir.load %[[IIMP]]#0 : !fir.ref<i32>
+!CHECK:           hlfir.assign %[[TEMP1]] to %[[PRIV_IIMP]]#0 temporary_lhs : i32, !fir.ref<i32>
+subroutine test_implicit_use2()
+  implicit none
+  integer :: iexp, iimp
+  integer, external :: ifun
+  integer :: sf
+
+  sf(iexp)=ifun(iimp)
+  !$omp task
+      iexp = sf(iexp)
+  !$omp end task
+end subroutine

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the fix!

luporl added 2 commits August 20, 2024 18:49
Variables referenced in the body of statement functions need to be
handled as if they are explicitly referenced. Otherwise, they are
skipped during implicit privatization, because statement functions
are represented as procedures in the parse tree.

To avoid missing symbols referenced only in statement functions
during implicit privatization, new symbols, associated with them,
are created and inserted into the context of the directive that
privatizes them. They are later collected and processed in
lowering. To avoid confusing these new symbols with regular ones,
they are tagged with the new OmpFromStmtFunction flag.

Fixes llvm#74273
Move the handling of symbols with implicitly determined DSAs to a
separate function, that can be called for each non-dummy symbol of
a statement function, so that they can be handled in the same way
as other types of symbols.
Copy link
Contributor

@mjklemm mjklemm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luporl luporl force-pushed the luporl-fix-priv-stmt-func branch from d789c59 to 5d453c6 Compare August 21, 2024 12:11
@luporl
Copy link
Contributor Author

luporl commented Aug 21, 2024

Thanks for the reviews @tblah and @mjklemm.

As OmpAttributeVisitor::Post(const parser::Name &name) was becoming quite complex, I have refactored it in the last commit. The main changes were:

  • Move IsPrivatizable to a separate function.
  • Move the code that handles symbols with implicitly determined DSAs to a separate function, CreateImplicitSymbols, undoing most changes from the first commit in this part.

This allowed the logic that handles statement function symbols to be simplified. Now we just need to collect non-dummy argument symbols from statement functions and call CreateImplicitSymbols for each of them.

@luporl luporl merged commit 216ba6b into llvm:main Aug 26, 2024
8 checks passed
@luporl luporl deleted the luporl-fix-priv-stmt-func branch August 26, 2024 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Flang][OpenMP] Incorrect execution result of using statement function in task construct
4 participants