Skip to content

Commit a9a5af8

Browse files
authored
[flang][openacc] Support early return in acc.loop (llvm#73841)
Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the `func.return` operation cannot be directly used inside the region. An early return is materialized by an `acc.yield` operation returning a `true` value. The standard end of the `acc.loop` region yield a `false` value in this case. A conditional branch operation on the `acc.loop` result will branch to the `finalBlock` or just to the continue block whether an early exit was produce in the acc.loop.
1 parent 6fb7c2d commit a9a5af8

File tree

4 files changed

+148
-19
lines changed

4 files changed

+148
-19
lines changed

flang/include/flang/Lower/OpenACC.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
6464
static constexpr llvm::StringRef declarePostDeallocSuffix =
6565
"_acc_declare_update_desc_post_dealloc";
6666

67-
void genOpenACCConstruct(AbstractConverter &,
68-
Fortran::semantics::SemanticsContext &,
69-
pft::Evaluation &, const parser::OpenACCConstruct &);
67+
mlir::Value genOpenACCConstruct(AbstractConverter &,
68+
Fortran::semantics::SemanticsContext &,
69+
pft::Evaluation &,
70+
const parser::OpenACCConstruct &);
7071
void genOpenACCDeclarativeConstruct(AbstractConverter &,
7172
Fortran::semantics::SemanticsContext &,
7273
StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
112113
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
113114
mlir::Location);
114115

116+
bool isInOpenACCLoop(fir::FirOpBuilder &);
117+
118+
void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
119+
120+
void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
121+
115122
} // namespace lower
116123
} // namespace Fortran
117124

flang/lib/Lower/Bridge.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
23822382
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
23832383
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
23842384
localSymbols.pushScope();
2385-
genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
2385+
mlir::Value exitCond = genOpenACCConstruct(
2386+
*this, bridge.getSemanticsContext(), getEval(), acc);
23862387
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
23872388
genFIR(e);
23882389
localSymbols.popScope();
23892390
builder->restoreInsertionPoint(insertPt);
2391+
2392+
const Fortran::parser::OpenACCLoopConstruct *accLoop =
2393+
std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
2394+
if (accLoop && exitCond) {
2395+
Fortran::lower::pft::FunctionLikeUnit *funit =
2396+
getEval().getOwningProcedure();
2397+
assert(funit && "not inside main program, function or subroutine");
2398+
mlir::Block *continueBlock =
2399+
builder->getBlock()->splitBlock(builder->getBlock()->end());
2400+
builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
2401+
funit->finalBlock, continueBlock);
2402+
builder->setInsertionPointToEnd(continueBlock);
2403+
}
23902404
}
23912405

23922406
void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
40914105
// Branch to the last block of the SUBROUTINE, which has the actual return.
40924106
if (!funit->finalBlock) {
40934107
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
4108+
Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
40944109
funit->finalBlock = builder->createBlock(&builder->getRegion());
40954110
builder->restoreInsertionPoint(insPt);
40964111
}
4097-
builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
4112+
4113+
if (Fortran::lower::isInOpenACCLoop(*builder))
4114+
Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
4115+
else
4116+
builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
40984117
}
40994118

41004119
void genFIR(const Fortran::parser::CycleStmt &) {

flang/lib/Lower/OpenACC.cpp

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
#include "flang/Optimizer/Builder/HLFIRTools.h"
2626
#include "flang/Optimizer/Builder/IntrinsicCall.h"
2727
#include "flang/Optimizer/Builder/Todo.h"
28+
#include "flang/Parser/parse-tree-visitor.h"
2829
#include "flang/Parser/parse-tree.h"
2930
#include "flang/Semantics/expression.h"
3031
#include "flang/Semantics/scope.h"
3132
#include "flang/Semantics/tools.h"
33+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3234
#include "llvm/Frontend/OpenACC/ACC.h.inc"
3335

3436
// Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
13811383
Fortran::lower::pft::Evaluation &eval,
13821384
const llvm::SmallVectorImpl<mlir::Value> &operands,
13831385
const llvm::SmallVectorImpl<int32_t> &operandSegments,
1384-
bool outerCombined = false) {
1385-
llvm::ArrayRef<mlir::Type> argTy;
1386-
Op op = builder.create<Op>(loc, argTy, operands);
1386+
bool outerCombined = false,
1387+
llvm::SmallVector<mlir::Type> retTy = {},
1388+
mlir::Value yieldValue = {}) {
1389+
Op op = builder.create<Op>(loc, retTy, operands);
13871390
builder.createBlock(&op.getRegion());
13881391
mlir::Block &block = op.getRegion().back();
13891392
builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
14011404
mlir::acc::YieldOp>(
14021405
builder, eval.getNestedEvaluations());
14031406

1404-
builder.create<Terminator>(loc);
1407+
if (yieldValue) {
1408+
if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
1409+
Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
1410+
yieldValue.getDefiningOp()->moveBefore(yieldOp);
1411+
} else {
1412+
builder.create<Terminator>(loc);
1413+
}
1414+
} else {
1415+
builder.create<Terminator>(loc);
1416+
}
14051417
builder.setInsertionPointToStart(&block);
14061418
return op;
14071419
}
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
14941506
Fortran::lower::pft::Evaluation &eval,
14951507
Fortran::semantics::SemanticsContext &semanticsContext,
14961508
Fortran::lower::StatementContext &stmtCtx,
1497-
const Fortran::parser::AccClauseList &accClauseList) {
1509+
const Fortran::parser::AccClauseList &accClauseList,
1510+
bool needEarlyReturnHandling = false) {
14981511
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
14991512

15001513
mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
15991612
addOperands(operands, operandSegments, privateOperands);
16001613
addOperands(operands, operandSegments, reductionOperands);
16011614

1615+
llvm::SmallVector<mlir::Type> retTy;
1616+
mlir::Value yieldValue;
1617+
if (needEarlyReturnHandling) {
1618+
mlir::Type i1Ty = builder.getI1Type();
1619+
yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
1620+
retTy.push_back(i1Ty);
1621+
}
1622+
16021623
auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
1603-
builder, currentLocation, eval, operands, operandSegments);
1624+
builder, currentLocation, eval, operands, operandSegments,
1625+
/*outerCombined=*/false, retTy, yieldValue);
16041626

