Skip to content

[flang][OpenMP][NFC] remove globals with mlir::StateStack #144898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flang/include/flang/Lower/AbstractConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

namespace mlir {
class SymbolTable;
class StateStack;
}

namespace fir {
Expand Down Expand Up @@ -361,6 +362,8 @@ class AbstractConverter {
/// functions in order to be in sync).
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;

virtual mlir::StateStack &getStateStack() = 0;

private:
/// Options controlling lowering behavior.
const Fortran::lower::LoweringOptions &loweringOptions;
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/StateStack.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
Expand Down Expand Up @@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {

mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }

mlir::StateStack &getStateStack() override { return stateStack; }

/// Add the symbol to the local map and return `true`. If the symbol is
/// already in the map and \p forced is `false`, the map is not updated.
/// Instead the value `false` is returned.
Expand Down Expand Up @@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
/// attribute since mlirSymbolTable must pro-actively be maintained when
/// new Symbol operations are created.
mlir::SymbolTable mlirSymbolTable;

/// Used to store context while recursing into regions during lowering.
mlir::StateStack stateStack;
};

} // namespace
Expand Down
103 changes: 69 additions & 34 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "flang/Support/OpenMP-utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Support/StateStack.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
Expand Down Expand Up @@ -198,9 +199,41 @@ class HostEvalInfo {
/// the handling of the outer region by keeping a stack of information
/// structures, but it will probably still require some further work to support
/// reverse offloading.
static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
sectionsStack;
class HostEvalInfoStackFrame
: public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)

HostEvalInfo info;
};

static HostEvalInfo *
getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
HostEvalInfoStackFrame *frame =
converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
return frame ? &frame->info : nullptr;
}

/// Stack frame for storing the OpenMPSectionsConstruct currently being
/// processed so that it can be referred to when lowering the construct.
class SectionsConstructStackFrame
: public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)

explicit SectionsConstructStackFrame(
const parser::OpenMPSectionsConstruct &sectionsConstruct)
: sectionsConstruct{sectionsConstruct} {}

const parser::OpenMPSectionsConstruct &sectionsConstruct;
};

static const parser::OpenMPSectionsConstruct *
getSectionsConstructStackTop(lower::AbstractConverter &converter) {
SectionsConstructStackFrame *frame =
converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
return frame ? &frame->sectionsConstruct : nullptr;
}

/// Bind symbols to their corresponding entry block arguments.
///
Expand Down Expand Up @@ -535,54 +568,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
if (!ompEval)
return;

HostEvalInfo &hostInfo = hostEvalInfo.back();
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
assert(hostInfo && "expected HOST_EVAL info structure");

switch (extractOmpDirective(*ompEval)) {
case OMPD_teams_distribute_parallel_do:
case OMPD_teams_distribute_parallel_do_simd:
cp.processThreadLimit(stmtCtx, hostInfo.ops);
cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute_parallel_do:
case OMPD_target_teams_distribute_parallel_do_simd:
cp.processNumTeams(stmtCtx, hostInfo.ops);
cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute_parallel_do:
case OMPD_distribute_parallel_do_simd:
cp.processNumThreads(stmtCtx, hostInfo.ops);
cp.processNumThreads(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_distribute:
case OMPD_distribute_simd:
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;

case OMPD_teams:
cp.processThreadLimit(stmtCtx, hostInfo.ops);
cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams:
cp.processNumTeams(stmtCtx, hostInfo.ops);
cp.processNumTeams(stmtCtx, hostInfo->ops);
processSingleNestedIf([](Directive nestedDir) {
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
});
break;

case OMPD_teams_distribute:
case OMPD_teams_distribute_simd:
cp.processThreadLimit(stmtCtx, hostInfo.ops);
cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_distribute:
case OMPD_target_teams_distribute_simd:
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
cp.processNumTeams(stmtCtx, hostInfo.ops);
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
cp.processNumTeams(stmtCtx, hostInfo->ops);
break;

case OMPD_teams_loop:
cp.processThreadLimit(stmtCtx, hostInfo.ops);
cp.processThreadLimit(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_target_teams_loop:
cp.processNumTeams(stmtCtx, hostInfo.ops);
cp.processNumTeams(stmtCtx, hostInfo->ops);
[[fallthrough]];
case OMPD_loop:
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
break;

// Standalone 'target' case.
Expand All @@ -596,8 +630,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
}
};

assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");

const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
assert(ompEval &&
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
Expand Down Expand Up @@ -1456,8 +1488,8 @@ static void genBodyOfTargetOp(
mlir::Region &region = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
bindEntryBlockArgs(converter, targetOp, args);
if (!hostEvalInfo.empty())
hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());

// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
Expand Down Expand Up @@ -1696,7 +1728,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
ClauseProcessor cp(converter, semaCtx, clauses);

if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
cp.processCollapse(loc, eval, clauseOps, iv);

clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
Expand Down Expand Up @@ -1741,7 +1774,8 @@ static void genParallelClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);

if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
cp.processNumThreads(stmtCtx, clauseOps);

cp.processProcBind(clauseOps);
Expand Down Expand Up @@ -1812,10 +1846,10 @@ static void genTargetClauses(
cp.processDepend(symTable, stmtCtx, clauseOps);
cp.processDevice(stmtCtx, clauseOps);
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
if (!hostEvalInfo.empty()) {
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) {
// Only process host_eval if compiling for the host device.
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
}
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
Expand Down Expand Up @@ -1952,7 +1986,8 @@ static void genTeamsClauses(
cp.processAllocate(clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);

if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
cp.processNumTeams(stmtCtx, clauseOps);
cp.processThreadLimit(stmtCtx, clauseOps);
}
Expand Down Expand Up @@ -2204,19 +2239,18 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
converter.getCurrentLocation(), clauseOps);
}

/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
static mlir::omp::SectionsOp
genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
lower::pft::Evaluation &eval, mlir::Location loc,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
assert(!sectionsStack.empty());
const parser::OpenMPSectionsConstruct *sectionsConstruct =
getSectionsConstructStackTop(converter);
assert(sectionsConstruct && "Missing additional parsing information");

const auto &sectionBlocks =
std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
sectionsStack.pop_back();
std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
mlir::omp::SectionsOperands clauseOps;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
Expand Down Expand Up @@ -2370,7 +2404,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,

// Introduce a new host_eval information structure for this target region.
if (!isTargetDevice)
hostEvalInfo.emplace_back();
converter.getStateStack().stackPush<HostEvalInfoStackFrame>();

mlir::omp::TargetOperands clauseOps;
DefaultMapsTy defaultMaps;
Expand Down Expand Up @@ -2497,7 +2531,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,

// Remove the host_eval information structure created for this target region.
if (!isTargetDevice)
hostEvalInfo.pop_back();
converter.getStateStack().stackPop();
return targetOp;
}

Expand Down Expand Up @@ -3771,7 +3805,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
eval, source, directive, clauses)};

sectionsStack.push_back(&sectionsConstruct);
mlir::SaveStateStack<SectionsConstructStackFrame> saveStateStack{
converter.getStateStack(), sectionsConstruct};
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
queue.begin());
}
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Support/StateStack.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ class StateStack {
return WalkResult::advance();
}

/// Get the top instance of frame type `T` or nullptr if none are found
template <typename T>
T *getStackTop() {
T *top = nullptr;
stackWalk<T>([&](T &frame) -> mlir::WalkResult {
top = &frame;
return mlir::WalkResult::interrupt();
});
return top;
}

private:
SmallVector<std::unique_ptr<StateStackFrame>> stack;
};
Expand Down
Loading