Skip to content

Commit 280e55d

Browse files
committed
[flang][OpenMP][NFC] remove globals with mlir::StateStack
Idea suggested by @skatrak
1 parent 7ac1750 commit 280e55d

File tree

4 files changed

+91
-31
lines changed

4 files changed

+91
-31
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
namespace mlir {
2828
class SymbolTable;
29+
class StateStack;
2930
}
3031

3132
namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
361362
/// functions in order to be in sync).
362363
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
363364

365+
virtual mlir::StateStack &getStateStack() = 0;
366+
364367
private:
365368
/// Options controlling lowering behavior.
366369
const Fortran::lower::LoweringOptions &loweringOptions;

flang/lib/Lower/Bridge.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
#include "llvm/Support/FileSystem.h"
7979
#include "llvm/Support/Path.h"
8080
#include "llvm/Target/TargetMachine.h"
81+
#include "mlir/Support/StateStack.h"
8182
#include <optional>
8283

8384
#define DEBUG_TYPE "flang-lower-bridge"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12371238

12381239
mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
12391240

1241+
mlir::StateStack &getStateStack() override { return stateStack; }
1242+
12401243
/// Add the symbol to the local map and return `true`. If the symbol is
12411244
/// already in the map and \p forced is `false`, the map is not updated.
12421245
/// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
65526555
/// attribute since mlirSymbolTable must pro-actively be maintained when
65536556
/// new Symbol operations are created.
65546557
mlir::SymbolTable mlirSymbolTable;
6558+
6559+
/// Used to store context while recursing into regions during lowering.
6560+
mlir::StateStack stateStack;
65556561
};
65566562

65576563
} // namespace

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 71 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "flang/Support/OpenMP-utils.h"
3939
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4040
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
41+
#include "mlir/Support/StateStack.h"
4142
#include "mlir/Transforms/RegionUtils.h"
4243
#include "llvm/ADT/STLExtras.h"
4344
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
200201
/// the handling of the outer region by keeping a stack of information
201202
/// structures, but it will probably still require some further work to support
202203
/// reverse offloading.
203-
static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
204-
static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
205-
sectionsStack;
204+
class HostEvalInfoStackFrame
205+
: public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
206+
public:
207+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
208+
209+
HostEvalInfo info;
210+
};
211+
212+
static HostEvalInfo *
213+
getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
214+
HostEvalInfoStackFrame *frame =
215+
converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
216+
return frame ? &frame->info : nullptr;
217+
}
218+
219+
/// Stack frame for storing the OpenMPSectionsConstruct currently being
220+
/// processed so that it can be refered to when lowering the construct.
221+
class SectionsConstructStackFrame
222+
: public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
223+
public:
224+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
225+
226+
explicit SectionsConstructStackFrame(
227+
const parser::OpenMPSectionsConstruct &sectionsConstruct)
228+
: sectionsConstruct{sectionsConstruct} {}
229+
230+
const parser::OpenMPSectionsConstruct &sectionsConstruct;
231+
};
232+
233+
static const parser::OpenMPSectionsConstruct *
234+
getSectionsConstructStackTop(lower::AbstractConverter &converter) {
235+
SectionsConstructStackFrame *frame =
236+
converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
237+
return frame ? &frame->sectionsConstruct : nullptr;
238+
}
206239

207240
/// Bind symbols to their corresponding entry block arguments.
208241
///
@@ -537,54 +570,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
537570
if (!ompEval)
538571
return;
539572

540-
HostEvalInfo &hostInfo = hostEvalInfo.back();
573+
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
574+
assert(hostInfo && "expected HOST_EVAL info structure");
541575

542576
switch (extractOmpDirective(*ompEval)) {
543577
case OMPD_teams_distribute_parallel_do:
544578
case OMPD_teams_distribute_parallel_do_simd:
545-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
579+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
546580
[[fallthrough]];
547581
case OMPD_target_teams_distribute_parallel_do:
548582
case OMPD_target_teams_distribute_parallel_do_simd:
549-
cp.processNumTeams(stmtCtx, hostInfo.ops);
583+
cp.processNumTeams(stmtCtx, hostInfo->ops);
550584
[[fallthrough]];
551585
case OMPD_distribute_parallel_do:
552586
case OMPD_distribute_parallel_do_simd:
553-
cp.processNumThreads(stmtCtx, hostInfo.ops);
587+
cp.processNumThreads(stmtCtx, hostInfo->ops);
554588
[[fallthrough]];
555589
case OMPD_distribute:
556590
case OMPD_distribute_simd:
557-
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
591+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
558592
break;
559593

560594
case OMPD_teams:
561-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
595+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
562596
[[fallthrough]];
563597
case OMPD_target_teams:
564-
cp.processNumTeams(stmtCtx, hostInfo.ops);
598+
cp.processNumTeams(stmtCtx, hostInfo->ops);
565599
processSingleNestedIf([](Directive nestedDir) {
566600
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
567601
});
568602
break;
569603

570604
case OMPD_teams_distribute:
571605
case OMPD_teams_distribute_simd:
572-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
606+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
573607
[[fallthrough]];
574608
case OMPD_target_teams_distribute:
575609
case OMPD_target_teams_distribute_simd:
576-
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
577-
cp.processNumTeams(stmtCtx, hostInfo.ops);
610+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
611+
cp.processNumTeams(stmtCtx, hostInfo->ops);
578612
break;
579613

580614
case OMPD_teams_loop:
581-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
615+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
582616
[[fallthrough]];
583617
case OMPD_target_teams_loop:
584-
cp.processNumTeams(stmtCtx, hostInfo.ops);
618+
cp.processNumTeams(stmtCtx, hostInfo->ops);
585619
[[fallthrough]];
586620
case OMPD_loop:
587-
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
621+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
588622
break;
589623

590624
// Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
598632
}
599633
};
600634

