Skip to content

Select Case constructs with character selector expressions #685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 99 additions & 24 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
#include "flang/Lower/CharacterExpr.h"
#include "flang/Lower/CharacterRuntime.h"
#include "flang/Lower/Coarray.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertType.h"
Expand Down Expand Up @@ -395,6 +396,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
return cat == Fortran::common::TypeCategory::Derived;
}

/// Insert a new block before \p block. Leave the insertion point unchanged.
mlir::Block *insertBlock(mlir::Block *block) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use mlir's splitBlock function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mlir functions splitBlock and (one overload of) createBlock are complementary. createBlock inserts a new block before a given block, and splitBlock inserts the new block after the given block. Since "the given block" may actually be the first of multiple blocks, createBlock makes much more sense here. Of course, where blocks are inserted doesn't affect actual code relationships between blocks. But it is much easier to read and debug IR when blocks are placed in more natural locations.

auto insertPt = builder->saveInsertionPoint();
auto newBlock = builder->createBlock(block);
builder->restoreInsertionPoint(insertPt);
return newBlock;
}

mlir::Block *blockOfLabel(Fortran::lower::pft::Evaluation &eval,
Fortran::parser::Label label) {
const auto &labelEvaluationMap =
Expand Down Expand Up @@ -791,10 +800,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// exit block of the immediately enclosed dimension.
auto createNextExitBlock = [&]() {
// Create unstructured loop exit blocks, outermost to innermost.
auto insertPt = builder->saveInsertionPoint();
exitBlock = builder->createBlock(exitBlock);
builder->restoreInsertionPoint(insertPt);
return exitBlock;
return exitBlock = insertBlock(exitBlock);
};
auto isInnermost = &info == &incrementLoopNestInfo.back();
auto isOutermost = &info == &incrementLoopNestInfo.front();
Expand Down Expand Up @@ -1130,36 +1136,53 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->restoreInsertionPoint(insertPt);
}

