Skip to content

Commit b8d5cfd

Browse files
committed
[CS] Allow ExprPatterns to be type-checked in the solver
Previously we would wait until CSApply, which would trigger their type-checking in `coercePatternToType`. This caused a number of bugs, and hampered solver-based completion, which does not run CSApply. Instead, form a conjunction of all the ExprPatterns present, which preserves some of the previous isolation behavior (though does not provide complete isolation). We can then modify `coercePatternToType` to accept a closure, which allows the solver to take over rewriting the ExprPatterns it has already solved. This then sets the stage for the complete removal of `coercePatternToType`, and doing all pattern type-checking in the solver.
1 parent 12412a8 commit b8d5cfd

21 files changed

+489
-142
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,10 @@ class Solution {
15001500
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
15011501
caseLabelItems;
15021502

1503+
/// A map of expressions to the ExprPatterns that they are being solved as
1504+
/// a part of.
1505+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
1506+
15031507
/// The set of parameters that have been inferred to be 'isolated'.
15041508
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
15051509

@@ -1685,6 +1689,16 @@ class Solution {
16851689
: nullptr;
16861690
}
16871691

1692+
/// Retrieve the solved ExprPattern that corresponds to provided
1693+
/// sub-expression.
1694+
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
1695+
auto result = exprPatterns.find(E);
1696+
if (result == exprPatterns.end())
1697+
return nullptr;
1698+
1699+
return result->second;
1700+
}
1701+
16881702
/// This method implements functionality of `Expr::isTypeReference`
16891703
/// with data provided by a given solution.
16901704
bool isTypeReference(Expr *E) const;
@@ -2148,6 +2162,10 @@ class ConstraintSystem {
21482162
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
21492163
caseLabelItems;
21502164

2165+
/// A map of expressions to the ExprPatterns that they are being solved as
2166+
/// a part of.
2167+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
2168+
21512169
/// The set of parameters that have been inferred to be 'isolated'.
21522170
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
21532171

@@ -2745,6 +2763,9 @@ class ConstraintSystem {
27452763
/// The length of \c caseLabelItems.
27462764
unsigned numCaseLabelItems;
27472765

2766+
/// The length of \c exprPatterns.
2767+
unsigned numExprPatterns;
2768+
27482769
/// The length of \c isolatedParams.
27492770
unsigned numIsolatedParams;
27502771

@@ -3166,6 +3187,15 @@ class ConstraintSystem {
31663187
caseLabelItems[item] = info;
31673188
}
31683189

3190+
/// Record a given ExprPattern as the parent of its sub-expression.
3191+
void setExprPatternFor(Expr *E, ExprPattern *EP) {
3192+
assert(E);
3193+
assert(EP);
3194+
auto inserted = exprPatterns.insert({E, EP}).second;
3195+
assert(inserted && "Mapping already defined?");
3196+
(void)inserted;
3197+
}
3198+
31693199
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
31703200
const CaseLabelItem *item) const {
31713201
auto known = caseLabelItems.find(item);
@@ -4315,6 +4345,11 @@ class ConstraintSystem {
43154345
/// \returns \c true if constraint generation failed, \c false otherwise
43164346
bool generateConstraints(SingleValueStmtExpr *E);
43174347

4348+
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349+
/// that solves each expression in turn.
4350+
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351+
ConstraintLocatorBuilder locator);
4352+
43184353
/// Generate constraints for the given (unchecked) expression.
43194354
///
43204355
/// \returns a possibly-sanitized expression, or null if an error occurred.

lib/IDE/TypeCheckCompletionCallback.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
8181
/// \endcode
8282
/// If the code completion expression occurs in such an AST, return the
8383
/// declaration of the \c $match variable, otherwise return \c nullptr.
84-
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
84+
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
85+
if (auto EP = S.getExprPatternFor(E))
86+
return EP.get()->getMatchVar();
87+
88+
// TODO: Once ExprPattern type-checking is fully moved into the solver,
89+
// the below can be deleted.
90+
auto &CS = S.getConstraintSystem();
8591
auto &Context = CS.getASTContext();
8692

8793
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
109115
}
110116

111117
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
112-
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
113-
Type MatchVarType;
114-
// If the MatchVar has an explicit type, it's not part of the solution. But
115-
// we can look it up in the constraint system directly.
116-
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
117-
MatchVarType = T;
118-
} else {
119-
MatchVarType = getTypeForCompletion(S, MatchVar);
120-
}
121-
if (MatchVarType) {
122-
return MatchVarType;
123-
}
124-
}
125-
return nullptr;
118+
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
119+
if (!MatchVar)
120+
return nullptr;
121+
122+
if (S.hasType(MatchVar))
123+
return S.getResolvedType(MatchVar);
124+
125+
// If the ExprPattern wasn't solved as part of the constraint system, it's
126+
// not part of the solution.
127+
// TODO: This can be removed once ExprPattern type-checking is fully part
128+
// of the constraint system.
129+
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
130+
return T;
131+
132+
return getTypeForCompletion(S, MatchVar);
126133
}
127134

