@@ -346,6 +346,8 @@ class SyntacticElementConstraintGenerator
346
346
hadError = true ;
347
347
return ;
348
348
}
349
+
350
+ caseItem->setPattern (pattern, /* resolved=*/ true );
349
351
}
350
352
351
353
// Let's generate constraints for pattern + where clause.
@@ -903,8 +905,6 @@ class SyntacticElementConstraintGenerator
903
905
}
904
906
}
905
907
906
- bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
907
-
908
908
auto *caseLoc = cs.getConstraintLocator (
909
909
locator, LocatorPathElt::SyntacticElement (caseStmt));
910
910
@@ -928,10 +928,8 @@ class SyntacticElementConstraintGenerator
928
928
locator->castLastElementTo <LocatorPathElt::SyntacticElement>()
929
929
.asStmt ());
930
930
931
- for (auto caseBodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
932
- auto parentVar = caseBodyVar->getParentVarDecl ();
933
- assert (parentVar && " Case body variables always have parents" );
934
- cs.setType (caseBodyVar, cs.getType (parentVar));
931
+ if (recordInferredSwitchCasePatternVars (caseStmt)) {
932
+ hadError = true ;
935
933
}
936
934
}
937
935
@@ -1058,6 +1056,75 @@ class SyntacticElementConstraintGenerator
1058
1056
locator->getLastElementAs <LocatorPathElt::SyntacticElement>();
1059
1057
return parentElt ? parentElt->getElement ().isStmt (kind) : false ;
1060
1058
}
1059
+
1060
+ bool recordInferredSwitchCasePatternVars (CaseStmt *caseStmt) {
1061
+ llvm::SmallDenseMap<Identifier, SmallVector<VarDecl *, 2 >, 4 > patternVars;
1062
+
1063
+ auto recordVar = [&](VarDecl *var) {
1064
+ if (!var->hasName ())
1065
+ return ;
1066
+ patternVars[var->getName ()].push_back (var);
1067
+ };
1068
+
1069
+ for (auto &caseItem : caseStmt->getMutableCaseLabelItems ()) {
1070
+ assert (caseItem.isPatternResolved ());
1071
+
1072
+ auto *pattern = caseItem.getPattern ();
1073
+ pattern->forEachVariable ([&](VarDecl *var) { recordVar (var); });
1074
+ }
1075
+
1076
+ for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
1077
+ if (!bodyVar->hasName ())
1078
+ continue ;
1079
+
1080
+ const auto &variants = patternVars[bodyVar->getName ()];
1081
+
1082
+ auto getType = [&](VarDecl *var) {
1083
+ auto type = cs.simplifyType (cs.getType (var));
1084
+ assert (!type->hasTypeVariable ());
1085
+ return type;
1086
+ };
1087
+
1088
+ switch (variants.size ()) {
1089
+ case 0 :
1090
+ break ;
1091
+
1092
+ case 1 :
1093
+ // If there is only one choice here, let's use it directly.
1094
+ cs.setType (bodyVar, getType (variants.front ()));
1095
+ break ;
1096
+
1097
+ default : {
1098
+ // If there are multiple choices it could only mean multiple
1099
+ // patterns e.g. `.a(let x), .b(let x), ...:`. Let's join them.
1100
+ Type joinType = getType (variants.front ());
1101
+
1102
+ SmallVector<VarDecl *, 2 > conflicts;
1103
+ for (auto *var : llvm::drop_begin (variants)) {
1104
+ auto varType = getType (var);
1105
+ // Type mismatch between different patterns.
1106
+ if (!joinType->isEqual (varType))
1107
+ conflicts.push_back (var);
1108
+ }
1109
+
1110
+ if (!conflicts.empty ()) {
1111
+ if (!cs.shouldAttemptFixes ())
1112
+ return true ;
1113
+
1114
+ // dfdf
1115
+ auto *locator = cs.getConstraintLocator (bodyVar);
1116
+ if (cs.recordFix (RenameConflictingPatternVariables::create (
1117
+ cs, joinType, conflicts, locator)))
1118
+ return true ;
1119
+ }
1120
+
1121
+ cs.setType (bodyVar, joinType);
1122
+ }
1123
+ }
1124
+ }
1125
+
1126
+ return false ;
1127
+ }
1061
1128
};
1062
1129
}
1063
1130
@@ -1465,6 +1532,8 @@ class SyntacticElementSolutionApplication
1465
1532
}
1466
1533
}
1467
1534
1535
+ bindSwitchCasePatternVars (context.getAsDeclContext (), caseStmt);
1536
+
1468
1537
for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray ()) {
1469
1538
assert (expected->hasName ());
1470
1539
auto prev = expected->getParentVarDecl ();
0 commit comments