Skip to content

Revert "[Sema] Eliminate duplication in CaseStmt typechecking for switch and do-catch" #30923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 135 additions & 61 deletions lib/Sema/TypeCheckStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,28 +1072,40 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
}
}

template <typename Iterator>
void checkSiblingCaseStmts(Iterator casesBegin, Iterator casesEnd,
CaseParentKind parentKind,
bool &limitExhaustivityChecks, Type subjectType) {
static_assert(
std::is_same<typename std::iterator_traits<Iterator>::value_type,
CaseStmt *>::value,
"Expected an iterator over CaseStmt *");
Stmt *visitSwitchStmt(SwitchStmt *switchStmt) {
// Type-check the subject expression.
Expr *subjectExpr = switchStmt->getSubjectExpr();
auto resultTy = TypeChecker::typeCheckExpression(subjectExpr, DC);
auto limitExhaustivityChecks = !resultTy;
if (Expr *newSubjectExpr =
TypeChecker::coerceToRValue(getASTContext(), subjectExpr))
subjectExpr = newSubjectExpr;
switchStmt->setSubjectExpr(subjectExpr);
Type subjectType = switchStmt->getSubjectExpr()->getType();

// Type-check the case blocks.
AddSwitchNest switchNest(*this);
AddLabeledStmt labelNest(*this, switchStmt);

// Pre-emptively visit all Decls (#if/#warning/#error) that still exist in
// the list of raw cases.
for (auto &node : switchStmt->getRawCases()) {
if (!node.is<Decl *>())
continue;
TypeChecker::typeCheckDecl(node.get<Decl *>());
}

SmallVector<VarDecl *, 8> scratchMemory1;
SmallVector<VarDecl *, 8> scratchMemory2;
CaseStmt *previousBlock = nullptr;

for (auto i = casesBegin; i != casesEnd; ++i) {
auto cases = switchStmt->getCases();
CaseStmt *previousBlock = nullptr;
for (auto i = cases.begin(), e = cases.end(); i != e; ++i) {
auto *caseBlock = *i;

if (parentKind == CaseParentKind::Switch) {
// Fallthrough transfers control to the next case block. In the
// final case block, it is invalid. Only switch supports fallthrough.
FallthroughSource = caseBlock;
FallthroughDest = std::next(i) == casesEnd ? nullptr : *std::next(i);
}
// Fallthrough transfers control to the next case block. In the
// final case block, it is invalid.
FallthroughSource = caseBlock;
FallthroughDest = std::next(i) == e ? nullptr : *std::next(i);

scratchMemory1.clear();
scratchMemory2.clear();
Expand Down Expand Up @@ -1181,57 +1193,24 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {

// Check restrictions on '@unknown'.
if (caseBlock->hasUnknownAttr()) {
assert(parentKind == CaseParentKind::Switch &&
"'@unknown' can only appear on switch cases");
checkUnknownAttrRestrictions(
getASTContext(), caseBlock, FallthroughDest,
limitExhaustivityChecks);
}

if (parentKind == CaseParentKind::Switch) {
// If the previous case fellthrough, similarly check that that case's
// bindings includes our first label item's pattern bindings and types.
// Only switch statements support fallthrough.
if (PreviousFallthrough && previousBlock) {
checkFallthroughPatternBindingsAndTypes(caseBlock, previousBlock);
}
PreviousFallthrough = nullptr;
// If the previous case fellthrough, similarly check that that case's
// bindings includes our first label item's pattern bindings and types.
if (PreviousFallthrough && previousBlock) {
checkFallthroughPatternBindingsAndTypes(caseBlock, previousBlock);
}

// Type-check the body statements.
PreviousFallthrough = nullptr;
Stmt *body = caseBlock->getBody();
limitExhaustivityChecks |= typeCheckStmt(body);
caseBlock->setBody(body);
previousBlock = caseBlock;
}
}

Stmt *visitSwitchStmt(SwitchStmt *switchStmt) {
// Type-check the subject expression.
Expr *subjectExpr = switchStmt->getSubjectExpr();
auto resultTy = TypeChecker::typeCheckExpression(subjectExpr, DC);
auto limitExhaustivityChecks = !resultTy;
if (Expr *newSubjectExpr =
TypeChecker::coerceToRValue(getASTContext(), subjectExpr))
subjectExpr = newSubjectExpr;
switchStmt->setSubjectExpr(subjectExpr);
Type subjectType = switchStmt->getSubjectExpr()->getType();

// Type-check the case blocks.
AddSwitchNest switchNest(*this);
AddLabeledStmt labelNest(*this, switchStmt);

// Pre-emptively visit all Decls (#if/#warning/#error) that still exist in
// the list of raw cases.
for (auto &node : switchStmt->getRawCases()) {
if (!node.is<Decl *>())
continue;
TypeChecker::typeCheckDecl(node.get<Decl *>());
}

auto cases = switchStmt->getCases();
checkSiblingCaseStmts(cases.begin(), cases.end(), CaseParentKind::Switch,
limitExhaustivityChecks, subjectType);

if (!switchStmt->isImplicit()) {
TypeChecker::checkSwitchExhaustiveness(switchStmt, DC,
Expand All @@ -1258,13 +1237,108 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
typeCheckStmt(newBody);
S->setBody(newBody);

// Do-catch statements always limit exhaustivity checks.
bool limitExhaustivityChecks = true;
SmallVector<VarDecl *, 8> scratchMemory1;
SmallVector<VarDecl *, 8> scratchMemory2;

auto clauses = S->getCatches();
CaseStmt *previousBlock = nullptr;
for (auto i = clauses.begin(), e = clauses.end(); i != e; ++i) {
auto *caseBlock = *i;

scratchMemory1.clear();
scratchMemory2.clear();

SmallVectorImpl<VarDecl *> *prevCaseDecls = nullptr;
SmallVectorImpl<VarDecl *> *nextCaseDecls = &scratchMemory1;

auto caseLabelItemArray = caseBlock->getMutableCaseLabelItems();
{
// Peel off the first iteration so we handle the first case label
// especially since we use it to begin the validation chain.
auto &labelItem = caseLabelItemArray.front();

// Resolve the pattern in our case label if it has not been resolved and
// check that our var decls follow invariants.
bool limit = true;
checkCaseLabelItemPattern(caseBlock, labelItem, limit,
getASTContext().getExceptionType(),
&prevCaseDecls, &nextCaseDecls);

// After this is complete, prevCaseDecls will be pointing at
// scratchMemory1 which contains the initial case block's var decls and
// nextCaseDecls will be a nullptr. Set nextCaseDecls to point at
// scratchMemory2 for the next iterations.
assert(prevCaseDecls == &scratchMemory1);
assert(nextCaseDecls == nullptr);
nextCaseDecls = &scratchMemory2;

// Check the guard expression, if present.
if (auto *guard = labelItem.getGuardExpr()) {
TypeChecker::typeCheckCondition(guard, DC);
labelItem.setGuardExpr(guard);
}
}

// Setup the types of our case body var decls.
for (auto *expected : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
assert(expected->hasName());
for (auto *prev : *prevCaseDecls) {
if (!prev->hasName() || expected->getName() != prev->getName()) {
continue;
}
if (prev->hasInterfaceType())
expected->setInterfaceType(prev->getInterfaceType());
break;
}
}

auto catches = S->getCatches();
checkSiblingCaseStmts(catches.begin(), catches.end(),
CaseParentKind::DoCatch, limitExhaustivityChecks,
getASTContext().getExceptionType());
// Then check the rest.
for (auto &labelItem : caseLabelItemArray.drop_front()) {
// Resolve the pattern in our case label if it has not been resolved
// and check that our var decls follow invariants.
bool limit = true;
checkCaseLabelItemPattern(caseBlock, labelItem, limit,
getASTContext().getExceptionType(),
&prevCaseDecls, &nextCaseDecls);
// Check the guard expression, if present.
if (auto *guard = labelItem.getGuardExpr()) {
TypeChecker::typeCheckCondition(guard, DC);
labelItem.setGuardExpr(guard);
}
}

// Our last CaseLabelItem's VarDecls are now in
// prevCaseDecls. Wire them up as parents of our case body var
// decls.
//
// NOTE: We know that the two lists of var decls must be in sync. Remember
// that we constructed our case body VarDecls from the first
// CaseLabelItems var decls. Just now we proved that all other
// CaseLabelItems have matching var decls of the first meaning
// transitively that our last case label item must have matching var decls
// for our case stmts CaseBodyVarDecls.
//
// NOTE: We do not check that we matched everything here. That is because
// the check has already been done by comparing the 1st CaseLabelItem var
// decls. If we insert a check here, we will emit the same error multiple
// times.
for (auto *expected : caseBlock->getCaseBodyVariablesOrEmptyArray()) {
assert(expected->hasName());
for (auto *prev : *prevCaseDecls) {
if (!prev->hasName() || expected->getName() != prev->getName()) {
continue;
}
expected->setParentVarDecl(prev);
break;
}
}

// Type-check the body statements.
Stmt *body = caseBlock->getBody();
typeCheckStmt(body);
caseBlock->setBody(body);
previousBlock = caseBlock;
}

return S;
}
Expand Down