@@ -338,6 +338,8 @@ class SyntacticElementConstraintGenerator
338
338
hadError = true ;
339
339
return ;
340
340
}
341
+
342
+ caseItem->setPattern (pattern, /* resolved=*/ true );
341
343
}
342
344
343
345
// Let's generate constraints for pattern + where clause.
@@ -774,8 +776,6 @@ class SyntacticElementConstraintGenerator
774
776
}
775
777
}
776
778
777
- bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
778
-
779
779
auto *caseLoc = cs.getConstraintLocator (
780
780
locator, LocatorPathElt::SyntacticElement (caseStmt));
781
781
@@ -799,10 +799,8 @@ class SyntacticElementConstraintGenerator
799
799
locator->castLastElementTo <LocatorPathElt::SyntacticElement>()
800
800
.asStmt ());
801
801
802
- for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
803
- auto parentVar = caseBodyVar->getParentVarDecl ();
804
- assert (parentVar && " Case body variables always have parents" );
805
- cs.setType (caseBodyVar, cs.getType (parentVar));
802
+ if (recordInferredSwitchCasePatternVars (caseStmt)) {
803
+ hadError = true ;
806
804
}
807
805
}
808
806
@@ -929,6 +927,75 @@ class SyntacticElementConstraintGenerator
929
927
locator->getLastElementAs <LocatorPathElt::SyntacticElement>();
930
928
return parentElt ? parentElt->getElement ().isStmt (kind) : false ;
931
929
}
930
+
931
+ bool recordInferredSwitchCasePatternVars (CaseStmt *caseStmt) {
932
+ llvm::SmallDenseMap<Identifier, SmallVector<VarDecl *, 2 >, 4 > patternVars;
933
+
934
+ auto recordVar = [&](VarDecl *var) {
935
+ if (!var->hasName ())
936
+ return ;
937
+ patternVars[var->getName ()].push_back (var);
938
+ };
939
+
940
+ for (auto &caseItem : caseStmt->getMutableCaseLabelItems ()) {
941
+ assert (caseItem.isPatternResolved ());
942
+
943
+ auto *pattern = caseItem.getPattern ();
944
+ pattern->forEachVariable ([&](VarDecl *var) { recordVar (var); });
945
+ }
946
+
947
+ for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
948
+ if (!bodyVar->hasName ())
949
+ continue ;
950
+
951
+ const auto &variants = patternVars[bodyVar->getName ()];
952
+
953
+ auto getType = [&](VarDecl *var) {
954
+ auto type = cs.simplifyType (cs.getType (var));
955
+ assert (!type->hasTypeVariable ());
956
+ return type;
957
+ };
958
+
959
+ switch (variants.size ()) {
960
+ case 0 :
961
+ break ;
962
+
963
+ case 1 :
964
+ // If there is only one choice here, let's use it directly.
965
+ cs.setType (bodyVar, getType (variants.front ()));
966
+ break ;
967
+
968
+ default : {
969
+ // If there are multiple choices it could only mean multiple
970
+ // patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them.
971
+ Type joinType = getType (variants.front ());
972
+
973
+ SmallVector<VarDecl *, 2 > conflicts;
974
+ for (auto *var : llvm::drop_begin (variants)) {
975
+ auto varType = getType (var);
976
+ // Type mismatch between different patterns.
977
+ if (!joinType->isEqual (varType))
978
+ conflicts.push_back (var);
979
+ }
980
+
981
+ if (!conflicts.empty ()) {
982
+ if (!cs.shouldAttemptFixes ())
983
+ return true ;
984
+
985
+ // dfdf
986
+ auto *locator = cs.getConstraintLocator (bodyVar);
987
+ if (cs.recordFix (RenameConflictingPatternVariables::create (
988
+ cs, joinType, conflicts, locator)))
989
+ return true ;
990
+ }
991
+
992
+ cs.setType (bodyVar, joinType);
993
+ }
994
+ }
995
+ }
996
+
997
+ return false ;
998
+ }
932
999
};
933
1000
}
934
1001
@@ -1336,6 +1403,8 @@ class SyntacticElementSolutionApplication
1336
1403
}
1337
1404
}
1338
1405
1406
+ bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
1407
+
1339
1408
for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
1340
1409
assert (expected->hasName ());
1341
1410
auto prev = expected->getParentVarDecl ();
0 commit comments