Skip to content

Commit 79679e5

Browse files
vdonaldsonjeanPerier
authored andcommitted
Select Case constructs with character selector expressions (#685)
Add the capability to lower select case constructs of any type to either FIR SelectCaseOp's or a sequence of comparisons and branches. Actually generate SelectCaseOp's for integer type constructs, but use comparisons and branches for logical types (the code is better) and for character types (it isn't possible yet to handle character SelectCaseOp's downstream).
1 parent 5a76017 commit 79679e5

File tree

2 files changed

+180
-52
lines changed

2 files changed

+180
-52
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "flang/Lower/Allocatable.h"
1818
#include "flang/Lower/CallInterface.h"
1919
#include "flang/Lower/CharacterExpr.h"
20+
#include "flang/Lower/CharacterRuntime.h"
2021
#include "flang/Lower/Coarray.h"
2122
#include "flang/Lower/ConvertExpr.h"
2223
#include "flang/Lower/ConvertType.h"
@@ -395,6 +396,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
395396
return cat == Fortran::common::TypeCategory::Derived;
396397
}
397398

399+
/// Insert a new block before \p block. Leave the insertion point unchanged.
400+
mlir::Block *insertBlock(mlir::Block *block) {
401+
auto insertPt = builder->saveInsertionPoint();
402+
auto newBlock = builder->createBlock(block);
403+
builder->restoreInsertionPoint(insertPt);
404+
return newBlock;
405+
}
406+
398407
mlir::Block *blockOfLabel(Fortran::lower::pft::Evaluation &eval,
399408
Fortran::parser::Label label) {
400409
const auto &labelEvaluationMap =
@@ -791,10 +800,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
791800
// exit block of the immediately enclosed dimension.
792801
auto createNextExitBlock = [&]() {
793802
// Create unstructured loop exit blocks, outermost to innermost.
794-
auto insertPt = builder->saveInsertionPoint();
795-
exitBlock = builder->createBlock(exitBlock);
796-
builder->restoreInsertionPoint(insertPt);
797-
return exitBlock;
803+
return exitBlock = insertBlock(exitBlock);
798804
};
799805
auto isInnermost = &info == &incrementLoopNestInfo.back();
800806
auto isOutermost = &info == &incrementLoopNestInfo.front();
@@ -1130,36 +1136,53 @@ class FirConverter : public Fortran::lower::AbstractConverter {
11301136
builder->restoreInsertionPoint(insertPt);
11311137
}
11321138

1139+
/// Generate FIR for a SELECT CASE statement.
1140+
/// The type may be CHARACTER, INTEGER, or LOGICAL.
11331141
void genFIR(const Fortran::parser::SelectCaseStmt &stmt) {
11341142
auto &eval = getEval();
11351143
auto *context = builder->getContext();
11361144
auto loc = toLocation();
11371145
Fortran::lower::StatementContext stmtCtx;
11381146
const auto *expr = Fortran::semantics::GetExpr(
11391147
std::get<Fortran::parser::Scalar<Fortran::parser::Expr>>(stmt.t));
1140-
auto exprType = expr->GetType();
1141-
mlir::Value selectExpr;
1142-
if (isCharacterCategory(exprType->category())) {
1143-
TODO(loc, "Select Case selector of type Character");
1148+
bool isCharSelector = isCharacterCategory(expr->GetType()->category());
1149+
bool isLogicalSelector = isLogicalCategory(expr->GetType()->category());
1150+
auto charValue = [&](const Fortran::lower::SomeExpr *expr) {
1151+
fir::ExtendedValue exv = genExprAddr(*expr, stmtCtx, &loc);
1152+
return exv.match(
1153+
[&](const fir::CharBoxValue &cbv) {
1154+
return Fortran::lower::CharacterExprHelper{*builder, loc}
1155+
.createEmboxChar(cbv.getAddr(), cbv.getLen());
1156+
},
1157+
[&](auto) {
1158+
fir::emitFatalError(loc, "not a character");
1159+
return mlir::Value{};
1160+
});
1161+
};
1162+
mlir::Value selector;
1163+
if (isCharSelector) {
1164+
selector = charValue(expr);
11441165
} else {
1145-
selectExpr = createFIRExpr(loc, expr, stmtCtx);
1146-
if (isLogicalCategory(exprType->category()))
1147-
selectExpr =
1148-
builder->createConvert(loc, builder->getI1Type(), selectExpr);
1166+
selector = createFIRExpr(loc, expr, stmtCtx);
1167+
if (isLogicalSelector)
1168+
selector = builder->createConvert(loc, builder->getI1Type(), selector);
11491169
}
1150-
auto selectType = selectExpr.getType();
1151-
llvm::SmallVector<mlir::Attribute, 10> attrList;
1152-
llvm::SmallVector<mlir::Value, 10> valueList;
1153-
llvm::SmallVector<mlir::Block *, 10> blockList;
1170+
auto selectType = selector.getType();
1171+
llvm::SmallVector<mlir::Attribute> attrList;
1172+
llvm::SmallVector<mlir::Value> valueList;
1173+
llvm::SmallVector<mlir::Block *> blockList;
11541174
auto *defaultBlock = eval.parentConstruct->constructExit->block;
11551175
using CaseValue = Fortran::parser::Scalar<Fortran::parser::ConstantExpr>;
11561176
auto addValue = [&](const CaseValue &caseValue) {
11571177
const auto *expr = Fortran::semantics::GetExpr(caseValue.thing);
1158-
const auto v = Fortran::evaluate::ToInt64(*expr);
1159-
valueList.push_back(
1160-
v ? builder->createIntegerConstant(loc, selectType, *v)
1161-
: builder->createConvert(
1162-
loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
1178+
if (isCharSelector)
1179+
valueList.push_back(charValue(expr));
1180+
else if (isLogicalSelector)
1181+
valueList.push_back(builder->createConvert(
1182+
loc, selectType, createFIRExpr(toLocation(), expr, stmtCtx)));
1183+
else
1184+
valueList.push_back(builder->createIntegerConstant(
1185+
loc, selectType, *Fortran::evaluate::ToInt64(*expr)));
11631186
};
11641187
for (Fortran::lower::pft::Evaluation *e = eval.controlSuccessor; e;
11651188
e = e->controlSuccessor) {
@@ -1197,13 +1220,65 @@ class FirConverter : public Fortran::lower::AbstractConverter {
11971220
}
11981221
}
11991222
// Skip a logical default block that can never be referenced.
1200-
if (selectType == builder->getI1Type() && attrList.size() == 2)
1223+
if (isLogicalSelector && attrList.size() == 2)
12011224
defaultBlock = eval.parentConstruct->constructExit->block;
12021225
attrList.push_back(mlir::UnitAttr::get(context));
12031226
blockList.push_back(defaultBlock);
12041227
stmtCtx.finalize();
1205-
builder->create<fir::SelectCaseOp>(toLocation(), selectExpr, attrList,
1206-
valueList, blockList);
1228+
1229+
// Generate a fir::SelectCaseOp.
1230+
// Explicit branch code is better for the LOGICAL type. The CHARACTER type
1231+
// does not yet have downstream support, and also uses explicit branch code.
1232+
// The -no-structured-fir option can be used to force generation of INTEGER
1233+
// type branch code.
1234+
if (!isLogicalSelector && !isCharSelector && eval.lowerAsStructured()) {
1235+
builder->create<fir::SelectCaseOp>(loc, selector, attrList, valueList,
1236+
blockList);
1237+
return;
1238+
}
1239+
1240+
// Generate a sequence of case value comparisons and branches.
1241+
auto caseValue = valueList.begin();
1242+
auto caseBlock = blockList.begin();
1243+
for (auto attr : attrList) {
1244+
if (attr.isa<mlir::UnitAttr>()) {
1245+
genFIRBranch(*caseBlock++);
1246+
break;
1247+
}
1248+
auto genCond = [&](mlir::Value rhs,
1249+
mlir::CmpIPredicate pred) -> mlir::Value {
1250+
if (!isCharSelector)
1251+
return builder->create<mlir::CmpIOp>(loc, pred, selector, rhs);
1252+
Fortran::lower::CharacterExprHelper charHelper{*builder, loc};
1253+
auto [lhsAddr, lhsLen] = charHelper.createUnboxChar(selector);
1254+
auto [rhsAddr, rhsLen] = charHelper.createUnboxChar(rhs);
1255+
return Fortran::lower::genRawCharCompare(*builder, loc, pred, lhsAddr,
1256+
lhsLen, rhsAddr, rhsLen);
1257+
};
1258+
auto *newBlock = insertBlock(*caseBlock);
1259+
if (attr.isa<fir::ClosedIntervalAttr>()) {
1260+
auto *newBlock2 = insertBlock(*caseBlock);
1261+
auto cond = genCond(*caseValue++, mlir::CmpIPredicate::sge);
1262+
genFIRConditionalBranch(cond, newBlock, newBlock2);
1263+
builder->setInsertionPointToEnd(newBlock);
1264+
auto cond2 = genCond(*caseValue++, mlir::CmpIPredicate::sle);
1265+
genFIRConditionalBranch(cond2, *caseBlock++, newBlock2);
1266+
builder->setInsertionPointToEnd(newBlock2);
1267+
continue;
1268+
}
1269+
mlir::CmpIPredicate pred;
1270+
if (attr.isa<fir::PointIntervalAttr>())
1271+
pred = mlir::CmpIPredicate::eq;
1272+
else if (attr.isa<fir::LowerBoundAttr>())
1273+
pred = mlir::CmpIPredicate::sge;
1274+
else if (attr.isa<fir::UpperBoundAttr>())
1275+
pred = mlir::CmpIPredicate::sle;
1276+
auto cond = genCond(*caseValue++, pred);
1277+
genFIRConditionalBranch(cond, *caseBlock++, newBlock);
1278+
builder->setInsertionPointToEnd(newBlock);
1279+
}
1280+
assert(caseValue == valueList.end() && caseBlock == blockList.end() &&
1281+
"select case list mismatch");
12071282
}
12081283

12091284
void genFIR(const Fortran::parser::AssociateConstruct &) {

flang/test/Lower/select-case-statement.f90

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,72 +45,64 @@ subroutine slogical(L)
4545
n7 = 0
4646
n8 = 0
4747

48-
! CHECK: fir.select_case {{.*}} : i1
49-
! CHECK-SAME: unit
5048
select case (L)
5149
end select
5250

53-
! CHECK: fir.select_case {{.*}} : i1
54-
! CHECK-SAME: point, %false
55-
! CHECK-SAME: unit
5651
select case (L)
52+
! CHECK: cmpi eq, {{.*}} %false
53+
! CHECK: cond_br
5754
case (.false.)
5855
n2 = 1
5956
end select
6057

61-
! CHECK: fir.select_case {{.*}} : i1
62-
! CHECK-SAME: point, %true
63-
! CHECK-SAME: unit
6458
select case (L)
59+
! CHECK: cmpi eq, {{.*}} %true
60+
! CHECK: cond_br
6561
case (.true.)
6662
n3 = 2
6763
end select
6864

69-
! CHECK: fir.select_case {{.*}} : i1
70-
! CHECK-SAME: unit
7165
select case (L)
7266
case default
7367
n4 = 3
7468
end select
7569

76-
! CHECK: fir.select_case {{.*}} : i1
77-
! CHECK-SAME: point, %false
78-
! CHECK-SAME: point, %true
79-
! CHECK-SAME: unit
8070
select case (L)
71+
! CHECK: cmpi eq, {{.*}} %false
72+
! CHECK: cond_br
8173
case (.false.)
8274
n5 = 1
75+
! CHECK: cmpi eq, {{.*}} %true
76+
! CHECK: cond_br
8377
case (.true.)
8478
n5 = 2
8579
end select
8680

87-
! CHECK: fir.select_case {{.*}} : i1
88-
! CHECK-SAME: point, %false
89-
! CHECK-SAME: unit
9081
select case (L)
82+
! CHECK: cmpi eq, {{.*}} %false
83+
! CHECK: cond_br
9184
case (.false.)
9285
n6 = 1
9386
case default
9487
n6 = 3
9588
end select
9689

97-
! CHECK: fir.select_case {{.*}} : i1
98-
! CHECK-SAME: point, %true
99-
! CHECK-SAME: unit
10090
select case (L)
91+
! CHECK: cmpi eq, {{.*}} %true
92+
! CHECK: cond_br
10193
case (.true.)
10294
n7 = 2
10395
case default
10496
n7 = 3
10597
end select
10698

107-
! CHECK: fir.select_case {{.*}} : i1
108-
! CHECK-SAME: point, %false
109-
! CHECK-SAME: point, %true
110-
! CHECK-SAME: unit
11199
select case (L)
100+
! CHECK: cmpi eq, {{.*}} %false
101+
! CHECK: cond_br
112102
case (.false.)
113103
n8 = 1
104+
! CHECK: cmpi eq, {{.*}} %true
105+
! CHECK: cond_br
114106
case (.true.)
115107
n8 = 2
116108
! CHECK-NOT: 888
@@ -121,6 +113,52 @@ subroutine slogical(L)
121113
print*, n1, n2, n3, n4, n5, n6, n7, n8
122114
end
123115

116+
! CHECK-LABEL: scharacter
117+
subroutine scharacter(c)
118+
character(*) :: c
119+
nn = 0
120+
select case (c)
121+
case default
122+
nn = -1
123+
! CHECK: CharacterCompareScalar1
124+
! CHECK-NEXT: constant 0
125+
! CHECK-NEXT: cmpi sle, {{.*}} %c0
126+
! CHECK-NEXT: cond_br
127+
case (:'d')
128+
nn = 10
129+
! CHECK: CharacterCompareScalar1
130+
! CHECK-NEXT: constant 0
131+
! CHECK-NEXT: cmpi sge, {{.*}} %c0
132+
! CHECK-NEXT: cond_br
133+
! CHECK: CharacterCompareScalar1
134+
! CHECK-NEXT: constant 0
135+
! CHECK-NEXT: cmpi sle, {{.*}} %c0
136+
! CHECK-NEXT: cond_br
137+
case ('ff':'ffff')
138+
nn = 20
139+
! CHECK: CharacterCompareScalar1
140+
! CHECK-NEXT: constant 0
141+
! CHECK-NEXT: cmpi eq, {{.*}} %c0
142+
! CHECK-NEXT: cond_br
143+
case ('m')
144+
nn = 30
145+
! CHECK: CharacterCompareScalar1
146+
! CHECK-NEXT: constant 0
147+
! CHECK-NEXT: cmpi eq, {{.*}} %c0
148+
! CHECK-NEXT: cond_br
149+
case ('qq')
150+
nn = 40
151+
! CHECK: CharacterCompareScalar1
152+
! CHECK-NEXT: constant 0
153+
! CHECK-NEXT: cmpi sge, {{.*}} %c0
154+
! CHECK-NEXT: cond_br
155+
case ('x':)
156+
nn = 50
157+
end select
158+
print*, nn
159+
end
160+
161+
! CHECK-LABEL: main
124162
program p
125163
integer sinteger, v(10)
126164

@@ -138,8 +176,23 @@ program p
138176
enddo
139177

140178
print*
141-
! expected output: 0 1 0 3 1 1 3 1
142-
call slogical(.false.)
143-
! expected output: 0 0 2 3 2 3 2 2
144-
call slogical(.true.)
179+
call slogical(.false.) ! expected output: 0 1 0 3 1 1 3 1
180+
call slogical(.true.) ! expected output: 0 0 2 3 2 3 2 2
181+
182+
print*
183+
call scharacter('aa') ! expected output: 10
184+
call scharacter('d') ! expected output: 10
185+
call scharacter('f') ! expected output: -1
186+
call scharacter('ff') ! expected output: 20
187+
call scharacter('fff') ! expected output: 20
188+
call scharacter('ffff') ! expected output: 20
189+
call scharacter('fffff') ! expected output: -1
190+
call scharacter('jj') ! expected output: -1
191+
call scharacter('m') ! expected output: 30
192+
call scharacter('q') ! expected output: -1
193+
call scharacter('qq') ! expected output: 40
194+
call scharacter('qqq') ! expected output: -1
195+
call scharacter('vv') ! expected output: -1
196+
call scharacter('xx') ! expected output: 50
197+
call scharacter('zz') ! expected output: 50
145198
end

0 commit comments

Comments
 (0)