38
38
#include " flang/Support/OpenMP-utils.h"
39
39
#include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
40
40
#include " mlir/Dialect/OpenMP/OpenMPDialect.h"
41
+ #include " mlir/Support/StateStack.h"
41
42
#include " mlir/Transforms/RegionUtils.h"
42
43
#include " llvm/ADT/STLExtras.h"
43
44
#include " llvm/Frontend/OpenMP/OMPConstants.h"
@@ -200,9 +201,41 @@ class HostEvalInfo {
200
201
// / the handling of the outer region by keeping a stack of information
201
202
// / structures, but it will probably still require some further work to support
202
203
// / reverse offloading.
203
- static llvm::SmallVector<HostEvalInfo, 0 > hostEvalInfo;
204
- static llvm::SmallVector<const parser::OpenMPSectionsConstruct *, 0 >
205
- sectionsStack;
204
+ class HostEvalInfoStackFrame
205
+ : public mlir::StateStackFrameBase<HostEvalInfoStackFrame> {
206
+ public:
207
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (HostEvalInfoStackFrame)
208
+
209
+ HostEvalInfo info;
210
+ };
211
+
212
+ static HostEvalInfo *
213
+ getHostEvalInfoStackTop (lower::AbstractConverter &converter) {
214
+ HostEvalInfoStackFrame *frame =
215
+ converter.getStateStack ().getStackTop <HostEvalInfoStackFrame>();
216
+ return frame ? &frame->info : nullptr ;
217
+ }
218
+
219
+ // / Stack frame for storing the OpenMPSectionsConstruct currently being
220
+ // / processed so that it can be refered to when lowering the construct.
221
+ class SectionsConstructStackFrame
222
+ : public mlir::StateStackFrameBase<SectionsConstructStackFrame> {
223
+ public:
224
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (SectionsConstructStackFrame)
225
+
226
+ explicit SectionsConstructStackFrame (
227
+ const parser::OpenMPSectionsConstruct §ionsConstruct)
228
+ : sectionsConstruct{sectionsConstruct} {}
229
+
230
+ const parser::OpenMPSectionsConstruct §ionsConstruct;
231
+ };
232
+
233
+ static const parser::OpenMPSectionsConstruct *
234
+ getSectionsConstructStackTop (lower::AbstractConverter &converter) {
235
+ SectionsConstructStackFrame *frame =
236
+ converter.getStateStack ().getStackTop <SectionsConstructStackFrame>();
237
+ return frame ? &frame->sectionsConstruct : nullptr ;
238
+ }
206
239
207
240
// / Bind symbols to their corresponding entry block arguments.
208
241
// /
@@ -537,54 +570,55 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
537
570
if (!ompEval)
538
571
return ;
539
572
540
- HostEvalInfo &hostInfo = hostEvalInfo.back ();
573
+ HostEvalInfo *hostInfo = getHostEvalInfoStackTop (converter);
574
+ assert (hostInfo && " expected HOST_EVAL info structure" );
541
575
542
576
switch (extractOmpDirective (*ompEval)) {
543
577
case OMPD_teams_distribute_parallel_do:
544
578
case OMPD_teams_distribute_parallel_do_simd:
545
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
579
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
546
580
[[fallthrough]];
547
581
case OMPD_target_teams_distribute_parallel_do:
548
582
case OMPD_target_teams_distribute_parallel_do_simd:
549
- cp.processNumTeams (stmtCtx, hostInfo. ops );
583
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
550
584
[[fallthrough]];
551
585
case OMPD_distribute_parallel_do:
552
586
case OMPD_distribute_parallel_do_simd:
553
- cp.processNumThreads (stmtCtx, hostInfo. ops );
587
+ cp.processNumThreads (stmtCtx, hostInfo-> ops );
554
588
[[fallthrough]];
555
589
case OMPD_distribute:
556
590
case OMPD_distribute_simd:
557
- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
591
+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
558
592
break ;
559
593
560
594
case OMPD_teams:
561
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
595
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
562
596
[[fallthrough]];
563
597
case OMPD_target_teams:
564
- cp.processNumTeams (stmtCtx, hostInfo. ops );
598
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
565
599
processSingleNestedIf ([](Directive nestedDir) {
566
600
return topDistributeSet.test (nestedDir) || topLoopSet.test (nestedDir);
567
601
});
568
602
break ;
569
603
570
604
case OMPD_teams_distribute:
571
605
case OMPD_teams_distribute_simd:
572
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
606
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
573
607
[[fallthrough]];
574
608
case OMPD_target_teams_distribute:
575
609
case OMPD_target_teams_distribute_simd:
576
- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
577
- cp.processNumTeams (stmtCtx, hostInfo. ops );
610
+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
611
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
578
612
break ;
579
613
580
614
case OMPD_teams_loop:
581
- cp.processThreadLimit (stmtCtx, hostInfo. ops );
615
+ cp.processThreadLimit (stmtCtx, hostInfo-> ops );
582
616
[[fallthrough]];
583
617
case OMPD_target_teams_loop:
584
- cp.processNumTeams (stmtCtx, hostInfo. ops );
618
+ cp.processNumTeams (stmtCtx, hostInfo-> ops );
585
619
[[fallthrough]];
586
620
case OMPD_loop:
587
- cp.processCollapse (loc, eval, hostInfo. ops , hostInfo. iv );
621
+ cp.processCollapse (loc, eval, hostInfo-> ops , hostInfo-> iv );
588
622
break ;
589
623
590
624
// Standalone 'target' case.
@@ -598,8 +632,6 @@ static void processHostEvalClauses(lower::AbstractConverter &converter,
598
632
}
599
633
};
600
634
601
- assert (!hostEvalInfo.empty () && " expected HOST_EVAL info structure" );
602
-
603
635
const auto *ompEval = eval.getIf <parser::OpenMPConstruct>();
604
636
assert (ompEval &&
605
637
llvm::omp::allTargetSet.test (extractOmpDirective (*ompEval)) &&
@@ -1468,8 +1500,8 @@ static void genBodyOfTargetOp(
1468
1500
mlir::Region ®ion = targetOp.getRegion ();
1469
1501
mlir::Block *entryBlock = genEntryBlock (firOpBuilder, args, region);
1470
1502
bindEntryBlockArgs (converter, targetOp, args);
1471
- if (! hostEvalInfo. empty ( ))
1472
- hostEvalInfo. back (). bindOperands (argIface.getHostEvalBlockArgs ());
1503
+ if (HostEvalInfo * hostEvalInfo = getHostEvalInfoStackTop (converter ))
1504
+ hostEvalInfo-> bindOperands (argIface.getHostEvalBlockArgs ());
1473
1505
1474
1506
// Check if cloning the bounds introduced any dependency on the outer region.
1475
1507
// If so, then either clone them as well if they are MemoryEffectFree, or else
@@ -1708,7 +1740,8 @@ genLoopNestClauses(lower::AbstractConverter &converter,
1708
1740
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
1709
1741
ClauseProcessor cp (converter, semaCtx, clauses);
1710
1742
1711
- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps, iv))
1743
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1744
+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps, iv))
1712
1745
cp.processCollapse (loc, eval, clauseOps, iv);
1713
1746
1714
1747
clauseOps.loopInclusive = converter.getFirOpBuilder ().getUnitAttr ();
@@ -1753,7 +1786,8 @@ static void genParallelClauses(
1753
1786
cp.processAllocate (clauseOps);
1754
1787
cp.processIf (llvm::omp::Directive::OMPD_parallel, clauseOps);
1755
1788
1756
- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps))
1789
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1790
+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps))
1757
1791
cp.processNumThreads (stmtCtx, clauseOps);
1758
1792
1759
1793
cp.processProcBind (clauseOps);
@@ -1818,16 +1852,17 @@ static void genTargetClauses(
1818
1852
llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms,
1819
1853
llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms,
1820
1854
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) {
1855
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
1821
1856
ClauseProcessor cp (converter, semaCtx, clauses);
1822
1857
cp.processBare (clauseOps);
1823
1858
cp.processDefaultMap (stmtCtx, defaultMaps);
1824
1859
cp.processDepend (symTable, stmtCtx, clauseOps);
1825
1860
cp.processDevice (stmtCtx, clauseOps);
1826
1861
cp.processHasDeviceAddr (stmtCtx, clauseOps, hasDeviceAddrSyms);
1827
- if (! hostEvalInfo. empty () ) {
1862
+ if (hostEvalInfo) {
1828
1863
// Only process host_eval if compiling for the host device.
1829
1864
processHostEvalClauses (converter, semaCtx, stmtCtx, eval, loc);
1830
- hostEvalInfo. back (). collectValues (clauseOps.hostEvalVars );
1865
+ hostEvalInfo-> collectValues (clauseOps.hostEvalVars );
1831
1866
}
1832
1867
cp.processIf (llvm::omp::Directive::OMPD_target, clauseOps);
1833
1868
cp.processIsDevicePtr (clauseOps, isDevicePtrSyms);
@@ -1963,7 +1998,8 @@ static void genTeamsClauses(
1963
1998
cp.processAllocate (clauseOps);
1964
1999
cp.processIf (llvm::omp::Directive::OMPD_teams, clauseOps);
1965
2000
1966
- if (hostEvalInfo.empty () || !hostEvalInfo.back ().apply (clauseOps)) {
2001
+ HostEvalInfo *hostEvalInfo = getHostEvalInfoStackTop (converter);
2002
+ if (!hostEvalInfo || !hostEvalInfo->apply (clauseOps)) {
1967
2003
cp.processNumTeams (stmtCtx, clauseOps);
1968
2004
cp.processThreadLimit (stmtCtx, clauseOps);
1969
2005
}
@@ -2224,10 +2260,13 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2224
2260
lower::pft::Evaluation &eval, mlir::Location loc,
2225
2261
const ConstructQueue &queue,
2226
2262
ConstructQueue::const_iterator item) {
2227
- assert (!sectionsStack.empty ());
2263
+ const parser::OpenMPSectionsConstruct *sectionsConstruct =
2264
+ getSectionsConstructStackTop (converter);
2265
+ assert (sectionsConstruct);
2266
+
2228
2267
const auto §ionBlocks =
2229
- std::get<parser::OmpSectionBlocks>(sectionsStack. back () ->t );
2230
- sectionsStack. pop_back ();
2268
+ std::get<parser::OmpSectionBlocks>(sectionsConstruct ->t );
2269
+ converter. getStateStack (). stackPop ();
2231
2270
mlir::omp::SectionsOperands clauseOps;
2232
2271
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
2233
2272
genSectionsClauses (converter, semaCtx, item->clauses , loc, clauseOps,
@@ -2381,7 +2420,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2381
2420
2382
2421
// Introduce a new host_eval information structure for this target region.
2383
2422
if (!isTargetDevice)
2384
- hostEvalInfo. emplace_back ();
2423
+ converter. getStateStack (). stackPush <HostEvalInfoStackFrame> ();
2385
2424
2386
2425
mlir::omp::TargetOperands clauseOps;
2387
2426
DefaultMapsTy defaultMaps;
@@ -2508,7 +2547,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2508
2547
2509
2548
// Remove the host_eval information structure created for this target region.
2510
2549
if (!isTargetDevice)
2511
- hostEvalInfo. pop_back ();
2550
+ converter. getStateStack (). stackPop ();
2512
2551
return targetOp;
2513
2552
}
2514
2553
@@ -4235,7 +4274,8 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
4235
4274
buildConstructQueue (converter.getFirOpBuilder ().getModule (), semaCtx,
4236
4275
eval, source, directive, clauses)};
4237
4276
4238
- sectionsStack.push_back (§ionsConstruct);
4277
+ converter.getStateStack ().stackPush <SectionsConstructStackFrame>(
4278
+ sectionsConstruct);
4239
4279
genOMPDispatch (converter, symTable, semaCtx, eval, currentLocation, queue,
4240
4280
queue.begin ());
4241
4281
}
0 commit comments