Skip to content

Commit 5800450

Browse files
committed
[CS] Allow ExprPatterns to be type-checked in the solver
Previously we would wait until CSApply, which would trigger their type-checking in `coercePatternToType`. This caused a number of bugs, and hampered solver-based completion, which does not run CSApply. Instead, form a single element conjunction, which will generate constraints for the expression when solved. The reason we need a conjunction is to preserve the current behavior where the expression pattern is type-checked independently from the rest of the pattern. We can then modify `coercePatternToType` to accept a closure, which allows the solver to take over rewriting the ExprPatterns it has already solved. This then sets the stage for the complete removal of `coercePatternToType`, and doing all pattern type-checking in the solver.
1 parent 3ed47f1 commit 5800450

21 files changed

+386
-146
lines changed

include/swift/Sema/ConstraintLocator.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,14 @@ class ConstraintLocator : public llvm::FoldingSetNode {
320320
return false;
321321
}
322322

323+
/// Determine whether this locator points directly to a given pattern.
324+
template <typename P>
325+
bool directlyAtPattern() const {
326+
if (auto *pattern = getAnchor().dyn_cast<Pattern *>())
327+
return isa<P>(pattern) && getPath().empty();
328+
return false;
329+
}
330+
323331
/// Check whether the first element in the path of this locator (if any)
324332
/// is a given \c LocatorPathElt subclass.
325333
template <class T>

include/swift/Sema/ConstraintSystem.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,13 +699,27 @@ T *getAsStmt(ASTNode node) {
699699
return nullptr;
700700
}
701701

702+
template <typename T>
703+
bool isPattern(ASTNode node) {
704+
if (node.isNull() || !node.is<Pattern *>())
705+
return false;
706+
707+
auto *P = node.get<Pattern *>();
708+
return isa<T>(P);
709+
}
710+
702711
template <typename T = Pattern>
703712
T *getAsPattern(ASTNode node) {
704713
if (auto *P = node.dyn_cast<Pattern *>())
705714
return dyn_cast_or_null<T>(P);
706715
return nullptr;
707716
}
708717

