Skip to content

Commit 8bf9c68

Browse files
committed
[CS] Allow ExprPatterns to be type-checked in the solver
Previously we would wait until CSApply, which would trigger their type-checking in `coercePatternToType`. This caused a number of bugs, and hampered solver-based completion, which does not run CSApply. Instead, form a conjunction of all the ExprPatterns present, which preserves some of the previous isolation behavior (though does not provide complete isolation). We can then modify `coercePatternToType` to accept a closure, which allows the solver to take over rewriting the ExprPatterns it has already solved. This then sets the stage for the complete removal of `coercePatternToType`, and doing all pattern type-checking in the solver.
1 parent 8cf6765 commit 8bf9c68

21 files changed

+463
-143
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,10 @@ class Solution {
14901490
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
14911491
caseLabelItems;
14921492

1493+
/// A map of expressions to the ExprPatterns that they are being solved as
1494+
/// a part of.
1495+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
1496+
14931497
/// The set of parameters that have been inferred to be 'isolated'.
14941498
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
14951499

@@ -1675,6 +1679,16 @@ class Solution {
16751679
: nullptr;
16761680
}
16771681

1682+
/// Retrieve the solved ExprPattern that corresponds to provided
1683+
/// sub-expression.
1684+
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
1685+
auto result = exprPatterns.find(E);
1686+
if (result == exprPatterns.end())
1687+
return nullptr;
1688+
1689+
return result->second;
1690+
}
1691+
16781692
/// This method implements functionality of `Expr::isTypeReference`
16791693
/// with data provided by a given solution.
16801694
bool isTypeReference(Expr *E) const;
@@ -2138,6 +2152,10 @@ class ConstraintSystem {
21382152
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
21392153
caseLabelItems;
21402154

2155+
/// A map of expressions to the ExprPatterns that they are being solved as
2156+
/// a part of.
2157+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
2158+
21412159
/// The set of parameters that have been inferred to be 'isolated'.
21422160
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
21432161

@@ -2729,6 +2747,9 @@ class ConstraintSystem {
27292747
/// The length of \c caseLabelItems.
27302748
unsigned numCaseLabelItems;
27312749

2750+
/// The length of \c exprPatterns.
2751+
unsigned numExprPatterns;
2752+
27322753
/// The length of \c isolatedParams.
27332754
unsigned numIsolatedParams;
27342755

@@ -3150,6 +3171,15 @@ class ConstraintSystem {
31503171
caseLabelItems[item] = info;
31513172
}
31523173

3174+
/// Record a given ExprPattern as the parent of its sub-expression.
3175+
void setExprPatternFor(Expr *E, ExprPattern *EP) {
3176+
assert(E);
3177+
assert(EP);
3178+
auto inserted = exprPatterns.insert({E, EP}).second;
3179+
assert(inserted && "Mapping already defined?");
3180+
(void)inserted;
3181+
}
3182+
31533183
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
31543184
const CaseLabelItem *item) const {
31553185
auto known = caseLabelItems.find(item);
@@ -4274,6 +4304,11 @@ class ConstraintSystem {
42744304
/// \returns \c true if constraint generation failed, \c false otherwise
42754305
bool generateConstraints(SingleValueStmtExpr *E);
42764306

4307+
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4308+
/// that solves each expression in turn.
4309+
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4310+
ConstraintLocatorBuilder locator);
4311+
42774312
/// Generate constraints for the given (unchecked) expression.
42784313
///
42794314
/// \returns a possibly-sanitized expression, or null if an error occurred.

lib/IDE/TypeCheckCompletionCallback.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
8181
/// \endcode
8282
/// If the code completion expression occurs in such an AST, return the
8383
/// declaration of the \c $match variable, otherwise return \c nullptr.
84-
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
84+
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
85+
if (auto EP = S.getExprPatternFor(E))
86+
return EP.get()->getMatchVar();
87+
88+
// TODO: Once ExprPattern type-checking is fully moved into the solver,
89+
// the below can be deleted.
90+
auto &CS = S.getConstraintSystem();
8591
auto &Context = CS.getASTContext();
8692

8793
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
109115
}
110116

111117
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
112-
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
113-
Type MatchVarType;
114-
// If the MatchVar has an explicit type, it's not part of the solution. But
115-
// we can look it up in the constraint system directly.
116-
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
117-
MatchVarType = T;
118-
} else {
119-
MatchVarType = getTypeForCompletion(S, MatchVar);
120-
}
121-
if (MatchVarType) {
122-
return MatchVarType;
123-
}
124-
}
125-
return nullptr;
118+
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
119+
if (!MatchVar)
120+
return nullptr;
121+
122+
if (S.hasType(MatchVar))
123+
return S.getResolvedType(MatchVar);
124+
125+
// If the ExprPattern wasn't solved as part of the constraint system, it's
126+
// not part of the solution.
127+
// TODO: This can be removed once ExprPattern type-checking is fully part
128+
// of the constraint system.
129+
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
130+
return T;
131+
132+
return getTypeForCompletion(S, MatchVar);
126133
}
127134

