-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][openacc] Support early return in acc.loop #73841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-openacc Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesEarly return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the Full diff: https://github.com/llvm/llvm-project/pull/73841.diff 4 Files Affected:
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..f23e4726f33e00a 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
static constexpr llvm::StringRef declarePostDeallocSuffix =
"_acc_declare_update_desc_post_dealloc";
-void genOpenACCConstruct(AbstractConverter &,
- Fortran::semantics::SemanticsContext &,
- pft::Evaluation &, const parser::OpenACCConstruct &);
+mlir::Value genOpenACCConstruct(AbstractConverter &,
+ Fortran::semantics::SemanticsContext &,
+ pft::Evaluation &,
+ const parser::OpenACCConstruct &);
void genOpenACCDeclarativeConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
mlir::Location);
+bool isInOpenACCLoop(fir::FirOpBuilder &);
+
+void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
+
+void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
+
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 23c48cc7bd97874..45da1355df168e2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
- genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
+ mlir::Value exitCond = genOpenACCConstruct(
+ *this, bridge.getSemanticsContext(), getEval(), acc);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
+
+ const Fortran::parser::OpenACCLoopConstruct *accLoop =
+ std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+ if (accLoop && exitCond) {
+ Fortran::lower::pft::FunctionLikeUnit *funit =
+ getEval().getOwningProcedure();
+ assert(funit && "not inside main program, function or subroutine");
+ mlir::Block *continueBlock =
+ builder->getBlock()->splitBlock(builder->getBlock()->end());
+ builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
+ funit->finalBlock, continueBlock);
+ builder->setInsertionPointToEnd(continueBlock);
+ }
}
void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Branch to the last block of the SUBROUTINE, which has the actual return.
if (!funit->finalBlock) {
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+ Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
funit->finalBlock = builder->createBlock(&builder->getRegion());
builder->restoreInsertionPoint(insPt);
}
- builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
+
+ if (Fortran::lower::isInOpenACCLoop(*builder))
+ Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
+ else
+ builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
}
void genFIR(const Fortran::parser::CycleStmt &) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 8c6c22210cf0894..e2abed1b9f4f675 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -25,10 +25,12 @@
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "llvm/Frontend/OpenACC/ACC.h.inc"
// Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::pft::Evaluation &eval,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments,
- bool outerCombined = false) {
- llvm::ArrayRef<mlir::Type> argTy;
- Op op = builder.create<Op>(loc, argTy, operands);
+ bool outerCombined = false,
+ llvm::SmallVector<mlir::Type> retTy = {},
+ mlir::Value yieldValue = {}) {
+ Op op = builder.create<Op>(loc, retTy, operands);
builder.createBlock(&op.getRegion());
mlir::Block &block = op.getRegion().back();
builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::YieldOp>(
builder, eval.getNestedEvaluations());
- builder.create<Terminator>(loc);
+ if (yieldValue) {
+ if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
+ Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
+ yieldValue.getDefiningOp()->moveBefore(yieldOp);
+ } else {
+ builder.create<Terminator>(loc);
+ }
+ } else {
+ builder.create<Terminator>(loc);
+ }
builder.setInsertionPointToStart(&block);
return op;
}
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ bool needEarlyReturnHandling = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
addOperands(operands, operandSegments, privateOperands);
addOperands(operands, operandSegments, reductionOperands);
+ llvm::SmallVector<mlir::Type> retTy;
+ mlir::Value yieldValue;
+ if (needEarlyReturnHandling) {
+ mlir::Type i1Ty = builder.getI1Type();
+ yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
+ retTy.push_back(i1Ty);
+ }
+
auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
- builder, currentLocation, eval, operands, operandSegments);
+ builder, currentLocation, eval, operands, operandSegments,
+ /*outerCombined=*/false, retTy, yieldValue);
if (hasGang)
loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
return loopOp;
}
-static void genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
+ bool hasReturnStmt = false;
+ for (auto &e : eval.getNestedEvaluations()) {
+ e.visit(Fortran::common::visitors{
+ [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
+ [&](const auto &s) {},
+ });
+ if (e.hasNestedEvaluations())
+ hasReturnStmt = hasEarlyReturn(e);
+ }
+ return hasReturnStmt;
+}
+
+static mlir::Value
+genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
const auto &loopDirective =
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
+ bool needEarlyExitHandling = false;
+ if (eval.lowerAsUnstructured())
+ needEarlyExitHandling = hasEarlyReturn(eval);
+
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
Fortran::lower::StatementContext stmtCtx;
@@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
if (loopDirective.v == llvm::acc::ACCD_loop) {
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
- createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
- accClauseList);
+ auto loopOp =
+ createLoopOp(converter, currentLocation, eval, semanticsContext,
+ stmtCtx, accClauseList, needEarlyExitHandling);
+ if (needEarlyExitHandling)
+ return loopOp.getResult(0);
}
+ return mlir::Value{};
}
template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
builder.restoreInsertionPoint(crtPos);
}
-void Fortran::lower::genOpenACCConstruct(
+mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCConstruct &accConstruct) {
+ mlir::Value exitCond;
std::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
genACC(converter, semanticsContext, eval, combinedConstruct);
},
[&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
- genACC(converter, semanticsContext, eval, loopConstruct);
+ exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
},
[&](const Fortran::parser::OpenACCStandaloneConstruct
&standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
},
},
accConstruct.u);
+ return exitCond;
}
void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
else
builder.create<mlir::acc::TerminatorOp>(loc);
}
+
+bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
+ if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+ return true;
+ return false;
+}
+
+void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
+ fir::FirOpBuilder &builder) {
+ if (auto loopOp =
+ builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+ builder.setInsertionPointAfter(loopOp);
+}
+
+void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
+ mlir::Location loc) {
+ mlir::Value yieldValue =
+ builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+ builder.create<mlir::acc::YieldOp>(loc, yieldValue);
+}
diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90
new file mode 100644
index 000000000000000..75f1c3073327228
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90
@@ -0,0 +1,37 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(x, a)
+ real :: x(200)
+ integer :: a
+
+ !$acc loop
+ do i = 100, 200
+ x(i) = 1.0
+ if (i == a) return
+ end do
+
+ i = 2
+end
+
+! CHECK-LABEL: func.func @_QPsub1
+! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[EXIT_COND:.*]] = acc.loop {
+! CHECK: ^bb{{.*}}:
+! CHECK: ^bb{{.*}}:
+! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+! CHECK: %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
+! CHECK: cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
+! CHECK: ^[[EARLY_RET]]:
+! CHECK: acc.yield %true : i1
+! CHECK: ^[[NO_RET]]:
+! CHECK: cf.br ^bb{{.*}}
+! CHECK: ^bb{{.*}}:
+! CHECK: acc.yield %false : i1
+! CHECK: }(i1)
+! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
+! CHECK: ^[[CONTINUE_BLOCK]]:
+! CHECK: hlfir.assign
+! CHECK: cf.br ^[[EXIT_BLOCK]]
+! CHECK: ^[[EXIT_BLOCK]]:
+! CHECK: return
+! CHECK: }
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesEarly return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the Full diff: https://github.com/llvm/llvm-project/pull/73841.diff 4 Files Affected:
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..f23e4726f33e00a 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
static constexpr llvm::StringRef declarePostDeallocSuffix =
"_acc_declare_update_desc_post_dealloc";
-void genOpenACCConstruct(AbstractConverter &,
- Fortran::semantics::SemanticsContext &,
- pft::Evaluation &, const parser::OpenACCConstruct &);
+mlir::Value genOpenACCConstruct(AbstractConverter &,
+ Fortran::semantics::SemanticsContext &,
+ pft::Evaluation &,
+ const parser::OpenACCConstruct &);
void genOpenACCDeclarativeConstruct(AbstractConverter &,
Fortran::semantics::SemanticsContext &,
StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
mlir::Location);
+bool isInOpenACCLoop(fir::FirOpBuilder &);
+
+void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
+
+void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
+
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 23c48cc7bd97874..45da1355df168e2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
localSymbols.pushScope();
- genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
+ mlir::Value exitCond = genOpenACCConstruct(
+ *this, bridge.getSemanticsContext(), getEval(), acc);
for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
genFIR(e);
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
+
+ const Fortran::parser::OpenACCLoopConstruct *accLoop =
+ std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+ if (accLoop && exitCond) {
+ Fortran::lower::pft::FunctionLikeUnit *funit =
+ getEval().getOwningProcedure();
+ assert(funit && "not inside main program, function or subroutine");
+ mlir::Block *continueBlock =
+ builder->getBlock()->splitBlock(builder->getBlock()->end());
+ builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
+ funit->finalBlock, continueBlock);
+ builder->setInsertionPointToEnd(continueBlock);
+ }
}
void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Branch to the last block of the SUBROUTINE, which has the actual return.
if (!funit->finalBlock) {
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+ Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
funit->finalBlock = builder->createBlock(&builder->getRegion());
builder->restoreInsertionPoint(insPt);
}
- builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
+
+ if (Fortran::lower::isInOpenACCLoop(*builder))
+ Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
+ else
+ builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
}
void genFIR(const Fortran::parser::CycleStmt &) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 8c6c22210cf0894..e2abed1b9f4f675 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -25,10 +25,12 @@
#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Builder/IntrinsicCall.h"
#include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/expression.h"
#include "flang/Semantics/scope.h"
#include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "llvm/Frontend/OpenACC/ACC.h.inc"
// Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
Fortran::lower::pft::Evaluation &eval,
const llvm::SmallVectorImpl<mlir::Value> &operands,
const llvm::SmallVectorImpl<int32_t> &operandSegments,
- bool outerCombined = false) {
- llvm::ArrayRef<mlir::Type> argTy;
- Op op = builder.create<Op>(loc, argTy, operands);
+ bool outerCombined = false,
+ llvm::SmallVector<mlir::Type> retTy = {},
+ mlir::Value yieldValue = {}) {
+ Op op = builder.create<Op>(loc, retTy, operands);
builder.createBlock(&op.getRegion());
mlir::Block &block = op.getRegion().back();
builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::acc::YieldOp>(
builder, eval.getNestedEvaluations());
- builder.create<Terminator>(loc);
+ if (yieldValue) {
+ if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
+ Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
+ yieldValue.getDefiningOp()->moveBefore(yieldOp);
+ } else {
+ builder.create<Terminator>(loc);
+ }
+ } else {
+ builder.create<Terminator>(loc);
+ }
builder.setInsertionPointToStart(&block);
return op;
}
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::pft::Evaluation &eval,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::StatementContext &stmtCtx,
- const Fortran::parser::AccClauseList &accClauseList) {
+ const Fortran::parser::AccClauseList &accClauseList,
+ bool needEarlyReturnHandling = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
addOperands(operands, operandSegments, privateOperands);
addOperands(operands, operandSegments, reductionOperands);
+ llvm::SmallVector<mlir::Type> retTy;
+ mlir::Value yieldValue;
+ if (needEarlyReturnHandling) {
+ mlir::Type i1Ty = builder.getI1Type();
+ yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
+ retTy.push_back(i1Ty);
+ }
+
auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
- builder, currentLocation, eval, operands, operandSegments);
+ builder, currentLocation, eval, operands, operandSegments,
+ /*outerCombined=*/false, retTy, yieldValue);
if (hasGang)
loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
return loopOp;
}
-static void genACC(Fortran::lower::AbstractConverter &converter,
- Fortran::semantics::SemanticsContext &semanticsContext,
- Fortran::lower::pft::Evaluation &eval,
- const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
+ bool hasReturnStmt = false;
+ for (auto &e : eval.getNestedEvaluations()) {
+ e.visit(Fortran::common::visitors{
+ [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
+ [&](const auto &s) {},
+ });
+ if (e.hasNestedEvaluations())
+ hasReturnStmt = hasEarlyReturn(e);
+ }
+ return hasReturnStmt;
+}
+
+static mlir::Value
+genACC(Fortran::lower::AbstractConverter &converter,
+ Fortran::semantics::SemanticsContext &semanticsContext,
+ Fortran::lower::pft::Evaluation &eval,
+ const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
const auto &beginLoopDirective =
std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
const auto &loopDirective =
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
+ bool needEarlyExitHandling = false;
+ if (eval.lowerAsUnstructured())
+ needEarlyExitHandling = hasEarlyReturn(eval);
+
mlir::Location currentLocation =
converter.genLocation(beginLoopDirective.source);
Fortran::lower::StatementContext stmtCtx;
@@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
if (loopDirective.v == llvm::acc::ACCD_loop) {
const auto &accClauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
- createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
- accClauseList);
+ auto loopOp =
+ createLoopOp(converter, currentLocation, eval, semanticsContext,
+ stmtCtx, accClauseList, needEarlyExitHandling);
+ if (needEarlyExitHandling)
+ return loopOp.getResult(0);
}
+ return mlir::Value{};
}
template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
builder.restoreInsertionPoint(crtPos);
}
-void Fortran::lower::genOpenACCConstruct(
+mlir::Value Fortran::lower::genOpenACCConstruct(
Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semanticsContext,
Fortran::lower::pft::Evaluation &eval,
const Fortran::parser::OpenACCConstruct &accConstruct) {
+ mlir::Value exitCond;
std::visit(
common::visitors{
[&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
genACC(converter, semanticsContext, eval, combinedConstruct);
},
[&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
- genACC(converter, semanticsContext, eval, loopConstruct);
+ exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
},
[&](const Fortran::parser::OpenACCStandaloneConstruct
&standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
},
},
accConstruct.u);
+ return exitCond;
}
void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
else
builder.create<mlir::acc::TerminatorOp>(loc);
}
+
+bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
+ if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+ return true;
+ return false;
+}
+
+void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
+ fir::FirOpBuilder &builder) {
+ if (auto loopOp =
+ builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+ builder.setInsertionPointAfter(loopOp);
+}
+
+void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
+ mlir::Location loc) {
+ mlir::Value yieldValue =
+ builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+ builder.create<mlir::acc::YieldOp>(loc, yieldValue);
+}
diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90
new file mode 100644
index 000000000000000..75f1c3073327228
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90
@@ -0,0 +1,37 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(x, a)
+ real :: x(200)
+ integer :: a
+
+ !$acc loop
+ do i = 100, 200
+ x(i) = 1.0
+ if (i == a) return
+ end do
+
+ i = 2
+end
+
+! CHECK-LABEL: func.func @_QPsub1
+! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[EXIT_COND:.*]] = acc.loop {
+! CHECK: ^bb{{.*}}:
+! CHECK: ^bb{{.*}}:
+! CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+! CHECK: %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
+! CHECK: cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
+! CHECK: ^[[EARLY_RET]]:
+! CHECK: acc.yield %true : i1
+! CHECK: ^[[NO_RET]]:
+! CHECK: cf.br ^bb{{.*}}
+! CHECK: ^bb{{.*}}:
+! CHECK: acc.yield %false : i1
+! CHECK: }(i1)
+! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
+! CHECK: ^[[CONTINUE_BLOCK]]:
+! CHECK: hlfir.assign
+! CHECK: cf.br ^[[EXIT_BLOCK]]
+! CHECK: ^[[EXIT_BLOCK]]:
+! CHECK: return
+! CHECK: }
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thank you!
Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the
func.return
operation cannot be directly used inside the region.An early return is materialized by an
acc.yield
operation returning atrue
value. The standard end of theacc.loop
region yield afalse
value in this case.A conditional branch operation on the
acc.loop
result will branch to thefinalBlock
or just to the continue block whether an early exit was produce in the acc.loop.