718+
template <typename T = Pattern>
719+
T *castToPattern(ASTNode node) {
720+
return cast<T>(node.get<Pattern *>());
721+
}
722+
709723
template <typename T = Stmt> T *castToStmt(ASTNode node) {
710724
return cast<T>(node.get<Stmt *>());
711725
}
@@ -1515,6 +1529,10 @@ class Solution {
15151529
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
15161530
caseLabelItems;
15171531

1532+
/// A map of expressions to the ExprPatterns that they are being solved as
1533+
/// a part of.
1534+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
1535+
15181536
/// The set of parameters that have been inferred to be 'isolated'.
15191537
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
15201538

@@ -1697,6 +1715,20 @@ class Solution {
16971715
: nullptr;
16981716
}
16991717

1718+
CaseLabelItemInfo getCaseLabelItemInfo(CaseLabelItem *labelItem) {
1719+
auto result = caseLabelItems.find(labelItem);
1720+
assert(result != caseLabelItems.end());
1721+
return result->second;
1722+
}
1723+
1724+
NullablePtr<ExprPattern> getExprPattern(Expr *E) const {
1725+
auto result = exprPatterns.find(E);
1726+
if (result == exprPatterns.end())
1727+
return nullptr;
1728+
1729+
return result->second;
1730+
}
1731+
17001732
/// This method implements functionality of `Expr::isTypeReference`
17011733
/// with data provided by a given solution.
17021734
bool isTypeReference(Expr *E) const;
@@ -2161,6 +2193,10 @@ class ConstraintSystem {
21612193
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
21622194
caseLabelItems;
21632195

2196+
/// A map of expressions to the ExprPatterns that they are being solved as
2197+
/// a part of.
2198+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
2199+
21642200
/// The set of parameters that have been inferred to be 'isolated'.
21652201
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
21662202

@@ -2751,6 +2787,9 @@ class ConstraintSystem {
27512787
/// The length of \c caseLabelItems.
27522788
unsigned numCaseLabelItems;
27532789

2790+
/// The length of \c exprPatterns.
2791+
unsigned numExprPatterns;
2792+
27542793
/// The length of \c isolatedParams.
27552794
unsigned numIsolatedParams;
27562795

@@ -3164,6 +3203,14 @@ class ConstraintSystem {
31643203
caseLabelItems[item] = info;
31653204
}
31663205

3206+
void setExprPattern(Expr *E, ExprPattern *EP) {
3207+
assert(E);
3208+
assert(EP);
3209+
auto inserted = exprPatterns.insert({E, EP}).second;
3210+
assert(inserted && "Mapping already defined?");
3211+
(void)inserted;
3212+
}
3213+
31673214
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
31683215
const CaseLabelItem *item) const {
31693216
auto known = caseLabelItems.find(item);
@@ -4287,6 +4334,9 @@ class ConstraintSystem {
42874334
/// \returns \c true if constraint generation failed, \c false otherwise
42884335
bool generateConstraints(SingleValueStmtExpr *E);
42894336

4337+
LLVM_NODISCARD
4338+
Type generateConstraints(ExprPattern *EP);
4339+
42904340
/// Generate constraints for the given (unchecked) expression.
42914341
///
42924342
/// \returns a possibly-sanitized expression, or null if an error occurred.
@@ -4311,6 +4361,10 @@ class ConstraintSystem {
43114361
LLVM_NODISCARD
43124362
bool generateConstraints(StmtCondition condition, DeclContext *dc);
43134363

4364+
LLVM_NODISCARD
4365+
bool generateConstraints(CaseLabelItem *caseLabelItem, DeclContext *dc,
4366+
Type convertTy, ConstraintLocator *locator);
4367+
43144368
/// Generate constraints for a case statement.
43154369
///
43164370
/// \param subjectType The type of the "subject" expression in the enclosing

lib/IDE/TypeCheckCompletionCallback.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
8181
/// \endcode
8282
/// If the code completion expression occurs in such an AST, return the
8383
/// declaration of the \c $match variable, otherwise return \c nullptr.
84-
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
84+
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
85+
if (auto EP = S.getExprPattern(E))
86+
return EP.get()->getMatchVar();
87+
88+
// TODO: Once ExprPattern type-checking is fully moved into the solver,
89+
// the below can be deleted.
90+
auto &CS = S.getConstraintSystem();
8591
auto &Context = CS.getASTContext();
8692

8793
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
109115
}
110116

111117
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
112-
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
113-
Type MatchVarType;
114-
// If the MatchVar has an explicit type, it's not part of the solution. But
115-
// we can look it up in the constraint system directly.
116-
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
117-
MatchVarType = T;
118-
} else {
119-
MatchVarType = getTypeForCompletion(S, MatchVar);
120-
}
121-
if (MatchVarType) {
122-
return MatchVarType;
123-
}
124-
}
125-
return nullptr;
118+
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
119+
if (!MatchVar)
120+
return nullptr;
121+
122+
if (S.hasType(MatchVar))
123+
return S.getResolvedType(MatchVar);
124+
125+
// If the ExprPattern wasn't solved as part of the constraint system, it's
126+
// not part of the solution.
127+
// TODO: This can be removed once ExprPattern type-checking is fully part
128+
// of the constraint system.
129+
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
130+
return T;
131+
132+
return getTypeForCompletion(S, MatchVar);
126133
}
127134

128135
void swift::ide::getSolutionSpecificVarTypes(

lib/Sema/CSApply.cpp

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -8554,6 +8554,9 @@ namespace {
85548554
return Action::SkipChildren();
85558555
}
85568556

8557+
NullablePtr<Pattern>
8558+
rewritePattern(Pattern *pattern, DeclContext *DC);
8559+
85578560
/// Rewrite the target, producing a new target.
85588561
Optional<SolutionApplicationTarget>
85598562
rewriteTarget(SolutionApplicationTarget target);
@@ -8800,12 +8803,68 @@ static Expr *wrapAsyncLetInitializer(
88008803
return resultInit;
88018804
}
88028805

8806+
static Pattern *rewriteExprPattern(SolutionApplicationTarget matchTarget,
8807+
Type patternTy,
8808+
RewriteTargetFn rewriteTarget) {
8809+
auto *EP = matchTarget.getExprPattern();
8810+
8811+
// See if we can simplify to another kind of pattern.
8812+
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
8813+
return simplified.get();
8814+
8815+
auto resultTarget = rewriteTarget(matchTarget);
8816+
if (!resultTarget)
8817+
return nullptr;
8818+
8819+
matchTarget = *resultTarget;
8820+
8821+
EP->setMatchExpr(matchTarget.getAsExpr());
8822+
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
8823+
EP->setType(patternTy);
8824+
return EP;
8825+
}
8826+
8827+
static Optional<Pattern *>
8828+
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
8829+
RewriteTargetFn rewriteTarget) {
8830+
// We may either have an ExprPattern, or a pattern that was mapped to an
8831+
// ExprPattern, such as an EnumElementPattern.
8832+
auto &cs = solution.getConstraintSystem();
8833+
auto matchTarget = cs.getSolutionApplicationTarget(P);
8834+
if (!matchTarget)
8835+
return None;
8836+
8837+
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
8838+
}
8839+
8840+
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
8841+
DeclContext *DC) {
8842+
auto &solution = Rewriter.solution;
8843+
8844+
// Figure out the pattern type.
8845+
auto patternTy = solution.getResolvedType(pattern);
8846+
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
8847+
8848+
// Coerce the pattern to its appropriate type.
8849+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
8850+
patternOptions |= TypeResolutionFlags::OverrideType;
8851+
8852+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8853+
return ::tryRewriteExprPattern(
8854+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8855+
};
8856+
8857+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
8858+
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
8859+
patternOptions, tryRewritePattern);
8860+
}
8861+
88038862
/// Apply the given solution to the initialization target.
88048863
///
88058864
/// \returns the resulting initialization expression.
88068865
static Optional<SolutionApplicationTarget> applySolutionToInitialization(
8807-
Solution &solution, SolutionApplicationTarget target,
8808-
Expr *initializer) {
8866+
Solution &solution, SolutionApplicationTarget target, Expr *initializer,
8867+
RewriteTargetFn rewriteTarget) {
88098868
auto wrappedVar = target.getInitializationWrappedVar();
88108869
Type initType;
88118870
if (wrappedVar) {
@@ -8870,10 +8929,15 @@ static Optional<SolutionApplicationTarget> applySolutionToInitialization(
88708929

88718930
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
88728931

8932+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8933+
return ::tryRewriteExprPattern(
8934+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8935+
};
8936+
88738937
// Apply the solution to the pattern as well.
88748938
auto contextualPattern = target.getContextualPattern();
88758939
if (auto coercedPattern = TypeChecker::coercePatternToType(
8876-
contextualPattern, finalPatternType, options)) {
8940+
contextualPattern, finalPatternType, options, tryRewritePattern)) {
88778941
resultTarget.setPattern(coercedPattern);
88788942
} else {
88798943
return None;
@@ -9023,10 +9087,16 @@ static Optional<SolutionApplicationTarget> applySolutionToForEachStmt(
90239087
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
90249088
options |= TypeResolutionFlags::OverrideType;
90259089

9090+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9091+
return ::tryRewriteExprPattern(
9092+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9093+
};
9094+
90269095
// Apply the solution to the pattern as well.
90279096
auto contextualPattern = target.getContextualPattern();
90289097
auto coercedPattern = TypeChecker::coercePatternToType(
9029-
contextualPattern, forEachStmtInfo.initType, options);
9098+
contextualPattern, forEachStmtInfo.initType, options,
9099+
tryRewritePattern);
90309100
if (!coercedPattern)
90319101
return None;
90329102

@@ -9114,7 +9184,8 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
91149184
switch (target.getExprContextualTypePurpose()) {
91159185
case CTP_Initialization: {
91169186
auto initResultTarget = applySolutionToInitialization(
9117-
solution, target, rewrittenExpr);
9187+
solution, target, rewrittenExpr,
9188+
[&](auto target) { return rewriteTarget(target); });
91189189
if (!initResultTarget)
91199190
return None;
91209191

@@ -9205,47 +9276,11 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
92059276
ConstraintSystem &cs = solution.getConstraintSystem();
92069277
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
92079278

9208-
// Figure out the pattern type.
9209-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9210-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9211-
9212-
// Check whether this enum element is resolved via ~= application.
9213-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9214-
if (auto target = cs.getSolutionApplicationTarget(enumElement)) {
9215-
auto *EP = target->getExprPattern();
9216-
auto enumType = solution.getResolvedType(EP);
9217-
9218-
auto *matchCall = target->getAsExpr();
9219-
9220-
auto *result = matchCall->walk(*this);
9221-
if (!result)
9222-
return None;
9223-
9224-
{
9225-
auto *matchVar = EP->getMatchVar();
9226-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9227-
}
9228-
9229-
EP->setMatchExpr(result);
9230-
EP->setType(enumType);
9231-
9232-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9233-
return target;
9234-
}
9235-
}
9236-
9237-
// Coerce the pattern to its appropriate type.
9238-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9239-
patternOptions |= TypeResolutionFlags::OverrideType;
9240-
auto contextualPattern =
9241-
ContextualPattern::forRawPattern(info.pattern,
9242-
target.getDeclContext());
9243-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9244-
contextualPattern, patternType, patternOptions)) {
9245-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9246-
} else {
9279+
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
9280+
if (!pattern)
92479281
return None;
9248-
}
9282+
9283+
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
92499284

92509285
// If there is a guard expression, coerce that.
92519286
if (auto *guardExpr = info.guardExpr) {
@@ -9314,8 +9349,13 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
93149349
options |= TypeResolutionFlags::OverrideType;
93159350
}
93169351

9352+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9353+
return ::tryRewriteExprPattern(
9354+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9355+
};
9356+
93179357
if (auto coercedPattern = TypeChecker::coercePatternToType(
9318-
contextualPattern, patternType, options)) {
9358+
contextualPattern, patternType, options, tryRewritePattern)) {
93199359
auto resultTarget = target;
93209360
resultTarget.setPattern(coercedPattern);
93219361
return resultTarget;

0 commit comments

Comments
 (0)