Skip to content

Commit c74fd07

Browse files
authored
Merge pull request #64280 from hamishknight/platypus
[CS] Allow ExprPatterns to be type-checked in the solver
2 parents 291fe21 + eaa61a0 commit c74fd07

37 files changed

+709
-338
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,9 @@ ERROR(cannot_match_expr_tuple_pattern_with_nontuple_value,none,
197197
ERROR(cannot_match_unresolved_expr_pattern_with_value,none,
198198
"pattern cannot match values of type %0",
199199
(Type))
200+
ERROR(cannot_match_value_with_pattern,none,
201+
"pattern of type %1 cannot match %0",
202+
(Type, Type))
200203

201204
ERROR(cannot_reference_compare_types,none,
202205
"cannot check reference equality of functions; operands here have types "

include/swift/Sema/CompletionContextFinder.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020

2121
namespace swift {
2222

23+
namespace constraints {
24+
class SyntacticElementTarget;
25+
}
26+
2327
class CompletionContextFinder : public ASTWalker {
2428
enum class ContextKind {
2529
FallbackExpression,
@@ -53,12 +57,9 @@ class CompletionContextFinder : public ASTWalker {
5357
return MacroWalking::Arguments;
5458
}
5559

56-
/// Finder for completion contexts within the provided initial expression.
57-
CompletionContextFinder(ASTNode initialNode, DeclContext *DC)
58-
: InitialExpr(initialNode.dyn_cast<Expr *>()), InitialDC(DC) {
59-
assert(DC);
60-
initialNode.walk(*this);
61-
};
60+
/// Finder for completion contexts within the provided SyntacticElementTarget.
61+
CompletionContextFinder(constraints::SyntacticElementTarget target,
62+
DeclContext *DC);
6263

6364
/// Finder for completion contexts within the outermost non-closure context of
6465
/// the code completion expression's direct context.

include/swift/Sema/ConstraintLocator.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ class ConstraintLocator : public llvm::FoldingSetNode {
318318
/// otherwise \c nullptr.
319319
NullablePtr<Pattern> getPatternMatch() const;
320320

321+
/// Whether the locator in question is for a pattern match.
322+
bool isForPatternMatch() const;
323+
321324
/// Returns true if \p locator is ending with either of the following
322325
/// - Member
323326
/// - Member -> KeyPathDynamicMember

include/swift/Sema/ConstraintLocatorPathElts.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ CUSTOM_LOCATOR_PATH_ELT(TernaryBranch)
225225
/// Performing a pattern patch.
226226
CUSTOM_LOCATOR_PATH_ELT(PatternMatch)
227227

228+
/// The constraint that models the allowed implicit casts for
229+
/// an EnumElementPattern.
230+
SIMPLE_LOCATOR_PATH_ELT(EnumPatternImplicitCastMatch)
231+
228232
/// Points to a particular attribute associated with one of
229233
/// the arguments e.g. `inout` or its type e.g. `@escaping`.
230234
///

include/swift/Sema/ConstraintSystem.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,10 @@ class Solution {
15001500
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
15011501
caseLabelItems;
15021502

1503+
/// A map of expressions to the ExprPatterns that they are being solved as
1504+
/// a part of.
1505+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
1506+
15031507
/// The set of parameters that have been inferred to be 'isolated'.
15041508
llvm::SmallVector<ParamDecl *, 2> isolatedParams;
15051509

@@ -1685,6 +1689,16 @@ class Solution {
16851689
: nullptr;
16861690
}
16871691

1692+
/// Retrieve the solved ExprPattern that corresponds to provided
1693+
/// sub-expression.
1694+
NullablePtr<ExprPattern> getExprPatternFor(Expr *E) const {
1695+
auto result = exprPatterns.find(E);
1696+
if (result == exprPatterns.end())
1697+
return nullptr;
1698+
1699+
return result->second;
1700+
}
1701+
16881702
/// This method implements functionality of `Expr::isTypeReference`
16891703
/// with data provided by a given solution.
16901704
bool isTypeReference(Expr *E) const;
@@ -2148,6 +2162,10 @@ class ConstraintSystem {
21482162
llvm::SmallMapVector<const CaseLabelItem *, CaseLabelItemInfo, 4>
21492163
caseLabelItems;
21502164

2165+
/// A map of expressions to the ExprPatterns that they are being solved as
2166+
/// a part of.
2167+
llvm::SmallMapVector<Expr *, ExprPattern *, 2> exprPatterns;
2168+
21512169
/// The set of parameters that have been inferred to be 'isolated'.
21522170
llvm::SmallSetVector<ParamDecl *, 2> isolatedParams;
21532171

@@ -2745,6 +2763,9 @@ class ConstraintSystem {
27452763
/// The length of \c caseLabelItems.
27462764
unsigned numCaseLabelItems;
27472765

2766+
/// The length of \c exprPatterns.
2767+
unsigned numExprPatterns;
2768+
27482769
/// The length of \c isolatedParams.
27492770
unsigned numIsolatedParams;
27502771

@@ -3166,6 +3187,15 @@ class ConstraintSystem {
31663187
caseLabelItems[item] = info;
31673188
}
31683189

3190+
/// Record a given ExprPattern as the parent of its sub-expression.
3191+
void setExprPatternFor(Expr *E, ExprPattern *EP) {
3192+
assert(E);
3193+
assert(EP);
3194+
auto inserted = exprPatterns.insert({E, EP}).second;
3195+
assert(inserted && "Mapping already defined?");
3196+
(void)inserted;
3197+
}
3198+
31693199
Optional<CaseLabelItemInfo> getCaseLabelItemInfo(
31703200
const CaseLabelItem *item) const {
31713201
auto known = caseLabelItems.find(item);
@@ -4315,6 +4345,11 @@ class ConstraintSystem {
43154345
/// \returns \c true if constraint generation failed, \c false otherwise
43164346
bool generateConstraints(SingleValueStmtExpr *E);
43174347

4348+
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349+
/// that solves each expression in turn.
4350+
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351+
ConstraintLocatorBuilder locator);
4352+
43184353
/// Generate constraints for the given (unchecked) expression.
43194354
///
43204355
/// \returns a possibly-sanitized expression, or null if an error occurred.

lib/IDE/TypeCheckCompletionCallback.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ Type swift::ide::getTypeForCompletion(const constraints::Solution &S,
8181
/// \endcode
8282
/// If the code completion expression occurs in such an AST, return the
8383
/// declaration of the \c $match variable, otherwise return \c nullptr.
84-
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
84+
static VarDecl *getMatchVarIfInPatternMatch(Expr *E, const Solution &S) {
85+
if (auto EP = S.getExprPatternFor(E))
86+
return EP.get()->getMatchVar();
87+
88+
// TODO: Once ExprPattern type-checking is fully moved into the solver,
89+
// the below can be deleted.
90+
auto &CS = S.getConstraintSystem();
8591
auto &Context = CS.getASTContext();
8692

8793
auto *Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(E));
@@ -109,20 +115,21 @@ static VarDecl *getMatchVarIfInPatternMatch(Expr *E, ConstraintSystem &CS) {
109115
}
110116

111117
Type swift::ide::getPatternMatchType(const constraints::Solution &S, Expr *E) {
112-
if (auto MatchVar = getMatchVarIfInPatternMatch(E, S.getConstraintSystem())) {
113-
Type MatchVarType;
114-
// If the MatchVar has an explicit type, it's not part of the solution. But
115-
// we can look it up in the constraint system directly.
116-
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
117-
MatchVarType = T;
118-
} else {
119-
MatchVarType = getTypeForCompletion(S, MatchVar);
120-
}
121-
if (MatchVarType) {
122-
return MatchVarType;
123-
}
124-
}
125-
return nullptr;
118+
auto MatchVar = getMatchVarIfInPatternMatch(E, S);
119+
if (!MatchVar)
120+
return nullptr;
121+
122+
if (S.hasType(MatchVar))
123+
return S.getResolvedType(MatchVar);
124+
125+
// If the ExprPattern wasn't solved as part of the constraint system, it's
126+
// not part of the solution.
127+
// TODO: This can be removed once ExprPattern type-checking is fully part
128+
// of the constraint system.
129+
if (auto T = S.getConstraintSystem().getVarType(MatchVar))
130+
return T;
131+
132+
return getTypeForCompletion(S, MatchVar);
126133
}
127134

128135
void swift::ide::getSolutionSpecificVarTypes(

lib/Sema/BuilderTransform.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,8 @@ Optional<BraceStmt *> TypeChecker::applyResultBuilderBodyTransform(
984984
SmallVector<Solution, 4> solutions;
985985
cs.solveForCodeCompletion(solutions);
986986

987-
CompletionContextFinder analyzer(func, func->getDeclContext());
987+
SyntacticElementTarget funcTarget(func);
988+
CompletionContextFinder analyzer(funcTarget, func->getDeclContext());
988989
if (analyzer.hasCompletion()) {
989990
filterSolutionsForCodeCompletion(solutions, analyzer);
990991
for (const auto &solution : solutions) {

lib/Sema/CSApply.cpp

Lines changed: 83 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8705,6 +8705,9 @@ namespace {
87058705
return Action::SkipChildren();
87068706
}
87078707

8708+
NullablePtr<Pattern>
8709+
rewritePattern(Pattern *pattern, DeclContext *DC);
8710+
87088711
/// Rewrite the target, producing a new target.
87098712
Optional<SyntacticElementTarget>
87108713
rewriteTarget(SyntacticElementTarget target);
@@ -8951,12 +8954,68 @@ static Expr *wrapAsyncLetInitializer(
89518954
return resultInit;
89528955
}
89538956

8957+
static Pattern *rewriteExprPattern(const SyntacticElementTarget &matchTarget,
8958+
Type patternTy,
8959+
RewriteTargetFn rewriteTarget) {
8960+
auto *EP = matchTarget.getExprPattern();
8961+
8962+
// See if we can simplify to another kind of pattern.
8963+
if (auto simplified = TypeChecker::trySimplifyExprPattern(EP, patternTy))
8964+
return simplified.get();
8965+
8966+
auto resultTarget = rewriteTarget(matchTarget);
8967+
if (!resultTarget)
8968+
return nullptr;
8969+
8970+
EP->setMatchExpr(resultTarget->getAsExpr());
8971+
EP->getMatchVar()->setInterfaceType(patternTy->mapTypeOutOfContext());
8972+
EP->setType(patternTy);
8973+
return EP;
8974+
}
8975+
8976+
/// Attempt to rewrite either an ExprPattern, or a pattern that was solved as
8977+
/// an ExprPattern, e.g an EnumElementPattern that could not refer to an enum
8978+
/// case.
8979+
static Optional<Pattern *>
8980+
tryRewriteExprPattern(Pattern *P, Solution &solution, Type patternTy,
8981+
RewriteTargetFn rewriteTarget) {
8982+
// See if we have a match expression target.
8983+
auto matchTarget = solution.getTargetFor(P);
8984+
if (!matchTarget)
8985+
return None;
8986+
8987+
return rewriteExprPattern(*matchTarget, patternTy, rewriteTarget);
8988+
}
8989+
8990+
NullablePtr<Pattern> ExprWalker::rewritePattern(Pattern *pattern,
8991+
DeclContext *DC) {
8992+
auto &solution = Rewriter.solution;
8993+
8994+
// Figure out the pattern type.
8995+
auto patternTy = solution.getResolvedType(pattern);
8996+
patternTy = patternTy->reconstituteSugar(/*recursive=*/false);
8997+
8998+
// Coerce the pattern to its appropriate type.
8999+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9000+
patternOptions |= TypeResolutionFlags::OverrideType;
9001+
9002+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9003+
return ::tryRewriteExprPattern(
9004+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9005+
};
9006+
9007+
auto contextualPattern = ContextualPattern::forRawPattern(pattern, DC);
9008+
return TypeChecker::coercePatternToType(contextualPattern, patternTy,
9009+
patternOptions, tryRewritePattern);
9010+
}
9011+
89549012
/// Apply the given solution to the initialization target.
89559013
///
89569014
/// \returns the resulting initialization expression.
89579015
static Optional<SyntacticElementTarget>
89589016
applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
8959-
Expr *initializer) {
9017+
Expr *initializer,
9018+
RewriteTargetFn rewriteTarget) {
89609019
auto wrappedVar = target.getInitializationWrappedVar();
89619020
Type initType;
89629021
if (wrappedVar) {
@@ -9021,10 +9080,14 @@ applySolutionToInitialization(Solution &solution, SyntacticElementTarget target,
90219080

90229081
finalPatternType = finalPatternType->reconstituteSugar(/*recursive =*/false);
90239082

9083+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9084+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9085+
};
9086+
90249087
// Apply the solution to the pattern as well.
90259088
auto contextualPattern = target.getContextualPattern();
90269089
if (auto coercedPattern = TypeChecker::coercePatternToType(
9027-
contextualPattern, finalPatternType, options)) {
9090+
contextualPattern, finalPatternType, options, tryRewritePattern)) {
90289091
resultTarget.setPattern(coercedPattern);
90299092
} else {
90309093
return None;
@@ -9171,10 +9234,15 @@ static Optional<SyntacticElementTarget> applySolutionToForEachStmt(
91719234
TypeResolutionOptions options(TypeResolverContext::ForEachStmt);
91729235
options |= TypeResolutionFlags::OverrideType;
91739236

9237+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9238+
return ::tryRewriteExprPattern(EP, solution, ty, rewriteTarget);
9239+
};
9240+
91749241
// Apply the solution to the pattern as well.
91759242
auto contextualPattern = target.getContextualPattern();
91769243
auto coercedPattern = TypeChecker::coercePatternToType(
9177-
contextualPattern, forEachStmtInfo.initType, options);
9244+
contextualPattern, forEachStmtInfo.initType, options,
9245+
tryRewritePattern);
91789246
if (!coercedPattern)
91799247
return None;
91809248

@@ -9262,7 +9330,8 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92629330
switch (target.getExprContextualTypePurpose()) {
92639331
case CTP_Initialization: {
92649332
auto initResultTarget = applySolutionToInitialization(
9265-
solution, target, rewrittenExpr);
9333+
solution, target, rewrittenExpr,
9334+
[&](auto target) { return rewriteTarget(target); });
92669335
if (!initResultTarget)
92679336
return None;
92689337

@@ -9353,47 +9422,11 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
93539422
ConstraintSystem &cs = solution.getConstraintSystem();
93549423
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
93559424

9356-
// Figure out the pattern type.
9357-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9358-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9359-
9360-
// Check whether this enum element is resolved via ~= application.
9361-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9362-
if (auto target = cs.getTargetFor(enumElement)) {
9363-
auto *EP = target->getExprPattern();
9364-
auto enumType = solution.getResolvedType(EP);
9365-
9366-
auto *matchCall = target->getAsExpr();
9367-
9368-
auto *result = matchCall->walk(*this);
9369-
if (!result)
9370-
return None;
9371-
9372-
{
9373-
auto *matchVar = EP->getMatchVar();
9374-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9375-
}
9376-
9377-
EP->setMatchExpr(result);
9378-
EP->setType(enumType);
9379-
9380-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9381-
return target;
9382-
}
9383-
}
9384-
9385-
// Coerce the pattern to its appropriate type.
9386-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9387-
patternOptions |= TypeResolutionFlags::OverrideType;
9388-
auto contextualPattern =
9389-
ContextualPattern::forRawPattern(info.pattern,
9390-
target.getDeclContext());
9391-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9392-
contextualPattern, patternType, patternOptions)) {
9393-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9394-
} else {
9425+
auto pattern = rewritePattern(info.pattern, target.getDeclContext());
9426+
if (!pattern)
93959427
return None;
9396-
}
9428+
9429+
(*caseLabelItem)->setPattern(pattern.get(), /*resolved=*/true);
93979430

93989431
// If there is a guard expression, coerce that.
93999432
if (auto *guardExpr = info.guardExpr) {
@@ -9461,8 +9494,13 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
94619494
options |= TypeResolutionFlags::OverrideType;
94629495
}
94639496

9497+
auto tryRewritePattern = [&](Pattern *EP, Type ty) {
9498+
return ::tryRewriteExprPattern(
9499+
EP, solution, ty, [&](auto target) { return rewriteTarget(target); });
9500+
};
9501+
94649502
if (auto coercedPattern = TypeChecker::coercePatternToType(
9465-
contextualPattern, patternType, options)) {
9503+
contextualPattern, patternType, options, tryRewritePattern)) {
94669504
auto resultTarget = target;
94679505
resultTarget.setPattern(coercedPattern);
94689506
return resultTarget;

0 commit comments

Comments
 (0)