Skip to content

[flang][OpenMP][NFC] remove globals with mlir::StateStack #144898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2025

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Jun 19, 2025

Idea suggested by @skatrak

@tblah tblah requested review from ergawy and skatrak June 19, 2025 14:05
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir flang Flang issues not falling into any other category flang:fir-hlfir labels Jun 19, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-fir-hlfir

Author: Tom Eccles (tblah)

Changes

Idea suggested by @skatrak


Full diff: https://github.com/llvm/llvm-project/pull/144898.diff

4 Files Affected:

  • (modified) flang/include/flang/Lower/AbstractConverter.h (+3)
  • (modified) flang/lib/Lower/Bridge.cpp (+6)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+71-31)
  • (modified) mlir/include/mlir/Support/StateStack.h (+11)
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 8ae68e143cd2f..de3e833f60699 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -26,6 +26,7 @@
 
 namespace mlir {
 class SymbolTable;
+class StateStack;
 }
 
 namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
   /// functions in order to be in sync).
   virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
 
+  virtual mlir::StateStack &getStateStack() = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 64b16b3abe991..462ceb8dff736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -78,6 +78,7 @@
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/Path.h"
 #include "llvm/Target/TargetMachine.h"
+#include "mlir/Support/StateStack.h"
 #include <optional>
 
 #define DEBUG_TYPE "flang-lower-bridge"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
 
+  mlir::StateStack &getStateStack() override { return stateStack; }
+
   /// Add the symbol to the local map and return `true`. If the symbol is
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// attribute since mlirSymbolTable must pro-actively be maintained when
   /// new Symbol operations are created.
   mlir::SymbolTable mlirSymbolTable;
+
+  /// Used to store context while recursing into regions during lowering.
+  mlir::StateStack stateStack;
 };
 
 } // namespace
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7ad8869597274..bff3321af2814 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -38,6 +38,7 @@
 #include "flang/Support/OpenMP-utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Support/StateStack.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
 /// the handling of the outer region by keeping a stack of information
 /// structures, but it will probably still require some further work to support
 /// reverse offloading.
-static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
-static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
-    sectionsStack;
+class HostEvalInfoStackFrame
+    : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
+
+  HostEvalInfo info;
+};
+
+static HostEvalInfo *
+getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
+  HostEvalInfoStackFrame *frame =
+      converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
+  return frame ? &frame->info : nullptr;
+}
+
+/// Stack frame for storing the OpenMPSectionsConstruct currently being
+/// processed so that it can be refered to when lowering the construct.
+class SectionsConstructStackFrame
+    : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
+
+  explicit SectionsConstructStackFrame(
+      const parser::OpenMPSectionsConstruct &sectionsConstruct)
+      : sectionsConstruct{sectionsConstruct} {}
+
+  const parser::OpenMPSectionsConstruct &sectionsConstruct;
+};
+
+static const parser::OpenMPSectionsConstruct *
+getSectionsConstructStackTop(lower::AbstractConverter &converter) {
+  SectionsConstructStackFrame *frame =
+      converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
+  return frame ? &frame->sectionsConstruct : nullptr;
+}
 
 /// Bind symbols to their corresponding entry block arguments.
 ///
@@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
     if (!ompEval)
       return;
 
-    HostEvalInfo &hostInfo = hostEvalInfo.back();
+    HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
+    assert(hostInfo && "expected HOST_EVAL info structure");
 
     switch (extractOmpDirective(*ompEval)) {
     case OMPD_teams_distribute_parallel_do:
     case OMPD_teams_distribute_parallel_do_simd:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_distribute_parallel_do:
     case OMPD_target_teams_distribute_parallel_do_simd:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_distribute_parallel_do:
     case OMPD_distribute_parallel_do_simd:
-      cp.processNumThreads(stmtCtx, hostInfo.ops);
+      cp.processNumThreads(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_distribute:
     case OMPD_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
       break;
 
     case OMPD_teams:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       processSingleNestedIf([](Directive nestedDir) {
         return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
       });
@@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
 
     case OMPD_teams_distribute:
     case OMPD_teams_distribute_simd:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_distribute:
     case OMPD_target_teams_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       break;
 
     case OMPD_teams_loop:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_loop:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_loop:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
       break;
 
     // Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
     }
   };
 
