Skip to content

[flang][openacc] Support early return in acc.loop #73841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2023

Conversation

clementval
Copy link
Contributor

Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the func.return operation cannot be directly used inside the region.
An early return is materialized by an acc.yield operation returning a true value. The standard end of the acc.loop region yield a false value in this case.
A conditional branch operation on the acc.loop result will branch to the finalBlock or just to the continue block whether an early exit was produce in the acc.loop.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir openacc labels Nov 29, 2023
@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-openacc

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the func.return operation cannot be directly used inside the region.
An early return is materialized by an acc.yield operation returning a true value. The standard end of the acc.loop region yield a false value in this case.
A conditional branch operation on the acc.loop result will branch to the finalBlock or just to the continue block whether an early exit was produce in the acc.loop.


Full diff: https://github.com/llvm/llvm-project/pull/73841.diff

4 Files Affected:

  • (modified) flang/include/flang/Lower/OpenACC.h (+10-3)
  • (modified) flang/lib/Lower/Bridge.cpp (+21-2)
  • (modified) flang/lib/Lower/OpenACC.cpp (+80-14)
  • (added) flang/test/Lower/OpenACC/acc-loop-exit.f90 (+37)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..f23e4726f33e00a 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
 static constexpr llvm::StringRef declarePostDeallocSuffix =
     "_acc_declare_update_desc_post_dealloc";
 
-void genOpenACCConstruct(AbstractConverter &,
-                         Fortran::semantics::SemanticsContext &,
-                         pft::Evaluation &, const parser::OpenACCConstruct &);
+mlir::Value genOpenACCConstruct(AbstractConverter &,
+                                Fortran::semantics::SemanticsContext &,
+                                pft::Evaluation &,
+                                const parser::OpenACCConstruct &);
 void genOpenACCDeclarativeConstruct(AbstractConverter &,
                                     Fortran::semantics::SemanticsContext &,
                                     StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
 void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
                           mlir::Location);
 
+bool isInOpenACCLoop(fir::FirOpBuilder &);
+
+void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
+
+void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
+
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 23c48cc7bd97874..45da1355df168e2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
-    genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
+    mlir::Value exitCond = genOpenACCConstruct(
+        *this, bridge.getSemanticsContext(), getEval(), acc);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
+
+    const Fortran::parser::OpenACCLoopConstruct *accLoop =
+        std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+    if (accLoop && exitCond) {
+      Fortran::lower::pft::FunctionLikeUnit *funit =
+          getEval().getOwningProcedure();
+      assert(funit && "not inside main program, function or subroutine");
+      mlir::Block *continueBlock =
+          builder->getBlock()->splitBlock(builder->getBlock()->end());
+      builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
+                                              funit->finalBlock, continueBlock);
+      builder->setInsertionPointToEnd(continueBlock);
+    }
   }
 
   void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // Branch to the last block of the SUBROUTINE, which has the actual return.
     if (!funit->finalBlock) {
       mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+      Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
       funit->finalBlock = builder->createBlock(&builder->getRegion());
       builder->restoreInsertionPoint(insPt);
     }
-    builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
+
+    if (Fortran::lower::isInOpenACCLoop(*builder))
+      Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
+    else
+      builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
   }
 
   void genFIR(const Fortran::parser::CycleStmt &) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 8c6c22210cf0894..e2abed1b9f4f675 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -25,10 +25,12 @@
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 
 // Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                          Fortran::lower::pft::Evaluation &eval,
                          const llvm::SmallVectorImpl<mlir::Value> &operands,
                          const llvm::SmallVectorImpl<int32_t> &operandSegments,
-                         bool outerCombined = false) {
-  llvm::ArrayRef<mlir::Type> argTy;
-  Op op = builder.create<Op>(loc, argTy, operands);
+                         bool outerCombined = false,
+                         llvm::SmallVector<mlir::Type> retTy = {},
+                         mlir::Value yieldValue = {}) {
+  Op op = builder.create<Op>(loc, retTy, operands);
   builder.createBlock(&op.getRegion());
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                                             mlir::acc::YieldOp>(
         builder, eval.getNestedEvaluations());
 
-  builder.create<Terminator>(loc);
+  if (yieldValue) {
+    if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
+      Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
+      yieldValue.getDefiningOp()->moveBefore(yieldOp);
+    } else {
+      builder.create<Terminator>(loc);
+    }
+  } else {
+    builder.create<Terminator>(loc);
+  }
   builder.setInsertionPointToStart(&block);
   return op;
 }
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
              Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
