Skip to content

[flang][OpenMP] Handle pre-detemined lastprivate for simd #129507

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 7, 2025

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Mar 3, 2025

This PR tries to fix lastprivate update issues in composite constructs. In particular, pre-determined lastprivate symbols are attached to the wrong leaf of the composite construct (the outermost one). When using delayed privatization (should be the default mode in the future), this results in trying to update the lastprivate symbol in the wrong construct (outside the omp.loop_nest op).

For example, given the following input:

!$omp target teams distribute parallel do simd collapse(2) private(y_max)
  do i=x_min,x_max
    do j=y_min,y_max
    enddo
  enddo

Without the fixes introduced in this PR, the DataSharingProcessor tries to generate the lastprivate update ops in the parallel op since this is the op for which the DSP instance is created.

The fix consists of 2 main parts:

  1. Instead of creating a single DSP instance, one instance is created for the leaf constructs that might need privatization (whether for explicit, implicit, or pre-determined symbols).
  2. When generating the lastprivate comparison ops, we don't directly use the SSA values of the UBs and steps. Instead, we regenerated these SSA values from the original loop bounds' expressions. We have to do this to avoid using host_eval values in the lastprivate comparison logic which is illegal.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp labels Mar 3, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 3, 2025

@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

This PR to tries to fix lastprivate update issues in composite constructs. In particular, pre-determined lastprivate symbols are attached to the wrong leaf of the composite construct (the outermost one). When using delayed privatization (should be the default mode in the future), this results in trying to update the lastprivate symbol in the wrong construct (outside the omp.loop_nest op).

For example, given the following input:

!$omp target teams distribute parallel do simd collapse(2) private(y_max)
  do i=x_min,x_max
    do j=y_min,y_max
    enddo
  enddo

Without the fixes introduced in this PR, the DataSharingProcessor tries to generate the lastprivate update ops in the parallel op since this is the op for which the DSP instance is created.

The fix consists of 2 main parts:

  1. Instead of creating a single DSP instance, one instance is created for the leaf constructs that might need privatization (whether for explicit, implicit, or pre-determined symbols).
  2. When generating the lastprivate comparison ops, we don't durectly use the SSA values of the UBs and steps. Instead, we regenerated these SSA values from the original loop bounds' expressions. We have to do this to avoid using host_eval values in the lastprivate comparison logic which is illegal.

Patch is 36.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129507.diff

11 Files Affected:

  • (added) flang/lib/Lower/OpenMP/ClauseFinder.h (+76)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+3-67)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+3-34)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+7-3)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+21-12)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+75)
  • (modified) flang/lib/Lower/OpenMP/Utils.h (+5)
  • (modified) flang/test/Lower/OpenMP/distribute-parallel-do-simd.f90 (+58-13)
  • (modified) flang/test/Lower/OpenMP/lastprivate-iv.f90 (+12-8)
  • (modified) flang/test/Lower/OpenMP/lastprivate-simd.f90 (+1-1)
  • (modified) flang/test/Lower/OpenMP/parallel-wsloop-lastpriv.f90 (+47-28)