-  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
-
   const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
   assert(ompEval &&
          llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
   mlir::Region &region = targetOp.getRegion();
   mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
   bindEntryBlockArgs(converter, targetOp, args);
-  if (!hostEvalInfo.empty())
-    hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
+  if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
+    hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
 
   // Check if cloning the bounds introduced any dependency on the outer region.
   // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
                    llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   ClauseProcessor cp(converter, semaCtx, clauses);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
     cp.processCollapse(loc, eval, clauseOps, iv);
 
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
     cp.processNumThreads(stmtCtx, clauseOps);
 
   cp.processProcBind(clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processBare(clauseOps);
   cp.processDefaultMap(stmtCtx, defaultMaps);
   cp.processDepend(symTable, stmtCtx, clauseOps);
   cp.processDevice(stmtCtx, clauseOps);
   cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
-  if (!hostEvalInfo.empty()) {
+  if (hostEvalInfo) {
     // Only process host_eval if compiling for the host device.
     processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
-    hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+    hostEvalInfo->collectValues(clauseOps.hostEvalVars);
   }
   cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
   cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
     cp.processNumTeams(stmtCtx, clauseOps);
     cp.processThreadLimit(stmtCtx, clauseOps);
   }
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               lower::pft::Evaluation &eval, mlir::Location loc,
               const ConstructQueue &queue,
               ConstructQueue::const_iterator item) {
-  assert(!sectionsStack.empty());
+  const parser::OpenMPSectionsConstruct *sectionsConstruct =
+      getSectionsConstructStackTop(converter);
+  assert(sectionsConstruct);
+
   const auto &sectionBlocks =
-      std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
-  sectionsStack.pop_back();
+      std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
+  converter.getStateStack().stackPop();
   mlir::omp::SectionsOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> reductionSyms;
   genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   // Introduce a new host_eval information structure for this target region.
   if (!isTargetDevice)
-    hostEvalInfo.emplace_back();
+    converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
 
   mlir::omp::TargetOperands clauseOps;
   DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   // Remove the host_eval information structure created for this target region.
   if (!isTargetDevice)
-    hostEvalInfo.pop_back();
+    converter.getStateStack().stackPop();
   return targetOp;
 }
 
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
       buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
                           eval, source, directive, clauses)};
 
-  sectionsStack.push_back(&sectionsConstruct);
+  converter.getStateStack().stackPush<SectionsConstructStackFrame>(
+      sectionsConstruct);
   genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
                  queue.begin());
 }
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
index aca2375028246..9641a22c47776 100644
--- a/mlir/include/mlir/Support/StateStack.h
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -83,6 +83,17 @@ class StateStack {
     return WalkResult::advance();
   }
 
+  /// Get the top instance of frame type `T` or nullptr if none are found
+  template <typename T>
+  T *getStackTop() {
+    T *top = nullptr;
+    stackWalk<T>([&](T &frame) -> mlir::WalkResult {
+      top = &frame;
+      return mlir::WalkResult::interrupt();
+    });
+    return top;
+  }
+
 private:
   SmallVector<std::unique_ptr<StateStackFrame>> stack;
 };

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-mlir-core

Author: Tom Eccles (tblah)

Changes

Idea suggested by @skatrak


Full diff: https://github.com/llvm/llvm-project/pull/144898.diff

4 Files Affected:

  • (modified) flang/include/flang/Lower/AbstractConverter.h (+3)
  • (modified) flang/lib/Lower/Bridge.cpp (+6)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+71-31)
  • (modified) mlir/include/mlir/Support/StateStack.h (+11)
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 8ae68e143cd2f..de3e833f60699 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -26,6 +26,7 @@
 
 namespace mlir {
 class SymbolTable;
+class StateStack;
 }
 
 namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
   /// functions in order to be in sync).
   virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
 
+  virtual mlir::StateStack &getStateStack() = 0;
+
 private:
   /// Options controlling lowering behavior.
   const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 64b16b3abe991..462ceb8dff736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -78,6 +78,7 @@
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Support/Path.h"
 #include "llvm/Target/TargetMachine.h"
+#include "mlir/Support/StateStack.h"
 #include <optional>
 
 #define DEBUG_TYPE "flang-lower-bridge"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
 
   mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
 
