Skip to content

Commit e4df0ed

Browse files
committed
[CS] Allow ExprPatterns to be type-checked in the solver
Previously we would wait until CSApply, which would trigger their type-checking in `coercePatternToType`. This caused a number of bugs, and hampered solver-based completion, which does not run CSApply. Instead, form a conjunction of all the ExprPatterns present, which preserves some of the previous isolation behavior (though does not provide complete isolation). We can then modify `coercePatternToType` to accept a closure, which allows the solver to take over rewriting the ExprPatterns it has already solved. This then sets the stage for the complete removal of `coercePatternToType`, and doing all pattern type-checking in the solver.
1 parent 3bc3be9 commit e4df0ed

24 files changed

+482
-145
lines changed

include/swift/Sema/ConstraintLocatorPathElts.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,9 @@ CUSTOM_LOCATOR_PATH_ELT(TernaryBranch)
222222
/// Performing a pattern patch.
223223
CUSTOM_LOCATOR_PATH_ELT(PatternMatch)
224224

225+
/// A conjunction for the ExprPatterns in a pattern.
226+
SIMPLE_LOCATOR_PATH_ELT(ExprPatternConjunction)
227+
225228
/// Points to a particular attribute associated with one of
226229
/// the arguments e.g. `inout` or its type e.g. `@escaping`.
227230
///

include/swift/Sema/ConstraintSystem.h

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,11 @@ T *getAsPattern(ASTNode node) {
706706
return nullptr;
707707
}
708708

709+
template <typename T = Pattern>
710+
T *castToPattern(ASTNode node) {
711+
return cast<T>(node.get<Pattern *>());
712+
}
713+
709714
template <typename T = Stmt> T *castToStmt(ASTNode node) {
710715
return cast<T>(node.get<Stmt *>());
711716
}
@@ -1515,6 +1520,10 @@ class Solution {
15151520
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
15161521
caseLabelItems;
15171522

1523+
/// A map of expressions to the ExprPatterns that they are being solved as
1524+
/// a part of.
1525+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
1526+
15181527
/// The set of parameters that have been inferred to be 'isolated'.
15191528
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
15201529

@@ -1700,6 +1709,16 @@ class Solution {
17001709
: nullptr;
17011710
}
17021711

1712+
/// Retrieve the solved ExprPattern that corresponds to provided
1713+
/// sub-expression.
1714+
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
1715+
auto result = exprPatterns.find(E);
1716+
if (result == exprPatterns.end())
1717+
return nullptr;
1718+
1719+
return result->second;
1720+
}
1721+
17031722
/// This method implements functionality of `Expr::isTypeReference`
17041723
/// with data provided by a given solution.
17051724
bool isTypeReference(Expr *E) const;
@@ -2163,6 +2182,10 @@ class ConstraintSystem {
21632182
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
21642183
caseLabelItems;
21652184

2185+
/// A map of expressions to the ExprPatterns that they are being solved as
2186+
/// a part of.
2187+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
2188+
21662189
/// The set of parameters that have been inferred to be 'isolated'.
21672190
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
21682191

@@ -2754,6 +2777,9 @@ class ConstraintSystem {
27542777
/// The length of \c caseLabelItems.
27552778
unsigned numCaseLabelItems;
27562779

2780+
/// The length of \c exprPatterns.
2781+
unsigned numExprPatterns;
2782+
27572783
/// The length of \c isolatedParams.
27582784
unsigned numIsolatedParams;
27592785

@@ -3175,6 +3201,15 @@ class ConstraintSystem {
31753201
caseLabelItems[item] = info;
31763202
}
31773203

3204+
/// Record a given ExprPattern as the parent of its sub-expression.
3205+
void setExprPatternFor(Expr *E, ExprPattern *EP) {
3206+
assert(E);
3207+
assert(EP);
3208+
auto inserted = exprPatterns.insert({E, EP}).second;
3209+
assert(inserted && "Mapping already defined?");
3210+
(void)inserted;
3211+
}
3212+
31783213
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
31793214
const CaseLabelItem *item) const {
31803215
auto known = caseLabelItems.find(item);
@@ -4299,6 +4334,11 @@ class ConstraintSystem {
42994334
/// \returns \c true if constraint generation failed, \c false otherwise
43004335
bool generateConstraints(SingleValueStmtExpr *E);
43014336

4337+
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4338+
/// that solves each expression in turn.
4339+
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4340+
ConstraintLocatorBuilder locator);
4341+
43024342
/// Generate constraints for the given (unchecked) expression.
43034343
///
43044344
/// \returns a possibly-sanitized expression, or null if an error occurred.

lib/IDE/TypeCheckCompletionCallback.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
8181
/// \endcode
8282
/// If the code completion expression occurs in such an AST, return the
8383
/// declaration of the \c $match variable, otherwise return \c nullptr.
84-
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
84+
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
85+
if (auto EP = S.getExprPatternFor(E))
86+
return EP.get()->getMatchVar();
87+
88+
// TODO: Once ExprPattern type-checking is fully moved into the solver,
89+
// the below can be deleted.
90+
auto &CS = S.getConstraintSystem();
8591
auto &Context = CS.getASTContext();
8692

8793
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
109115
}
110116

111117
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
112-
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
113-
Type MatchVarType;
114-
// If the MatchVar has an explicit type, it's not part of the solution. But
115-
// we can look it up in the constraint system directly.
116-
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
117-
MatchVarType = T;
118-
} else {
119-
MatchVarType = getTypeForCompletion(S, MatchVar);
120-
}
121-
if (MatchVarType) {
122-
return MatchVarType;
123-
}
124-
}
125-
return nullptr;
118+
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
119+
if (!MatchVar)
120+
return nullptr;
121+
122+
if (S.hasType(MatchVar))
123+
return S.getResolvedType(MatchVar);
124+
125+
// If the ExprPattern wasn't solved as part of the constraint system, it's
126+
// not part of the solution.
127+
// TODO: This can be removed once ExprPattern type-checking is fully part
128+
// of the constraint system.
129+
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
130+
return T;
131+
132+
return getTypeForCompletion(S, MatchVar);
126133
}
127134