128135
void swift::ide::getSolutionSpecificVarTypes(

lib/Sema/CSApply.cpp

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8583,6 +8583,9 @@ namespace {
85838583
return Action::SkipChildren();
85848584
}
85858585

8586+
NullablePtr<Pattern>
8587+
rewritePattern(Pattern *pattern, DeclContext *DC);
8588+
85868589
/// Rewrite the target, producing a new target.
85878590
Optional<SyntacticElementTarget>
85888591
rewriteTarget(SyntacticElementTarget target);
@@ -8829,12 +8832,68 @@ static Expr *wrapAsyncLetInitializer(
88298832
return resultInit;
88308833
}
88318834

8835+
static Pattern *rewriteExprPattern(const SyntacticElementTarget &matchTarget,
8836+
Type patternTy,
8837+
RewriteTargetFn rewriteTarget) {
8838+
auto *EP = matchTarget.getExprPattern();
8839+
8840+
// See if we can simplify to another kind of pattern.
8841+
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
8842+
return simplified.get();
8843+
8844+
auto resultTarget = rewriteTarget(matchTarget);
8845+
if (!resultTarget)
8846+
return nullptr;
8847+
8848+
EP->setMatchExpr(resultTarget->getAsExpr());
8849+
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
8850+
EP->setType(patternTy);
8851+
return EP;
8852+
}
8853+
8854+
/// Attempt to rewrite either an ExprPattern, or a pattern that was solved as
8855+
/// an ExprPattern, e.g an EnumElementPattern that could not refer to an enum
8856+
/// case.
8857+
static Optional<Pattern *>
8858+
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
8859+
RewriteTargetFn rewriteTarget) {
8860+
// See if we have a match expression target.
8861+
auto matchTarget = solution.getTargetFor(P);
8862+
if (!matchTarget)
8863+
return None;
8864+
8865+
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
8866+
}
8867+
8868+
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
8869+
DeclContext *DC) {
8870+
auto &solution = Rewriter.solution;
8871+
8872+
// Figure out the pattern type.
8873+
auto patternTy = solution.getResolvedType(pattern);
8874+
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
8875+
8876+
// Coerce the pattern to its appropriate type.
8877+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
8878+
patternOptions |= TypeResolutionFlags::OverrideType;
8879+
8880+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8881+
return ::tryRewriteExprPattern(
8882+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
8883+
};
8884+
8885+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
8886+
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
8887+
patternOptions, tryRewritePattern);
8888+
}
8889+
88328890
/// Apply the given solution to the initialization target.
88338891
///
88348892
/// \returns the resulting initialization expression.
88358893
static Optional<SyntacticElementTarget>
88368894
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
8837-
Expr *initializer) {
8895+
Expr *initializer,
8896+
RewriteTargetFn rewriteTarget) {
88388897
auto wrappedVar = target.getInitializationWrappedVar();
88398898
Type initType;
88408899
if (wrappedVar) {
@@ -8899,10 +8958,14 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
88998958

89008959
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
89018960

8961+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
8962+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
8963+
};
8964+
89028965
// Apply the solution to the pattern as well.
89038966
auto contextualPattern = target.getContextualPattern();
89048967
if (auto coercedPattern = TypeChecker::coercePatternToType(
8905-
contextualPattern, finalPatternType, options)) {
8968+
contextualPattern, finalPatternType, options, tryRewritePattern)) {
89068969
resultTarget.setPattern(coercedPattern);
89078970
} else {
89088971
return None;
@@ -9049,10 +9112,15 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
90499112
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
90509113
options |= TypeResolutionFlags::OverrideType;
90519114

9115+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9116+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9117+
};
9118+
90529119
// Apply the solution to the pattern as well.
90539120
auto contextualPattern = target.getContextualPattern();
90549121
auto coercedPattern = TypeChecker::coercePatternToType(
9055-
contextualPattern, forEachStmtInfo.initType, options);
9122+
contextualPattern, forEachStmtInfo.initType, options,
9123+
tryRewritePattern);
90569124
if (!coercedPattern)
90579125
return None;
90589126

@@ -9140,7 +9208,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
91409208
switch (target.getExprContextualTypePurpose()) {
91419209
case CTP_Initialization: {
91429210
auto initResultTarget = applySolutionToInitialization(
9143-
solution, target, rewrittenExpr);
9211+
solution, target, rewrittenExpr,
9212+
[&](auto target) { return rewriteTarget(target); });
91449213
if (!initResultTarget)
91459214
return None;
91469215

@@ -9231,47 +9300,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92319300
ConstraintSystem &cs = solution.getConstraintSystem();
92329301
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
92339302

9234-
// Figure out the pattern type.
9235-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9236-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9237-
9238-
// Check whether this enum element is resolved via ~= application.
9239-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9240-
if (auto target = cs.getTargetFor(enumElement)) {
9241-
auto *EP = target->getExprPattern();
9242-
auto enumType = solution.getResolvedType(EP);
9243-
9244-
auto *matchCall = target->getAsExpr();
9245-
9246-
auto *result = matchCall->walk(*this);
9247-
if (!result)
9248-
return None;
9249-
9250-
{
9251-
auto *matchVar = EP->getMatchVar();
9252-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9253-
}
9254-
9255-
EP->setMatchExpr(result);
9256-
EP->setType(enumType);
9257-
9258-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9259-
return target;
9260-
}
9261-
}
9262-
9263-
// Coerce the pattern to its appropriate type.
9264-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9265-
patternOptions |= TypeResolutionFlags::OverrideType;
9266-
auto contextualPattern =
9267-
ContextualPattern::forRawPattern(info.pattern,
9268-
target.getDeclContext());
9269-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9270-
contextualPattern, patternType, patternOptions)) {
9271-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9272-
} else {
9303+
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
9304+
if (!pattern)
92739305
return None;
9274-
}
9306+
9307+
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
92759308

92769309
// If there is a guard expression, coerce that.
92779310
if (auto *guardExpr = info.guardExpr) {
@@ -9339,8 +9372,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
93399372
options |= TypeResolutionFlags::OverrideType;
93409373
}
93419374

9375+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9376+
return ::tryRewriteExprPattern(
9377+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9378+
};
9379+
93429380
if (auto coercedPattern = TypeChecker::coercePatternToType(
9343-
contextualPattern, patternType, options)) {
9381+
contextualPattern, patternType, options, tryRewritePattern)) {
93449382
auto resultTarget = target;
93459383
resultTarget.setPattern(coercedPattern);
93469384
return resultTarget;

lib/Sema/CSGen.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,12 +2401,6 @@ namespace {
24012401
// function, to set the type of the pattern.
24022402
auto setType = [&](Type type) {
24032403
CS.setType(pattern, type);
2404-
if (auto PE = dyn_cast<ExprPattern>(pattern)) {
2405-
// Set the type of the pattern's sub-expression as well, so code
2406-
// completion can retrieve the expression's type in case it is a code
2407-
// completion token.
2408-
CS.setType(PE->getSubExpr(), type);
2409-
}
24102404
return type;
24112405
};
24122406

@@ -2824,15 +2818,12 @@ namespace {
28242818
return setType(patternType);
28252819
}
28262820

2827-
// Refutable patterns occur when checking the PatternBindingDecls in an
2828-
// if/let or while/let condition. They always require an initial value,
2829-
// so they always allow unspecified types.
2830-
case PatternKind::Expr:
2831-
// TODO: we could try harder here, e.g. for enum elements to provide the
2832-
// enum type.
2833-
return setType(
2834-
CS.createTypeVariable(
2835-
CS.getConstraintLocator(locator), TVO_CanBindToNoEscape));
2821+
case PatternKind::Expr: {
2822+
// We generate constraints for ExprPatterns in a separate pass. For
2823+
// now, just create a type variable.
2824+
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2825+
TVO_CanBindToNoEscape));
2826+
}
28362827
}
28372828

28382829
llvm_unreachable("Unhandled pattern kind");
@@ -4638,8 +4629,20 @@ Type ConstraintSystem::generateConstraints(
46384629
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
46394630
unsigned patternIndex) {
46404631
ConstraintGenerator cg(*this, nullptr);
4641-
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4642-
patternBinding, patternIndex);
4632+
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4633+
patternBinding, patternIndex);
4634+
assert(ty);
4635+
4636+
// Gather the ExprPatterns, and form a conjunction for their expressions.
4637+
SmallVector<ExprPattern *, 4> exprPatterns;
4638+
pattern->forEachNode([&](Pattern *P) {
4639+
if (auto *EP = dyn_cast<ExprPattern>(P))
4640+
exprPatterns.push_back(EP);
4641+
});
4642+
if (!exprPatterns.empty())
4643+
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4644+
4645+
return ty;
46434646
}
46444647

46454648
bool ConstraintSystem::generateConstraints(StmtCondition condition,

0 commit comments

Comments
 (0)