128135
void swift::ide::getSolutionSpecificVarTypes(

lib/Sema/CSApply.cpp

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8646,6 +8646,9 @@ namespace {
86468646
return Action::SkipChildren();
86478647
}
86488648

8649+
NullablePtr<Pattern>
8650+
rewritePattern(Pattern *pattern, DeclContext *DC);
8651+
86498652
/// Rewrite the target, producing a new target.
86508653
Optional<SyntacticElementTarget>
86518654
rewriteTarget(SyntacticElementTarget target);
@@ -8892,12 +8895,68 @@ static Expr *wrapAsyncLetInitializer(
88928895
return resultInit;
88938896
}
88948897

8898+
static Pattern *rewriteExprPattern(const SyntacticElementTarget &matchTarget,
8899+
Type patternTy,
8900+
RewriteTargetFn rewriteTarget) {
8901+
auto *EP = matchTarget.getExprPattern();
8902+
8903+
// See if we can simplify to another kind of pattern.
8904+
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
8905+
return simplified.get();
8906+
8907+
auto resultTarget = rewriteTarget(matchTarget);
8908+
if (!resultTarget)
8909+
return nullptr;
8910+
8911+
EP->setMatchExpr(resultTarget->getAsExpr());
8912+
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
8913+
EP->setType(patternTy);
8914+
return EP;
8915+
}
8916+
8917+
/// Attempt to rewrite either an ExprPattern, or a pattern that was solved as
8918+
/// an ExprPattern, e.g an EnumElementPattern that could not refer to an enum
8919+
/// case.
8920+
static Optional<Pattern *>
8921+
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
8922+
RewriteTargetFn rewriteTarget) {
8923+
// See if we have a match expression target.
8924+
auto matchTarget = solution.getTargetFor(P);
8925+
if (!matchTarget)
8926+
return None;
8927+
8928+
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
8929+
}
8930+
8931+
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
8932+
DeclContext *DC) {
8933+
auto &solution = Rewriter.solution;
8934+
8935+
// Figure out the pattern type.
8936+
auto patternTy = solution.getResolvedType(pattern);
8937+
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
8938+
8939+
// Coerce the pattern to its appropriate type.
8940+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
8941+
patternOptions |= TypeResolutionFlags::OverrideType;
8942+
8943+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8944+
return ::tryRewriteExprPattern(
8945+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8946+
};
8947+
8948+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
8949+
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
8950+
patternOptions, tryRewritePattern);
8951+
}
8952+
88958953
/// Apply the given solution to the initialization target.
88968954
///
88978955
/// \returns the resulting initialization expression.
88988956
static Optional<SyntacticElementTarget>
88998957
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
8900-
Expr *initializer) {
8958+
Expr *initializer,
8959+
RewriteTargetFn rewriteTarget) {
89018960
auto wrappedVar = target.getInitializationWrappedVar();
89028961
Type initType;
89038962
if (wrappedVar) {
@@ -8962,10 +9021,14 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
89629021

89639022
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
89649023

9024+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9025+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9026+
};
9027+
89659028
// Apply the solution to the pattern as well.
89669029
auto contextualPattern = target.getContextualPattern();
89679030
if (auto coercedPattern = TypeChecker::coercePatternToType(
8968-
contextualPattern, finalPatternType, options)) {
9031+
contextualPattern, finalPatternType, options, tryRewritePattern)) {
89699032
resultTarget.setPattern(coercedPattern);
89709033
} else {
89719034
return None;
@@ -9112,10 +9175,15 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
91129175
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
91139176
options |= TypeResolutionFlags::OverrideType;
91149177

9178+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9179+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9180+
};
9181+
91159182
// Apply the solution to the pattern as well.
91169183
auto contextualPattern = target.getContextualPattern();
91179184
auto coercedPattern = TypeChecker::coercePatternToType(
9118-
contextualPattern, forEachStmtInfo.initType, options);
9185+
contextualPattern, forEachStmtInfo.initType, options,
9186+
tryRewritePattern);
91199187
if (!coercedPattern)
91209188
return None;
91219189

@@ -9203,7 +9271,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92039271
switch (target.getExprContextualTypePurpose()) {
92049272
case CTP_Initialization: {
92059273
auto initResultTarget = applySolutionToInitialization(
9206-
solution, target, rewrittenExpr);
9274+
solution, target, rewrittenExpr,
9275+
[&](auto target) { return rewriteTarget(target); });
92079276
if (!initResultTarget)
92089277
return None;
92099278

@@ -9294,47 +9363,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92949363
ConstraintSystem &cs = solution.getConstraintSystem();
92959364
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
92969365

9297-
// Figure out the pattern type.
9298-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9299-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9300-
9301-
// Check whether this enum element is resolved via ~= application.
9302-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9303-
if (auto target = cs.getTargetFor(enumElement)) {
9304-
auto *EP = target->getExprPattern();
9305-
auto enumType = solution.getResolvedType(EP);
9306-
9307-
auto *matchCall = target->getAsExpr();
9308-
9309-
auto *result = matchCall->walk(*this);
9310-
if (!result)
9311-
return None;
9312-
9313-
{
9314-
auto *matchVar = EP->getMatchVar();
9315-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9316-
}
9317-
9318-
EP->setMatchExpr(result);
9319-
EP->setType(enumType);
9320-
9321-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9322-
return target;
9323-
}
9324-
}
9325-
9326-
// Coerce the pattern to its appropriate type.
9327-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9328-
patternOptions |= TypeResolutionFlags::OverrideType;
9329-
auto contextualPattern =
9330-
ContextualPattern::forRawPattern(info.pattern,
9331-
target.getDeclContext());
9332-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9333-
contextualPattern, patternType, patternOptions)) {
9334-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9335-
} else {
9366+
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
9367+
if (!pattern)
93369368
return None;
9337-
}
9369+
9370+
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
93389371

93399372
// If there is a guard expression, coerce that.
93409373
if (auto *guardExpr = info.guardExpr) {
@@ -9402,8 +9435,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
94029435
options |= TypeResolutionFlags::OverrideType;
94039436
}
94049437

9438+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9439+
return ::tryRewriteExprPattern(
9440+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9441+
};
9442+
94059443
if (auto coercedPattern = TypeChecker::coercePatternToType(
9406-
contextualPattern, patternType, options)) {
9444+
contextualPattern, patternType, options, tryRewritePattern)) {
94079445
auto resultTarget = target;
94089446
resultTarget.setPattern(coercedPattern);
94099447
return resultTarget;

lib/Sema/CSGen.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,12 +2433,6 @@ namespace {
24332433
// function, to set the type of the pattern.
24342434
auto setType = [&](Type type) {
24352435
CS.setType(pattern, type);
2436-
if (auto PE = dyn_cast<ExprPattern>(pattern)) {
2437-
// Set the type of the pattern's sub-expression as well, so code
2438-
// completion can retrieve the expression's type in case it is a code
2439-
// completion token.
2440-
CS.setType(PE->getSubExpr(), type);
2441-
}
24422436
return type;
24432437
};
24442438

@@ -2863,15 +2857,12 @@ namespace {
28632857
return setType(patternType);
28642858
}
28652859

2866-
// Refutable patterns occur when checking the PatternBindingDecls in an
2867-
// if/let or while/let condition. They always require an initial value,
2868-
// so they always allow unspecified types.
2869-
case PatternKind::Expr:
2870-
// TODO: we could try harder here, e.g. for enum elements to provide the
2871-
// enum type.
2872-
return setType(
2873-
CS.createTypeVariable(CS.getConstraintLocator(locator),
2874-
TVO_CanBindToNoEscape | TVO_CanBindToHole));
2860+
case PatternKind::Expr: {
2861+
// We generate constraints for ExprPatterns in a separate pass. For
2862+
// now, just create a type variable.
2863+
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2864+
TVO_CanBindToNoEscape));
2865+
}
28752866
}
28762867

28772868
llvm_unreachable("Unhandled pattern kind");
@@ -4730,8 +4721,20 @@ Type ConstraintSystem::generateConstraints(
47304721
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
47314722
unsigned patternIndex) {
47324723
ConstraintGenerator cg(*this, nullptr);
4733-
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4734-
patternBinding, patternIndex);
4724+
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4725+
patternBinding, patternIndex);
4726+
assert(ty);
4727+
4728+
// Gather the ExprPatterns, and form a conjunction for their expressions.
4729+
SmallVector<ExprPattern *, 4> exprPatterns;
4730+
pattern->forEachNode([&](Pattern *P) {
4731+
if (auto *EP = dyn_cast<ExprPattern>(P))
4732+
exprPatterns.push_back(EP);
4733+
});
4734+
if (!exprPatterns.empty())
4735+
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4736+
4737+
return ty;
47354738
}
47364739

47374740
bool ConstraintSystem::generateConstraints(StmtCondition condition,

0 commit comments

Comments
 (0)