128135
void swift::ide::getSolutionSpecificVarTypes(

lib/Sema/CSApply.cpp

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8574,6 +8574,9 @@ namespace {
85748574
return Action::SkipChildren();
85758575
}
85768576

8577+
NullablePtr<Pattern>
8578+
rewritePattern(Pattern *pattern, DeclContext *DC);
8579+
85778580
/// Rewrite the target, producing a new target.
85788581
Optional<SyntacticElementTarget>
85798582
rewriteTarget(SyntacticElementTarget target);
@@ -8820,12 +8823,70 @@ static Expr *wrapAsyncLetInitializer(
88208823
return resultInit;
88218824
}
88228825

8826+
static Pattern *rewriteExprPattern(SyntacticElementTarget matchTarget,
8827+
Type patternTy,
8828+
RewriteTargetFn rewriteTarget) {
8829+
auto *EP = matchTarget.getExprPattern();
8830+
8831+
// See if we can simplify to another kind of pattern.
8832+
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
8833+
return simplified.get();
8834+
8835+
auto resultTarget = rewriteTarget(matchTarget);
8836+
if (!resultTarget)
8837+
return nullptr;
8838+
8839+
matchTarget = *resultTarget;
8840+
8841+
EP->setMatchExpr(matchTarget.getAsExpr());
8842+
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
8843+
EP->setType(patternTy);
8844+
return EP;
8845+
}
8846+
8847+
/// Attempt to rewrite either an ExprPattern, or a pattern that was solved as
8848+
/// an ExprPattern, e.g an EnumElementPattern that could not refer to an enum
8849+
/// case.
8850+
static Optional<Pattern *>
8851+
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
8852+
RewriteTargetFn rewriteTarget) {
8853+
// See if we have a match expression target.
8854+
auto matchTarget = solution.getTargetFor(P);
8855+
if (!matchTarget)
8856+
return None;
8857+
8858+
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
8859+
}
8860+
8861+
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
8862+
DeclContext *DC) {
8863+
auto &solution = Rewriter.solution;
8864+
8865+
// Figure out the pattern type.
8866+
auto patternTy = solution.getResolvedType(pattern);
8867+
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
8868+
8869+
// Coerce the pattern to its appropriate type.
8870+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
8871+
patternOptions |= TypeResolutionFlags::OverrideType;
8872+
8873+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8874+
return ::tryRewriteExprPattern(
8875+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8876+
};
8877+
8878+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
8879+
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
8880+
patternOptions, tryRewritePattern);
8881+
}
8882+
88238883
/// Apply the given solution to the initialization target.
88248884
///
88258885
/// \returns the resulting initialization expression.
88268886
static Optional<SyntacticElementTarget>
88278887
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
8828-
Expr *initializer) {
8888+
Expr *initializer,
8889+
RewriteTargetFn rewriteTarget) {
88298890
auto wrappedVar = target.getInitializationWrappedVar();
88308891
Type initType;
88318892
if (wrappedVar) {
@@ -8890,10 +8951,14 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
88908951

88918952
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
88928953

8954+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8955+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
8956+
};
8957+
88938958
// Apply the solution to the pattern as well.
88948959
auto contextualPattern = target.getContextualPattern();
88958960
if (auto coercedPattern = TypeChecker::coercePatternToType(
8896-
contextualPattern, finalPatternType, options)) {
8961+
contextualPattern, finalPatternType, options, tryRewritePattern)) {
88978962
resultTarget.setPattern(coercedPattern);
88988963
} else {
88998964
return None;
@@ -9040,10 +9105,15 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
90409105
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
90419106
options |= TypeResolutionFlags::OverrideType;
90429107

9108+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9109+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9110+
};
9111+
90439112
// Apply the solution to the pattern as well.
90449113
auto contextualPattern = target.getContextualPattern();
90459114
auto coercedPattern = TypeChecker::coercePatternToType(
9046-
contextualPattern, forEachStmtInfo.initType, options);
9115+
contextualPattern, forEachStmtInfo.initType, options,
9116+
tryRewritePattern);
90479117
if (!coercedPattern)
90489118
return None;
90499119

