25
25
#include " flang/Optimizer/Builder/HLFIRTools.h"
26
26
#include " flang/Optimizer/Builder/IntrinsicCall.h"
27
27
#include " flang/Optimizer/Builder/Todo.h"
28
+ #include " flang/Parser/parse-tree-visitor.h"
28
29
#include " flang/Parser/parse-tree.h"
29
30
#include " flang/Semantics/expression.h"
30
31
#include " flang/Semantics/scope.h"
31
32
#include " flang/Semantics/tools.h"
33
+ #include " mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
32
34
#include " llvm/Frontend/OpenACC/ACC.h.inc"
33
35
34
36
// Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1381
1383
Fortran::lower::pft::Evaluation &eval,
1382
1384
const llvm::SmallVectorImpl<mlir::Value> &operands,
1383
1385
const llvm::SmallVectorImpl<int32_t > &operandSegments,
1384
- bool outerCombined = false ) {
1385
- llvm::ArrayRef<mlir::Type> argTy;
1386
- Op op = builder.create <Op>(loc, argTy, operands);
1386
+ bool outerCombined = false ,
1387
+ llvm::SmallVector<mlir::Type> retTy = {},
1388
+ mlir::Value yieldValue = {}) {
1389
+ Op op = builder.create <Op>(loc, retTy, operands);
1387
1390
builder.createBlock (&op.getRegion ());
1388
1391
mlir::Block &block = op.getRegion ().back ();
1389
1392
builder.setInsertionPointToStart (&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
1401
1404
mlir::acc::YieldOp>(
1402
1405
builder, eval.getNestedEvaluations ());
1403
1406
1404
- builder.create <Terminator>(loc);
1407
+ if (yieldValue) {
1408
+ if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
1409
+ Terminator yieldOp = builder.create <Terminator>(loc, yieldValue);
1410
+ yieldValue.getDefiningOp ()->moveBefore (yieldOp);
1411
+ } else {
1412
+ builder.create <Terminator>(loc);
1413
+ }
1414
+ } else {
1415
+ builder.create <Terminator>(loc);
1416
+ }
1405
1417
builder.setInsertionPointToStart (&block);
1406
1418
return op;
1407
1419
}
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1494
1506
Fortran::lower::pft::Evaluation &eval,
1495
1507
Fortran::semantics::SemanticsContext &semanticsContext,
1496
1508
Fortran::lower::StatementContext &stmtCtx,
1497
- const Fortran::parser::AccClauseList &accClauseList) {
1509
+ const Fortran::parser::AccClauseList &accClauseList,
1510
+ bool needEarlyReturnHandling = false ) {
1498
1511
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
1499
1512
1500
1513
mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1599
1612
addOperands (operands, operandSegments, privateOperands);
1600
1613
addOperands (operands, operandSegments, reductionOperands);
1601
1614
1615
+ llvm::SmallVector<mlir::Type> retTy;
1616
+ mlir::Value yieldValue;
1617
+ if (needEarlyReturnHandling) {
1618
+ mlir::Type i1Ty = builder.getI1Type ();
1619
+ yieldValue = builder.createIntegerConstant (currentLocation, i1Ty, 0 );
1620
+ retTy.push_back (i1Ty);
1621
+ }
1622
+
1602
1623
auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
1603
- builder, currentLocation, eval, operands, operandSegments);
1624
+ builder, currentLocation, eval, operands, operandSegments,
1625
+ /* outerCombined=*/ false , retTy, yieldValue);
1604
1626
1605
1627
if (hasGang)
1606
1628
loopOp.setHasGangAttr (builder.getUnitAttr ());
@@ -1647,26 +1669,48 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1647
1669
return loopOp;
1648
1670
}
1649
1671
1650
- static void genACC (Fortran::lower::AbstractConverter &converter,
1651
- Fortran::semantics::SemanticsContext &semanticsContext,
1652
- Fortran::lower::pft::Evaluation &eval,
1653
- const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
1672
+ static bool hasEarlyReturn (Fortran::lower::pft::Evaluation &eval) {
1673
+ bool hasReturnStmt = false ;
1674
+ for (auto &e : eval.getNestedEvaluations ()) {
1675
+ e.visit (Fortran::common::visitors{
1676
+ [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true ; },
1677
+ [&](const auto &s) {},
1678
+ });
1679
+ if (e.hasNestedEvaluations ())
1680
+ hasReturnStmt = hasEarlyReturn (e);
1681
+ }
1682
+ return hasReturnStmt;
1683
+ }
1684
+
1685
+ static mlir::Value
1686
+ genACC (Fortran::lower::AbstractConverter &converter,
1687
+ Fortran::semantics::SemanticsContext &semanticsContext,
1688
+ Fortran::lower::pft::Evaluation &eval,
1689
+ const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
1654
1690
1655
1691
const auto &beginLoopDirective =
1656
1692
std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t );
1657
1693
const auto &loopDirective =
1658
1694
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t );
1659
1695
1696
+ bool needEarlyExitHandling = false ;
1697
+ if (eval.lowerAsUnstructured ())
1698
+ needEarlyExitHandling = hasEarlyReturn (eval);
1699
+
1660
1700
mlir::Location currentLocation =
1661
1701
converter.genLocation (beginLoopDirective.source );
1662
1702
Fortran::lower::StatementContext stmtCtx;
1663
1703
1664
1704
if (loopDirective.v == llvm::acc::ACCD_loop) {
1665
1705
const auto &accClauseList =
1666
1706
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t );
1667
- createLoopOp (converter, currentLocation, eval, semanticsContext, stmtCtx,
1668
- accClauseList);
1707
+ auto loopOp =
1708
+ createLoopOp (converter, currentLocation, eval, semanticsContext,
1709
+ stmtCtx, accClauseList, needEarlyExitHandling);
1710
+ if (needEarlyExitHandling)
1711
+ return loopOp.getResult (0 );
1669
1712
}
1713
+ return mlir::Value{};
1670
1714
}
1671
1715
1672
1716
template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
3431
3475
builder.restoreInsertionPoint (crtPos);
3432
3476
}
3433
3477
3434
- void Fortran::lower::genOpenACCConstruct (
3478
+ mlir::Value Fortran::lower::genOpenACCConstruct (
3435
3479
Fortran::lower::AbstractConverter &converter,
3436
3480
Fortran::semantics::SemanticsContext &semanticsContext,
3437
3481
Fortran::lower::pft::Evaluation &eval,
3438
3482
const Fortran::parser::OpenACCConstruct &accConstruct) {
3439
3483
3484
+ mlir::Value exitCond;
3440
3485
std::visit (
3441
3486
common::visitors{
3442
3487
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
3447
3492
genACC (converter, semanticsContext, eval, combinedConstruct);
3448
3493
},
3449
3494
[&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
3450
- genACC (converter, semanticsContext, eval, loopConstruct);
3495
+ exitCond = genACC (converter, semanticsContext, eval, loopConstruct);
3451
3496
},
3452
3497
[&](const Fortran::parser::OpenACCStandaloneConstruct
3453
3498
&standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
3467
3512
},
3468
3513
},
3469
3514
accConstruct.u );
3515
+ return exitCond;
3470
3516
}
3471
3517
3472
3518
void Fortran::lower::genOpenACCDeclarativeConstruct (
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
3560
3606
else
3561
3607
builder.create <mlir::acc::TerminatorOp>(loc);
3562
3608
}
3609
+
3610
+ bool Fortran::lower::isInOpenACCLoop (fir::FirOpBuilder &builder) {
3611
+ if (builder.getBlock ()->getParent ()->getParentOfType <mlir::acc::LoopOp>())
3612
+ return true ;
3613
+ return false ;
3614
+ }
3615
+
3616
+ void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside (
3617
+ fir::FirOpBuilder &builder) {
3618
+ if (auto loopOp =
3619
+ builder.getBlock ()->getParent ()->getParentOfType <mlir::acc::LoopOp>())
3620
+ builder.setInsertionPointAfter (loopOp);
3621
+ }
3622
+
3623
+ void Fortran::lower::genEarlyReturnInOpenACCLoop (fir::FirOpBuilder &builder,
3624
+ mlir::Location loc) {
3625
+ mlir::Value yieldValue =
3626
+ builder.createIntegerConstant (loc, builder.getI1Type (), 1 );
3627
+ builder.create <mlir::acc::YieldOp>(loc, yieldValue);
3628
+ }
0 commit comments