-             const Fortran::parser::AccClauseList &accClauseList) {
+             const Fortran::parser::AccClauseList &accClauseList,
+             bool needEarlyReturnHandling = false) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
   mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, privateOperands);
   addOperands(operands, operandSegments, reductionOperands);
 
+  llvm::SmallVector<mlir::Type> retTy;
+  mlir::Value yieldValue;
+  if (needEarlyReturnHandling) {
+    mlir::Type i1Ty = builder.getI1Type();
+    yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
+    retTy.push_back(i1Ty);
+  }
+
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, eval, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments,
+      /*outerCombined=*/false, retTy, yieldValue);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   return loopOp;
 }
 
-static void genACC(Fortran::lower::AbstractConverter &converter,
-                   Fortran::semantics::SemanticsContext &semanticsContext,
-                   Fortran::lower::pft::Evaluation &eval,
-                   const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
+  bool hasReturnStmt = false;
+  for (auto &e : eval.getNestedEvaluations()) {
+    e.visit(Fortran::common::visitors{
+        [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
+        [&](const auto &s) {},
+    });
+    if (e.hasNestedEvaluations())
+      hasReturnStmt = hasEarlyReturn(e);
+  }
+  return hasReturnStmt;
+}
+
+static mlir::Value
+genACC(Fortran::lower::AbstractConverter &converter,
+       Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
+       const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
 
   const auto &beginLoopDirective =
       std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
   const auto &loopDirective =
       std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
 
+  bool needEarlyExitHandling = false;
+  if (eval.lowerAsUnstructured())
+    needEarlyExitHandling = hasEarlyReturn(eval);
+
   mlir::Location currentLocation =
       converter.genLocation(beginLoopDirective.source);
   Fortran::lower::StatementContext stmtCtx;
@@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   if (loopDirective.v == llvm::acc::ACCD_loop) {
     const auto &accClauseList =
         std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    auto loopOp =
+        createLoopOp(converter, currentLocation, eval, semanticsContext,
+                     stmtCtx, accClauseList, needEarlyExitHandling);
+    if (needEarlyExitHandling)
+      return loopOp.getResult(0);
   }
+  return mlir::Value{};
 }
 
 template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
   builder.restoreInsertionPoint(crtPos);
 }
 
-void Fortran::lower::genOpenACCConstruct(
+mlir::Value Fortran::lower::genOpenACCConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OpenACCConstruct &accConstruct) {
 
+  mlir::Value exitCond;
   std::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
             genACC(converter, semanticsContext, eval, combinedConstruct);
           },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            genACC(converter, semanticsContext, eval, loopConstruct);
+            exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
           },
       },
       accConstruct.u);
+  return exitCond;
 }
 
 void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
   else
     builder.create<mlir::acc::TerminatorOp>(loc);
 }
+
+bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
+  if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+    return true;
+  return false;
+}
+
+void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
+    fir::FirOpBuilder &builder) {
+  if (auto loopOp =
+          builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+    builder.setInsertionPointAfter(loopOp);
+}
+
+void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
+                                                 mlir::Location loc) {
+  mlir::Value yieldValue =
+      builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+  builder.create<mlir::acc::YieldOp>(loc, yieldValue);
+}
diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90
new file mode 100644
index 000000000000000..75f1c3073327228
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90
@@ -0,0 +1,37 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(x, a)
+  real :: x(200)
+  integer :: a
+
+  !$acc loop
+  do i = 100, 200
+    x(i) = 1.0
+    if (i == a) return
+  end do
+
+  i = 2
+end 
+
+! CHECK-LABEL: func.func @_QPsub1
+! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[EXIT_COND:.*]] = acc.loop {
+! CHECK: ^bb{{.*}}:
+! CHECK: ^bb{{.*}}:
+! CHECK:   %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+! CHECK:   %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
+! CHECK:   cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
+! CHECK: ^[[EARLY_RET]]:
+! CHECK:   acc.yield %true : i1
+! CHECK: ^[[NO_RET]]:
+! CHECK:   cf.br ^bb{{.*}}
+! CHECK: ^bb{{.*}}:
+! CHECK:   acc.yield %false : i1
+! CHECK: }(i1)
+! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
+! CHECK: ^[[CONTINUE_BLOCK]]:
+! CHECK:   hlfir.assign
+! CHECK:   cf.br ^[[EXIT_BLOCK]]
+! CHECK: ^[[EXIT_BLOCK]]:
+! CHECK:   return
+! CHECK: }

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Early return is accepted in OpenACC loop not directly nested in a compute construct. Since acc.loop operation has a region, the func.return operation cannot be directly used inside the region.
An early return is materialized by an acc.yield operation returning a true value. The standard end of the acc.loop region yield a false value in this case.
A conditional branch operation on the acc.loop result will branch to the finalBlock or just to the continue block whether an early exit was produce in the acc.loop.


