@@ -366,10 +366,29 @@ getDeclareTargetFunctionDevice(
366
366
return std::nullopt;
367
367
}
368
368
369
- static llvm::SmallVector<const Fortran::semantics::Symbol *>
369
+ // / Set up the entry block of the given `omp.loop_nest` operation, adding a
370
+ // / block argument for each loop induction variable and allocating and
371
+ // / initializing a private value to hold each of them.
372
+ // /
373
+ // / This function can also bind the symbols of any variables that should match
374
+ // / block arguments on parent loop wrapper operations attached to the same
375
+ // / loop. This allows the introduction of any necessary `hlfir.declare`
376
+ // / operations inside of the entry block of the `omp.loop_nest` operation and
377
+ // / not directly under any of the wrappers, which would invalidate them.
378
+ // /
379
+ // / \param [in] op - the loop nest operation.
380
+ // / \param [in] converter - PFT to MLIR conversion interface.
381
+ // / \param [in] loc - location.
382
+ // / \param [in] args - symbols of induction variables.
383
+ // / \param [in] wrapperSyms - symbols of variables to be mapped to loop wrapper
384
+ // / entry block arguments.
385
+ // / \param [in] wrapperArgs - entry block arguments of parent loop wrappers.
386
+ static void
370
387
genLoopVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
371
388
mlir::Location &loc,
372
- llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
389
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
390
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
391
+ llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
373
392
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
374
393
auto ®ion = op->getRegion (0 );
375
394
@@ -380,6 +399,12 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
380
399
llvm::SmallVector<mlir::Type> tiv (args.size (), loopVarType);
381
400
llvm::SmallVector<mlir::Location> locs (args.size (), loc);
382
401
firOpBuilder.createBlock (®ion, {}, tiv, locs);
402
+
403
+ // Bind the entry block arguments of parent wrappers to the corresponding
404
+ // symbols.
405
+ for (auto [arg, prv] : llvm::zip_equal (wrapperSyms, wrapperArgs))
406
+ converter.bindSymbol (*arg, prv);
407
+
383
408
// The argument is not currently in memory, so make a temporary for the
384
409
// argument, and store it there, then bind that location to the argument.
385
410
mlir::Operation *storeOp = nullptr ;
@@ -389,7 +414,6 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
389
414
createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
390
415
}
391
416
firOpBuilder.setInsertionPointAfter (storeOp);
392
- return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
393
417
}
394
418
395
419
static void genReductionVars (
@@ -410,58 +434,6 @@ static void genReductionVars(
410
434
}
411
435
}
412
436
413
- static llvm::SmallVector<const Fortran::semantics::Symbol *>
414
- genLoopAndReductionVars (
415
- mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
416
- mlir::Location &loc,
417
- llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
418
- llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
419
- llvm::ArrayRef<mlir::Type> reductionTypes) {
420
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
421
-
422
- llvm::SmallVector<mlir::Type> blockArgTypes;
423
- llvm::SmallVector<mlir::Location> blockArgLocs;
424
- blockArgTypes.reserve (loopArgs.size () + reductionArgs.size ());
425
- blockArgLocs.reserve (blockArgTypes.size ());
426
- mlir::Block *entryBlock;
427
-
428
- if (loopArgs.size ()) {
429
- std::size_t loopVarTypeSize = 0 ;
430
- for (const Fortran::semantics::Symbol *arg : loopArgs)
431
- loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
432
- mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
433
- std::fill_n (std::back_inserter (blockArgTypes), loopArgs.size (),
434
- loopVarType);
435
- std::fill_n (std::back_inserter (blockArgLocs), loopArgs.size (), loc);
436
- }
437
- if (reductionArgs.size ()) {
438
- llvm::copy (reductionTypes, std::back_inserter (blockArgTypes));
439
- std::fill_n (std::back_inserter (blockArgLocs), reductionArgs.size (), loc);
440
- }
441
- entryBlock = firOpBuilder.createBlock (&op->getRegion (0 ), {}, blockArgTypes,
442
- blockArgLocs);
443
- // The argument is not currently in memory, so make a temporary for the
444
- // argument, and store it there, then bind that location to the argument.
445
- if (loopArgs.size ()) {
446
- mlir::Operation *storeOp = nullptr ;
447
- for (auto [argIndex, argSymbol] : llvm::enumerate (loopArgs)) {
448
- mlir::Value indexVal =
449
- fir::getBase (op->getRegion (0 ).front ().getArgument (argIndex));
450
- storeOp =
451
- createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
452
- }
453
- firOpBuilder.setInsertionPointAfter (storeOp);
454
- }
455
- // Bind the reduction arguments to their block arguments
456
- for (auto [arg, prv] : llvm::zip_equal (
457
- reductionArgs,
458
- llvm::drop_begin (entryBlock->getArguments (), loopArgs.size ()))) {
459
- converter.bindSymbol (*arg, prv);
460
- }
461
-
462
- return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
463
- }
464
-
465
437
static void
466
438
markDeclareTarget (mlir::Operation *op,
467
439
Fortran::lower::AbstractConverter &converter,
@@ -1270,20 +1242,16 @@ static void genTeamsClauses(Fortran::lower::AbstractConverter &converter,
1270
1242
static void genWsloopClauses (
1271
1243
Fortran::lower::AbstractConverter &converter,
1272
1244
Fortran::semantics::SemanticsContext &semaCtx,
1273
- Fortran::lower::StatementContext &stmtCtx,
1274
- Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses,
1245
+ Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses,
1275
1246
mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps,
1276
- llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv,
1277
1247
llvm::SmallVectorImpl<mlir::Type> &reductionTypes,
1278
1248
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) {
1279
1249
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1280
1250
ClauseProcessor cp (converter, semaCtx, clauses);
1281
- cp.processCollapse (loc, eval, clauseOps, iv);
1282
1251
cp.processNowait (clauseOps);
1283
1252
cp.processOrdered (clauseOps);
1284
1253
cp.processReduction (loc, clauseOps, &reductionTypes, &reductionSyms);
1285
1254
cp.processSchedule (stmtCtx, clauseOps);
1286
- clauseOps.loopInclusiveAttr = firOpBuilder.getUnitAttr ();
1287
1255
// TODO Support delayed privatization.
1288
1256
1289
1257
if (ReductionProcessor::doReductionByRef (clauseOps.reductionVars ))
@@ -1526,7 +1494,8 @@ genSimdOp(Fortran::lower::AbstractConverter &converter,
1526
1494
auto *nestedEval = getCollapsedLoopEval (eval, getCollapseValue (clauses));
1527
1495
1528
1496
auto ivCallback = [&](mlir::Operation *op) {
1529
- return genLoopVars (op, converter, loc, iv);
1497
+ genLoopVars (op, converter, loc, iv);
1498
+ return iv;
1530
1499
};
1531
1500
1532
1501
createBodyOfOp (*loopOp,
@@ -1801,32 +1770,48 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter,
1801
1770
Fortran::semantics::SemanticsContext &semaCtx,
1802
1771
Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
1803
1772
const List<Clause> &clauses) {
1773
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
1804
1774
DataSharingProcessor dsp (converter, semaCtx, clauses, eval);
1805
1775
dsp.processStep1 ();
1806
1776
1807
1777
Fortran::lower::StatementContext stmtCtx;
1808
- mlir::omp::WsloopClauseOps clauseOps;
1778
+ mlir::omp::LoopNestClauseOps loopClauseOps;
1779
+ mlir::omp::WsloopClauseOps wsClauseOps;
1809
1780
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
1810
1781
llvm::SmallVector<mlir::Type> reductionTypes;
1811
1782
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms;
1812
- genWsloopClauses (converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps,
1813
- iv, reductionTypes, reductionSyms);
1783
+ genLoopNestClauses (converter, semaCtx, eval, clauses, loc, loopClauseOps, iv);
1784
+ genWsloopClauses (converter, semaCtx, stmtCtx, clauses, loc, wsClauseOps,
1785
+ reductionTypes, reductionSyms);
1786
+
1787
+ // Create omp.wsloop wrapper and populate entry block arguments with reduction
1788
+ // variables.
1789
+ auto wsloopOp = firOpBuilder.create <mlir::omp::WsloopOp>(loc, wsClauseOps);
1790
+ llvm::SmallVector<mlir::Location> reductionLocs (reductionSyms.size (), loc);
1791
+ mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock (
1792
+ &wsloopOp.getRegion (), {}, reductionTypes, reductionLocs);
1793
+ firOpBuilder.setInsertionPoint (
1794
+ Fortran::lower::genOpenMPTerminator (firOpBuilder, wsloopOp, loc));
1795
+
1796
+ // Create nested omp.loop_nest and fill body with loop contents.
1797
+ auto loopOp = firOpBuilder.create <mlir::omp::LoopNestOp>(loc, loopClauseOps);
1814
1798
1815
1799
auto *nestedEval = getCollapsedLoopEval (eval, getCollapseValue (clauses));
1816
1800
1817
1801
auto ivCallback = [&](mlir::Operation *op) {
1818
- return genLoopAndReductionVars (op, converter, loc, iv, reductionSyms,
1819
- reductionTypes);
1802
+ genLoopVars (op, converter, loc, iv, reductionSyms,
1803
+ wsloopEntryBlock->getArguments ());
1804
+ return iv;
1820
1805
};
1821
1806
1822
- return genOpWithBody<mlir::omp::WsloopOp>(
1823
- OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval,
1824
- llvm::omp::Directive::OMPD_do)
1825
- .setClauses (&clauses)
1826
- .setDataSharingProcessor (&dsp)
1827
- .setReductions (&reductionSyms, &reductionTypes)
1828
- .setGenRegionEntryCb (ivCallback),
1829
- clauseOps) ;
1807
+ createBodyOfOp (*loopOp,
1808
+ OpWithBodyGenInfo (converter, semaCtx, loc, *nestedEval,
1809
+ llvm::omp::Directive::OMPD_do)
1810
+ .setClauses (&clauses)
1811
+ .setDataSharingProcessor (&dsp)
1812
+ .setReductions (&reductionSyms, &reductionTypes)
1813
+ .setGenRegionEntryCb (ivCallback));
1814
+ return wsloopOp ;
1830
1815
}
1831
1816
1832
1817
// ===----------------------------------------------------------------------===//
@@ -2482,8 +2467,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
2482
2467
mlir::Operation *Fortran::lower::genOpenMPTerminator (fir::FirOpBuilder &builder,
2483
2468
mlir::Operation *op,
2484
2469
mlir::Location loc) {
2485
- if (mlir::isa<mlir::omp::WsloopOp , mlir::omp::DeclareReductionOp,
2486
- mlir::omp::AtomicUpdateOp, mlir::omp:: LoopNestOp>(op))
2470
+ if (mlir::isa<mlir::omp::AtomicUpdateOp , mlir::omp::DeclareReductionOp,
2471
+ mlir::omp::LoopNestOp>(op))
2487
2472
return builder.create <mlir::omp::YieldOp>(loc);
2488
2473
return builder.create <mlir::omp::TerminatorOp>(loc);
2489
2474
}
0 commit comments