Skip to content

Commit 0382b23

Browse files
tblahAnthony Tran
authored andcommitted
[flang][OpenMP][NFC] remove globals with mlir::StateStack (llvm#144898)
Idea suggested by @skatrak
1 parent 95ab482 commit 0382b23

File tree

4 files changed

+89
-34
lines changed

4 files changed

+89
-34
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
namespace mlir {
2828
class SymbolTable;
29+
class StateStack;
2930
}
3031

3132
namespace fir {
@@ -361,6 +362,8 @@ class AbstractConverter {
361362
/// functions in order to be in sync).
362363
virtual mlir::SymbolTable *getMLIRSymbolTable() = 0;
363364

365+
virtual mlir::StateStack &getStateStack() = 0;
366+
364367
private:
365368
/// Options controlling lowering behavior.
366369
const Fortran::lower::LoweringOptions &loweringOptions;

flang/lib/Lower/Bridge.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
#include "mlir/IR/Matchers.h"
7070
#include "mlir/IR/PatternMatch.h"
7171
#include "mlir/Parser/Parser.h"
72+
#include "mlir/Support/StateStack.h"
7273
#include "mlir/Transforms/RegionUtils.h"
7374
#include "llvm/ADT/SmallVector.h"
7475
#include "llvm/ADT/StringSet.h"
@@ -1237,6 +1238,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
12371238

12381239
mlir::SymbolTable *getMLIRSymbolTable() override { return &mlirSymbolTable; }
12391240

1241+
mlir::StateStack &getStateStack() override { return stateStack; }
1242+
12401243
/// Add the symbol to the local map and return `true`. If the symbol is
12411244
/// already in the map and \p forced is `false`, the map is not updated.
12421245
/// Instead the value `false` is returned.
@@ -6552,6 +6555,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
65526555
/// attribute since mlirSymbolTable must pro-actively be maintained when
65536556
/// new Symbol operations are created.
65546557
mlir::SymbolTable mlirSymbolTable;
6558+
6559+
/// Used to store context while recursing into regions during lowering.
6560+
mlir::StateStack stateStack;
65556561
};
65566562

65576563
} // namespace

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "flang/Support/OpenMP-utils.h"
4040
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
4141
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
42+
#include "mlir/Support/StateStack.h"
4243
#include "mlir/Transforms/RegionUtils.h"
4344
#include "llvm/ADT/STLExtras.h"
4445
#include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -198,9 +199,41 @@ class HostEvalInfo {
198199
/// the handling of the outer region by keeping a stack of information
199200
/// structures, but it will probably still require some further work to support
200201
/// reverse offloading.
201-
static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo;
202-
static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0>
203-
sectionsStack;
202+
class HostEvalInfoStackFrame
203+
: public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
204+
public:
205+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HostEvalInfoStackFrame)
206+
207+
HostEvalInfo info;
208+
};
209+
210+
static HostEvalInfo *
211+
getHostEvalInfoStackTop(lower::AbstractConverter &converter) {
212+
HostEvalInfoStackFrame *frame =
213+
converter.getStateStack().getStackTop<HostEvalInfoStackFrame>();
214+
return frame ? &frame->info : nullptr;
215+
}
216+
217+
/// Stack frame for storing the OpenMPSectionsConstruct currently being
218+
/// processed so that it can be referred to when lowering the construct.
219+
class SectionsConstructStackFrame
220+
: public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
221+
public:
222+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SectionsConstructStackFrame)
223+
224+
explicit SectionsConstructStackFrame(
225+
const parser::OpenMPSectionsConstruct &sectionsConstruct)
226+
: sectionsConstruct{sectionsConstruct} {}
227+
228+
const parser::OpenMPSectionsConstruct &sectionsConstruct;
229+
};
230+
231+
static const parser::OpenMPSectionsConstruct *
232+
getSectionsConstructStackTop(lower::AbstractConverter &converter) {
233+
SectionsConstructStackFrame *frame =
234+
converter.getStateStack().getStackTop<SectionsConstructStackFrame>();
235+
return frame ? &frame->sectionsConstruct : nullptr;
236+
}
204237

205238
/// Bind symbols to their corresponding entry block arguments.
206239
///
@@ -535,54 +568,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
535568
if (!ompEval)
536569
return;
537570