diff --git a/flang/lib/Lower/OpenMP/ClauseFinder.h b/flang/lib/Lower/OpenMP/ClauseFinder.h
new file mode 100644
index 0000000000000..3b77f2ca1d4cb
--- /dev/null
+++ b/flang/lib/Lower/OpenMP/ClauseFinder.h
@@ -0,0 +1,76 @@
+//===-- Lower/OpenMP/ClauseFinder.h --------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
+//
+//===----------------------------------------------------------------------===//
+#ifndef FORTRAN_LOWER_CLAUSEFINDER_H
+#define FORTRAN_LOWER_CLAUSEFINDER_H
+
+#include "Clauses.h"
+
+namespace Fortran {
+namespace lower {
+namespace omp {
+
+class ClauseFinder {
+  using ClauseIterator = List<Clause>::const_iterator;
+
+public:
+  /// Utility to find a clause within a range in the clause list.
+  template <typename T>
+  static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end) {
+    for (ClauseIterator it = begin; it != end; ++it) {
+      if (std::get_if<T>(&it->u))
+        return it;
+    }
+
+    return end;
+  }
+
+  /// Return the first instance of the given clause found in the clause list or
+  /// `nullptr` if not present. If more than one instance is expected, use
+  /// `findRepeatableClause` instead.
+  template <typename T>
+  static const T *findUniqueClause(const List<Clause> &clauses,
+                                   const parser::CharBlock **source = nullptr) {
+    ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
+    if (it != clauses.end()) {
+      if (source)
+        *source = &it->source;
+      return &std::get<T>(it->u);
+    }
+    return nullptr;
+  }
+
+  /// Call `callbackFn` for each occurrence of the given clause. Return `true`
+  /// if at least one instance was found.
+  template <typename T>
+  static bool findRepeatableClause(
+      const List<Clause> &clauses,
+      std::function<void(const T &, const parser::CharBlock &source)>
+          callbackFn) {
+    bool found = false;
+    ClauseIterator nextIt, endIt = clauses.end();
+    for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+      nextIt = findClause<T>(it, endIt);
+
+      if (nextIt != endIt) {
+        callbackFn(std::get<T>(nextIt->u), nextIt->source);
+        found = true;
+        ++nextIt;
+      }
+    }
+    return found;
+  }
+};
+} // namespace omp
+} // namespace lower
+} // namespace Fortran
+
+#endif // FORTRAN_LOWER_CLAUSEFINDER_H
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index e21d299570b86..98a2bb7583d98 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -12,6 +12,7 @@
 
 #include "ClauseProcessor.h"
 #include "Clauses.h"
+#include "Utils.h"
 
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Parser/tools.h"
@@ -201,24 +202,6 @@ static void addUseDeviceClause(
     useDeviceSyms.push_back(object.sym());
 }
 
-static void convertLoopBounds(lower::AbstractConverter &converter,
-                              mlir::Location loc,
-                              mlir::omp::LoopRelatedClauseOps &result,
-                              std::size_t loopVarTypeSize) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  // The types of lower bound, upper bound, and step are converted into the
-  // type of the loop variable if necessary.
-  mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
-  for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) {
-    result.loopLowerBounds[it] = firOpBuilder.createConvert(
-        loc, loopVarType, result.loopLowerBounds[it]);
-    result.loopUpperBounds[it] = firOpBuilder.createConvert(
-        loc, loopVarType, result.loopUpperBounds[it]);
-    result.loopSteps[it] =
-        firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]);
-  }
-}
-
 //===----------------------------------------------------------------------===//
 // ClauseProcessor unique clauses
 //===----------------------------------------------------------------------===//