+  mlir::StateStack &getStateStack() override { return stateStack; }
+
   /// Add the symbol to the local map and return `true`. If the symbol is
   /// already in the map and \p forced is `false`, the map is not updated.
   /// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   /// attribute since mlirSymbolTable must pro-actively be maintained when
   /// new Symbol operations are created.
   mlir::SymbolTable mlirSymbolTable;
+
+  /// Used to store context while recursing into regions during lowering.
+  mlir::StateStack stateStack;
 };
 
 } // namespace
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7ad8869597274..bff3321af2814 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -38,6 +38,7 @@
 #include "flang/Support/OpenMP-utils.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Support/StateStack.h"
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
 /// the handling of the outer region by keeping a stack of information
 /// structures, but it will probably still require some further work to support
 /// reverse offloading.
-static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
-static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
-    sectionsStack;
+class HostEvalInfoStackFrame
+    : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
+
+  HostEvalInfo info;
+};
+
+static HostEvalInfo *
+getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
+  HostEvalInfoStackFrame *frame =
+      converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
+  return frame ? &frame->info : nullptr;
+}
+
+/// Stack frame for storing the OpenMPSectionsConstruct currently being
+/// processed so that it can be refered to when lowering the construct.
+class SectionsConstructStackFrame
+    : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
+
+  explicit SectionsConstructStackFrame(
+      const parser::OpenMPSectionsConstruct &sectionsConstruct)
+      : sectionsConstruct{sectionsConstruct} {}
+
+  const parser::OpenMPSectionsConstruct &sectionsConstruct;
+};
+
+static const parser::OpenMPSectionsConstruct *
+getSectionsConstructStackTop(lower::AbstractConverter &converter) {
+  SectionsConstructStackFrame *frame =
+      converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
+  return frame ? &frame->sectionsConstruct : nullptr;
+}
 
 /// Bind symbols to their corresponding entry block arguments.
 ///
@@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
     if (!ompEval)
       return;
 
-    HostEvalInfo &hostInfo = hostEvalInfo.back();
+    HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
+    assert(hostInfo && "expected HOST_EVAL info structure");
 
     switch (extractOmpDirective(*ompEval)) {
     case OMPD_teams_distribute_parallel_do:
     case OMPD_teams_distribute_parallel_do_simd:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_distribute_parallel_do:
     case OMPD_target_teams_distribute_parallel_do_simd:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_distribute_parallel_do:
     case OMPD_distribute_parallel_do_simd:
-      cp.processNumThreads(stmtCtx, hostInfo.ops);
+      cp.processNumThreads(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_distribute:
     case OMPD_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
       break;
 
     case OMPD_teams:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       processSingleNestedIf([](Directive nestedDir) {
         return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
       });
@@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
 
     case OMPD_teams_distribute:
     case OMPD_teams_distribute_simd:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_distribute:
     case OMPD_target_teams_distribute_simd:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       break;
 
     case OMPD_teams_loop:
-      cp.processThreadLimit(stmtCtx, hostInfo.ops);
+      cp.processThreadLimit(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_target_teams_loop:
-      cp.processNumTeams(stmtCtx, hostInfo.ops);
+      cp.processNumTeams(stmtCtx, hostInfo->ops);
       [[fallthrough]];
     case OMPD_loop:
-      cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+      cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
       break;
 
     // Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
     }
   };
 
-  assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
-
   const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
   assert(ompEval &&
          llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
   mlir::Region &region = targetOp.getRegion();
   mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
   bindEntryBlockArgs(converter, targetOp, args);
-  if (!hostEvalInfo.empty())
-    hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
+  if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
+    hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
 
   // Check if cloning the bounds introduced any dependency on the outer region.
   // If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
                    llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   ClauseProcessor cp(converter, semaCtx, clauses);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
     cp.processCollapse(loc, eval, clauseOps, iv);
 
   clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
     cp.processNumThreads(stmtCtx, clauseOps);
 
   cp.processProcBind(clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
     llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processBare(clauseOps);
   cp.processDefaultMap(stmtCtx, defaultMaps);
   cp.processDepend(symTable, stmtCtx, clauseOps);
   cp.processDevice(stmtCtx, clauseOps);
   cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
-  if (!hostEvalInfo.empty()) {
+  if (hostEvalInfo) {
     // Only process host_eval if compiling for the host device.
     processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
-    hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+    hostEvalInfo->collectValues(clauseOps.hostEvalVars);
   }
   cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
   cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
   cp.processAllocate(clauseOps);
   cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
 
-  if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+  HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+  if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
     cp.processNumTeams(stmtCtx, clauseOps);
     cp.processThreadLimit(stmtCtx, clauseOps);
   }
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               lower::pft::Evaluation &eval, mlir::Location loc,
               const ConstructQueue &queue,
               ConstructQueue::const_iterator item) {
-  assert(!sectionsStack.empty());
+  const parser::OpenMPSectionsConstruct *sectionsConstruct =
+      getSectionsConstructStackTop(converter);
+  assert(sectionsConstruct);
+
   const auto &sectionBlocks =
-      std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
-  sectionsStack.pop_back();
+      std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
+  converter.getStateStack().stackPop();
   mlir::omp::SectionsOperands clauseOps;
   llvm::SmallVector<const semantics::Symbol *> reductionSyms;
   genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   // Introduce a new host_eval information structure for this target region.
   if (!isTargetDevice)
-    hostEvalInfo.emplace_back();
+    converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
 
   mlir::omp::TargetOperands clauseOps;
   DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   // Remove the host_eval information structure created for this target region.
   if (!isTargetDevice)
-    hostEvalInfo.pop_back();
+    converter.getStateStack().stackPop();
   return targetOp;
 }
 
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
       buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
                           eval, source, directive, clauses)};
 