538-
HostEvalInfo &hostInfo = hostEvalInfo.back();
571+
HostEvalInfo *hostInfo = getHostEvalInfoStackTop(converter);
572+
assert(hostInfo && "expected HOST_EVAL info structure");
539573

540574
switch (extractOmpDirective(*ompEval)) {
541575
case OMPD_teams_distribute_parallel_do:
542576
case OMPD_teams_distribute_parallel_do_simd:
543-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
577+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
544578
[[fallthrough]];
545579
case OMPD_target_teams_distribute_parallel_do:
546580
case OMPD_target_teams_distribute_parallel_do_simd:
547-
cp.processNumTeams(stmtCtx, hostInfo.ops);
581+
cp.processNumTeams(stmtCtx, hostInfo->ops);
548582
[[fallthrough]];
549583
case OMPD_distribute_parallel_do:
550584
case OMPD_distribute_parallel_do_simd:
551-
cp.processNumThreads(stmtCtx, hostInfo.ops);
585+
cp.processNumThreads(stmtCtx, hostInfo->ops);
552586
[[fallthrough]];
553587
case OMPD_distribute:
554588
case OMPD_distribute_simd:
555-
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
589+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
556590
break;
557591

558592
case OMPD_teams:
559-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
593+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
560594
[[fallthrough]];
561595
case OMPD_target_teams:
562-
cp.processNumTeams(stmtCtx, hostInfo.ops);
596+
cp.processNumTeams(stmtCtx, hostInfo->ops);
563597
processSingleNestedIf([](Directive nestedDir) {
564598
return topDistributeSet.test(nestedDir) || topLoopSet.test(nestedDir);
565599
});
566600
break;
567601

568602
case OMPD_teams_distribute:
569603
case OMPD_teams_distribute_simd:
570-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
604+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
571605
[[fallthrough]];
572606
case OMPD_target_teams_distribute:
573607
case OMPD_target_teams_distribute_simd:
574-
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
575-
cp.processNumTeams(stmtCtx, hostInfo.ops);
608+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
609+
cp.processNumTeams(stmtCtx, hostInfo->ops);
576610
break;
577611

578612
case OMPD_teams_loop:
579-
cp.processThreadLimit(stmtCtx, hostInfo.ops);
613+
cp.processThreadLimit(stmtCtx, hostInfo->ops);
580614
[[fallthrough]];
581615
case OMPD_target_teams_loop:
582-
cp.processNumTeams(stmtCtx, hostInfo.ops);
616+
cp.processNumTeams(stmtCtx, hostInfo->ops);
583617
[[fallthrough]];
584618
case OMPD_loop:
585-
cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv);
619+
cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv);
586620
break;
587621

588622
// Standalone 'target' case.
@@ -596,8 +630,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
596630
}
597631
};
598632

599-
assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure");
600-
601633
const auto *ompEval = eval.getIf<parser::OpenMPConstruct>();
602634
assert(ompEval &&
603635
llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) &&
@@ -1456,8 +1488,8 @@ static void genBodyOfTargetOp(
14561488
mlir::Region &region = targetOp.getRegion();
14571489
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
14581490
bindEntryBlockArgs(converter, targetOp, args);
1459-
if (!hostEvalInfo.empty())
1460-
hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs());
1491+
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter))
1492+
hostEvalInfo->bindOperands(argIface.getHostEvalBlockArgs());
14611493

14621494
// Check if cloning the bounds introduced any dependency on the outer region.
14631495
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1696,7 +1728,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
16961728
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
16971729
ClauseProcessor cp(converter, semaCtx, clauses);
16981730

1699-
if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv))
1731+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
1732+
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps, iv))
17001733
cp.processCollapse(loc, eval, clauseOps, iv);
17011734

17021735
clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr();
@@ -1741,7 +1774,8 @@ static void genParallelClauses(
17411774
cp.processAllocate(clauseOps);
17421775
cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps);
17431776

1744-
if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps))
1777+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
1778+
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps))
17451779
cp.processNumThreads(stmtCtx, clauseOps);
17461780