601-
assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
602-
603635
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
604636
assert(ompEval &&
605637
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
14681500
mlir::Region &region = targetOp.getRegion();
14691501
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
14701502
bindEntryBlockArgs(converter, targetOp, args);
1471-
if (!hostEvalInfo.empty())
1472-
hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
1503+
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
1504+
hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
14731505

14741506
// Check if cloning the bounds introduced any dependency on the outer region.
14751507
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
17081740
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
17091741
ClauseProcessor cp(converter, semaCtx, clauses);
17101742

1711-
if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
1743+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
1744+
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
17121745
cp.processCollapse(loc, eval, clauseOps, iv);
17131746

17141747
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
17531786
cp.processAllocate(clauseOps);
17541787
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
17551788

1756-
if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
1789+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
1790+
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
17571791
cp.processNumThreads(stmtCtx, clauseOps);
17581792

17591793
cp.processProcBind(clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
18181852
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
18191853
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
18201854
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
1855+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
18211856
ClauseProcessor cp(converter, semaCtx, clauses);
18221857
cp.processBare(clauseOps);
18231858
cp.processDefaultMap(stmtCtx, defaultMaps);
18241859
cp.processDepend(symTable, stmtCtx, clauseOps);
18251860
cp.processDevice(stmtCtx, clauseOps);
18261861
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
1827-
if (!hostEvalInfo.empty()) {
1862+
if (hostEvalInfo) {
18281863
// Only process host_eval if compiling for the host device.
18291864
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
1830-
hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
1865+
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
18311866
}
18321867
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
18331868
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
19631998
cp.processAllocate(clauseOps);
19641999
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
19652000

1966-
if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
2001+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
2002+
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
19672003
cp.processNumTeams(stmtCtx, clauseOps);
19682004
cp.processThreadLimit(stmtCtx, clauseOps);
19692005
}
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22242260
lower::pft::Evaluation &eval, mlir::Location loc,
22252261
const ConstructQueue &queue,
22262262
ConstructQueue::const_iterator item) {
2227-
assert(!sectionsStack.empty());
2263+
const parser::OpenMPSectionsConstruct *sectionsConstruct =
2264+
getSectionsConstructStackTop(converter);
2265+
assert(sectionsConstruct);
2266+
22282267
const auto &sectionBlocks =
2229-
std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
2230-
sectionsStack.pop_back();
2268+
std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
2269+
converter.getStateStack().stackPop();
22312270
mlir::omp::SectionsOperands clauseOps;
22322271
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
22332272
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23812420

23822421
// Introduce a new host_eval information structure for this target region.
23832422
if (!isTargetDevice)
2384-
hostEvalInfo.emplace_back();
2423+
converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
23852424

23862425
mlir::omp::TargetOperands clauseOps;
23872426
DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
25082547

25092548
// Remove the host_eval information structure created for this target region.
25102549
if (!isTargetDevice)
2511-
hostEvalInfo.pop_back();
2550+
converter.getStateStack().stackPop();
25122551
return targetOp;
25132552
}
25142553

@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
42354274
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
42364275
eval, source, directive, clauses)};
42374276

4238-
sectionsStack.push_back(&sectionsConstruct);
4277+
converter.getStateStack().stackPush<SectionsConstructStackFrame>(
4278+
sectionsConstruct);
42394279
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
42404280
queue.begin());
42414281
}

mlir/include/mlir/Support/StateStack.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ class StateStack {
8383
return WalkResult::advance();
8484
}
8585

86+
/// Get the top instance of frame type `T` or nullptr if none are found
87+
template <typename T>
88+
T *getStackTop() {
89+
T *top = nullptr;
90+
stackWalk<T>([&](T &frame) -> mlir::WalkResult {
91+
top = &frame;
92+
return mlir::WalkResult::interrupt();
93+
});
94+
return top;
95+
}
96+
8697
private:
8798
SmallVector<std::unique_ptr<StateStackFrame>> stack;
8899
};

0 commit comments

Comments
 (0)