Skip to content

Commit 7ce6fa6

Browse files
committed
[CSClosure] Detect and diagnose conflicting pattern variables in case statements
Avoid mutating case label items while solving, instead let's use types recorded in the constraint system for each pattern variable and use that for var reference in the case body. This also helps to detect and diagnose type conflicts while solving.
1 parent f704f80 commit 7ce6fa6

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

lib/Sema/CSClosure.cpp

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,8 @@ class SyntacticElementConstraintGenerator
346346
hadError = true;
347347
return;
348348
}
349+
350+
caseItem->setPattern(pattern, /*resolved=*/true);
349351
}
350352

351353
// Let's generate constraints for pattern + where clause.
@@ -903,8 +905,6 @@ class SyntacticElementConstraintGenerator
903905
}
904906
}
905907

906-
bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt);
907-
908908
auto *caseLoc = cs.getConstraintLocator(
909909
locator, LocatorPathElt::SyntacticElement(caseStmt));
910910

@@ -928,10 +928,8 @@ class SyntacticElementConstraintGenerator
928928
locator->castLastElementTo<LocatorPathElt::SyntacticElement>()
929929
.asStmt());
930930

931-
for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
932-
auto parentVar = caseBodyVar->getParentVarDecl();
933-
assert(parentVar && "Case body variables always have parents");
934-
cs.setType(caseBodyVar, cs.getType(parentVar));
931+
if (recordInferredSwitchCasePatternVars(caseStmt)) {
932+
hadError = true;
935933
}
936934
}
937935

@@ -1058,6 +1056,75 @@ class SyntacticElementConstraintGenerator
10581056
locator->getLastElementAs<LocatorPathElt::SyntacticElement>();
10591057
return parentElt ? parentElt->getElement().isStmt(kind) : false;
10601058
}
1059+
1060+
bool recordInferredSwitchCasePatternVars(CaseStmt *caseStmt) {
1061+
llvm::SmallDenseMap<Identifier, SmallVector<VarDecl *, 2>, 4> patternVars;
1062+
1063+
auto recordVar = [&](VarDecl *var) {
1064+
if (!var->hasName())
1065+
return;
1066+
patternVars[var->getName()].push_back(var);
1067+
};
1068+
1069+
for (auto &caseItem : caseStmt->getMutableCaseLabelItems()) {
1070+
assert(caseItem.isPatternResolved());
1071+
1072+
auto *pattern = caseItem.getPattern();
1073+
pattern->forEachVariable([&](VarDecl *var) { recordVar(var); });
1074+
}
1075+
1076+
for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
1077+
if (!bodyVar->hasName())
1078+
continue;
1079+
1080+
const auto &variants = patternVars[bodyVar->getName()];
1081+
1082+
auto getType = [&](VarDecl *var) {
1083+
auto type = cs.simplifyType(cs.getType(var));
1084+
assert(!type->hasTypeVariable());
1085+
return type;
1086+
};
1087+
1088+
switch (variants.size()) {
1089+
case 0:
1090+
break;
1091+
1092+
case 1:
1093+
// If there is only one choice here, let's use it directly.
1094+
cs.setType(bodyVar, getType(variants.front()));
1095+
break;
1096+
1097+
default: {
1098+
// If there are multiple choices it could only mean multiple
1099+
// patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them.
1100+
Type joinType = getType(variants.front());
1101+
1102+
SmallVector<VarDecl *, 2> conflicts;
1103+
for (auto *var : llvm::drop_begin(variants)) {
1104+
auto varType = getType(var);
1105+
// Type mismatch between different patterns.
1106+
if (!joinType->isEqual(varType))
1107+
conflicts.push_back(var);
1108+
}
1109+
1110+
if (!conflicts.empty()) {
1111+
if (!cs.shouldAttemptFixes())
1112+
return true;
1113+
1114+
// dfdf
1115+
auto *locator = cs.getConstraintLocator(bodyVar);
1116+
if (cs.recordFix(RenameConflictingPatternVariables::create(
1117+
cs, joinType, conflicts, locator)))
1118+
return true;
1119+
}
1120+
1121+
cs.setType(bodyVar, joinType);
1122+
}
1123+
}
1124+
}
1125+
1126+
return false;
1127+
}
10611128
};
10621129
}
10631130

@@ -1465,6 +1532,8 @@ class SyntacticElementSolutionApplication
14651532
}
14661533
}
14671534

1535+
bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt);
1536+
14681537
for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
14691538
assert(expected->hasName());
14701539
auto prev = expected->getParentVarDecl();

test/expr/closure/multi_statement.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,35 @@ func test_fallthrough_stmt() {
454454
}
455455
}()
456456
}
457+
458+
func test_conflicting_pattern_vars() {
459+
enum E {
460+
case a(Int, String)
461+
case b(String, Int)
462+
}
463+
464+
func fn(_: (E) -> Void) {}
465+
func fn<T>(_: (E) -> T) {}
466+
467+
func test(e: E) {
468+
fn {
469+
switch $0 {
470+
case .a(let x, let y),
471+
.b(let x, let y):
472+
// expected-error@-1 {{pattern variable bound to type 'String', expected type 'Int'}}
473+
// expected-error@-2 {{pattern variable bound to type 'Int', expected type 'String'}}
474+
_ = x
475+
_ = y
476+
}
477+
}
478+
479+
fn {
480+
switch $0 {
481+
case .a(let x, let y),
482+
.b(let y, let x): // Ok
483+
_ = x
484+
_ = y
485+
}
486+
}
487+
}
488+
}

0 commit comments

Comments
 (0)