16051627
if (hasGang)
16061628
loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,26 +1669,48 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
16471669
return loopOp;
16481670
}
16491671

1650-
static void genACC(Fortran::lower::AbstractConverter &converter,
1651-
Fortran::semantics::SemanticsContext &semanticsContext,
1652-
Fortran::lower::pft::Evaluation &eval,
1653-
const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
1672+
static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
1673+
bool hasReturnStmt = false;
1674+
for (auto &e : eval.getNestedEvaluations()) {
1675+
e.visit(Fortran::common::visitors{
1676+
[&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
1677+
[&](const auto &s) {},
1678+
});
1679+
if (e.hasNestedEvaluations())
1680+
hasReturnStmt = hasEarlyReturn(e);
1681+
}
1682+
return hasReturnStmt;
1683+
}
1684+
1685+
static mlir::Value
1686+
genACC(Fortran::lower::AbstractConverter &converter,
1687+
Fortran::semantics::SemanticsContext &semanticsContext,
1688+
Fortran::lower::pft::Evaluation &eval,
1689+
const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
16541690

16551691
const auto &beginLoopDirective =
16561692
std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
16571693
const auto &loopDirective =
16581694
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
16591695

1696+
bool needEarlyExitHandling = false;
1697+
if (eval.lowerAsUnstructured())
1698+
needEarlyExitHandling = hasEarlyReturn(eval);
1699+
16601700
mlir::Location currentLocation =
16611701
converter.genLocation(beginLoopDirective.source);
16621702
Fortran::lower::StatementContext stmtCtx;
16631703

16641704
if (loopDirective.v == llvm::acc::ACCD_loop) {
16651705
const auto &accClauseList =
16661706
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
1667-
createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
1668-
accClauseList);
1707+
auto loopOp =
1708+
createLoopOp(converter, currentLocation, eval, semanticsContext,
1709+
stmtCtx, accClauseList, needEarlyExitHandling);
1710+
if (needEarlyExitHandling)
1711+
return loopOp.getResult(0);
16691712
}
1713+
return mlir::Value{};
16701714
}
16711715

16721716
template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
34313475
builder.restoreInsertionPoint(crtPos);
34323476
}
34333477

3434-
void Fortran::lower::genOpenACCConstruct(
3478+
mlir::Value Fortran::lower::genOpenACCConstruct(
34353479
Fortran::lower::AbstractConverter &converter,
34363480
Fortran::semantics::SemanticsContext &semanticsContext,
34373481
Fortran::lower::pft::Evaluation &eval,
34383482
const Fortran::parser::OpenACCConstruct &accConstruct) {
34393483

3484+
mlir::Value exitCond;
34403485
std::visit(
34413486
common::visitors{
34423487
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
34473492
genACC(converter, semanticsContext, eval, combinedConstruct);
34483493
},
34493494
[&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
3450-
genACC(converter, semanticsContext, eval, loopConstruct);
3495+
exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
34513496
},
34523497
[&](const Fortran::parser::OpenACCStandaloneConstruct
34533498
&standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
34673512
},
34683513
},
34693514
accConstruct.u);
3515+
return exitCond;
34703516
}
34713517

34723518
void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
35603606
else
35613607
builder.create<mlir::acc::TerminatorOp>(loc);
35623608
}
3609+
3610+
bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
3611+
if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
3612+
return true;
3613+
return false;
3614+
}
3615+
3616+
void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
3617+
fir::FirOpBuilder &builder) {
3618+
if (auto loopOp =
3619+
builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
3620+
builder.setInsertionPointAfter(loopOp);
3621+
}
3622+
3623+
void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
3624+
mlir::Location loc) {
3625+
mlir::Value yieldValue =
3626+
builder.createIntegerConstant(loc, builder.getI1Type(), 1);
3627+
builder.create<mlir::acc::YieldOp>(loc, yieldValue);
3628+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
2+
3+
subroutine sub1(x, a)
4+
real :: x(200)
5+
integer :: a
6+
7+
!$acc loop
8+
do i = 100, 200
9+
x(i) = 1.0
10+
if (i == a) return
11+
end do
12+
13+
i = 2
14+
end
15+
16+
! CHECK-LABEL: func.func @_QPsub1
17+
! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
18+
! CHECK: %[[EXIT_COND:.*]] = acc.loop {
19+
! CHECK: ^bb{{.*}}:
20+
! CHECK: ^bb{{.*}}:
21+
! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
22+
! CHECK: %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
23+
! CHECK: cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
24+
! CHECK: ^[[EARLY_RET]]:
25+
! CHECK: acc.yield %true : i1
26+
! CHECK: ^[[NO_RET]]:
27+
! CHECK: cf.br ^bb{{.*}}
28+
! CHECK: ^bb{{.*}}:
29+
! CHECK: acc.yield %false : i1
30+
! CHECK: }(i1)
31+
! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
32+
! CHECK: ^[[CONTINUE_BLOCK]]:
33+
! CHECK: hlfir.assign
34+
! CHECK: cf.br ^[[EXIT_BLOCK]]
35+
! CHECK: ^[[EXIT_BLOCK]]:
36+
! CHECK: return
37+
! CHECK: }

0 commit comments

Comments
 (0)