Skip to content

Commit 4d7d2ef

Browse files
authored
Merge pull request #23334 from gottesmm/pr-cafc42f13f3d474469835f12b2fdc97e9c947c76
2 parents e038f7e + 8d3d952 commit 4d7d2ef

File tree

1 file changed

+142
-126
lines changed

1 file changed

+142
-126
lines changed

lib/Sema/TypeCheckStmt.cpp

Lines changed: 142 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,142 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
947947
return S;
948948
}
949949

950+
void checkCaseLabelItem(CaseStmt *caseBlock, CaseLabelItem &labelItem,
951+
bool &limitExhaustivityChecks, Type subjectType) {
952+
SWIFT_DEFER {
953+
// Check the guard expression, if present.
954+
if (auto *guard = labelItem.getGuardExpr()) {
955+
limitExhaustivityChecks |= TC.typeCheckCondition(guard, DC);
956+
labelItem.setGuardExpr(guard);
957+
}
958+
};
959+
960+
Pattern *pattern = labelItem.getPattern();
961+
auto *newPattern = TC.resolvePattern(pattern, DC,
962+
/*isStmtCondition*/ false);
963+
if (!newPattern)
964+
return;
965+
pattern = newPattern;
966+
// Coerce the pattern to the subject's type.
967+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
968+
if (!subjectType ||
969+
TC.coercePatternToType(pattern, TypeResolution::forContextual(DC),
970+
subjectType, patternOptions)) {
971+
limitExhaustivityChecks = true;
972+
973+
// If that failed, mark any variables binding pieces of the pattern
974+
// as invalid to silence follow-on errors.
975+
pattern->forEachVariable([&](VarDecl *VD) { VD->markInvalid(); });
976+
}
977+
labelItem.setPattern(pattern);
978+
979+
// For each variable in the pattern, make sure its type is identical to what
980+
// it was in the first label item's pattern.
981+
auto *firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
982+
SmallVector<VarDecl *, 4> vars;
983+
firstPattern->collectVariables(vars);
984+
pattern->forEachVariable([&](VarDecl *vd) {
985+
if (!vd->hasName())
986+
return;
987+
for (auto *expected : vars) {
988+
if (expected->hasName() && expected->getName() == vd->getName()) {
989+
if (vd->hasType() && expected->hasType() && !expected->isInvalid() &&
990+
!vd->getType()->isEqual(expected->getType())) {
991+
TC.diagnose(vd->getLoc(), diag::type_mismatch_multiple_pattern_list,
992+
vd->getType(), expected->getType());
993+
vd->markInvalid();
994+
expected->markInvalid();
995+
}
996+
if (expected->isLet() != vd->isLet()) {
997+
auto diag = TC.diagnose(
998+
vd->getLoc(), diag::mutability_mismatch_multiple_pattern_list,
999+
vd->isLet(), expected->isLet());
1000+
1001+
VarPattern *foundVP = nullptr;
1002+
vd->getParentPattern()->forEachNode([&](Pattern *P) {
1003+
if (auto *VP = dyn_cast<VarPattern>(P))
1004+
if (VP->getSingleVar() == vd)
1005+
foundVP = VP;
1006+
});
1007+
if (foundVP)
1008+
diag.fixItReplace(foundVP->getLoc(),
1009+
expected->isLet() ? "let" : "var");
1010+
vd->markInvalid();
1011+
expected->markInvalid();
1012+
}
1013+
return;
1014+
}
1015+
}
1016+
});
1017+
}
1018+
1019+
void checkUnknownAttrRestrictions(CaseStmt *caseBlock,
1020+
bool &limitExhaustivityChecks) {
1021+
if (caseBlock->getCaseLabelItems().size() != 1) {
1022+
assert(!caseBlock->getCaseLabelItems().empty() &&
1023+
"parser should not produce case blocks with no items");
1024+
TC.diagnose(caseBlock->getLoc(), diag::unknown_case_multiple_patterns)
1025+
.highlight(caseBlock->getCaseLabelItems()[1].getSourceRange());
1026+
limitExhaustivityChecks = true;
1027+
}
1028+
1029+
if (FallthroughDest != nullptr) {
1030+
if (!caseBlock->isDefault())
1031+
TC.diagnose(caseBlock->getLoc(), diag::unknown_case_must_be_last);
1032+
limitExhaustivityChecks = true;
1033+
}
1034+
1035+
const auto &labelItem = caseBlock->getCaseLabelItems().front();
1036+
if (labelItem.getGuardExpr() && !labelItem.isDefault()) {
1037+
TC.diagnose(labelItem.getStartLoc(), diag::unknown_case_where_clause)
1038+
.highlight(labelItem.getGuardExpr()->getSourceRange());
1039+
}
1040+
1041+
const Pattern *pattern =
1042+
labelItem.getPattern()->getSemanticsProvidingPattern();
1043+
if (!isa<AnyPattern>(pattern)) {
1044+
TC.diagnose(labelItem.getStartLoc(), diag::unknown_case_must_be_catchall)
1045+
.highlight(pattern->getSourceRange());
1046+
}
1047+
}
1048+
1049+
void checkFallthroughPatternBindingsAndTypes(CaseStmt *caseBlock,
1050+
CaseStmt *previousBlock) {
1051+
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
1052+
SmallVector<VarDecl *, 4> vars;
1053+
firstPattern->collectVariables(vars);
1054+
1055+
for (auto &labelItem : previousBlock->getCaseLabelItems()) {
1056+
const Pattern *pattern = labelItem.getPattern();
1057+
SmallVector<VarDecl *, 4> PreviousVars;
1058+
pattern->collectVariables(PreviousVars);
1059+
for (auto expected : vars) {
1060+
bool matched = false;
1061+
if (!expected->hasName())
1062+
continue;
1063+
for (auto previous : PreviousVars) {
1064+
if (previous->hasName() &&
1065+
expected->getName() == previous->getName()) {
1066+
if (!previous->getType()->isEqual(expected->getType())) {
1067+
TC.diagnose(previous->getLoc(),
1068+
diag::type_mismatch_fallthrough_pattern_list,
1069+
previous->getType(), expected->getType());
1070+
previous->markInvalid();
1071+
expected->markInvalid();
1072+
}
1073+
matched = true;
1074+
break;
1075+
}
1076+
}
1077+
if (!matched) {
1078+
TC.diagnose(PreviousFallthrough->getLoc(),
1079+
diag::fallthrough_into_case_with_var_binding,
1080+
expected->getName());
1081+
}
1082+
}
1083+
}
1084+
}
1085+
9501086
Stmt *visitSwitchStmt(SwitchStmt *switchStmt) {
9511087
// Type-check the subject expression.
9521088
Expr *subjectExpr = switchStmt->getSubjectExpr();
@@ -979,141 +1115,21 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
9791115
FallthroughDest = std::next(i) == e ? nullptr : *std::next(i);
9801116

9811117
for (auto &labelItem : caseBlock->getMutableCaseLabelItems()) {
982-
// Resolve the pattern in the label.
983-
Pattern *pattern = labelItem.getPattern();
984-
if (auto *newPattern = TC.resolvePattern(pattern, DC,
985-
/*isStmtCondition*/false)) {
986-
pattern = newPattern;
987-
// Coerce the pattern to the subject's type.
988-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
989-
if (!subjectType ||
990-
TC.coercePatternToType(pattern, TypeResolution::forContextual(DC),
991-
subjectType, patternOptions)) {
992-
limitExhaustivityChecks = true;
993-
994-
// If that failed, mark any variables binding pieces of the pattern
995-
// as invalid to silence follow-on errors.
996-
pattern->forEachVariable([&](VarDecl *VD) {
997-
VD->markInvalid();
998-
});
999-
}
1000-
labelItem.setPattern(pattern);
1001-
1002-
// For each variable in the pattern, make sure its type is identical to what it
1003-
// was in the first label item's pattern.
1004-
auto *firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
1005-
SmallVector<VarDecl *, 4> vars;
1006-
firstPattern->collectVariables(vars);
1007-
pattern->forEachVariable([&](VarDecl *vd) {
1008-
if (!vd->hasName())
1009-
return;
1010-
for (auto *expected : vars) {
1011-
if (expected->hasName() && expected->getName() == vd->getName()) {
1012-
if (vd->hasType() && expected->hasType() &&
1013-
!expected->isInvalid() &&
1014-
!vd->getType()->isEqual(expected->getType())) {
1015-
TC.diagnose(vd->getLoc(),
1016-
diag::type_mismatch_multiple_pattern_list,
1017-
vd->getType(), expected->getType());
1018-
vd->markInvalid();
1019-
expected->markInvalid();
1020-
}
1021-
if (expected->isLet() != vd->isLet()) {
1022-
auto diag = TC.diagnose(
1023-
vd->getLoc(),
1024-
diag::mutability_mismatch_multiple_pattern_list,
1025-
vd->isLet(), expected->isLet());
1026-
1027-
VarPattern *foundVP = nullptr;
1028-
vd->getParentPattern()->forEachNode([&](Pattern *P) {
1029-
if (auto *VP = dyn_cast<VarPattern>(P))
1030-
if (VP->getSingleVar() == vd)
1031-
foundVP = VP;
1032-
});
1033-
if (foundVP)
1034-
diag.fixItReplace(foundVP->getLoc(),
1035-
expected->isLet() ? "let" : "var");
1036-
vd->markInvalid();
1037-
expected->markInvalid();
1038-
}
1039-
return;
1040-
}
1041-
}
1042-
});
1043-
}
1044-
// Check the guard expression, if present.
1045-
if (auto *guard = labelItem.getGuardExpr()) {
1046-
limitExhaustivityChecks |= TC.typeCheckCondition(guard, DC);
1047-
labelItem.setGuardExpr(guard);
1048-
}
1118+
// Resolve the pattern in our case label if it has not been resolved
1119+
// and check that our var decls follow invariants.
1120+
checkCaseLabelItem(caseBlock, labelItem, limitExhaustivityChecks,
1121+
subjectType);
10491122
}
10501123

10511124
// Check restrictions on '@unknown'.
10521125
if (caseBlock->hasUnknownAttr()) {
1053-
if (caseBlock->getCaseLabelItems().size() != 1) {
1054-
assert(!caseBlock->getCaseLabelItems().empty() &&
1055-
"parser should not produce case blocks with no items");
1056-
TC.diagnose(caseBlock->getLoc(),
1057-
diag::unknown_case_multiple_patterns)
1058-
.highlight(caseBlock->getCaseLabelItems()[1].getSourceRange());
1059-
limitExhaustivityChecks = true;
1060-
}
1061-
1062-
if (FallthroughDest != nullptr) {
1063-
if (!caseBlock->isDefault())
1064-
TC.diagnose(caseBlock->getLoc(), diag::unknown_case_must_be_last);
1065-
limitExhaustivityChecks = true;
1066-
}
1067-
1068-
const auto &labelItem = caseBlock->getCaseLabelItems().front();
1069-
if (labelItem.getGuardExpr() && !labelItem.isDefault()) {
1070-
TC.diagnose(labelItem.getStartLoc(),
1071-
diag::unknown_case_where_clause)
1072-
.highlight(labelItem.getGuardExpr()->getSourceRange());
1073-
}
1074-
1075-
const Pattern *pattern =
1076-
labelItem.getPattern()->getSemanticsProvidingPattern();
1077-
if (!isa<AnyPattern>(pattern)) {
1078-
TC.diagnose(labelItem.getStartLoc(),
1079-
diag::unknown_case_must_be_catchall)
1080-
.highlight(pattern->getSourceRange());
1081-
}
1126+
checkUnknownAttrRestrictions(caseBlock, limitExhaustivityChecks);
10821127
}
10831128

10841129
// If the previous case fellthrough, similarly check that that case's bindings
10851130
// includes our first label item's pattern bindings and types.
10861131
if (PreviousFallthrough && previousBlock) {
1087-
auto firstPattern = caseBlock->getCaseLabelItems()[0].getPattern();
1088-
SmallVector<VarDecl *, 4> vars;
1089-
firstPattern->collectVariables(vars);
1090-
1091-
for (auto &labelItem : previousBlock->getCaseLabelItems()) {
1092-
const Pattern *pattern = labelItem.getPattern();
1093-
SmallVector<VarDecl *, 4> PreviousVars;
1094-
pattern->collectVariables(PreviousVars);
1095-
for (auto expected : vars) {
1096-
bool matched = false;
1097-
if (!expected->hasName())
1098-
continue;
1099-
for (auto previous: PreviousVars) {
1100-
if (previous->hasName() && expected->getName() == previous->getName()) {
1101-
if (!previous->getType()->isEqual(expected->getType())) {
1102-
TC.diagnose(previous->getLoc(), diag::type_mismatch_fallthrough_pattern_list,
1103-
previous->getType(), expected->getType());
1104-
previous->markInvalid();
1105-
expected->markInvalid();
1106-
}
1107-
matched = true;
1108-
break;
1109-
}
1110-
}
1111-
if (!matched) {
1112-
TC.diagnose(PreviousFallthrough->getLoc(),
1113-
diag::fallthrough_into_case_with_var_binding, expected->getName());
1114-
}
1115-
}
1116-
}
1132+
checkFallthroughPatternBindingsAndTypes(caseBlock, previousBlock);
11171133
}
11181134

11191135
// Type-check the body statements.

0 commit comments

Comments
 (0)