/// Generate FIR for a SELECT CASE statement.
/// The type may be CHARACTER, INTEGER, or LOGICAL.
void genFIR(const Fortran::parser::SelectCaseStmt &stmt) {
auto &eval = getEval();
auto *context = builder->getContext();
auto loc = toLocation();
Fortran::lower::StatementContext stmtCtx;
const auto *expr = Fortran::semantics::GetExpr(
std::get<Fortran::parser::Scalar<Fortran::parser::Expr>>(stmt.t));
auto exprType = expr->GetType();
mlir::Value selectExpr;
if (isCharacterCategory(exprType->category())) {
TODO(loc, "Select Case selector of type Character");
bool isCharSelector = isCharacterCategory(expr->GetType()->category());
bool isLogicalSelector = isLogicalCategory(expr->GetType()->category());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have moved to using the FIR type for these tests. @jeanPerier ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general yes, when there is a mlir::Type available working with it is the preferred way. But Here I think it would make the code more complex given this is needed before producing any mlir values, and working with mlir::Type would not be 100% straightforward given logicals are mlir::IntegerType. So what Val has here looks clean to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that was the motivation for doing it this way.

auto charValue = [&](const Fortran::lower::SomeExpr *expr) {
fir::ExtendedValue exv = genExprAddr(*expr, stmtCtx, &loc);
return exv.match(
[&](const fir::CharBoxValue &cbv) {
return Fortran::lower::CharacterExprHelper{*builder, loc}
.createEmboxChar(cbv.getAddr(), cbv.getLen());
},
[&](auto) {
fir::emitFatalError(loc, "not a character");
return mlir::Value{};
});
};
mlir::Value selector;
if (isCharSelector) {
selector = charValue(expr);
} else {
selectExpr = createFIRExpr(loc, expr, stmtCtx);
if (isLogicalCategory(exprType->category()))
selectExpr =
builder->createConvert(loc, builder->getI1Type(), selectExpr);
selector = createFIRExpr(loc, expr, stmtCtx);
if (isLogicalSelector)
selector = builder->createConvert(loc, builder->getI1Type(), selector);
}
auto selectType = selectExpr.getType();
llvm::SmallVector<mlir::Attribute, 10> attrList;
llvm::SmallVector<mlir::Value, 10> valueList;
llvm::SmallVector<mlir::Block *, 10> blockList;
auto selectType = selector.getType();
llvm::SmallVector<mlir::Attribute> attrList;
llvm::SmallVector<mlir::Value> valueList;
llvm::SmallVector<mlir::Block *> blockList;
auto *defaultBlock = eval.parentConstruct->constructExit->block;
using CaseValue = Fortran::parser::Scalar<Fortran::parser::ConstantExpr>;
auto addValue = [&](const CaseValue &caseValue) {
const auto *expr = Fortran::semantics::GetExpr(caseValue.thing);
const auto v = Fortran::evaluate::ToInt64(*expr);
valueList.push_back(
v ? builder->createIntegerConstant(loc, selectType, *v)
: builder->createConvert(
loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
if (isCharSelector)
valueList.push_back(charValue(expr));
else if (isLogicalSelector)
valueList.push_back(builder->createConvert(
loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
else
valueList.push_back(builder->createIntegerConstant(
loc, selectType, *Fortran::evaluate::ToInt64(*expr)));
};
for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e;
e = e->controlSuccessor) {
Expand Down Expand Up @@ -1197,13 +1220,65 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
// Skip a logical default block that can never be referenced.
if (selectType == builder->getI1Type() && attrList.size() == 2)
if (isLogicalSelector && attrList.size() == 2)
defaultBlock = eval.parentConstruct->constructExit->block;
attrList.push_back(mlir::UnitAttr::get(context));
blockList.push_back(defaultBlock);
stmtCtx.finalize();
builder->create<fir::SelectCaseOp>(toLocation(), selectExpr, attrList,
valueList, blockList);

// Generate a fir::SelectCaseOp.
// Explicit branch code is better for the LOGICAL type. The CHARACTER type
// does not yet have downstream support, and also uses explicit branch code.
// The -no-structured-fir option can be used to force generation of INTEGER
// type branch code.
if (!isLogicalSelector && !isCharSelector && eval.lowerAsStructured()) {
builder->create<fir::SelectCaseOp>(loc, selector, attrList, valueList,
blockList);
return;
}

// Generate a sequence of case value comparisons and branches.
auto caseValue = valueList.begin();
auto caseBlock = blockList.begin();
for (auto attr : attrList) {
if (attr.isa<mlir::UnitAttr>()) {
genFIRBranch(*caseBlock++);
break;
}
auto genCond = [&](mlir::Value rhs,
mlir::CmpIPredicate pred) -> mlir::Value {
if (!isCharSelector)
return builder->create<mlir::CmpIOp>(loc, pred, selector, rhs);
Fortran::lower::CharacterExprHelper charHelper{*builder, loc};
auto [lhsAddr, lhsLen] = charHelper.createUnboxChar(selector);
auto [rhsAddr, rhsLen] = charHelper.createUnboxChar(rhs);
return Fortran::lower::genRawCharCompare(*builder, loc, pred, lhsAddr,
lhsLen, rhsAddr, rhsLen);
};
auto *newBlock = insertBlock(*caseBlock);
if (attr.isa<fir::ClosedIntervalAttr>()) {
auto *newBlock2 = insertBlock(*caseBlock);
auto cond = genCond(*caseValue++, mlir::CmpIPredicate::sge);
genFIRConditionalBranch(cond, newBlock, newBlock2);
builder->setInsertionPointToEnd(newBlock);
auto cond2 = genCond(*caseValue++, mlir::CmpIPredicate::sle);
genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
builder->setInsertionPointToEnd(newBlock2);
continue;
}
mlir::CmpIPredicate pred;
if (attr.isa<fir::PointIntervalAttr>())
pred = mlir::CmpIPredicate::eq;
else if (attr.isa<fir::LowerBoundAttr>())
pred = mlir::CmpIPredicate::sge;
else if (attr.isa<fir::UpperBoundAttr>())
pred = mlir::CmpIPredicate::sle;
auto cond = genCond(*caseValue++, pred);
genFIRConditionalBranch(cond, *caseBlock++, newBlock);
builder->setInsertionPointToEnd(newBlock);
}
assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
"select case list mismatch");
}

void genFIR(const Fortran::parser::AssociateConstruct &) {
Expand Down
109 changes: 81 additions & 28 deletions flang/test/Lower/select-case-statement.f90
Original file line number Diff line number Diff line change
Expand Up @@ -45,72 +45,64 @@ subroutine slogical(L)
n7 = 0
n8 = 0

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: unit
select case (L)
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: point, %false
! CHECK-SAME: unit
select case (L)
! CHECK: cmpi eq, {{.*}} %false
! CHECK: cond_br
case (.false.)
n2 = 1
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: point, %true
! CHECK-SAME: unit
select case (L)
! CHECK: cmpi eq, {{.*}} %true
! CHECK: cond_br
case (.true.)
n3 = 2
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: unit
select case (L)
case default
n4 = 3
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: point, %false
! CHECK-SAME: point, %true
! CHECK-SAME: unit
select case (L)
! CHECK: cmpi eq, {{.*}} %false
! CHECK: cond_br
case (.false.)
n5 = 1
! CHECK: cmpi eq, {{.*}} %true
! CHECK: cond_br
case (.true.)
n5 = 2
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: point, %false
! CHECK-SAME: unit
select case (L)
! CHECK: cmpi eq, {{.*}} %false
! CHECK: cond_br
case (.false.)
n6 = 1
case default
n6 = 3
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: point, %true
! CHECK-SAME: unit
select case (L)
! CHECK: cmpi eq, {{.*}} %true
! CHECK: cond_br
case (.true.)
n7 = 2
case default
n7 = 3
end select

! CHECK: fir.select_case {{.*}} : i1
! CHECK-SAME: point, %false
! CHECK-SAME: point, %true
! CHECK-SAME: unit
select case (L)
! CHECK: cmpi eq, {{.*}} %false
! CHECK: cond_br
case (.false.)
n8 = 1
! CHECK: cmpi eq, {{.*}} %true
! CHECK: cond_br
case (.true.)
n8 = 2
! CHECK-NOT: 888
Expand All @@ -121,6 +113,52 @@ subroutine slogical(L)
print*, n1, n2, n3, n4, n5, n6, n7, n8
end

! CHECK-LABEL: scharacter
subroutine scharacter(c)
character(*) :: c
nn = 0
select case (c)
case default
nn = -1
! CHECK: CharacterCompareScalar1
! CHECK-NEXT: constant 0
! CHECK-NEXT: cmpi sle, {{.*}} %c0
! CHECK-NEXT: cond_br
case (:'d')
nn = 10
! CHECK: CharacterCompareScalar1
! CHECK-NEXT: constant 0
! CHECK-NEXT: cmpi sge, {{.*}} %c0
! CHECK-NEXT: cond_br
! CHECK: CharacterCompareScalar1
! CHECK-NEXT: constant 0
! CHECK-NEXT: cmpi sle, {{.*}} %c0
! CHECK-NEXT: cond_br
case ('ff':'ffff')
nn = 20
! CHECK: CharacterCompareScalar1
! CHECK-NEXT: constant 0
! CHECK-NEXT: cmpi eq, {{.*}} %c0
! CHECK-NEXT: cond_br
case ('m')
nn = 30
! CHECK: CharacterCompareScalar1
! CHECK-NEXT: constant 0
! CHECK-NEXT: cmpi eq, {{.*}} %c0
! CHECK-NEXT: cond_br
case ('qq')
nn = 40
! CHECK: CharacterCompareScalar1
! CHECK-NEXT: constant 0
! CHECK-NEXT: cmpi sge, {{.*}} %c0
! CHECK-NEXT: cond_br
case ('x':)
nn = 50
end select
print*, nn
end

! CHECK-LABEL: main
program p
integer sinteger, v(10)

Expand All @@ -138,8 +176,23 @@ program p
enddo

print*
! expected output: 0 1 0 3 1 1 3 1
call slogical(.false.)
! expected output: 0 0 2 3 2 3 2 2
call slogical(.true.)
call slogical(.false.) ! expected output: 0 1 0 3 1 1 3 1
call slogical(.true.) ! expected output: 0 0 2 3 2 3 2 2

print*
call scharacter('aa') ! expected output: 10
call scharacter('d') ! expected output: 10
call scharacter('f') ! expected output: -1
call scharacter('ff') ! expected output: 20
call scharacter('fff') ! expected output: 20
call scharacter('ffff') ! expected output: 20
call scharacter('fffff') ! expected output: -1
call scharacter('jj') ! expected output: -1
call scharacter('m') ! expected output: 30
call scharacter('q') ! expected output: -1
call scharacter('qq') ! expected output: 40
call scharacter('qqq') ! expected output: -1
call scharacter('vv') ! expected output: -1
call scharacter('xx') ! expected output: 50
call scharacter('zz') ! expected output: 50
end