17471781
cp.processProcBind(clauseOps);
@@ -1812,10 +1846,10 @@ static void genTargetClauses(
18121846
cp.processDepend(symTable, stmtCtx, clauseOps);
18131847
cp.processDevice(stmtCtx, clauseOps);
18141848
cp.processHasDeviceAddr(stmtCtx, clauseOps, hasDeviceAddrSyms);
1815-
if (!hostEvalInfo.empty()) {
1849+
if (HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter)) {
18161850
// Only process host_eval if compiling for the host device.
18171851
processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc);
1818-
hostEvalInfo.back().collectValues(clauseOps.hostEvalVars);
1852+
hostEvalInfo->collectValues(clauseOps.hostEvalVars);
18191853
}
18201854
cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps);
18211855
cp.processIsDevicePtr(clauseOps, isDevicePtrSyms);
@@ -1952,7 +1986,8 @@ static void genTeamsClauses(
19521986
cp.processAllocate(clauseOps);
19531987
cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps);
19541988

1955-
if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) {
1989+
HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop(converter);
1990+
if (!hostEvalInfo || !hostEvalInfo->apply(clauseOps)) {
19561991
cp.processNumTeams(stmtCtx, clauseOps);
19571992
cp.processThreadLimit(stmtCtx, clauseOps);
19581993
}
@@ -2204,19 +2239,18 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22042239
converter.getCurrentLocation(), clauseOps);
22052240
}
22062241

2207-
/// This breaks the normal prototype of the gen*Op functions: adding the
2208-
/// sectionBlocks argument so that the enclosed section constructs can be
2209-
/// lowered here with correct reduction symbol remapping.
22102242
static mlir::omp::SectionsOp
22112243
genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
22122244
semantics::SemanticsContext &semaCtx,
22132245
lower::pft::Evaluation &eval, mlir::Location loc,
22142246
const ConstructQueue &queue,
22152247
ConstructQueue::const_iterator item) {
2216-
assert(!sectionsStack.empty());
2248+
const parser::OpenMPSectionsConstruct *sectionsConstruct =
2249+
getSectionsConstructStackTop(converter);
2250+
assert(sectionsConstruct && "Missing additional parsing information");
2251+
22172252
const auto &sectionBlocks =
2218-
std::get<parser::OmpSectionBlocks>(sectionsStack.back()->t);
2219-
sectionsStack.pop_back();
2253+
std::get<parser::OmpSectionBlocks>(sectionsConstruct->t);
22202254
mlir::omp::SectionsOperands clauseOps;
22212255
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
22222256
genSectionsClauses(converter, semaCtx, item->clauses, loc, clauseOps,
@@ -2370,7 +2404,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
23702404

23712405
// Introduce a new host_eval information structure for this target region.
23722406
if (!isTargetDevice)
2373-
hostEvalInfo.emplace_back();
2407+
converter.getStateStack().stackPush<HostEvalInfoStackFrame>();
23742408

23752409
mlir::omp::TargetOperands clauseOps;
23762410
DefaultMapsTy defaultMaps;
@@ -2497,7 +2531,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
24972531

24982532
// Remove the host_eval information structure created for this target region.
24992533
if (!isTargetDevice)
2500-
hostEvalInfo.pop_back();
2534+
converter.getStateStack().stackPop();
25012535
return targetOp;
25022536
}
25032537

@@ -3771,7 +3805,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
37713805
buildConstructQueue(converter.getFirOpBuilder().getModule(), semaCtx,
37723806
eval, source, directive, clauses)};
37733807

3774-
sectionsStack.push_back(&sectionsConstruct);
3808+
mlir::SaveStateStack<SectionsConstructStackFrame> saveStateStack{
3809+
converter.getStateStack(), sectionsConstruct};
37753810
genOMPDispatch(converter, symTable, semaCtx, eval, currentLocation, queue,
37763811
queue.begin());
37773812
}

mlir/include/mlir/Support/StateStack.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,17 @@ class StateStack {
8484
return WalkResult::advance();
8585
}
8686

87+
/// Get the top instance of frame type `T` or nullptr if none are found
88+
template <typename T>
89+
T *getStackTop() {
90+
T *top = nullptr;
91+
stackWalk<T>([&](T &frame) -> mlir::WalkResult {
92+
top = &frame;
93+
return mlir::WalkResult::interrupt();
94+
});
95+
return top;
96+
}
97+
8798
private:
8899
SmallVector<std::unique_ptr<StateStackFrame>> stack;
89100
};

0 commit comments

Comments
 (0)