39
39
#include " flang/Support/OpenMP-utils.h"
40
40
#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
41
41
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
42
+ #include " mlir/Support/StateStack.h"
42
43
#include " mlir/Transforms/RegionUtils.h"
43
44
#include " llvm/ADT/STLExtras.h"
44
45
#include " llvm/Frontend/OpenMP/OMPConstants.h"
@@ -198,9 +199,41 @@ class HostEvalInfo {
198
199
// / the handling of the outer region by keeping a stack of information
199
200
// / structures, but it will probably still require some further work to support
200
201
// / reverse offloading.
201
- static llvm::SmallVector<HostEvalInfo, 0 > hostEvalInfo;
202
- static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0 >
203
- sectionsStack;
202
+ class HostEvalInfoStackFrame
203
+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
204
+ public:
205
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (HostEvalInfoStackFrame)
206
+
207
+ HostEvalInfo info;
208
+ };
209
+
210
+ static HostEvalInfo *
211
+ getHostEvalInfoStackTop (lower::AbstractConverter &converter) {
212
+ HostEvalInfoStackFrame *frame =
213
+ converter.getStateStack ().getStackTop <HostEvalInfoStackFrame>();
214
+ return frame ? &frame->info : nullptr ;
215
+ }
216
+
217
+ // / Stack frame for storing the OpenMPSectionsConstruct currently being
218
+ // / processed so that it can be referred to when lowering the construct.
219
+ class SectionsConstructStackFrame
220
+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
221
+ public:
222
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (SectionsConstructStackFrame)
223
+
224
+ explicit SectionsConstructStackFrame (
225
+ const parser::OpenMPSectionsConstruct §ionsConstruct)
226
+ : sectionsConstruct{sectionsConstruct} {}
227
+
228
+ const parser::OpenMPSectionsConstruct §ionsConstruct;
229
+ };
230
+
231
+ static const parser::OpenMPSectionsConstruct *
232
+ getSectionsConstructStackTop (lower::AbstractConverter &converter) {
233
+ SectionsConstructStackFrame *frame =
234
+ converter.getStateStack ().getStackTop <SectionsConstructStackFrame>();
235
+ return frame ? &frame->sectionsConstruct : nullptr ;
236
+ }
204
237
205
238
// / Bind symbols to their corresponding entry block arguments.
206
239
// /
@@ -535,54 +568,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
535
568
if (!ompEval)
536
569
return ;
537
570
538
- HostEvalInfo &hostInfo = hostEvalInfo.back ();
571
+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop (converter);
572
+ assert (hostInfo && " expected HOST_EVAL info structure" );
539
573
540
574
switch (extractOmpDirective (*ompEval)) {
541
575
case OMPD_teams_distribute_parallel_do:
542
576
case OMPD_teams_distribute_parallel_do_simd:
543
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
577
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
544
578
[[fallthrough]];
545
579
case OMPD_target_teams_distribute_parallel_do:
546
580
case OMPD_target_teams_distribute_parallel_do_simd:
547
- cp.processNumTeams (stmtCtx, hostInfo. ops );
581
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
548
582
[[fallthrough]];
549
583
case OMPD_distribute_parallel_do:
550
584
case OMPD_distribute_parallel_do_simd:
551
- cp.processNumThreads (stmtCtx, hostInfo. ops );
585
+ cp.processNumThreads (stmtCtx, hostInfo-> ops );
552
586
[[fallthrough]];
553
587
case OMPD_distribute:
554
588
case OMPD_distribute_simd:
555
- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
589
+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
556
590
break ;
557
591
558
592
case OMPD_teams:
559
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
593
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
560
594
[[fallthrough]];
561
595
case OMPD_target_teams:
562
- cp.processNumTeams (stmtCtx, hostInfo. ops );
596
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
563
597
processSingleNestedIf ([](Directive nestedDir) {
564
598
return topDistributeSet.test (nestedDir) || topLoopSet.test (nestedDir);
565
599
});
566
600
break ;
567
601
568
602
case OMPD_teams_distribute:
569
603
case OMPD_teams_distribute_simd:
570
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
604
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
571
605
[[fallthrough]];
572
606
case OMPD_target_teams_distribute:
573
607
case OMPD_target_teams_distribute_simd:
574
- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
575
- cp.processNumTeams (stmtCtx, hostInfo. ops );
608
+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
609
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
576
610
break ;
577
611
578
612
case OMPD_teams_loop:
579
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
613
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
580
614
[[fallthrough]];
581
615
case OMPD_target_teams_loop:
582
- cp.processNumTeams (stmtCtx, hostInfo. ops );
616
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
583
617
[[fallthrough]];
584
618
case OMPD_loop:
585
- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
619
+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
586
620
break ;
587
621
588
622
// Standalone 'target' case.
@@ -596,8 +630,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
596
630
}
597
631
};
598
632
599
- assert (!hostEvalInfo.empty () && " expected HOST_EVAL info structure" );
600
-
601
633
const auto *ompEval = eval.getIf <parser::OpenMPConstruct>();
602
634
assert (ompEval &&
603
635
llvm::omp::allTargetSet.test (extractOmpDirective (*ompEval)) &&
@@ -1456,8 +1488,8 @@ static void genBodyOfTargetOp(
1456
1488
mlir::Region ®ion = targetOp.getRegion ();
1457
1489
mlir::Block *entryBlock = genEntryBlock (firOpBuilder, args, region);
1458
1490
bindEntryBlockArgs (converter, targetOp, args);
1459
- if (! hostEvalInfo. empty ( ))
1460
- hostEvalInfo. back (). bindOperands (argIface.getHostEvalBlockArgs ());
1491
+ if (HostEvalInfo * hostEvalInfo = getHostEvalInfoStackTop (converter ))
1492
+ hostEvalInfo-> bindOperands (argIface.getHostEvalBlockArgs ());
1461
1493
1462
1494
// Check if cloning the bounds introduced any dependency on the outer region.
1463
1495
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1696,7 +1728,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
1696
1728
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
1697
1729
ClauseProcessor cp (converter, semaCtx, clauses);
1698
1730
1699
- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps, iv))
1731
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1732
+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps, iv))
1700
1733
cp.processCollapse (loc, eval, clauseOps, iv);
1701
1734
1702
1735
clauseOps.loopInclusive = converter.getFirOpBuilder ().getUnitAttr ();
@@ -1741,7 +1774,8 @@ static void genParallelClauses(
1741
1774
cp.processAllocate (clauseOps);
1742
1775
cp.processIf (llvm::omp::Directive::OMPD_parallel, clauseOps);
1743
1776
1744
- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps))
1777
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1778
+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps))
1745
1779
cp.processNumThreads (stmtCtx, clauseOps);
1746
1780
1747
1781
cp.processProcBind (clauseOps);
@@ -1812,10 +1846,10 @@ static void genTargetClauses(
1812
1846
cp.processDepend (symTable, stmtCtx, clauseOps);
1813
1847
cp.processDevice (stmtCtx, clauseOps);
1814
1848
cp.processHasDeviceAddr (stmtCtx, clauseOps, hasDeviceAddrSyms);
1815
- if (! hostEvalInfo. empty ( )) {
1849
+ if (HostEvalInfo * hostEvalInfo = getHostEvalInfoStackTop (converter )) {
1816
1850
// Only process host_eval if compiling for the host device.
1817
1851
processHostEvalClauses (converter, semaCtx, stmtCtx, eval, loc);
1818
- hostEvalInfo. back (). collectValues (clauseOps.hostEvalVars );
1852
+ hostEvalInfo-> collectValues (clauseOps.hostEvalVars );
1819
1853
}
1820
1854
cp.processIf (llvm::omp::Directive::OMPD_target, clauseOps);
1821
1855
cp.processIsDevicePtr (clauseOps, isDevicePtrSyms);
@@ -1952,7 +1986,8 @@ static void genTeamsClauses(
1952
1986
cp.processAllocate (clauseOps);
1953
1987
cp.processIf (llvm::omp::Directive::OMPD_teams, clauseOps);
1954
1988
1955
- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps)) {
1989
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1990
+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps)) {
1956
1991
cp.processNumTeams (stmtCtx, clauseOps);
1957
1992
cp.processThreadLimit (stmtCtx, clauseOps);
1958
1993
}
@@ -2204,19 +2239,18 @@ genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2204
2239
converter.getCurrentLocation (), clauseOps);
2205
2240
}
2206
2241
2207
- // / This breaks the normal prototype of the gen*Op functions: adding the
2208
- // / sectionBlocks argument so that the enclosed section constructs can be
2209
- // / lowered here with correct reduction symbol remapping.
2210
2242
static mlir::omp::SectionsOp
2211
2243
genSectionsOp (lower::AbstractConverter &converter, lower::SymMap &symTable,
2212
2244
semantics::SemanticsContext &semaCtx,
2213
2245
lower::pft::Evaluation &eval, mlir::Location loc,
2214
2246
const ConstructQueue &queue,
2215
2247
ConstructQueue::const_iterator item) {
2216
- assert (!sectionsStack.empty ());
2248
+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
2249
+ getSectionsConstructStackTop (converter);
2250
+ assert (sectionsConstruct && " Missing additional parsing information" );
2251
+
2217
2252
const auto §ionBlocks =
2218
- std::get<parser::OmpSectionBlocks>(sectionsStack.back ()->t );
2219
- sectionsStack.pop_back ();
2253
+ std::get<parser::OmpSectionBlocks>(sectionsConstruct->t );
2220
2254
mlir::omp::SectionsOperands clauseOps;
2221
2255
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
2222
2256
genSectionsClauses (converter, semaCtx, item->clauses , loc, clauseOps,
@@ -2370,7 +2404,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2370
2404
2371
2405
// Introduce a new host_eval information structure for this target region.
2372
2406
if (!isTargetDevice)
2373
- hostEvalInfo. emplace_back ();
2407
+ converter. getStateStack (). stackPush <HostEvalInfoStackFrame> ();
2374
2408
2375
2409
mlir::omp::TargetOperands clauseOps;
2376
2410
DefaultMapsTy defaultMaps;
@@ -2497,7 +2531,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2497
2531
2498
2532
// Remove the host_eval information structure created for this target region.
2499
2533
if (!isTargetDevice)
2500
- hostEvalInfo. pop_back ();
2534
+ converter. getStateStack (). stackPop ();
2501
2535
return targetOp;
2502
2536
}
2503
2537
@@ -3771,7 +3805,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
3771
3805
buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
3772
3806
eval, source, directive, clauses)};
3773
3807
3774
- sectionsStack.push_back (§ionsConstruct);
3808
+ mlir::SaveStateStack<SectionsConstructStackFrame> saveStateStack{
3809
+ converter.getStateStack (), sectionsConstruct};
3775
3810
genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
3776
3811
queue.begin ());
3777
3812
}
0 commit comments