@@ -9131,7 +9201,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
91319201
switch (target.getExprContextualTypePurpose()) {
91329202
case CTP_Initialization: {
91339203
auto initResultTarget = applySolutionToInitialization(
9134-
solution, target, rewrittenExpr);
9204+
solution, target, rewrittenExpr,
9205+
[&](auto target) { return rewriteTarget(target); });
91359206
if (!initResultTarget)
91369207
return None;
91379208

@@ -9222,47 +9293,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92229293
ConstraintSystem &cs = solution.getConstraintSystem();
92239294
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
92249295

9225-
// Figure out the pattern type.
9226-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9227-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9228-
9229-
// Check whether this enum element is resolved via ~= application.
9230-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9231-
if (auto target = cs.getTargetFor(enumElement)) {
9232-
auto *EP = target->getExprPattern();
9233-
auto enumType = solution.getResolvedType(EP);
9234-
9235-
auto *matchCall = target->getAsExpr();
9236-
9237-
auto *result = matchCall->walk(*this);
9238-
if (!result)
9239-
return None;
9240-
9241-
{
9242-
auto *matchVar = EP->getMatchVar();
9243-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9244-
}
9245-
9246-
EP->setMatchExpr(result);
9247-
EP->setType(enumType);
9248-
9249-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9250-
return target;
9251-
}
9252-
}
9253-
9254-
// Coerce the pattern to its appropriate type.
9255-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9256-
patternOptions |= TypeResolutionFlags::OverrideType;
9257-
auto contextualPattern =
9258-
ContextualPattern::forRawPattern(info.pattern,
9259-
target.getDeclContext());
9260-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9261-
contextualPattern, patternType, patternOptions)) {
9262-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9263-
} else {
9296+
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
9297+
if (!pattern)
92649298
return None;
9265-
}
9299+
9300+
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
92669301

92679302
// If there is a guard expression, coerce that.
92689303
if (auto *guardExpr = info.guardExpr) {
@@ -9330,8 +9365,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
93309365
options |= TypeResolutionFlags::OverrideType;
93319366
}
93329367

9368+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9369+
return ::tryRewriteExprPattern(
9370+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9371+
};
9372+
93339373
if (auto coercedPattern = TypeChecker::coercePatternToType(
9334-
contextualPattern, patternType, options)) {
9374+
contextualPattern, patternType, options, tryRewritePattern)) {
93359375
auto resultTarget = target;
93369376
resultTarget.setPattern(coercedPattern);
93379377
return resultTarget;

lib/Sema/CSGen.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,12 +2393,6 @@ namespace {
23932393
// function, to set the type of the pattern.
23942394
auto setType = [&](Type type) {
23952395
CS.setType(pattern, type);
2396-
if (auto PE = dyn_cast<ExprPattern>(pattern)) {
2397-
// Set the type of the pattern's sub-expression as well, so code
2398-
// completion can retrieve the expression's type in case it is a code
2399-
// completion token.
2400-
CS.setType(PE->getSubExpr(), type);
2401-
}
24022396
return type;
24032397
};
24042398

@@ -2816,15 +2810,12 @@ namespace {
28162810
return setType(patternType);
28172811
}
28182812

2819-
// Refutable patterns occur when checking the PatternBindingDecls in an
2820-
// if/let or while/let condition. They always require an initial value,
2821-
// so they always allow unspecified types.
2822-
case PatternKind::Expr:
2823-
// TODO: we could try harder here, e.g. for enum elements to provide the
2824-
// enum type.
2825-
return setType(
2826-
CS.createTypeVariable(
2827-
CS.getConstraintLocator(locator), TVO_CanBindToNoEscape));
2813+
case PatternKind::Expr: {
2814+
// We generate constraints for ExprPatterns in a separate pass. For
2815+
// now, just create a type variable.
2816+
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2817+
TVO_CanBindToNoEscape));
2818+
}
28282819
}
28292820

28302821
llvm_unreachable("Unhandled pattern kind");
@@ -4627,8 +4618,20 @@ Type ConstraintSystem::generateConstraints(
46274618
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
46284619
unsigned patternIndex) {
46294620
ConstraintGenerator cg(*this, nullptr);
4630-
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4631-
patternBinding, patternIndex);
4621+
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4622+
patternBinding, patternIndex);
4623+
assert(ty);
4624+
4625+
// Gather the ExprPatterns, and form a conjunction for their expressions.
4626+
SmallVector<ExprPattern *, 4> exprPatterns;
4627+
pattern->forEachNode([&](Pattern *P) {
4628+
if (auto *EP = dyn_cast<ExprPattern>(P))
4629+
exprPatterns.push_back(EP);
4630+
});
4631+
if (!exprPatterns.empty())
4632+
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4633+
4634+
return ty;
46324635
}
46334636

46344637
bool ConstraintSystem::generateConstraints(StmtCondition condition,

0 commit comments

Comments
 (0)