Skip to content

Commit 4063549

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent bb5c720 commit 4063549

File tree

8 files changed

+134
-150
lines changed

8 files changed

+134
-150
lines changed

include/swift/AST/Pattern.h

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,11 @@ class alignas(8) Pattern : public ASTAllocated<Pattern> {
140140
/// equivalent to matching this pattern.
141141
///
142142
/// Looks through ParenPattern, BindingPattern, and TypedPattern.
143-
Pattern *getSemanticsProvidingPattern();
144-
const Pattern *getSemanticsProvidingPattern() const {
145-
return const_cast<Pattern*>(this)->getSemanticsProvidingPattern();
143+
Pattern *getSemanticsProvidingPattern(bool allowTypedPattern = true);
144+
const Pattern *
145+
getSemanticsProvidingPattern(bool allowTypedPattern = true) const {
146+
return const_cast<Pattern *>(this)->getSemanticsProvidingPattern(
147+
allowTypedPattern);
146148
}
147149

148150
/// Returns whether this pattern has been type-checked yet.
@@ -799,14 +801,26 @@ class BindingPattern : public Pattern {
799801
}
800802
};
801803

802-
inline Pattern *Pattern::getSemanticsProvidingPattern() {
803-
if (auto *pp = dyn_cast<ParenPattern>(this))
804-
return pp->getSubPattern()->getSemanticsProvidingPattern();
805-
if (auto *tp = dyn_cast<TypedPattern>(this))
806-
return tp->getSubPattern()->getSemanticsProvidingPattern();
807-
if (auto *vp = dyn_cast<BindingPattern>(this))
808-
return vp->getSubPattern()->getSemanticsProvidingPattern();
809-
return this;
804+
inline Pattern *Pattern::getSemanticsProvidingPattern(bool allowTypedPattern) {
805+
auto *P = this;
806+
while (true) {
807+
if (auto *PP = dyn_cast<ParenPattern>(P)) {
808+
P = PP->getSubPattern();
809+
continue;
810+
}
811+
if (auto *BP = dyn_cast<BindingPattern>(P)) {
812+
P = BP->getSubPattern();
813+
continue;
814+
}
815+
if (allowTypedPattern) {
816+
if (auto *TP = dyn_cast<TypedPattern>(P)) {
817+
P = TP->getSubPattern();
818+
continue;
819+
}
820+
}
821+
break;
822+
}
823+
return P;
810824
}
811825

812826
/// Describes a pattern and the context in which it occurs.

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4345,27 +4345,22 @@ class ConstraintSystem {
43454345
/// \returns \c true if constraint generation failed, \c false otherwise
43464346
bool generateConstraints(SingleValueStmtExpr *E);
43474347

4348-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349-
/// that solves each expression in turn.
4350-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351-
ConstraintLocatorBuilder locator);
4352-
43534348
/// Generate constraints for the given (unchecked) expression.
43544349
///
43554350
/// \returns a possibly-sanitized expression, or null if an error occurred.
43564351
[[nodiscard]]
43574352
Expr *generateConstraints(Expr *E, DeclContext *dc,
43584353
bool isInputExpression = true);
43594354

4360-
/// Generate constraints for binding the given pattern to the
4361-
/// value of the given expression.
4355+
/// Generate constraints for a given pattern.
43624356
///
4363-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4357+
/// \returns The type of the pattern, or \c None if a failure occured.
43644358
[[nodiscard]]
4365-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4366-
bool bindPatternVarsOneWay,
4367-
PatternBindingDecl *patternBinding,
4368-
unsigned patternIndex);
4359+
Optional<Type> generateConstraints(Pattern *P,
4360+
ConstraintLocatorBuilder locator,
4361+
bool bindPatternVarsOneWay,
4362+
PatternBindingDecl *patternBinding,
4363+
unsigned patternIndex);
43694364

43704365
/// Generate constraints for a statement condition.
43714366
///

lib/Sema/CSGen.cpp

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,7 +2442,7 @@ namespace {
24422442
/// for the types of each variable declared within the pattern, along
24432443
/// with a one-way constraint binding that to the type to which the
24442444
/// variable will be ascribed or inferred.
2445-
Type getTypeForPattern(
2445+
Optional<Type> getTypeForPattern(
24462446
Pattern *pattern, ConstraintLocatorBuilder locator,
24472447
bool bindPatternVarsOneWay,
24482448
PatternBindingDecl *patternBinding = nullptr,
@@ -2466,14 +2466,21 @@ namespace {
24662466
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24672467
bindPatternVarsOneWay);
24682468

2469-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2469+
if (!underlyingType)
2470+
return None;
2471+
2472+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24702473
}
24712474
case PatternKind::Binding: {
24722475
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24732476
auto type = getTypeForPattern(subPattern, locator,
24742477
bindPatternVarsOneWay);
2478+
2479+
if (!type)
2480+
return None;
2481+
24752482
// Var doesn't affect the type.
2476-
return setType(type);
2483+
return setType(*type);
24772484
}
24782485
case PatternKind::Any: {
24792486
Type type;
@@ -2653,6 +2660,9 @@ namespace {
26532660

26542661
Type type = TypeChecker::typeCheckPattern(contextualPattern);
26552662

2663+
if (!type)
2664+
return None;
2665+
26562666
// Look through reference storage types.
26572667
type = type->getReferenceStorageReferent();
26582668

@@ -2664,16 +2674,19 @@ namespace {
26642674
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26652675
// Determine the subpattern type. It will be convertible to the
26662676
// ascribed type.
2667-
Type subPatternType = getTypeForPattern(
2677+
auto subPatternType = getTypeForPattern(
26682678
subPattern,
26692679
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26702680
bindPatternVarsOneWay);
26712681

2682+
if (!subPatternType)
2683+
return None;
2684+
26722685
// NOTE: The order here is important! Pattern matching equality is
26732686
// not symmetric (we need to fix that either by using a different
26742687
// constraint, or actually making it symmetric).
26752688
CS.addConstraint(
2676-
ConstraintKind::Equal, openedType, subPatternType,
2689+
ConstraintKind::Equal, openedType, *subPatternType,
26772690
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26782691

26792692
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2693,12 +2706,15 @@ namespace {
26932706
auto &tupleElt = tuplePat->getElement(i);
26942707

26952708
auto *eltPattern = tupleElt.getPattern();
2696-
Type eltTy = getTypeForPattern(
2709+
auto eltTy = getTypeForPattern(
26972710
eltPattern,
26982711
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26992712
bindPatternVarsOneWay);
27002713

2701-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2714+
if (!eltTy)
2715+
return None;
2716+
2717+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
27022718
}
27032719

27042720
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2707,12 +2723,15 @@ namespace {
27072723
case PatternKind::OptionalSome: {
27082724
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
27092725
// The subpattern must have optional type.
2710-
Type subPatternType = getTypeForPattern(
2726+
auto subPatternType = getTypeForPattern(
27112727
subPattern,
27122728
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27132729
bindPatternVarsOneWay);
27142730

2715-
return setType(OptionalType::get(subPatternType));
2731+
if (!subPatternType)
2732+
return None;
2733+
2734+
return setType(OptionalType::get(*subPatternType));
27162735
}
27172736

27182737
case PatternKind::Is: {
@@ -2742,12 +2761,14 @@ namespace {
27422761
subPattern,
27432762
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27442763
bindPatternVarsOneWay);
2764+
if (!subPatternType)
2765+
return None;
27452766

27462767
// NOTE: The order here is important! Pattern matching equality is
27472768
// not symmetric (we need to fix that either by using a different
27482769
// constraint, or actually making it symmetric).
27492770
CS.addConstraint(
2750-
ConstraintKind::Equal, castType, subPatternType,
2771+
ConstraintKind::Equal, castType, *subPatternType,
27512772
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
27522773
}
27532774
return setType(isType);
@@ -2811,6 +2832,9 @@ namespace {
28112832
TypeResolverContext::InExpression, patternMatchLoc);
28122833
}();
28132834

2835+
if (!parentType)
2836+
return None;
2837+
28142838
// Perform member lookup into the parent's metatype.
28152839
Type parentMetaType = MetatypeType::get(parentType);
28162840
CS.addValueMemberConstraint(parentMetaType, enumPattern->getName(),
@@ -2838,13 +2862,13 @@ namespace {
28382862
// When there is a subpattern, the member will have function type,
28392863
// and we're matching the type of that subpattern to the parameter
28402864
// types.
2841-
Type subPatternType = getTypeForPattern(
2865+
auto subPatternType = getTypeForPattern(
28422866
subPattern,
28432867
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
28442868
bindPatternVarsOneWay);
28452869

28462870
SmallVector<AnyFunctionType::Param, 4> params;
2847-
decomposeTuple(subPatternType, params);
2871+
decomposeTuple(*subPatternType, params);
28482872

28492873
// Remove parameter labels; they aren't used when matching cases,
28502874
// but outright conflicts will be checked during coercion.
@@ -2877,10 +2901,24 @@ namespace {
28772901
}
28782902

28792903
case PatternKind::Expr: {
2880-
// We generate constraints for ExprPatterns in a separate pass. For
2881-
// now, just create a type variable.
2882-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2883-
TVO_CanBindToNoEscape));
2904+
auto *EP = cast<ExprPattern>(pattern);
2905+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2906+
TVO_CanBindToNoEscape);
2907+
2908+
auto target = SyntacticElementTarget::forExprPattern(EP);
2909+
2910+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2911+
/*leaveClosureBodyUnchecked=*/false)) {
2912+
return None;
2913+
}
2914+
CS.setType(EP->getMatchVar(), patternTy);
2915+
2916+
if (CS.generateConstraints(target))
2917+
return None;
2918+
2919+
CS.setTargetFor(EP, target);
2920+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2921+
return setType(patternTy);
28842922
}
28852923
}
28862924

@@ -4328,11 +4366,19 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
43284366
initializer, LocatorPathElt::ContextualType(CTP_Initialization));
43294367

43304368
Type patternType;
4369+
bool forExprPattern = false;
43314370
if (auto pattern = target.getInitializationPattern()) {
4332-
patternType = cs.generateConstraints(
4371+
auto *semanticPattern =
4372+
pattern->getSemanticsProvidingPattern(/*allowTypedPattern*/ false);
4373+
forExprPattern = isa<ExprPattern>(semanticPattern);
4374+
auto ty = cs.generateConstraints(
43334375
pattern, locator, target.shouldBindPatternVarsOneWay(),
43344376
target.getInitializationPatternBindingDecl(),
43354377
target.getInitializationPatternBindingIndex());
4378+
if (!ty)
4379+
return true;
4380+
4381+
patternType = *ty;
43364382
} else {
43374383
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
43384384
}
@@ -4341,9 +4387,15 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
43414387
return cs.generateWrappedPropertyTypeConstraints(
43424388
wrappedVar, cs.getType(target.getAsExpr()), patternType);
43434389

4344-
// Add a conversion constraint between the types.
4345-
cs.addConstraint(ConstraintKind::Conversion, cs.getType(target.getAsExpr()),
4346-
patternType, locator, /*isFavored*/true);
4390+
// Add a constraint between the types. For ExprPatterns, we want an equality
4391+
// constraint, because we want to propagate the type of the initializer
4392+
// directly into the implicit '~=' call. We'll then allow conversions when
4393+
// matching that as an argument. This avoids producing bad diagnostics where
4394+
// we try and apply fixes to the conversion outside of the call.
4395+
auto kind = forExprPattern ? ConstraintKind::Equal
4396+
: ConstraintKind::Conversion;
4397+
cs.addConstraint(kind, cs.getType(target.getAsExpr()), patternType, locator,
4398+
/*isFavored*/ true);
43474399

43484400
return false;
43494401
}
@@ -4491,7 +4543,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44914543
// Collect constraints from the element pattern.
44924544
auto elementLocator = cs.getConstraintLocator(
44934545
sequenceExpr, ConstraintLocator::SequenceElementType);
4494-
Type initType =
4546+
auto initType =
44954547
cs.generateConstraints(pattern, elementLocator,
44964548
target.shouldBindPatternVarsOneWay(), nullptr, 0);
44974549
if (!initType)
@@ -4510,7 +4562,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45104562
// resolving `optional object` constraint which is sometimes too eager.
45114563
cs.addConstraint(ConstraintKind::Conversion, nextType,
45124564
OptionalType::get(elementType), elementTypeLoc);
4513-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4565+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
45144566
elementLocator);
45154567
}
45164568

@@ -4536,7 +4588,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45364588

45374589
// Populate all of the information for a for-each loop.
45384590
forEachStmtInfo.elementType = elementType;
4539-
forEachStmtInfo.initType = initType;
4591+
forEachStmtInfo.initType = *initType;
45404592
target.setPattern(pattern);
45414593
target.getForEachStmtInfo() = forEachStmtInfo;
45424594
return target;
@@ -4716,7 +4768,7 @@ bool ConstraintSystem::generateConstraints(
47164768

47174769
// Generate constraints to bind all of the internal declarations
47184770
// and verify the pattern.
4719-
Type patternType = generateConstraints(
4771+
auto patternType = generateConstraints(
47204772
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
47214773
target.getPatternBindingOfUninitializedVar(),
47224774
target.getIndexOfUninitializedVar());
@@ -4745,25 +4797,13 @@ Expr *ConstraintSystem::generateConstraints(
47454797
return generateConstraintsFor(*this, expr, dc);
47464798
}
47474799

4748-
Type ConstraintSystem::generateConstraints(
4800+
Optional<Type> ConstraintSystem::generateConstraints(
47494801
Pattern *pattern, ConstraintLocatorBuilder locator,
47504802
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
47514803
unsigned patternIndex) {
47524804
ConstraintGenerator cg(*this, nullptr);
4753-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4754-
patternBinding, patternIndex);
4755-
assert(ty);
4756-
4757-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4758-
SmallVector<ExprPattern *, 4> exprPatterns;
4759-
pattern->forEachNode([&](Pattern *P) {
4760-
if (auto *EP = dyn_cast<ExprPattern>(P))
4761-
exprPatterns.push_back(EP);
4762-
});
4763-
if (!exprPatterns.empty())
4764-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4765-
4766-
return ty;
4805+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4806+
patternBinding, patternIndex);
47674807
}
47684808

47694809
bool ConstraintSystem::generateConstraints(StmtCondition condition,

0 commit comments

Comments
 (0)