Skip to content

Commit 07ef390

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent 052be39 commit 07ef390

File tree

8 files changed

+134
-150
lines changed

8 files changed

+134
-150
lines changed

include/swift/AST/Pattern.h

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,11 @@ class alignas(8) Pattern : public ASTAllocated<Pattern> {
140140
/// equivalent to matching this pattern.
141141
///
142142
/// Looks through ParenPattern, BindingPattern, and TypedPattern.
143-
Pattern *getSemanticsProvidingPattern();
144-
const Pattern *getSemanticsProvidingPattern() const {
145-
return const_cast<Pattern*>(this)->getSemanticsProvidingPattern();
143+
Pattern *getSemanticsProvidingPattern(bool allowTypedPattern = true);
144+
const Pattern *
145+
getSemanticsProvidingPattern(bool allowTypedPattern = true) const {
146+
return const_cast<Pattern *>(this)->getSemanticsProvidingPattern(
147+
allowTypedPattern);
146148
}
147149

148150
/// Returns whether this pattern has been type-checked yet.
@@ -799,14 +801,26 @@ class BindingPattern : public Pattern {
799801
}
800802
};
801803

802-
inline Pattern *Pattern::getSemanticsProvidingPattern() {
803-
if (auto *pp = dyn_cast<ParenPattern>(this))
804-
return pp->getSubPattern()->getSemanticsProvidingPattern();
805-
if (auto *tp = dyn_cast<TypedPattern>(this))
806-
return tp->getSubPattern()->getSemanticsProvidingPattern();
807-
if (auto *vp = dyn_cast<BindingPattern>(this))
808-
return vp->getSubPattern()->getSemanticsProvidingPattern();
809-
return this;
804+
inline Pattern *Pattern::getSemanticsProvidingPattern(bool allowTypedPattern) {
805+
auto *P = this;
806+
while (true) {
807+
if (auto *PP = dyn_cast<ParenPattern>(P)) {
808+
P = PP->getSubPattern();
809+
continue;
810+
}
811+
if (auto *BP = dyn_cast<BindingPattern>(P)) {
812+
P = BP->getSubPattern();
813+
continue;
814+
}
815+
if (allowTypedPattern) {
816+
if (auto *TP = dyn_cast<TypedPattern>(P)) {
817+
P = TP->getSubPattern();
818+
continue;
819+
}
820+
}
821+
break;
822+
}
823+
return P;
810824
}
811825

812826
/// Describes a pattern and the context in which it occurs.

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4345,27 +4345,22 @@ class ConstraintSystem {
43454345
/// \returns \c true if constraint generation failed, \c false otherwise
43464346
bool generateConstraints(SingleValueStmtExpr *E);
43474347

4348-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349-
/// that solves each expression in turn.
4350-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351-
ConstraintLocatorBuilder locator);
4352-
43534348
/// Generate constraints for the given (unchecked) expression.
43544349
///
43554350
/// \returns a possibly-sanitized expression, or null if an error occurred.
43564351
[[nodiscard]]
43574352
Expr *generateConstraints(Expr *E, DeclContext *dc,
43584353
bool isInputExpression = true);
43594354

4360-
/// Generate constraints for binding the given pattern to the
4361-
/// value of the given expression.
4355+
/// Generate constraints for a given pattern.
43624356
///
4363-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4357+
/// \returns The type of the pattern, or \c None if a failure occured.
43644358
[[nodiscard]]
4365-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4366-
bool bindPatternVarsOneWay,
4367-
PatternBindingDecl *patternBinding,
4368-
unsigned patternIndex);
4359+
Optional<Type> generateConstraints(Pattern *P,
4360+
ConstraintLocatorBuilder locator,
4361+
bool bindPatternVarsOneWay,
4362+
PatternBindingDecl *patternBinding,
4363+
unsigned patternIndex);
43694364

43704365
/// Generate constraints for a statement condition.
43714366
///

lib/Sema/CSGen.cpp

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,7 +2442,7 @@ namespace {
24422442
/// for the types of each variable declared within the pattern, along
24432443
/// with a one-way constraint binding that to the type to which the
24442444
/// variable will be ascribed or inferred.
2445-
Type getTypeForPattern(
2445+
Optional<Type> getTypeForPattern(
24462446
Pattern *pattern, ConstraintLocatorBuilder locator,
24472447
bool bindPatternVarsOneWay,
24482448
PatternBindingDecl *patternBinding = nullptr,
@@ -2466,14 +2466,21 @@ namespace {
24662466
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24672467
bindPatternVarsOneWay);
24682468

2469-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2469+
if (!underlyingType)
2470+
return None;
2471+
2472+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24702473
}
24712474
case PatternKind::Binding: {
24722475
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24732476
auto type = getTypeForPattern(subPattern, locator,
24742477
bindPatternVarsOneWay);
2478+
2479+
if (!type)
2480+
return None;
2481+
24752482
// Var doesn't affect the type.
2476-
return setType(type);
2483+
return setType(*type);
24772484
}
24782485
case PatternKind::Any: {
24792486
Type type;
@@ -2653,6 +2660,9 @@ namespace {
26532660

26542661
Type type = TypeChecker::typeCheckPattern(contextualPattern);
26552662

2663+
if (!type)
2664+
return None;
2665+
26562666
// Look through reference storage types.
26572667
type = type->getReferenceStorageReferent();
26582668

@@ -2664,16 +2674,19 @@ namespace {
26642674
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26652675
// Determine the subpattern type. It will be convertible to the
26662676
// ascribed type.
2667-
Type subPatternType = getTypeForPattern(
2677+
auto subPatternType = getTypeForPattern(
26682678
subPattern,
26692679
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26702680
bindPatternVarsOneWay);
26712681

2682+
if (!subPatternType)
2683+
return None;
2684+
26722685
// NOTE: The order here is important! Pattern matching equality is
26732686
// not symmetric (we need to fix that either by using a different
26742687
// constraint, or actually making it symmetric).
26752688
CS.addConstraint(
2676-
ConstraintKind::Equal, openedType, subPatternType,
2689+
ConstraintKind::Equal, openedType, *subPatternType,
26772690
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26782691

26792692
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2693,12 +2706,15 @@ namespace {
26932706
auto &tupleElt = tuplePat->getElement(i);
26942707

26952708
auto *eltPattern = tupleElt.getPattern();
2696-
Type eltTy = getTypeForPattern(
2709+
auto eltTy = getTypeForPattern(
26972710
eltPattern,
26982711
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26992712
bindPatternVarsOneWay);
27002713

2701-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2714+
if (!eltTy)
2715+
return None;
2716+
2717+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
27022718
}
27032719

27042720
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2707,12 +2723,15 @@ namespace {
27072723
case PatternKind::OptionalSome: {
27082724
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
27092725
// The subpattern must have optional type.
2710-
Type subPatternType = getTypeForPattern(
2726+
auto subPatternType = getTypeForPattern(
27112727
subPattern,
27122728
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27132729
bindPatternVarsOneWay);
27142730

2715-
return setType(OptionalType::get(subPatternType));
2731+
if (!subPatternType)
2732+
return None;
2733+
2734+
return setType(OptionalType::get(*subPatternType));
27162735
}
27172736

27182737
case PatternKind::Is: {
@@ -2742,12 +2761,14 @@ namespace {
27422761
subPattern,
27432762
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27442763
bindPatternVarsOneWay);
2764+
if (!subPatternType)
2765+
return None;
27452766

27462767
// NOTE: The order here is important! Pattern matching equality is
27472768
// not symmetric (we need to fix that either by using a different
27482769
// constraint, or actually making it symmetric).
27492770
CS.addConstraint(
2750-
ConstraintKind::Equal, castType, subPatternType,
2771+
ConstraintKind::Equal, castType, *subPatternType,
27512772
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
27522773
}
27532774
return setType(isType);
@@ -2811,6 +2832,9 @@ namespace {
28112832
TypeResolverContext::InExpression, patternMatchLoc);
28122833
}();
28132834

2835+
if (!parentType)
2836+
return None;
2837+
28142838
// Perform member lookup into the parent's metatype.
28152839
Type parentMetaType = MetatypeType::get(parentType);
28162840
CS.addValueMemberConstraint(parentMetaType, enumPattern->getName(),
@@ -2838,13 +2862,13 @@ namespace {
28382862
// When there is a subpattern, the member will have function type,
28392863
// and we're matching the type of that subpattern to the parameter
28402864
// types.
2841-
Type subPatternType = getTypeForPattern(
2865+
auto subPatternType = getTypeForPattern(
28422866
subPattern,
28432867
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
28442868
bindPatternVarsOneWay);
28452869

28462870
SmallVector<AnyFunctionType::Param, 4> params;
2847-
decomposeTuple(subPatternType, params);
2871+
decomposeTuple(*subPatternType, params);
28482872

28492873
// Remove parameter labels; they aren't used when matching cases,
28502874
// but outright conflicts will be checked during coercion.
@@ -2877,10 +2901,24 @@ namespace {
28772901
}
28782902

28792903
case PatternKind::Expr: {
2880-
// We generate constraints for ExprPatterns in a separate pass. For
2881-
// now, just create a type variable.
2882-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2883-
TVO_CanBindToNoEscape));
2904+
auto *EP = cast<ExprPattern>(pattern);
2905+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2906+
TVO_CanBindToNoEscape);
2907+
2908+
auto target = SyntacticElementTarget::forExprPattern(EP);
2909+
2910+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2911+
/*leaveClosureBodyUnchecked=*/false)) {
2912+
return None;
2913+
}
2914+
CS.setType(EP->getMatchVar(), patternTy);
2915+
2916+
if (CS.generateConstraints(target))
2917+
return None;
2918+
2919+
CS.setTargetFor(EP, target);
2920+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2921+
return setType(patternTy);
28842922
}
28852923
}
28862924

@@ -4318,11 +4356,19 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
43184356
initializer, LocatorPathElt::ContextualType(CTP_Initialization));
43194357

43204358
Type patternType;
4359+
bool forExprPattern = false;
43214360
if (auto pattern = target.getInitializationPattern()) {
4322-
patternType = cs.generateConstraints(
4361+
auto *semanticPattern =
4362+
pattern->getSemanticsProvidingPattern(/*allowTypedPattern*/ false);
4363+
forExprPattern = isa<ExprPattern>(semanticPattern);
4364+
auto ty = cs.generateConstraints(
43234365
pattern, locator, target.shouldBindPatternVarsOneWay(),
43244366
target.getInitializationPatternBindingDecl(),
43254367
target.getInitializationPatternBindingIndex());
4368+
if (!ty)
4369+
return true;
4370+
4371+
patternType = *ty;
43264372
} else {
43274373
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
43284374
}
@@ -4331,9 +4377,15 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
43314377
return cs.generateWrappedPropertyTypeConstraints(
43324378
wrappedVar, cs.getType(target.getAsExpr()), patternType);
43334379

4334-
// Add a conversion constraint between the types.
4335-
cs.addConstraint(ConstraintKind::Conversion, cs.getType(target.getAsExpr()),
4336-
patternType, locator, /*isFavored*/true);
4380+
// Add a constraint between the types. For ExprPatterns, we want an equality
4381+
// constraint, because we want to propagate the type of the initializer directly
4382+
// into the implicit '~=' call. We'll then allow conversions when matching that as
4383+
// an argument. This avoids producing bad diagnostics where we try and apply fixes
4384+
// to the conversion outside of the call.
4385+
auto kind = forExprPattern ? ConstraintKind::Equal
4386+
: ConstraintKind::Conversion;
4387+
cs.addConstraint(kind, cs.getType(target.getAsExpr()),
4388+
patternType, locator, /*isFavored*/true);
43374389

43384390
return false;
43394391
}
@@ -4481,7 +4533,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44814533
// Collect constraints from the element pattern.
44824534
auto elementLocator = cs.getConstraintLocator(
44834535
sequenceExpr, ConstraintLocator::SequenceElementType);
4484-
Type initType =
4536+
auto initType =
44854537
cs.generateConstraints(pattern, elementLocator,
44864538
target.shouldBindPatternVarsOneWay(), nullptr, 0);
44874539
if (!initType)
@@ -4500,7 +4552,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45004552
// resolving `optional object` constraint which is sometimes too eager.
45014553
cs.addConstraint(ConstraintKind::Conversion, nextType,
45024554
OptionalType::get(elementType), elementTypeLoc);
4503-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4555+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
45044556
elementLocator);
45054557
}
45064558

@@ -4526,7 +4578,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45264578

45274579
// Populate all of the information for a for-each loop.
45284580
forEachStmtInfo.elementType = elementType;
4529-
forEachStmtInfo.initType = initType;
4581+
forEachStmtInfo.initType = *initType;
45304582
target.setPattern(pattern);
45314583
target.getForEachStmtInfo() = forEachStmtInfo;
45324584
return target;
@@ -4706,7 +4758,7 @@ bool ConstraintSystem::generateConstraints(
47064758

47074759
// Generate constraints to bind all of the internal declarations
47084760
// and verify the pattern.
4709-
Type patternType = generateConstraints(
4761+
auto patternType = generateConstraints(
47104762
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
47114763
target.getPatternBindingOfUninitializedVar(),
47124764
target.getIndexOfUninitializedVar());
@@ -4735,25 +4787,13 @@ Expr *ConstraintSystem::generateConstraints(
47354787
return generateConstraintsFor(*this, expr, dc);
47364788
}
47374789

4738-
Type ConstraintSystem::generateConstraints(
4790+
Optional<Type> ConstraintSystem::generateConstraints(
47394791
Pattern *pattern, ConstraintLocatorBuilder locator,
47404792
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
47414793
unsigned patternIndex) {
47424794
ConstraintGenerator cg(*this, nullptr);
4743-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4744-
patternBinding, patternIndex);
4745-
assert(ty);
4746-
4747-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4748-
SmallVector<ExprPattern *, 4> exprPatterns;
4749-
pattern->forEachNode([&](Pattern *P) {
4750-
if (auto *EP = dyn_cast<ExprPattern>(P))
4751-
exprPatterns.push_back(EP);
4752-
});
4753-
if (!exprPatterns.empty())
4754-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4755-
4756-
return ty;
4795+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4796+
patternBinding, patternIndex);
47574797
}
47584798

47594799
bool ConstraintSystem::generateConstraints(StmtCondition condition,

0 commit comments

Comments
 (0)