@@ -240,55 +223,8 @@ bool ClauseProcessor::processCollapse(
     mlir::Location currentLocation, lower::pft::Evaluation &eval,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
-  bool found = false;
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
-  if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
-    TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
-  }
-
-  std::int64_t collapseValue = 1l;
-  if (auto *clause = findUniqueClause<omp::clause::Collapse>()) {
-    collapseValue = evaluate::ToInt64(clause->v).value();
-    found = true;
-  }
-
-  std::size_t loopVarTypeSize = 0;
-  do {
-    lower::pft::Evaluation *doLoop =
-        &doConstructEval->getFirstNestedEvaluation();
-    auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
-    assert(doStmt && "Expected do loop to be in the nested evaluation");
-    const auto &loopControl =
-        std::get<std::optional<parser::LoopControl>>(doStmt->t);
-    const parser::LoopControl::Bounds *bounds =
-        std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
-    assert(bounds && "Expected bounds for worksharing do loop");
-    lower::StatementContext stmtCtx;
-    result.loopLowerBounds.push_back(fir::getBase(
-        converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
-    result.loopUpperBounds.push_back(fir::getBase(
-        converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
-    if (bounds->step) {
-      result.loopSteps.push_back(fir::getBase(
-          converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
-    } else { // If `step` is not present, assume it as `1`.
-      result.loopSteps.push_back(firOpBuilder.createIntegerConstant(
-          currentLocation, firOpBuilder.getIntegerType(32), 1));
-    }
-    iv.push_back(bounds->name.thing.symbol);
-    loopVarTypeSize = std::max(loopVarTypeSize,
-                               bounds->name.thing.symbol->GetUltimate().size());
-    collapseValue--;
-    doConstructEval =
-        &*std::next(doConstructEval->getNestedEvaluations().begin());
-  } while (collapseValue > 0);
-
-  convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
-
-  return found;
+  return collectLoopRelatedInfo(converter, currentLocation, eval, clauses,
+                                result, iv);
 }
 
 bool ClauseProcessor::processDevice(lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 889a09a8f0cd8..968c6724023e8 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,6 +13,7 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
+#include "ClauseFinder.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
@@ -148,10 +149,6 @@ class ClauseProcessor {
 private:
   using ClauseIterator = List<Clause>::const_iterator;
 
-  /// Utility to find a clause within a range in the clause list.
-  template <typename T>
-  static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
-
   /// Return the first instance of the given clause found in the clause list or
   /// `nullptr` if not present. If more than one instance is expected, use
   /// `findRepeatableClause` instead.
@@ -199,45 +196,17 @@ void ClauseProcessor::processTODO(mlir::Location currentLocation,
     (checkUnhandledClause(it->id, std::get_if<Ts>(&it->u)), ...);
 }
 
-template <typename T>
-ClauseProcessor::ClauseIterator
-ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
-  for (ClauseIterator it = begin; it != end; ++it) {
-    if (std::get_if<T>(&it->u))
-      return it;
-  }
-
-  return end;
-}
-
 template <typename T>
 const T *
 ClauseProcessor::findUniqueClause(const parser::CharBlock **source) const {
-  ClauseIterator it = findClause<T>(clauses.begin(), clauses.end());
-  if (it != clauses.end()) {
-    if (source)
-      *source = &it->source;
-    return &std::get<T>(it->u);
-  }
-  return nullptr;
+    return ClauseFinder::findUniqueClause<T>(clauses, source);
 }
 
 template <typename T>
 bool ClauseProcessor::findRepeatableClause(
     std::function<void(const T &, const parser::CharBlock &source)> callbackFn)
     const {
-  bool found = false;
-  ClauseIterator nextIt, endIt = clauses.end();
-  for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
-    nextIt = findClause<T>(it, endIt);
-
-    if (nextIt != endIt) {
-      callbackFn(std::get<T>(nextIt->u), nextIt->source);
-      found = true;
-      ++nextIt;
-    }
-  }
-  return found;
+  return ClauseFinder::findRepeatableClause<T>(clauses, callbackFn);
 }
 
 template <typename T>
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 6b0d783fb8c30..34716d2f6fcb2 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -257,6 +257,11 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
     return;
 
   if (mlir::isa<mlir::omp::WsloopOp>(op) || mlir::isa<mlir::omp::SimdOp>(op)) {
+    mlir::omp::LoopRelatedClauseOps result;
+    llvm::SmallVector<const semantics::Symbol *> iv;
+    collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
+                           clauses, result, iv);
+
     // Update the original variable just before exiting the worksharing
     // loop. Conversion as follows:
     //
@@ -280,9 +285,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
     mlir::Value cmpOp;
     llvm::SmallVector<mlir::Value> vs;
     vs.reserve(loopOp.getIVs().size());
-    for (auto [iv, ub, step] :
-         llvm::zip_equal(loopOp.getIVs(), loopOp.getLoopUpperBounds(),
-                         loopOp.getLoopSteps())) {
+    for (auto [iv, ub, step] : llvm::zip_equal(
+             loopOp.getIVs(), result.loopUpperBounds, result.loopSteps)) {
       // v = iv + step
       // cmp = step < 0 ? v < ub : v > ub
       mlir::Value v = firOpBuilder.create<mlir::arith::AddIOp>(loc, iv, step);
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 97b4778d488fa..b9357e04aa393 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1207,24 +1207,24 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
       // nested loop as a unit, so we need to pass the top level wrapper (if
       // present). Otherwise, these operations will be inserted within a
       // wrapper region.
-      mlir::Operation *privatizationTopLevelOp = &op;
+      mlir::Operation *privatizationBottomLevelOp = &op;
       if (auto loopNest = llvm::dyn_cast<mlir::omp::LoopNestOp>(op)) {
         llvm::SmallVector<mlir::omp::LoopWrapperInterface> wrappers;
         loopNest.gatherWrappers(wrappers);
         if (!wrappers.empty())
-          privatizationTopLevelOp = &*wrappers.back();
+          privatizationBottomLevelOp = &*wrappers.front();
       }
 
       if (!info.dsp) {
         assert(tempDsp.has_value());
-        tempDsp->processStep2(privatizationTopLevelOp, isLoop);
+        tempDsp->processStep2(privatizationBottomLevelOp, isLoop);
       } else {
         if (isLoop && regionArgs.size() > 0) {
           for (const auto &regionArg : regionArgs) {
             info.dsp->pushLoopIV(info.converter.getSymbolAddress(*regionArg));
           }
         }
-        info.dsp->processStep2(privatizationTopLevelOp, isLoop);
+        info.dsp->processStep2(privatizationBottomLevelOp, isLoop);
       }
     }
   }
@@ -2737,18 +2737,20 @@ static void genCompositeDistributeParallelDoSimd(
   genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc,
                      parallelClauseOps, parallelReductionSyms);
 
-  DataSharingProcessor dsp(converter, semaCtx, simdItem->clauses, eval,
-                           /*shouldCollectPreDeterminedSymbols=*/true,
-                           /*useDelayedPrivatization=*/true, symTable);
-  dsp.processStep1(&parallelClauseOps);
+  DataSharingProcessor parallelItemDSP(
+      converter, semaCtx, parallelItem->clauses, eval,
+      /*shouldCollectPreDeterminedSymbols=*/false,
+      /*useDelayedPrivatization=*/true, symTable);
+  parallelItemDSP.processStep1(&parallelClauseOps);
 
   EntryBlockArgs parallelArgs;
-  parallelArgs.priv.syms = dsp.getDelayedPrivSymbols();
+  parallelArgs.priv.syms = parallelItemDSP.getDelayedPrivSymbols();
   parallelArgs.priv.vars = parallelClauseOps.privateVars;
   parallelArgs.reduction.syms = parallelReductionSyms;
   parallelArgs.reduction.vars = parallelClauseOps.reductionVars;
   genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem,
-                parallelClauseOps, parallelArgs, &dsp, /*isComposite=*/true);
+                parallelClauseOps, parallelArgs, &parallelItemDSP,
+                /*isComposite=*/true);
 
   // Clause processing.
   mlir::omp::DistributeOperands distributeClauseOps;
@@ -2765,6 +2767,11 @@ static void genCompositeDistributeParallelDoSimd(
   genSimdClauses(converter, semaCtx, simdItem->clauses, loc, simdClauseOps,
                  simdReductionSyms);
 
+  DataSharingProcessor simdItemDSP(converter, semaCtx, simdItem->clauses, eval,
+                                   /*shouldCollectPreDeterminedSymbols=*/true,
+                                   /*useDelayedPrivatization=*/true, symTable);
+  simdItemDSP.processStep1(&simdClauseOps);
+
   mlir::omp::LoopNestOperands loopNestClauseOps;
   llvm::SmallVector<const semantics::Symbol *> iv;
   genLoopNestClauses(converter, semaCtx, eval, simdItem->clauses, loc,
@@ -2786,7 +2793,8 @@ static void genCompositeDistributeParallelDoSimd(
   wsloopOp.setComposite(/*val=*/true);
 
   EntryBlockArgs simdArgs;
-  // TODO: Add private syms and vars.
+  simdArgs.priv.syms = simdItemDSP.getDelayedPrivSymbols();
+  simdArgs.priv.vars = simdClauseOps.privateVars;
   simdArgs.reduction.syms = simdReductionSyms;
   simdArgs.reduction.vars = simdClauseOps.reductionVars;
   auto simdOp =
@@ -2798,7 +2806,8 @@ static void genCompositeDistributeParallelDoSimd(
                 {{distributeOp, distributeArgs},
                  {wsloopOp, wsloopArgs},
                  {simdOp, simdArgs}},
-                llvm::omp::Directive::OMPD_distribute_parallel_do_simd, dsp);
+                llvm::omp::Directive::OMPD_distribute_parallel_do_simd,
+                simdItemDSP);
 }
 
 static void genCompositeDistributeSimd(lower::AbstractConverter &converter,
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 48bcf492fd368..744c3bd04a0a7 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -14,6 +14,7 @@
 
 #include "Clauses.h"
 
+#include "ClauseFinder.h"
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
 #include <flang/Lower/DirectivesCommon.h>
@@ -595,6 +596,80 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
   }
 }
 
+static void convertLoopBounds(lower::AbstractConverter &converter,
+                              mlir::Location loc,
+                              mlir::omp::LoopRelatedClauseOps &result,
+                              std::size_t loopVarTypeSize) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  // The types of lower bound, upper bound, and step are converted into the
+  // type of the loop variable if necessary.
+  mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+  for (unsigned it = 0; it < (unsigned)result.loopLowerBounds.size(); it++) {
+    result.loopLowerBounds[it] = firOpBuilder.createConvert(
+        loc, loopVarType, result.loopLowerBounds[it]);
+    result.loopUpperBounds[it] = firOpBuilder.createConvert(
+        loc, loopVarType, result.loopUpperBounds[it]);
+    result.loopSteps[it] =
+        firOpBuilder.createConvert(loc, loopVarType, result.loopSteps[it]);
+  }
+}
+
+bool collectLoopRelatedInfo(
+    lower::AbstractConverter &converter, mlir::Location currentLocation,
+    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    mlir::omp::LoopRelatedClauseOps &result,
+    llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
+  bool found = false;
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  // Collect the loops to collapse.
+  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+    TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
+  }
+
+  std::int64_t collapseValue = 1l;
+  if (auto *clause =
+          ClauseFinder::findUniqueClause<omp::clause::Collapse>(clauses)) {
+    collapseValue = evaluate::ToInt64(clause->v).value();
+    found = true;
+  }
+
+  std::size_t loopVarTypeSize = 0;
+  do {
+    lower::pft::Evaluation *doLoop =
+        &doConstructEval->getFirstNestedEvaluation();
+    auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
+    assert(doStmt && "Expected do loop to be in the nested evaluation");
+    const auto &loopControl =
+        std::get<std::optional<parser::LoopControl>>(doStmt->t);
+    const parser::LoopControl::Bounds *bounds =
+        std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+    assert(bounds && "Expected bounds for worksharing do loop");
+    lower::StatementContext stmtCtx;
+    result.loopLowerBounds.push_back(fir::getBase(
+        converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)));
+    result.loopUpperBounds.push_back(fir::getBase(
+        converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)));
+    if (bounds->step) {
+      result.loopSteps.push_back(fir::getBase(
+          converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)));
+    } else { // If `step` is not present, assume it as `1`.
+      result.loopSteps.push_back(firOpBuilder.createIntegerConstant(
+          currentLocation, firOpBuilder.getIntegerType(32), 1));
+    }
+    iv.push_back(bounds->name.thing.symbol);
+    loopVarTypeSize = std::max(loopVarTypeSize,
+                               bounds->name.thing.symbol->GetUltimate().size());
+    collapseValue--;
+    doConstructEval =
+        &*std::next(doConstructEval->getNestedEvaluations().begin());
+  } while (collapseValue > 0);
+
+  convertLoopBounds(converter, currentLocation, result, loopVarTypeSize);
+
+  return found;
+}
 } // namespace omp
 } // namespace lower
 } // namespace Fortran
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 3943eb633b04e..30b4613837b9a 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -163,6 +163,11 @@ void genObjectList(const ObjectList &objects,
 void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
                                      mlir::Location loc);
 