Full diff: https://github.com/llvm/llvm-project/pull/73841.diff

4 Files Affected:

  • (modified) flang/include/flang/Lower/OpenACC.h (+10-3)
  • (modified) flang/lib/Lower/Bridge.cpp (+21-2)
  • (modified) flang/lib/Lower/OpenACC.cpp (+80-14)
  • (added) flang/test/Lower/OpenACC/acc-loop-exit.f90 (+37)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 409956f0ecb309f..f23e4726f33e00a 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -64,9 +64,10 @@ static constexpr llvm::StringRef declarePreDeallocSuffix =
 static constexpr llvm::StringRef declarePostDeallocSuffix =
     "_acc_declare_update_desc_post_dealloc";
 
-void genOpenACCConstruct(AbstractConverter &,
-                         Fortran::semantics::SemanticsContext &,
-                         pft::Evaluation &, const parser::OpenACCConstruct &);
+mlir::Value genOpenACCConstruct(AbstractConverter &,
+                                Fortran::semantics::SemanticsContext &,
+                                pft::Evaluation &,
+                                const parser::OpenACCConstruct &);
 void genOpenACCDeclarativeConstruct(AbstractConverter &,
                                     Fortran::semantics::SemanticsContext &,
                                     StatementContext &,
@@ -112,6 +113,12 @@ void attachDeclarePostDeallocAction(AbstractConverter &, fir::FirOpBuilder &,
 void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
                           mlir::Location);
 
+bool isInOpenACCLoop(fir::FirOpBuilder &);
+
+void setInsertionPointAfterOpenACCLoopIfInside(fir::FirOpBuilder &);
+
+void genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &, mlir::Location);
+
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 23c48cc7bd97874..45da1355df168e2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2382,11 +2382,25 @@ class FirConverter : public Fortran::lower::AbstractConverter {
   void genFIR(const Fortran::parser::OpenACCConstruct &acc) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     localSymbols.pushScope();
-    genOpenACCConstruct(*this, bridge.getSemanticsContext(), getEval(), acc);
+    mlir::Value exitCond = genOpenACCConstruct(
+        *this, bridge.getSemanticsContext(), getEval(), acc);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
+
+    const Fortran::parser::OpenACCLoopConstruct *accLoop =
+        std::get_if<Fortran::parser::OpenACCLoopConstruct>(&acc.u);
+    if (accLoop && exitCond) {
+      Fortran::lower::pft::FunctionLikeUnit *funit =
+          getEval().getOwningProcedure();
+      assert(funit && "not inside main program, function or subroutine");
+      mlir::Block *continueBlock =
+          builder->getBlock()->splitBlock(builder->getBlock()->end());
+      builder->create<mlir::cf::CondBranchOp>(toLocation(), exitCond,
+                                              funit->finalBlock, continueBlock);
+      builder->setInsertionPointToEnd(continueBlock);
+    }
   }
 
   void genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &accDecl) {
@@ -4091,10 +4105,15 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // Branch to the last block of the SUBROUTINE, which has the actual return.
     if (!funit->finalBlock) {
       mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
+      Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(*builder);
       funit->finalBlock = builder->createBlock(&builder->getRegion());
       builder->restoreInsertionPoint(insPt);
     }
-    builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
+
+    if (Fortran::lower::isInOpenACCLoop(*builder))
+      Fortran::lower::genEarlyReturnInOpenACCLoop(*builder, loc);
+    else
+      builder->create<mlir::cf::BranchOp>(loc, funit->finalBlock);
   }
 
   void genFIR(const Fortran::parser::CycleStmt &) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 8c6c22210cf0894..e2abed1b9f4f675 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -25,10 +25,12 @@
 #include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Builder/IntrinsicCall.h"
 #include "flang/Optimizer/Builder/Todo.h"
+#include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/scope.h"
 #include "flang/Semantics/tools.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "llvm/Frontend/OpenACC/ACC.h.inc"
 
 // Special value for * passed in device_type or gang clauses.
@@ -1381,9 +1383,10 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                          Fortran::lower::pft::Evaluation &eval,
                          const llvm::SmallVectorImpl<mlir::Value> &operands,
                          const llvm::SmallVectorImpl<int32_t> &operandSegments,
-                         bool outerCombined = false) {
-  llvm::ArrayRef<mlir::Type> argTy;
-  Op op = builder.create<Op>(loc, argTy, operands);
+                         bool outerCombined = false,
+                         llvm::SmallVector<mlir::Type> retTy = {},
+                         mlir::Value yieldValue = {}) {
+  Op op = builder.create<Op>(loc, retTy, operands);
   builder.createBlock(&op.getRegion());
   mlir::Block &block = op.getRegion().back();
   builder.setInsertionPointToStart(&block);
@@ -1401,7 +1404,16 @@ static Op createRegionOp(fir::FirOpBuilder &builder, mlir::Location loc,
                                             mlir::acc::YieldOp>(
         builder, eval.getNestedEvaluations());
 
-  builder.create<Terminator>(loc);
+  if (yieldValue) {
+    if constexpr (std::is_same_v<Terminator, mlir::acc::YieldOp>) {
+      Terminator yieldOp = builder.create<Terminator>(loc, yieldValue);
+      yieldValue.getDefiningOp()->moveBefore(yieldOp);
+    } else {
+      builder.create<Terminator>(loc);
+    }
+  } else {
+    builder.create<Terminator>(loc);
+  }
   builder.setInsertionPointToStart(&block);
   return op;
 }
@@ -1494,7 +1506,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
              Fortran::lower::pft::Evaluation &eval,
              Fortran::semantics::SemanticsContext &semanticsContext,
              Fortran::lower::StatementContext &stmtCtx,
-             const Fortran::parser::AccClauseList &accClauseList) {
+             const Fortran::parser::AccClauseList &accClauseList,
+             bool needEarlyReturnHandling = false) {
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
   mlir::Value workerNum;
@@ -1599,8 +1612,17 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   addOperands(operands, operandSegments, privateOperands);
   addOperands(operands, operandSegments, reductionOperands);
 
+  llvm::SmallVector<mlir::Type> retTy;
+  mlir::Value yieldValue;
+  if (needEarlyReturnHandling) {
+    mlir::Type i1Ty = builder.getI1Type();
+    yieldValue = builder.createIntegerConstant(currentLocation, i1Ty, 0);
+    retTy.push_back(i1Ty);
+  }
+
   auto loopOp = createRegionOp<mlir::acc::LoopOp, mlir::acc::YieldOp>(
-      builder, currentLocation, eval, operands, operandSegments);
+      builder, currentLocation, eval, operands, operandSegments,
+      /*outerCombined=*/false, retTy, yieldValue);
 
   if (hasGang)
     loopOp.setHasGangAttr(builder.getUnitAttr());
@@ -1647,16 +1669,34 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
   return loopOp;
 }
 
-static void genACC(Fortran::lower::AbstractConverter &converter,
-                   Fortran::semantics::SemanticsContext &semanticsContext,
-                   Fortran::lower::pft::Evaluation &eval,
-                   const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
+static bool hasEarlyReturn(Fortran::lower::pft::Evaluation &eval) {
+  bool hasReturnStmt = false;
+  for (auto &e : eval.getNestedEvaluations()) {
+    e.visit(Fortran::common::visitors{
+        [&](const Fortran::parser::ReturnStmt &) { hasReturnStmt = true; },
+        [&](const auto &s) {},
+    });
+    if (e.hasNestedEvaluations())
+      hasReturnStmt = hasEarlyReturn(e);
+  }
+  return hasReturnStmt;
+}
+
+static mlir::Value
+genACC(Fortran::lower::AbstractConverter &converter,
+       Fortran::semantics::SemanticsContext &semanticsContext,
+       Fortran::lower::pft::Evaluation &eval,
+       const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
 
   const auto &beginLoopDirective =
       std::get<Fortran::parser::AccBeginLoopDirective>(loopConstruct.t);
   const auto &loopDirective =
       std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
 
+  bool needEarlyExitHandling = false;
+  if (eval.lowerAsUnstructured())
+    needEarlyExitHandling = hasEarlyReturn(eval);
+
   mlir::Location currentLocation =
       converter.genLocation(beginLoopDirective.source);
   Fortran::lower::StatementContext stmtCtx;
@@ -1664,9 +1704,13 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
   if (loopDirective.v == llvm::acc::ACCD_loop) {
     const auto &accClauseList =
         std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
-    createLoopOp(converter, currentLocation, eval, semanticsContext, stmtCtx,
-                 accClauseList);
+    auto loopOp =
+        createLoopOp(converter, currentLocation, eval, semanticsContext,
+                     stmtCtx, accClauseList, needEarlyExitHandling);
+    if (needEarlyExitHandling)
+      return loopOp.getResult(0);
   }
+  return mlir::Value{};
 }
 
 template <typename Op, typename Clause>
@@ -3431,12 +3475,13 @@ genACC(Fortran::lower::AbstractConverter &converter,
   builder.restoreInsertionPoint(crtPos);
 }
 
-void Fortran::lower::genOpenACCConstruct(
+mlir::Value Fortran::lower::genOpenACCConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::semantics::SemanticsContext &semanticsContext,
     Fortran::lower::pft::Evaluation &eval,
     const Fortran::parser::OpenACCConstruct &accConstruct) {
 
+  mlir::Value exitCond;
   std::visit(
       common::visitors{
           [&](const Fortran::parser::OpenACCBlockConstruct &blockConstruct) {
@@ -3447,7 +3492,7 @@ void Fortran::lower::genOpenACCConstruct(
             genACC(converter, semanticsContext, eval, combinedConstruct);
           },
           [&](const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
-            genACC(converter, semanticsContext, eval, loopConstruct);
+            exitCond = genACC(converter, semanticsContext, eval, loopConstruct);
           },
           [&](const Fortran::parser::OpenACCStandaloneConstruct
                   &standaloneConstruct) {
@@ -3467,6 +3512,7 @@ void Fortran::lower::genOpenACCConstruct(
           },
       },
       accConstruct.u);
+  return exitCond;
 }
 
 void Fortran::lower::genOpenACCDeclarativeConstruct(
@@ -3560,3 +3606,23 @@ void Fortran::lower::genOpenACCTerminator(fir::FirOpBuilder &builder,
   else
     builder.create<mlir::acc::TerminatorOp>(loc);
 }
+
+bool Fortran::lower::isInOpenACCLoop(fir::FirOpBuilder &builder) {
+  if (builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+    return true;
+  return false;
+}
+
+void Fortran::lower::setInsertionPointAfterOpenACCLoopIfInside(
+    fir::FirOpBuilder &builder) {
+  if (auto loopOp =
+          builder.getBlock()->getParent()->getParentOfType<mlir::acc::LoopOp>())
+    builder.setInsertionPointAfter(loopOp);
+}
+
+void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
+                                                 mlir::Location loc) {
+  mlir::Value yieldValue =
+      builder.createIntegerConstant(loc, builder.getI1Type(), 1);
+  builder.create<mlir::acc::YieldOp>(loc, yieldValue);
+}
diff --git a/flang/test/Lower/OpenACC/acc-loop-exit.f90 b/flang/test/Lower/OpenACC/acc-loop-exit.f90
new file mode 100644
index 000000000000000..75f1c3073327228
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-exit.f90
@@ -0,0 +1,37 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+subroutine sub1(x, a)
+  real :: x(200)
+  integer :: a
+
+  !$acc loop
+  do i = 100, 200
+    x(i) = 1.0
+    if (i == a) return
+  end do
+
+  i = 2
+end 
+
+! CHECK-LABEL: func.func @_QPsub1
+! CHECK: %[[A:.*]]:2 = hlfir.declare %arg1 {uniq_name = "_QFsub1Ea"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+! CHECK: %[[EXIT_COND:.*]] = acc.loop {
+! CHECK: ^bb{{.*}}:
+! CHECK: ^bb{{.*}}:
+! CHECK:   %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+! CHECK:   %[[CMP:.*]] = arith.cmpi eq, %15, %[[LOAD_A]] : i32
+! CHECK:   cf.cond_br %[[CMP]], ^[[EARLY_RET:.*]], ^[[NO_RET:.*]]
+! CHECK: ^[[EARLY_RET]]:
+! CHECK:   acc.yield %true : i1
+! CHECK: ^[[NO_RET]]:
+! CHECK:   cf.br ^bb{{.*}}
+! CHECK: ^bb{{.*}}:
+! CHECK:   acc.yield %false : i1
+! CHECK: }(i1)
+! CHECK: cf.cond_br %[[EXIT_COND]], ^[[EXIT_BLOCK:.*]], ^[[CONTINUE_BLOCK:.*]]
+! CHECK: ^[[CONTINUE_BLOCK]]:
+! CHECK:   hlfir.assign
+! CHECK:   cf.br ^[[EXIT_BLOCK]]
+! CHECK: ^[[EXIT_BLOCK]]:
+! CHECK:   return
+! CHECK: }

Copy link
Contributor

@razvanlupusoru razvanlupusoru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thank you!

@clementval clementval merged commit a9a5af8 into llvm:main Nov 30, 2023
@clementval clementval deleted the acc_loop_return branch November 30, 2023 22:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category openacc
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants