Skip to content

Commit ee3934a

Browse files
committed
[CS] Allow ExprPatterns to be type-checked in the solver
Previously we would wait until CSApply, which would trigger their type-checking in `coercePatternToType`. This caused a number of bugs, and hampered solver-based completion, which does not run CSApply. Instead, form a conjunction of all the ExprPatterns present, which preserves some of the previous isolation behavior (though does not provide complete isolation). We can then modify `coercePatternToType` to accept a closure, which allows the solver to take over rewriting the ExprPatterns it has already solved. This then sets the stage for the complete removal of `coercePatternToType`, and doing all pattern type-checking in the solver.
1 parent 6d2c014 commit ee3934a

23 files changed

+487
-147
lines changed

include/swift/Sema/ConstraintLocatorPathElts.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ CUSTOM_LOCATOR_PATH_ELT(TernaryBranch)
222222
/// Performing a pattern patch.
223223
CUSTOM_LOCATOR_PATH_ELT(PatternMatch)
224224

225+
/// A conjunction for the ExprPatterns in a pattern.
226+
SIMPLE_LOCATOR_PATH_ELT(ExprPatternConjunction)
227+
225228
/// Points to a particular attribute associated with one of
226229
/// the arguments e.g. `inout` or its type e.g. `@escaping`.
227230
///

include/swift/Sema/ConstraintSystem.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,13 +699,27 @@ T *getAsStmt(ASTNode node) {
699699
return nullptr;
700700
}
701701

702+
template <typename T>
703+
bool isPattern(ASTNode node) {
704+
if (node.isNull() || !node.is<Pattern *>())
705+
return false;
706+
707+
auto *P = node.get<Pattern *>();
708+
return isa<T>(P);
709+
}
710+
702711
template <typename T = Pattern>
703712
T *getAsPattern(ASTNode node) {
704713
if (auto *P = node.dyn_cast<Pattern *>())
705714
return dyn_cast_or_null<T>(P);
706715
return nullptr;
707716
}
708717

718+
template <typename T = Pattern>
719+
T *castToPattern(ASTNode node) {
720+
return cast<T>(node.get<Pattern *>());
721+
}
722+
709723
template <typename T = Stmt> T *castToStmt(ASTNode node) {
710724
return cast<T>(node.get<Stmt *>());
711725
}
@@ -1514,6 +1528,10 @@ class Solution {
15141528
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
15151529
caseLabelItems;
15161530

1531+
/// A map of expressions to the ExprPatterns that they are being solved as
1532+
/// a part of.
1533+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
1534+
15171535
/// The set of parameters that have been inferred to be 'isolated'.
15181536
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
15191537

@@ -1699,6 +1717,20 @@ class Solution {
16991717
: nullptr;
17001718
}
17011719

1720+
CaseLabelItemInfo getCaseLabelItemInfo(CaseLabelItem *labelItem) {
1721+
auto result = caseLabelItems.find(labelItem);
1722+
assert(result != caseLabelItems.end());
1723+
return result->second;
1724+
}
1725+
1726+
NullablePtr<ExprPattern> getExprPattern(Expr *E) const {
1727+
auto result = exprPatterns.find(E);
1728+
if (result == exprPatterns.end())
1729+
return nullptr;
1730+
1731+
return result->second;
1732+
}
1733+
17021734
/// This method implements functionality of `Expr::isTypeReference`
17031735
/// with data provided by a given solution.
17041736
bool isTypeReference(Expr *E) const;
@@ -2162,6 +2194,10 @@ class ConstraintSystem {
21622194
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
21632195
caseLabelItems;
21642196

2197+
/// A map of expressions to the ExprPatterns that they are being solved as
2198+
/// a part of.
2199+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
2200+
21652201
/// The set of parameters that have been inferred to be 'isolated'.
21662202
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
21672203

@@ -2752,6 +2788,9 @@ class ConstraintSystem {
27522788
/// The length of \c caseLabelItems.
27532789
unsigned numCaseLabelItems;
27542790

2791+
/// The length of \c exprPatterns.
2792+
unsigned numExprPatterns;
2793+
27552794
/// The length of \c isolatedParams.
27562795
unsigned numIsolatedParams;
27572796

@@ -3173,6 +3212,14 @@ class ConstraintSystem {
31733212
caseLabelItems[item] = info;
31743213
}
31753214

3215+
void setExprPatternFor(Expr *E, ExprPattern *EP) {
3216+
assert(E);
3217+
assert(EP);
3218+
auto inserted = exprPatterns.insert({E, EP}).second;
3219+
assert(inserted && "Mapping already defined?");
3220+
(void)inserted;
3221+
}
3222+
31763223
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
31773224
const CaseLabelItem *item) const {
31783225
auto known = caseLabelItems.find(item);
@@ -4300,6 +4347,9 @@ class ConstraintSystem {
43004347
/// \returns \c true if constraint generation failed, \c false otherwise
43014348
bool generateConstraints(SingleValueStmtExpr *E);
43024349

4350+
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351+
ConstraintLocatorBuilder locator);
4352+
43034353
/// Generate constraints for the given (unchecked) expression.
43044354
///
43054355
/// \returns a possibly-sanitized expression, or null if an error occurred.
@@ -4324,6 +4374,10 @@ class ConstraintSystem {
43244374
[[nodiscard]]
43254375
bool generateConstraints(StmtCondition condition, DeclContext *dc);
43264376

4377+
[[nodiscard]]
4378+
bool generateConstraints(CaseLabelItem *caseLabelItem, DeclContext *dc,
4379+
Type convertTy, ConstraintLocator *locator);
4380+
43274381
/// Generate constraints for a case statement.
43284382
///
43294383
/// \param subjectType The type of the "subject" expression in the enclosing

lib/IDE/TypeCheckCompletionCallback.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
8181
/// \endcode
8282
/// If the code completion expression occurs in such an AST, return the
8383
/// declaration of the \c $match variable, otherwise return \c nullptr.
84-
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
84+
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
85+
if (auto EP = S.getExprPattern(E))
86+
return EP.get()->getMatchVar();
87+
88+
// TODO: Once ExprPattern type-checking is fully moved into the solver,
89+
// the below can be deleted.
90+
auto &CS = S.getConstraintSystem();
8591
auto &Context = CS.getASTContext();
8692

8793
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
109115
}
110116

111117
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
112-
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
113-
Type MatchVarType;
114-
// If the MatchVar has an explicit type, it's not part of the solution. But
115-
// we can look it up in the constraint system directly.
116-
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
117-
MatchVarType = T;
118-
} else {
119-
MatchVarType = getTypeForCompletion(S, MatchVar);
120-
}
121-
if (MatchVarType) {
122-
return MatchVarType;
123-
}
124-
}
125-
return nullptr;
118+
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
119+
if (!MatchVar)
120+
return nullptr;
121+
122+
if (S.hasType(MatchVar))
123+
return S.getResolvedType(MatchVar);
124+
125+
// If the ExprPattern wasn't solved as part of the constraint system, it's
126+
// not part of the solution.
127+
// TODO: This can be removed once ExprPattern type-checking is fully part
128+
// of the constraint system.
129+
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
130+
return T;
131+
132+
return getTypeForCompletion(S, MatchVar);
126133
}
127134

128135
void swift::ide::getSolutionSpecificVarTypes(

lib/Sema/CSApply.cpp

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8570,6 +8570,9 @@ namespace {
85708570
return Action::SkipChildren();
85718571
}
85728572

8573+
NullablePtr<Pattern>
8574+
rewritePattern(Pattern *pattern, DeclContext *DC);
8575+
85738576
/// Rewrite the target, producing a new target.
85748577
Optional<SyntacticElementTarget>
85758578
rewriteTarget(SyntacticElementTarget target);
@@ -8816,12 +8819,68 @@ static Expr *wrapAsyncLetInitializer(
88168819
return resultInit;
88178820
}
88188821

8822+
static Pattern *rewriteExprPattern(SyntacticElementTarget matchTarget,
8823+
Type patternTy,
8824+
RewriteTargetFn rewriteTarget) {
8825+
auto *EP = matchTarget.getExprPattern();
8826+
8827+
// See if we can simplify to another kind of pattern.
8828+
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
8829+
return simplified.get();
8830+
8831+
auto resultTarget = rewriteTarget(matchTarget);
8832+
if (!resultTarget)
8833+
return nullptr;
8834+
8835+
matchTarget = *resultTarget;
8836+
8837+
EP->setMatchExpr(matchTarget.getAsExpr());
8838+
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
8839+
EP->setType(patternTy);
8840+
return EP;
8841+
}
8842+
8843+
static Optional<Pattern *>
8844+
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
8845+
RewriteTargetFn rewriteTarget) {
8846+
// We may either have an ExprPattern, or a pattern that was mapped to an
8847+
// ExprPattern, such as an EnumElementPattern.
8848+
auto matchTarget = solution.getTargetFor(P);
8849+
if (!matchTarget)
8850+
return None;
8851+
8852+
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
8853+
}
8854+
8855+
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
8856+
DeclContext *DC) {
8857+
auto &solution = Rewriter.solution;
8858+
8859+
// Figure out the pattern type.
8860+
auto patternTy = solution.getResolvedType(pattern);
8861+
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
8862+
8863+
// Coerce the pattern to its appropriate type.
8864+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
8865+
patternOptions |= TypeResolutionFlags::OverrideType;
8866+
8867+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8868+
return ::tryRewriteExprPattern(
8869+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8870+
};
8871+
8872+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
8873+
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
8874+
patternOptions, tryRewritePattern);
8875+
}
8876+
88198877
/// Apply the given solution to the initialization target.
88208878
///
88218879
/// \returns the resulting initialization expression.
88228880
static Optional<SyntacticElementTarget>
88238881
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
8824-
Expr *initializer) {
8882+
Expr *initializer,
8883+
RewriteTargetFn rewriteTarget) {
88258884
auto wrappedVar = target.getInitializationWrappedVar();
88268885
Type initType;
88278886
if (wrappedVar) {
@@ -8886,10 +8945,15 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
88868945

88878946
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
88888947

8948+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8949+
return ::tryRewriteExprPattern(
8950+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8951+
};
8952+
88898953
// Apply the solution to the pattern as well.
88908954
auto contextualPattern = target.getContextualPattern();
88918955
if (auto coercedPattern = TypeChecker::coercePatternToType(
8892-
contextualPattern, finalPatternType, options)) {
8956+
contextualPattern, finalPatternType, options, tryRewritePattern)) {
88938957
resultTarget.setPattern(coercedPattern);
88948958
} else {
88958959
return None;
@@ -9036,10 +9100,16 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
90369100
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
90379101
options |= TypeResolutionFlags::OverrideType;
90389102

9103+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9104+
return ::tryRewriteExprPattern(
9105+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9106+
};
9107+
90399108
// Apply the solution to the pattern as well.
90409109
auto contextualPattern = target.getContextualPattern();
90419110
auto coercedPattern = TypeChecker::coercePatternToType(
9042-
contextualPattern, forEachStmtInfo.initType, options);
9111+
contextualPattern, forEachStmtInfo.initType, options,
9112+
tryRewritePattern);
90439113
if (!coercedPattern)
90449114
return None;
90459115

@@ -9127,7 +9197,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
91279197
switch (target.getExprContextualTypePurpose()) {
91289198
case CTP_Initialization: {
91299199
auto initResultTarget = applySolutionToInitialization(
9130-
solution, target, rewrittenExpr);
9200+
solution, target, rewrittenExpr,
9201+
[&](auto target) { return rewriteTarget(target); });
91319202
if (!initResultTarget)
91329203
return None;
91339204

@@ -9218,47 +9289,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92189289
ConstraintSystem &cs = solution.getConstraintSystem();
92199290
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
92209291

9221-
// Figure out the pattern type.
9222-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9223-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9224-
9225-
// Check whether this enum element is resolved via ~= application.
9226-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9227-
if (auto target = cs.getTargetFor(enumElement)) {
9228-
auto *EP = target->getExprPattern();
9229-
auto enumType = solution.getResolvedType(EP);
9230-
9231-
auto *matchCall = target->getAsExpr();
9232-
9233-
auto *result = matchCall->walk(*this);
9234-
if (!result)
9235-
return None;
9236-
9237-
{
9238-
auto *matchVar = EP->getMatchVar();
9239-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9240-
}
9241-
9242-
EP->setMatchExpr(result);
9243-
EP->setType(enumType);
9244-
9245-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9246-
return target;
9247-
}
9248-
}
9249-
9250-
// Coerce the pattern to its appropriate type.
9251-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9252-
patternOptions |= TypeResolutionFlags::OverrideType;
9253-
auto contextualPattern =
9254-
ContextualPattern::forRawPattern(info.pattern,
9255-
target.getDeclContext());
9256-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9257-
contextualPattern, patternType, patternOptions)) {
9258-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9259-
} else {
9292+
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
9293+
if (!pattern)
92609294
return None;
9261-
}
9295+
9296+
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
92629297

92639298
// If there is a guard expression, coerce that.
92649299
if (auto *guardExpr = info.guardExpr) {
@@ -9326,8 +9361,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
93269361
options |= TypeResolutionFlags::OverrideType;
93279362
}
93289363

9364+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9365+
return ::tryRewriteExprPattern(
9366+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9367+
};
9368+
93299369
if (auto coercedPattern = TypeChecker::coercePatternToType(
9330-
contextualPattern, patternType, options)) {
9370+
contextualPattern, patternType, options, tryRewritePattern)) {
93319371
auto resultTarget = target;
93329372
resultTarget.setPattern(coercedPattern);
93339373
return resultTarget;

0 commit comments

Comments
 (0)