+bool collectLoopRelatedInfo(
+    lower::AbstractConverter &converter, mlir::Location currentLocation,
+    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    mlir::omp::LoopRelatedClauseOps &result,
+    llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
 } // namespace omp
 } // namespace lower
 } // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/distribute-parallel-do-simd.f90 b/flang/test/Lower/OpenMP/distribute-parallel-do-simd.f90
index bea7f037cecf3..142bc02ae8c1d 100644
--- a/flang/test/Lower/OpenMP/distribute-parallel-do-simd.f90
+++ b/flang/test/Lower/OpenMP/distribute-parallel-do-simd.f90
@@ -8,10 +8,10 @@
 subroutine distribute_parallel_do_simd_num_threads()
   !$omp teams
 
-  ! CHECK:      omp.parallel num_threads({{.*}}) private({{.*}}) {
+  ! CHECK:      omp.parallel num_threads({{.*}})...
[truncated]

Copy link

github-actions bot commented Mar 3, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@ergawy ergawy force-pushed the users/ergawy/fix_simd_lastprivate_issue branch from 07ea43b to 93c59a5 Compare March 3, 2025 10:54
@ergawy ergawy force-pushed the users/ergawy/fix_simd_lastprivate_issue branch from 93c59a5 to a9c5795 Compare March 4, 2025 07:10
@ergawy ergawy requested a review from kiranchandramohan March 4, 2025 08:28
@ergawy
Copy link
Member Author

ergawy commented Mar 5, 2025

Ping! Please take a look when you have time.

@ronlieb ronlieb self-requested a review March 5, 2025 09:19
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit. Thanks

@ergawy ergawy force-pushed the users/ergawy/fix_simd_lastprivate_issue branch from a9c5795 to b99612c Compare March 5, 2025 13:27
This PR to tries to fix `lastprivate` update issues in composite
constructs. In particular, pre-determined `lastprivate` symbols are
attached to the wrong leaf of the composite construct (the outermost
one). When using delayed privatization (should be the default mode in
the future), this results in trying to update the `lastprivate` symbol
in the wrong construct (outside the `omp.loop_nest` op).

For example, given the following input:
```fortran
!$omp target teams distribute parallel do simd collapse(2) private(y_max)
  do i=x_min,x_max
    do j=y_min,y_max
    enddo
  enddo
```

Without the fixes introduced in this PR, the `DataSharingProcessor` tries to
generate the `lastprivate` update ops in the `parallel` op since this
is the op for which the DSP instance is created.

The fix consists of 2 main parts:
1. Instead of creating a single DSP instance, one instance is created
   for the leaf constructs that might need privatization (whether for
   explicit, implicit, or pre-determined symbols).
2. When generating the `lastprivate` comparison ops, we don't durectly
   use the SSA values of the UBs and steps. Instead, we regenerated
   these SSA values from the original loop bounds' expressions. We have
   to do this to avoid using `host_eval` values in the `lastprivate`
   comparison logic which is illegal.
@ergawy ergawy force-pushed the users/ergawy/fix_simd_lastprivate_issue branch from b99612c to 1b370ef Compare March 6, 2025 13:29
@ergawy ergawy merged commit 9543e9e into main Mar 7, 2025
11 checks passed
@ergawy ergawy deleted the users/ergawy/fix_simd_lastprivate_issue branch March 7, 2025 04:44
@ergawy ergawy restored the users/ergawy/fix_simd_lastprivate_issue branch March 7, 2025 08:18
jph-13 pushed a commit to jph-13/llvm-project that referenced this pull request Mar 21, 2025
…29507)

This PR tries to fix `lastprivate` update issues in composite
constructs. In particular, pre-determined `lastprivate` symbols are
attached to the wrong leaf of the composite construct (the outermost
one). When using delayed privatization (should be the default mode in
the future), this results in trying to update the `lastprivate` symbol
in the wrong construct (outside the `omp.loop_nest` op).

For example, given the following input:
```fortran
!$omp target teams distribute parallel do simd collapse(2) private(y_max)
  do i=x_min,x_max
    do j=y_min,y_max
    enddo
  enddo
```

Without the fixes introduced in this PR, the `DataSharingProcessor`
tries to generate the `lastprivate` update ops in the `parallel` op
since this is the op for which the DSP instance is created.

The fix consists of 2 main parts:
1. Instead of creating a single DSP instance, one instance is created
for the leaf constructs that might need privatization (whether for
explicit, implicit, or pre-determined symbols).
2. When generating the `lastprivate` comparison ops, we don't directly
use the SSA values of the UBs and steps. Instead, we regenerated these
SSA values from the original loop bounds' expressions. We have to do
this to avoid using `host_eval` values in the `lastprivate` comparison
logic which is illegal.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants