Skip to content

[flang][Lower] Emit exiting branches from within constructs #92455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 66 additions & 22 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,43 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genBranch(targetEval.block);
}

/// A construct contains nested evaluations. Some of these evaluations
/// may start a new basic block, others will add code to an existing
/// block.
/// Collect the list of nested evaluations that are last in their block,
/// organize them into two sets:
/// 1. Exiting evaluations: they may need a branch exiting from their
/// parent construct,
/// 2. Fall-through evaluations: they will continue to the following
/// evaluation. They may still need a branch, but they do not exit
/// the construct. They appear in cases where the following evaluation
/// is a target of some branch.
void collectFinalEvaluations(
Fortran::lower::pft::Evaluation &construct,
llvm::SmallVector<Fortran::lower::pft::Evaluation *> &exits,
llvm::SmallVector<Fortran::lower::pft::Evaluation *> &fallThroughs) {
Fortran::lower::pft::EvaluationList &nested =
construct.getNestedEvaluations();
if (nested.empty())
return;

Fortran::lower::pft::Evaluation *exit = construct.constructExit;
Fortran::lower::pft::Evaluation *previous = &nested.front();

for (auto it = ++nested.begin(), end = nested.end(); it != end;
previous = &*it++) {
if (it->block == nullptr)
continue;
// "*it" starts a new block, check what to do with "previous"
if (it->isIntermediateConstructStmt() && previous != exit)
exits.push_back(previous);
else if (previous->lexicalSuccessor && previous->lexicalSuccessor->block)
fallThroughs.push_back(previous);
}
if (previous != exit)
exits.push_back(previous);
}

/// Generate a SelectOp or branch sequence that compares \p selector against
/// values in \p valueList and targets corresponding labels in \p labelList.
/// If no value matches the selector, branch to \p defaultEval.
Expand Down Expand Up @@ -2107,6 +2144,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

// Unstructured branch sequence.
llvm::SmallVector<Fortran::lower::pft::Evaluation *> exits, fallThroughs;
collectFinalEvaluations(eval, exits, fallThroughs);

for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
auto genIfBranch = [&](mlir::Value cond) {
if (e.lexicalSuccessor == e.controlSuccessor) // empty block -> exit
Expand All @@ -2127,6 +2167,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
genIfBranch(genIfCondition(s));
} else {
genFIR(e);
if (blockIsUnterminated()) {
if (llvm::is_contained(exits, &e))
genConstructExitBranch(*eval.constructExit);
else if (llvm::is_contained(fallThroughs, &e))
genBranch(e.lexicalSuccessor->block);
}
}
}
}
Expand All @@ -2135,11 +2181,21 @@ class FirConverter : public Fortran::lower::AbstractConverter {
Fortran::lower::pft::Evaluation &eval = getEval();
Fortran::lower::StatementContext stmtCtx;
pushActiveConstruct(eval, stmtCtx);

llvm::SmallVector<Fortran::lower::pft::Evaluation *> exits, fallThroughs;
collectFinalEvaluations(eval, exits, fallThroughs);

for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) {
if (e.getIf<Fortran::parser::EndSelectStmt>())
maybeStartBlock(e.block);
else
genFIR(e);
if (blockIsUnterminated()) {
if (llvm::is_contained(exits, &e))
genConstructExitBranch(*eval.constructExit);
else if (llvm::is_contained(fallThroughs, &e))
genBranch(e.lexicalSuccessor->block);
}
}
popActiveConstruct();
}
Expand Down Expand Up @@ -3005,6 +3061,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}

pushActiveConstruct(getEval(), stmtCtx);
llvm::SmallVector<Fortran::lower::pft::Evaluation *> exits, fallThroughs;
collectFinalEvaluations(getEval(), exits, fallThroughs);
Fortran::lower::pft::Evaluation &constructExit = *getEval().constructExit;

for (Fortran::lower::pft::Evaluation &eval :
getEval().getNestedEvaluations()) {
setCurrentPosition(eval.position);
Expand Down Expand Up @@ -3201,6 +3261,12 @@ class FirConverter : public Fortran::lower::AbstractConverter {
} else {
genFIR(eval);
}
if (blockIsUnterminated()) {
if (llvm::is_contained(exits, &eval))
genConstructExitBranch(constructExit);
else if (llvm::is_contained(fallThroughs, &eval))
genBranch(eval.lexicalSuccessor->block);
}
}
popActiveConstruct();
}
Expand Down Expand Up @@ -4535,28 +4601,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
setCurrentEval(eval);
setCurrentPosition(eval.position);
eval.visit([&](const auto &stmt) { genFIR(stmt); });

// Generate an end-of-block branch for several special cases. For
// constructs, this can be done for either the end construct statement,
// or for the construct itself, which will skip this code if the
// end statement was visited first and generated a branch.
Fortran::lower::pft::Evaluation *successor = [&]() {
if (eval.isConstruct() ||
(eval.isDirective() && eval.hasNestedEvaluations()))
return eval.getLastNestedEvaluation().lexicalSuccessor;
return eval.lexicalSuccessor;
}();

if (successor && blockIsUnterminated()) {
if (successor->isIntermediateConstructStmt() &&
successor->parentConstruct->lowerAsUnstructured())
// Exit from an intermediate unstructured IF or SELECT construct block.
genBranch(successor->parentConstruct->constructExit->block);
else if (unstructuredContext && eval.isConstructStmt() &&
successor == eval.controlSuccessor)
// Exit from a degenerate, empty construct block.
genBranch(eval.parentConstruct->constructExit->block);
}
}

/// Map mlir function block arguments to the corresponding Fortran dummy
Expand Down
77 changes: 70 additions & 7 deletions flang/test/Lower/branching-directive.f90
Original file line number Diff line number Diff line change
@@ -1,25 +1,88 @@
!RUN: flang-new -fc1 -emit-hlfir -fopenmp -o - %s | FileCheck %s
!RUN: bbc -emit-hlfir -fopenacc -fopenmp -o - %s | FileCheck %s

!https://github.com/llvm/llvm-project/issues/91526

!CHECK-LABEL: func.func @_QPsimple1
!CHECK: cf.cond_br %{{[0-9]+}}, ^bb[[THEN:[0-9]+]], ^bb[[ELSE:[0-9]+]]
!CHECK: ^bb[[THEN]]:
!CHECK: cf.br ^bb[[EXIT:[0-9]+]]
!CHECK: omp.parallel
!CHECK: cf.br ^bb[[ENDIF:[0-9]+]]
!CHECK: ^bb[[ELSE]]:
!CHECK: fir.call @_FortranAStopStatement
!CHECK: fir.unreachable
!CHECK: ^bb[[EXIT]]:
!CHECK: ^bb[[ENDIF]]:
!CHECK: return

subroutine simple(y)
subroutine simple1(y)
implicit none
logical, intent(in) :: y
integer :: i
if (y) then
!$omp parallel
!$omp parallel
i = 1
!$omp end parallel
!$omp end parallel
else
stop 1
end if
end subroutine simple
end subroutine

!CHECK-LABEL: func.func @_QPsimple2
!CHECK: cf.cond_br %{{[0-9]+}}, ^bb[[THEN:[0-9]+]], ^bb[[ELSE:[0-9]+]]
!CHECK: ^bb[[THEN]]:
!CHECK: omp.parallel
!CHECK: cf.br ^bb[[ENDIF:[0-9]+]]
!CHECK: ^bb[[ELSE]]:
!CHECK: fir.call @_FortranAStopStatement
!CHECK: fir.unreachable
!CHECK: ^bb[[ENDIF]]:
!CHECK: fir.call @_FortranAioOutputReal64
!CHECK: return
subroutine simple2(x, yn)
implicit none
logical, intent(in) :: yn
integer, intent(in) :: x
integer :: i
real(8) :: E
E = 0d0

if (yn) then
!$omp parallel do private(i) reduction(+:E)
do i = 1, x
E = E + i
end do
!$omp end parallel do
else
stop 1
end if
print *, E
end subroutine

!CHECK-LABEL: func.func @_QPacccase
!CHECK: fir.select_case %{{[0-9]+}} : i32 [{{.*}}, ^bb[[CASE1:[0-9]+]], {{.*}}, ^bb[[CASE2:[0-9]+]], {{.*}}, ^bb[[CASE3:[0-9]+]]]
!CHECK: ^bb[[CASE1]]:
!CHECK: acc.serial
!CHECK: cf.br ^bb[[EXIT:[0-9]+]]
!CHECK: ^bb[[CASE2]]:
!CHECK: fir.call @_FortranAioOutputAscii
!CHECK: cf.br ^bb[[EXIT]]
!CHECK: ^bb[[CASE3]]:
!CHECK: fir.call @_FortranAioOutputAscii
!CHECK: cf.br ^bb[[EXIT]]
!CHECK: ^bb[[EXIT]]:
!CHECK: return
subroutine acccase(var)
integer :: var
integer :: res(10)
select case (var)
case (1)
print *, "case 1"
!$acc serial
res(1) = 1
!$acc end serial
case (2)
print *, "case 2"
case default
print *, "case default"
end select
end subroutine

31 changes: 31 additions & 0 deletions flang/test/Lower/unstructured-control-flow.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
!RUN: bbc -emit-hlfir -o - %s | FileCheck %s

!CHECK-LABEL: func.func @_QPunstructured1
!CHECK: fir.select %{{[0-9]+}} : i32 [{{.*}}, ^bb[[BLOCK3:[0-9]+]], {{.*}}, ^bb[[BLOCK4:[0-9]+]], {{.*}}, ^bb[[BLOCK5:[0-9]+]], {{.*}}, ^bb[[BLOCK1:[0-9]+]]]
!CHECK: ^bb[[BLOCK1]]:
!CHECK: cf.cond_br %{{[0-9]+}}, ^bb[[BLOCK2:[0-9]+]], ^bb[[BLOCK4]]
!CHECK: ^bb[[BLOCK2]]:
!CHECK: fir.if
!CHECK: cf.br ^bb[[BLOCK3]]
!CHECK: ^bb[[BLOCK3]]:
!CHECK: %[[C10:[a-z0-9_]+]] = arith.constant 10 : i32
!CHECK: arith.addi {{.*}}, %[[C10]]
!CHECK: cf.br ^bb[[BLOCK4]]
!CHECK: ^bb[[BLOCK4]]:
!CHECK: %[[C100:[a-z0-9_]+]] = arith.constant 100 : i32
!CHECK: arith.addi {{.*}}, %[[C100]]
!CHECK: cf.br ^bb[[BLOCK5]]
!CHECK: ^bb[[BLOCK5]]:
!CHECK: %[[C1000:[a-z0-9_]+]] = arith.constant 1000 : i32
!CHECK: arith.addi {{.*}}, %[[C1000]]
!CHECK: return
subroutine unstructured1(j, k)
goto (11, 22, 33) j-3 ! computed goto - an expression outside [1,3] is a nop
if (j == 2) goto 22
if (j == 1) goto 11
k = k + 1
11 k = k + 10
22 k = k + 100
33 k = k + 1000
end

Loading