Skip to content

Commit 9c8bebe

Browse files
committed
[CSClosure] Detect and diagnose conflicting pattern variables in case statements
Avoid mutating case label items while solving, instead let's use types recorded in the constraint system for each pattern variable and use that for var reference in the case body. This also helps to detect and diagnose type conflicts while solving.
1 parent 62ba749 commit 9c8bebe

File tree

2 files changed

+107
-6
lines changed

2 files changed

+107
-6
lines changed

lib/Sema/CSClosure.cpp

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ class SyntacticElementConstraintGenerator
338338
hadError = true;
339339
return;
340340
}
341+
342+
caseItem->setPattern(pattern, /*resolved=*/true);
341343
}
342344

343345
// Let's generate constraints for pattern + where clause.
@@ -774,8 +776,6 @@ class SyntacticElementConstraintGenerator
774776
}
775777
}
776778

777-
bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt);
778-
779779
auto *caseLoc = cs.getConstraintLocator(
780780
locator, LocatorPathElt::SyntacticElement(caseStmt));
781781

@@ -799,10 +799,8 @@ class SyntacticElementConstraintGenerator
799799
locator->castLastElementTo<LocatorPathElt::SyntacticElement>()
800800
.asStmt());
801801

802-
for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
803-
auto parentVar = caseBodyVar->getParentVarDecl();
804-
assert(parentVar && "Case body variables always have parents");
805-
cs.setType(caseBodyVar, cs.getType(parentVar));
802+
if (recordInferredSwitchCasePatternVars(caseStmt)) {
803+
hadError = true;
806804
}
807805
}
808806

@@ -929,6 +927,75 @@ class SyntacticElementConstraintGenerator
929927
locator->getLastElementAs<LocatorPathElt::SyntacticElement>();
930928
return parentElt ? parentElt->getElement().isStmt(kind) : false;
931929
}
930+
931+
bool recordInferredSwitchCasePatternVars(CaseStmt *caseStmt) {
932+
llvm::SmallDenseMap<Identifier, SmallVector<VarDecl *, 2>, 4> patternVars;
933+
934+
auto recordVar = [&](VarDecl *var) {
935+
if (!var->hasName())
936+
return;
937+
patternVars[var->getName()].push_back(var);
938+
};
939+
940+
for (auto &caseItem : caseStmt->getMutableCaseLabelItems()) {
941+
assert(caseItem.isPatternResolved());
942+
943+
auto *pattern = caseItem.getPattern();
944+
pattern->forEachVariable([&](VarDecl *var) { recordVar(var); });
945+
}
946+
947+
for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
948+
if (!bodyVar->hasName())
949+
continue;
950+
951+
const auto &variants = patternVars[bodyVar->getName()];
952+
953+
auto getType = [&](VarDecl *var) {
954+
auto type = cs.simplifyType(cs.getType(var));
955+
assert(!type->hasTypeVariable());
956+
return type;
957+
};
958+
959+
switch (variants.size()) {
960+
case 0:
961+
break;
962+
963+
case 1:
964+
// If there is only one choice here, let's use it directly.
965+
cs.setType(bodyVar, getType(variants.front()));
966+
break;
967+
968+
default: {
969+
// If there are multiple choices it could only mean multiple
970+
// patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them.
971+
Type joinType = getType(variants.front());
972+
973+
SmallVector<VarDecl *, 2> conflicts;
974+
for (auto *var : llvm::drop_begin(variants)) {
975+
auto varType = getType(var);
976+
// Type mismatch between different patterns.
977+
if (!joinType->isEqual(varType))
978+
conflicts.push_back(var);
979+
}
980+
981+
if (!conflicts.empty()) {
982+
if (!cs.shouldAttemptFixes())
983+
return true;
984+
985+
// dfdf
986+
auto *locator = cs.getConstraintLocator(bodyVar);
987+
if (cs.recordFix(RenameConflictingPatternVariables::create(
988+
cs, joinType, conflicts, locator)))
989+
return true;
990+
}
991+
992+
cs.setType(bodyVar, joinType);
993+
}
994+
}
995+
}
996+
997+
return false;
998+
}
932999
};
9331000
}
9341001

@@ -1336,6 +1403,8 @@ class SyntacticElementSolutionApplication
13361403
}
13371404
}
13381405

1406+
bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt);
1407+
13391408
for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) {
13401409
assert(expected->hasName());
13411410
auto prev = expected->getParentVarDecl();

test/expr/closure/multi_statement.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,35 @@ func test_missing_conformance_diagnostics_in_for_sequence() {
515515
}
516516
}
517517
}
518+
519+
func test_conflicting_pattern_vars() {
520+
enum E {
521+
case a(Int, String)
522+
case b(String, Int)
523+
}
524+
525+
func fn(_: (E) -> Void) {}
526+
func fn<T>(_: (E) -> T) {}
527+
528+
func test(e: E) {
529+
fn {
530+
switch $0 {
531+
case .a(let x, let y),
532+
.b(let x, let y):
533+
// expected-error@-1 {{pattern variable bound to type 'String', expected type 'Int'}}
534+
// expected-error@-2 {{pattern variable bound to type 'Int', expected type 'String'}}
535+
_ = x
536+
_ = y
537+
}
538+
}
539+
540+
fn {
541+
switch $0 {
542+
case .a(let x, let y),
543+
.b(let y, let x): // Ok
544+
_ = x
545+
_ = y
546+
}
547+
}
548+
}
549+
}

0 commit comments

Comments
 (0)