Skip to content

Commit 5213bec

Browse files
committed
[Sema] Eliminate duplication in CaseStmt typechecking for switch and do-catch statements
1 parent 4264b39 commit 5213bec

File tree

1 file changed

+61
-135
lines changed

1 file changed

+61
-135
lines changed

lib/Sema/TypeCheckStmt.cpp

Lines changed: 61 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,40 +1072,28 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10721072
}
10731073
}
10741074

1075-
Stmt *visitSwitchStmt(SwitchStmt *switchStmt) {
1076-
// Type-check the subject expression.
1077-
Expr *subjectExpr = switchStmt->getSubjectExpr();
1078-
auto resultTy = TypeChecker::typeCheckExpression(subjectExpr, DC);
1079-
auto limitExhaustivityChecks = !resultTy;
1080-
if (Expr *newSubjectExpr =
1081-
TypeChecker::coerceToRValue(getASTContext(), subjectExpr))
1082-
subjectExpr = newSubjectExpr;
1083-
switchStmt->setSubjectExpr(subjectExpr);
1084-
Type subjectType = switchStmt->getSubjectExpr()->getType();
1085-
1086-
// Type-check the case blocks.
1087-
AddSwitchNest switchNest(*this);
1088-
AddLabeledStmt labelNest(*this, switchStmt);
1089-
1090-
// Pre-emptively visit all Decls (#if/#warning/#error) that still exist in
1091-
// the list of raw cases.
1092-
for (auto &node : switchStmt->getRawCases()) {
1093-
if (!node.is<Decl *>())
1094-
continue;
1095-
TypeChecker::typeCheckDecl(node.get<Decl *>());
1096-
}
1075+
template <typename Iterator>
1076+
void checkSiblingCaseStmts(Iterator casesBegin, Iterator casesEnd,
1077+
CaseParentKind parentKind,
1078+
bool &limitExhaustivityChecks, Type subjectType) {
1079+
static_assert(
1080+
std::is_same<typename std::iterator_traits<Iterator>::value_type,
1081+
CaseStmt *>::value,
1082+
"Expected an iterator over CaseStmt *");
10971083

10981084
SmallVector<VarDecl *, 8> scratchMemory1;
10991085
SmallVector<VarDecl *, 8> scratchMemory2;
1100-
1101-
auto cases = switchStmt->getCases();
11021086
CaseStmt *previousBlock = nullptr;
1103-
for (auto i = cases.begin(), e = cases.end(); i != e; ++i) {
1087+
1088+
for (auto i = casesBegin; i != casesEnd; ++i) {
11041089
auto *caseBlock = *i;
1105-
// Fallthrough transfers control to the next case block. In the
1106-
// final case block, it is invalid.
1107-
FallthroughSource = caseBlock;
1108-
FallthroughDest = std::next(i) == e ? nullptr : *std::next(i);
1090+
1091+
if (parentKind == CaseParentKind::Switch) {
1092+
// Fallthrough transfers control to the next case block. In the
1093+
// final case block, it is invalid. Only switch supports fallthrough.
1094+
FallthroughSource = caseBlock;
1095+
FallthroughDest = std::next(i) == casesEnd ? nullptr : *std::next(i);
1096+
}
11091097

11101098
scratchMemory1.clear();
11111099
scratchMemory2.clear();
@@ -1193,24 +1181,57 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
11931181

11941182
// Check restrictions on '@unknown'.
11951183
if (caseBlock->hasUnknownAttr()) {
1184+
assert(parentKind == CaseParentKind::Switch &&
1185+
"'@unknown' can only appear on switch cases");
11961186
checkUnknownAttrRestrictions(
11971187
getASTContext(), caseBlock, FallthroughDest,
11981188
limitExhaustivityChecks);
11991189
}
12001190

1201-
// If the previous case fellthrough, similarly check that that case's
1202-
// bindings includes our first label item's pattern bindings and types.
1203-
if (PreviousFallthrough && previousBlock) {
1204-
checkFallthroughPatternBindingsAndTypes(caseBlock, previousBlock);
1191+
if (parentKind == CaseParentKind::Switch) {
1192+
// If the previous case fellthrough, similarly check that that case's
1193+
// bindings includes our first label item's pattern bindings and types.
1194+
// Only switch statements support fallthrough.
1195+
if (PreviousFallthrough && previousBlock) {
1196+
checkFallthroughPatternBindingsAndTypes(caseBlock, previousBlock);
1197+
}
1198+
PreviousFallthrough = nullptr;
12051199
}
12061200

12071201
// Type-check the body statements.
1208-
PreviousFallthrough = nullptr;
12091202
Stmt *body = caseBlock->getBody();
12101203
limitExhaustivityChecks |= typeCheckStmt(body);
12111204
caseBlock->setBody(body);
12121205
previousBlock = caseBlock;
12131206
}
1207+
}
1208+
1209+
Stmt *visitSwitchStmt(SwitchStmt *switchStmt) {
1210+
// Type-check the subject expression.
1211+
Expr *subjectExpr = switchStmt->getSubjectExpr();
1212+
auto resultTy = TypeChecker::typeCheckExpression(subjectExpr, DC);
1213+
auto limitExhaustivityChecks = !resultTy;
1214+
if (Expr *newSubjectExpr =
1215+
TypeChecker::coerceToRValue(getASTContext(), subjectExpr))
1216+
subjectExpr = newSubjectExpr;
1217+
switchStmt->setSubjectExpr(subjectExpr);
1218+
Type subjectType = switchStmt->getSubjectExpr()->getType();
1219+
1220+
// Type-check the case blocks.
1221+
AddSwitchNest switchNest(*this);
1222+
AddLabeledStmt labelNest(*this, switchStmt);
1223+
1224+
// Pre-emptively visit all Decls (#if/#warning/#error) that still exist in
1225+
// the list of raw cases.
1226+
for (auto &node : switchStmt->getRawCases()) {
1227+
if (!node.is<Decl *>())
1228+
continue;
1229+
TypeChecker::typeCheckDecl(node.get<Decl *>());
1230+
}
1231+
1232+
auto cases = switchStmt->getCases();
1233+
checkSiblingCaseStmts(cases.begin(), cases.end(), CaseParentKind::Switch,
1234+
limitExhaustivityChecks, subjectType);
12141235

12151236
if (!switchStmt->isImplicit()) {
12161237
TypeChecker::checkSwitchExhaustiveness(switchStmt, DC,
@@ -1237,108 +1258,13 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
12371258
typeCheckStmt(newBody);
12381259
S->setBody(newBody);
12391260

1240-
SmallVector<VarDecl *, 8> scratchMemory1;
1241-
SmallVector<VarDecl *, 8> scratchMemory2;
1242-
1243-
auto clauses = S->getCatches();
1244-
CaseStmt *previousBlock = nullptr;
1245-
for (auto i = clauses.begin(), e = clauses.end(); i != e; ++i) {
1246-
auto *caseBlock = *i;
1247-
1248-
scratchMemory1.clear();
1249-
scratchMemory2.clear();
1250-
1251-
SmallVectorImpl<VarDecl *> *prevCaseDecls = nullptr;
1252-
SmallVectorImpl<VarDecl *> *nextCaseDecls = &scratchMemory1;
1253-
1254-
auto caseLabelItemArray = caseBlock->getMutableCaseLabelItems();
1255-
{
1256-
// Peel off the first iteration so we handle the first case label
1257-
// especially since we use it to begin the validation chain.
1258-
auto &labelItem = caseLabelItemArray.front();
1259-
1260-
// Resolve the pattern in our case label if it has not been resolved and
1261-
// check that our var decls follow invariants.
1262-
bool limit = true;
1263-
checkCaseLabelItemPattern(caseBlock, labelItem, limit,
1264-
getASTContext().getExceptionType(),
1265-
&prevCaseDecls, &nextCaseDecls);
1266-
1267-
// After this is complete, prevCaseDecls will be pointing at
1268-
// scratchMemory1 which contains the initial case block's var decls and
1269-
// nextCaseDecls will be a nullptr. Set nextCaseDecls to point at
1270-
// scratchMemory2 for the next iterations.
1271-
assert(prevCaseDecls == &scratchMemory1);
1272-
assert(nextCaseDecls == nullptr);
1273-
nextCaseDecls = &scratchMemory2;
1274-
1275-
// Check the guard expression, if present.
1276-
if (auto *guard = labelItem.getGuardExpr()) {
1277-
TypeChecker::typeCheckCondition(guard, DC);
1278-
labelItem.setGuardExpr(guard);
1279-
}
1280-
}
1281-
1282-
// Setup the types of our case body var decls.
1283-
for (auto *expected : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
1284-
assert(expected->hasName());
1285-
for (auto *prev : *prevCaseDecls) {
1286-
if (!prev->hasName() || expected->getName() != prev->getName()) {
1287-
continue;
1288-
}
1289-
if (prev->hasInterfaceType())
1290-
expected->setInterfaceType(prev->getInterfaceType());
1291-
break;
1292-
}
1293-
}
1261+
// Do-catch statements always limit exhaustivity checks.
1262+
bool limitExhaustivityChecks = true;
12941263

1295-
// Then check the rest.
1296-
for (auto &labelItem : caseLabelItemArray.drop_front()) {
1297-
// Resolve the pattern in our case label if it has not been resolved
1298-
// and check that our var decls follow invariants.
1299-
bool limit = true;
1300-
checkCaseLabelItemPattern(caseBlock, labelItem, limit,
1301-
getASTContext().getExceptionType(),
1302-
&prevCaseDecls, &nextCaseDecls);
1303-
// Check the guard expression, if present.
1304-
if (auto *guard = labelItem.getGuardExpr()) {
1305-
TypeChecker::typeCheckCondition(guard, DC);
1306-
labelItem.setGuardExpr(guard);
1307-
}
1308-
}
1309-
1310-
// Our last CaseLabelItem's VarDecls are now in
1311-
// prevCaseDecls. Wire them up as parents of our case body var
1312-
// decls.
1313-
//
1314-
// NOTE: We know that the two lists of var decls must be in sync. Remember
1315-
// that we constructed our case body VarDecls from the first
1316-
// CaseLabelItems var decls. Just now we proved that all other
1317-
// CaseLabelItems have matching var decls of the first meaning
1318-
// transitively that our last case label item must have matching var decls
1319-
// for our case stmts CaseBodyVarDecls.
1320-
//
1321-
// NOTE: We do not check that we matched everything here. That is because
1322-
// the check has already been done by comparing the 1st CaseLabelItem var
1323-
// decls. If we insert a check here, we will emit the same error multiple
1324-
// times.
1325-
for (auto *expected : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
1326-
assert(expected->hasName());
1327-
for (auto *prev : *prevCaseDecls) {
1328-
if (!prev->hasName() || expected->getName() != prev->getName()) {
1329-
continue;
1330-
}
1331-
expected->setParentVarDecl(prev);
1332-
break;
1333-
}
1334-
}
1335-
1336-
// Type-check the body statements.
1337-
Stmt *body = caseBlock->getBody();
1338-
typeCheckStmt(body);
1339-
caseBlock->setBody(body);
1340-
previousBlock = caseBlock;
1341-
}
1264+
auto catches = S->getCatches();
1265+
checkSiblingCaseStmts(catches.begin(), catches.end(),
1266+
CaseParentKind::DoCatch, limitExhaustivityChecks,
1267+
getASTContext().getExceptionType());
13421268

13431269
return S;
13441270
}

0 commit comments

Comments
 (0)