-  sectionsStack.push_back(&sectionsConstruct);
+  converter.getStateStack().stackPush<SectionsConstructStackFrame>(
+      sectionsConstruct);
   genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
                  queue.begin());
 }
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
index aca2375028246..9641a22c47776 100644
--- a/mlir/include/mlir/Support/StateStack.h
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -83,6 +83,17 @@ class StateStack {
     return WalkResult::advance();
   }
 
+  /// Get the top instance of frame type `T` or nullptr if none are found
+  template <typename T>
+  T *getStackTop() {
+    T *top = nullptr;
+    stackWalk<T>([&](T &frame) -> mlir::WalkResult {
+      top = &frame;
+      return mlir::WalkResult::interrupt();
+    });
+    return top;
+  }
+
 private:
   SmallVector<std::unique_ptr<StateStackFrame>> stack;
 };

@tblah
Copy link
Contributor Author

tblah commented Jun 19, 2025

Copy link

github-actions bot commented Jun 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@tblah tblah force-pushed the users/tblah/stack-frame-1 branch from 280e55d to 8b0e0a3 Compare June 19, 2025 14:14
Copy link
Member

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Tom for working on this, I think it's a very nice improvement! LGTM.

}

/// Stack frame for storing the OpenMPSectionsConstruct currently being
/// processed so that it can be refered to when lowering the construct.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// processed so that it can be refered to when lowering the construct.
/// processed so that it can be referred to when lowering the construct.

ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBare(clauseOps);
cp.processDefaultMap(stmtCtx, defaultMaps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
if (!hostEvalInfo.empty()) {
if (hostEvalInfo) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: For consistency with genBodyOfTargetOp, and general conciseness

Suggested change
if (hostEvalInfo) {
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) {

assert(!sectionsStack.empty());
const parser::OpenMPSectionsConstruct *sectionsConstruct =
getSectionsConstructStackTop(converter);
assert(sectionsConstruct);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add small message to the assert.

std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
sectionsStack.pop_back();
std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
converter.getStateStack().stackPop();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Wouldn't it be possible to let this call be handled by the same function that pushes the stack frame? We could potentially use SaveStateStack then.

@tblah tblah force-pushed the users/tblah/stack-frame-1 branch from 8b0e0a3 to ec40d1a Compare June 20, 2025 11:27
Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks

Base automatically changed from users/tblah/stack-frame-0 to main June 24, 2025 16:45
@tblah tblah force-pushed the users/tblah/stack-frame-1 branch from ec40d1a to 89f1ee4 Compare June 24, 2025 16:46
@tblah tblah merged commit 8f7f48a into main Jun 24, 2025
6 of 7 checks passed
@tblah tblah deleted the users/tblah/stack-frame-1 branch June 24, 2025 17:30
DrSergei pushed a commit to DrSergei/llvm-project that referenced this pull request Jun 24, 2025
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants