-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][OpenMP][NFC] remove globals with mlir::StateStack #144898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-fir-hlfir Author: Tom Eccles (tblah) ChangesIdea suggested by @skatrak Full diff: https://github.com/llvm/llvm-project/pull/144898.diff 4 Files Affected:
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 8ae68e143cd2f..de3e833f60699 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -26,6 +26,7 @@
namespace mlir {
class SymbolTable;
+class StateStack;
}
namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
/// functions in order to be in sync).
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+ virtual mlir::StateStack &getStateStack() = 0;
+
private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 64b16b3abe991..462ceb8dff736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -78,6 +78,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include "llvm/Target/TargetMachine.h"
+#include "mlir/Support/StateStack.h"
#include <optional>
#define DEBUG_TYPE "flang-lower-bridge"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+ mlir::StateStack &getStateStack() override { return stateStack; }
+
/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// attribute since mlirSymbolTable must pro-actively be maintained when
/// new Symbol operations are created.
mlir::SymbolTable mlirSymbolTable;
+
+ /// Used to store context while recursing into regions during lowering.
+ mlir::StateStack stateStack;
};
} // namespace
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7ad8869597274..bff3321af2814 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -38,6 +38,7 @@
#include "flang/Support/OpenMP-utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Support/StateStack.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
/// the handling of the outer region by keeping a stack of information
/// structures, but it will probably still require some further work to support
/// reverse offloading.
-static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
-static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
- sectionsStack;
+class HostEvalInfoStackFrame
+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
+
+ HostEvalInfo info;
+};
+
+static HostEvalInfo *
+getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
+ HostEvalInfoStackFrame *frame =
+ converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
+ return frame ? &frame->info : nullptr;
+}
+
+/// Stack frame for storing the OpenMPSectionsConstruct currently being
+/// processed so that it can be refered to when lowering the construct.
+class SectionsConstructStackFrame
+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
+
+ explicit SectionsConstructStackFrame(
+ const parser::OpenMPSectionsConstruct §ionsConstruct)
+ : sectionsConstruct{sectionsConstruct} {}
+
+ const parser::OpenMPSectionsConstruct §ionsConstruct;
+};
+
+static const parser::OpenMPSectionsConstruct *
+getSectionsConstructStackTop(lower::AbstractConverter &converter) {
+ SectionsConstructStackFrame *frame =
+ converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
+ return frame ? &frame->sectionsConstruct : nullptr;
+}
/// Bind symbols to their corresponding entry block arguments.
///
@@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
if (!ompEval)
return;
- HostEvalInfo &hostInfo = hostEvalInfo.back();
+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
+ assert(hostInfo && "expected HOST_EVAL info structure");
switch (extractOmpDirective(*ompEval)) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute_parallel_do:
case OMPD_target_teams_distribute_parallel_do_simd:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute_parallel_do:
case OMPD_distribute_parallel_do_simd:
- cp.processNumThreads(stmtCtx, hostInfo.ops);
+ cp.processNumThreads(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute:
case OMPD_distribute_simd:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
case OMPD_teams:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
processSingleNestedIf([](Directive nestedDir) {
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
});
@@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
case OMPD_teams_distribute:
case OMPD_teams_distribute_simd:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute:
case OMPD_target_teams_distribute_simd:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
break;
case OMPD_teams_loop:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_loop:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_loop:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
// Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
}
};
- assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
-
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
mlir::Region ®ion = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
bindEntryBlockArgs(converter, targetOp, args);
- if (!hostEvalInfo.empty())
- hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
+ if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
+ hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
cp.processCollapse(loc, eval, clauseOps, iv);
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
cp.processNumThreads(stmtCtx, clauseOps);
cp.processProcBind(clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBare(clauseOps);
cp.processDefaultMap(stmtCtx, defaultMaps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
- if (!hostEvalInfo.empty()) {
+ if (hostEvalInfo) {
// Only process host_eval if compiling for the host device.
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
- hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+ hostEvalInfo->collectValues(clauseOps.hostEvalVars);
}
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
}
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
- assert(!sectionsStack.empty());
+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
+ getSectionsConstructStackTop(converter);
+ assert(sectionsConstruct);
+
const auto §ionBlocks =
- std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
- sectionsStack.pop_back();
+ std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
+ converter.getStateStack().stackPop();
mlir::omp::SectionsOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Introduce a new host_eval information structure for this target region.
if (!isTargetDevice)
- hostEvalInfo.emplace_back();
+ converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
mlir::omp::TargetOperands clauseOps;
DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Remove the host_eval information structure created for this target region.
if (!isTargetDevice)
- hostEvalInfo.pop_back();
+ converter.getStateStack().stackPop();
return targetOp;
}
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
eval, source, directive, clauses)};
- sectionsStack.push_back(§ionsConstruct);
+ converter.getStateStack().stackPush<SectionsConstructStackFrame>(
+ sectionsConstruct);
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
queue.begin());
}
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
index aca2375028246..9641a22c47776 100644
--- a/mlir/include/mlir/Support/StateStack.h
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -83,6 +83,17 @@ class StateStack {
return WalkResult::advance();
}
+ /// Get the top instance of frame type `T` or nullptr if none are found
+ template <typename T>
+ T *getStackTop() {
+ T *top = nullptr;
+ stackWalk<T>([&](T &frame) -> mlir::WalkResult {
+ top = &frame;
+ return mlir::WalkResult::interrupt();
+ });
+ return top;
+ }
+
private:
SmallVector<std::unique_ptr<StateStackFrame>> stack;
};
|
@llvm/pr-subscribers-mlir-core Author: Tom Eccles (tblah) ChangesIdea suggested by @skatrak Full diff: https://github.com/llvm/llvm-project/pull/144898.diff 4 Files Affected:
diff --git a/flang/include/flang/Lower/AbstractConverter.h b/flang/include/flang/Lower/AbstractConverter.h
index 8ae68e143cd2f..de3e833f60699 100644
--- a/flang/include/flang/Lower/AbstractConverter.h
+++ b/flang/include/flang/Lower/AbstractConverter.h
@@ -26,6 +26,7 @@
namespace mlir {
class SymbolTable;
+class StateStack;
}
namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
/// functions in order to be in sync).
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
+ virtual mlir::StateStack &getStateStack() = 0;
+
private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 64b16b3abe991..462ceb8dff736 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -78,6 +78,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include "llvm/Target/TargetMachine.h"
+#include "mlir/Support/StateStack.h"
#include <optional>
#define DEBUG_TYPE "flang-lower-bridge"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
+ mlir::StateStack &getStateStack() override { return stateStack; }
+
/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// attribute since mlirSymbolTable must pro-actively be maintained when
/// new Symbol operations are created.
mlir::SymbolTable mlirSymbolTable;
+
+ /// Used to store context while recursing into regions during lowering.
+ mlir::StateStack stateStack;
};
} // namespace
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7ad8869597274..bff3321af2814 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -38,6 +38,7 @@
#include "flang/Support/OpenMP-utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Support/StateStack.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
/// the handling of the outer region by keeping a stack of information
/// structures, but it will probably still require some further work to support
/// reverse offloading.
-static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
-static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
- sectionsStack;
+class HostEvalInfoStackFrame
+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
+
+ HostEvalInfo info;
+};
+
+static HostEvalInfo *
+getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
+ HostEvalInfoStackFrame *frame =
+ converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
+ return frame ? &frame->info : nullptr;
+}
+
+/// Stack frame for storing the OpenMPSectionsConstruct currently being
+/// processed so that it can be refered to when lowering the construct.
+class SectionsConstructStackFrame
+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
+
+ explicit SectionsConstructStackFrame(
+ const parser::OpenMPSectionsConstruct §ionsConstruct)
+ : sectionsConstruct{sectionsConstruct} {}
+
+ const parser::OpenMPSectionsConstruct §ionsConstruct;
+};
+
+static const parser::OpenMPSectionsConstruct *
+getSectionsConstructStackTop(lower::AbstractConverter &converter) {
+ SectionsConstructStackFrame *frame =
+ converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
+ return frame ? &frame->sectionsConstruct : nullptr;
+}
/// Bind symbols to their corresponding entry block arguments.
///
@@ -537,31 +570,32 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
if (!ompEval)
return;
- HostEvalInfo &hostInfo = hostEvalInfo.back();
+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
+ assert(hostInfo && "expected HOST_EVAL info structure");
switch (extractOmpDirective(*ompEval)) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute_parallel_do:
case OMPD_target_teams_distribute_parallel_do_simd:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute_parallel_do:
case OMPD_distribute_parallel_do_simd:
- cp.processNumThreads(stmtCtx, hostInfo.ops);
+ cp.processNumThreads(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute:
case OMPD_distribute_simd:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
case OMPD_teams:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
processSingleNestedIf([](Directive nestedDir) {
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
});
@@ -569,22 +603,22 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
case OMPD_teams_distribute:
case OMPD_teams_distribute_simd:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute:
case OMPD_target_teams_distribute_simd:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
break;
case OMPD_teams_loop:
- cp.processThreadLimit(stmtCtx, hostInfo.ops);
+ cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_loop:
- cp.processNumTeams(stmtCtx, hostInfo.ops);
+ cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_loop:
- cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
+ cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;
// Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
}
};
- assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
-
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
mlir::Region ®ion = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
bindEntryBlockArgs(converter, targetOp, args);
- if (!hostEvalInfo.empty())
- hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
+ if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
+ hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
cp.processCollapse(loc, eval, clauseOps, iv);
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
cp.processNumThreads(stmtCtx, clauseOps);
cp.processProcBind(clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processBare(clauseOps);
cp.processDefaultMap(stmtCtx, defaultMaps);
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
- if (!hostEvalInfo.empty()) {
+ if (hostEvalInfo) {
// Only process host_eval if compiling for the host device.
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
- hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
+ hostEvalInfo->collectValues(clauseOps.hostEvalVars);
}
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
- if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
+ if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
}
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
- assert(!sectionsStack.empty());
+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
+ getSectionsConstructStackTop(converter);
+ assert(sectionsConstruct);
+
const auto §ionBlocks =
- std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
- sectionsStack.pop_back();
+ std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
+ converter.getStateStack().stackPop();
mlir::omp::SectionsOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Introduce a new host_eval information structure for this target region.
if (!isTargetDevice)
- hostEvalInfo.emplace_back();
+ converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
mlir::omp::TargetOperands clauseOps;
DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
// Remove the host_eval information structure created for this target region.
if (!isTargetDevice)
- hostEvalInfo.pop_back();
+ converter.getStateStack().stackPop();
return targetOp;
}
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
eval, source, directive, clauses)};
- sectionsStack.push_back(§ionsConstruct);
+ converter.getStateStack().stackPush<SectionsConstructStackFrame>(
+ sectionsConstruct);
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
queue.begin());
}
diff --git a/mlir/include/mlir/Support/StateStack.h b/mlir/include/mlir/Support/StateStack.h
index aca2375028246..9641a22c47776 100644
--- a/mlir/include/mlir/Support/StateStack.h
+++ b/mlir/include/mlir/Support/StateStack.h
@@ -83,6 +83,17 @@ class StateStack {
return WalkResult::advance();
}
+ /// Get the top instance of frame type `T` or nullptr if none are found
+ template <typename T>
+ T *getStackTop() {
+ T *top = nullptr;
+ stackWalk<T>([&](T &frame) -> mlir::WalkResult {
+ top = &frame;
+ return mlir::WalkResult::interrupt();
+ });
+ return top;
+ }
+
private:
SmallVector<std::unique_ptr<StateStackFrame>> stack;
};
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
280e55d
to
8b0e0a3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Tom for working on this, I think it's a very nice improvement! LGTM.
flang/lib/Lower/OpenMP/OpenMP.cpp
Outdated
} | ||
|
||
/// Stack frame for storing the OpenMPSectionsConstruct currently being | ||
/// processed so that it can be refered to when lowering the construct. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// processed so that it can be refered to when lowering the construct. | |
/// processed so that it can be referred to when lowering the construct. |
flang/lib/Lower/OpenMP/OpenMP.cpp
Outdated
ClauseProcessor cp(converter, semaCtx, clauses); | ||
cp.processBare(clauseOps); | ||
cp.processDefaultMap(stmtCtx, defaultMaps); | ||
cp.processDepend(symTable, stmtCtx, clauseOps); | ||
cp.processDevice(stmtCtx, clauseOps); | ||
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms); | ||
if (!hostEvalInfo.empty()) { | ||
if (hostEvalInfo) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: For consistency with genBodyOfTargetOp
, and general conciseness
if (hostEvalInfo) { | |
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) { |
flang/lib/Lower/OpenMP/OpenMP.cpp
Outdated
assert(!sectionsStack.empty()); | ||
const parser::OpenMPSectionsConstruct *sectionsConstruct = | ||
getSectionsConstructStackTop(converter); | ||
assert(sectionsConstruct); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add small message to the assert.
flang/lib/Lower/OpenMP/OpenMP.cpp
Outdated
std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t); | ||
sectionsStack.pop_back(); | ||
std::get<parser::OmpSectionBlocks>(sectionsConstruct->t); | ||
converter.getStateStack().stackPop(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Wouldn't it be possible to let this call be handled by the same function that pushes the stack frame? We could potentially use SaveStateStack
then.
8b0e0a3
to
ec40d1a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks
ec40d1a
to
89f1ee4
Compare
Idea suggested by @skatrak