Skip to content

Commit 3ecd55f

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent 734df72 commit 3ecd55f

File tree

6 files changed

+93
-134
lines changed

6 files changed

+93
-134
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4345,27 +4345,22 @@ class ConstraintSystem {
43454345
/// \returns \c true if constraint generation failed, \c false otherwise
43464346
bool generateConstraints(SingleValueStmtExpr *E);
43474347

4348-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349-
/// that solves each expression in turn.
4350-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351-
ConstraintLocatorBuilder locator);
4352-
43534348
/// Generate constraints for the given (unchecked) expression.
43544349
///
43554350
/// \returns a possibly-sanitized expression, or null if an error occurred.
43564351
[[nodiscard]]
43574352
Expr *generateConstraints(Expr *E, DeclContext *dc,
43584353
bool isInputExpression = true);
43594354

4360-
/// Generate constraints for binding the given pattern to the
4361-
/// value of the given expression.
4355+
/// Generate constraints for a given pattern.
43624356
///
4363-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4357+
/// \returns The type of the pattern, or \c None if a failure occured.
43644358
[[nodiscard]]
4365-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4366-
bool bindPatternVarsOneWay,
4367-
PatternBindingDecl *patternBinding,
4368-
unsigned patternIndex);
4359+
Optional<Type> generateConstraints(Pattern *P,
4360+
ConstraintLocatorBuilder locator,
4361+
bool bindPatternVarsOneWay,
4362+
PatternBindingDecl *patternBinding,
4363+
unsigned patternIndex);
43694364

43704365
/// Generate constraints for a statement condition.
43714366
///

lib/Sema/CSGen.cpp

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,7 +2442,7 @@ namespace {
24422442
/// for the types of each variable declared within the pattern, along
24432443
/// with a one-way constraint binding that to the type to which the
24442444
/// variable will be ascribed or inferred.
2445-
Type getTypeForPattern(
2445+
Optional<Type> getTypeForPattern(
24462446
Pattern *pattern, ConstraintLocatorBuilder locator,
24472447
bool bindPatternVarsOneWay,
24482448
PatternBindingDecl *patternBinding = nullptr,
@@ -2466,14 +2466,21 @@ namespace {
24662466
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24672467
bindPatternVarsOneWay);
24682468

2469-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2469+
if (!underlyingType)
2470+
return None;
2471+
2472+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24702473
}
24712474
case PatternKind::Binding: {
24722475
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24732476
auto type = getTypeForPattern(subPattern, locator,
24742477
bindPatternVarsOneWay);
2478+
2479+
if (!type)
2480+
return None;
2481+
24752482
// Var doesn't affect the type.
2476-
return setType(type);
2483+
return setType(*type);
24772484
}
24782485
case PatternKind::Any: {
24792486
Type type;
@@ -2653,6 +2660,9 @@ namespace {
26532660

26542661
Type type = TypeChecker::typeCheckPattern(contextualPattern);
26552662

2663+
if (!type)
2664+
return None;
2665+
26562666
// Look through reference storage types.
26572667
type = type->getReferenceStorageReferent();
26582668

@@ -2664,16 +2674,19 @@ namespace {
26642674
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26652675
// Determine the subpattern type. It will be convertible to the
26662676
// ascribed type.
2667-
Type subPatternType = getTypeForPattern(
2677+
auto subPatternType = getTypeForPattern(
26682678
subPattern,
26692679
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26702680
bindPatternVarsOneWay);
26712681

2682+
if (!subPatternType)
2683+
return None;
2684+
26722685
// NOTE: The order here is important! Pattern matching equality is
26732686
// not symmetric (we need to fix that either by using a different
26742687
// constraint, or actually making it symmetric).
26752688
CS.addConstraint(
2676-
ConstraintKind::Equal, openedType, subPatternType,
2689+
ConstraintKind::Equal, openedType, *subPatternType,
26772690
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26782691

26792692
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2693,12 +2706,15 @@ namespace {
26932706
auto &tupleElt = tuplePat->getElement(i);
26942707

26952708
auto *eltPattern = tupleElt.getPattern();
2696-
Type eltTy = getTypeForPattern(
2709+
auto eltTy = getTypeForPattern(
26972710
eltPattern,
26982711
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26992712
bindPatternVarsOneWay);
27002713

2701-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2714+
if (!eltTy)
2715+
return None;
2716+
2717+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
27022718
}
27032719

27042720
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2707,12 +2723,15 @@ namespace {
27072723
case PatternKind::OptionalSome: {
27082724
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
27092725
// The subpattern must have optional type.
2710-
Type subPatternType = getTypeForPattern(
2726+
auto subPatternType = getTypeForPattern(
27112727
subPattern,
27122728
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27132729
bindPatternVarsOneWay);
27142730

2715-
return setType(OptionalType::get(subPatternType));
2731+
if (!subPatternType)
2732+
return None;
2733+
2734+
return setType(OptionalType::get(*subPatternType));
27162735
}
27172736

27182737
case PatternKind::Is: {
@@ -2742,12 +2761,14 @@ namespace {
27422761
subPattern,
27432762
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27442763
bindPatternVarsOneWay);
2764+
if (!subPatternType)
2765+
return None;
27452766

27462767
// NOTE: The order here is important! Pattern matching equality is
27472768
// not symmetric (we need to fix that either by using a different
27482769
// constraint, or actually making it symmetric).
27492770
CS.addConstraint(
2750-
ConstraintKind::Equal, castType, subPatternType,
2771+
ConstraintKind::Equal, castType, *subPatternType,
27512772
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
27522773
}
27532774
return setType(isType);
@@ -2811,6 +2832,9 @@ namespace {
28112832
TypeResolverContext::InExpression, patternMatchLoc);
28122833
}();
28132834

2835+
if (!parentType)
2836+
return None;
2837+
28142838
// Perform member lookup into the parent's metatype.
28152839
Type parentMetaType = MetatypeType::get(parentType);
28162840
CS.addValueMemberConstraint(parentMetaType, enumPattern->getName(),
@@ -2838,13 +2862,13 @@ namespace {
28382862
// When there is a subpattern, the member will have function type,
28392863
// and we're matching the type of that subpattern to the parameter
28402864
// types.
2841-
Type subPatternType = getTypeForPattern(
2865+
auto subPatternType = getTypeForPattern(
28422866
subPattern,
28432867
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
28442868
bindPatternVarsOneWay);
28452869

28462870
SmallVector<AnyFunctionType::Param, 4> params;
2847-
decomposeTuple(subPatternType, params);
2871+
decomposeTuple(*subPatternType, params);
28482872

28492873
// Remove parameter labels; they aren't used when matching cases,
28502874
// but outright conflicts will be checked during coercion.
@@ -2877,10 +2901,24 @@ namespace {
28772901
}
28782902

28792903
case PatternKind::Expr: {
2880-
// We generate constraints for ExprPatterns in a separate pass. For
2881-
// now, just create a type variable.
2882-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2883-
TVO_CanBindToNoEscape));
2904+
auto *EP = cast<ExprPattern>(pattern);
2905+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2906+
TVO_CanBindToNoEscape);
2907+
2908+
auto target = SyntacticElementTarget::forExprPattern(EP);
2909+
2910+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2911+
/*leaveClosureBodyUnchecked=*/false)) {
2912+
return None;
2913+
}
2914+
CS.setType(EP->getMatchVar(), patternTy);
2915+
2916+
if (CS.generateConstraints(target))
2917+
return None;
2918+
2919+
CS.setTargetFor(EP, target);
2920+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2921+
return setType(patternTy);
28842922
}
28852923
}
28862924

@@ -4321,10 +4359,14 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
43214359

43224360
Type patternType;
43234361
if (auto pattern = target.getInitializationPattern()) {
4324-
patternType = cs.generateConstraints(
4362+
auto ty = cs.generateConstraints(
43254363
pattern, locator, target.shouldBindPatternVarsOneWay(),
43264364
target.getInitializationPatternBindingDecl(),
43274365
target.getInitializationPatternBindingIndex());
4366+
if (!ty)
4367+
return true;
4368+
4369+
patternType = *ty;
43284370
} else {
43294371
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
43304372
}
@@ -4483,7 +4525,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44834525
// Collect constraints from the element pattern.
44844526
auto elementLocator = cs.getConstraintLocator(
44854527
sequenceExpr, ConstraintLocator::SequenceElementType);
4486-
Type initType =
4528+
auto initType =
44874529
cs.generateConstraints(pattern, elementLocator,
44884530
target.shouldBindPatternVarsOneWay(), nullptr, 0);
44894531
if (!initType)
@@ -4502,7 +4544,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45024544
// resolving `optional object` constraint which is sometimes too eager.
45034545
cs.addConstraint(ConstraintKind::Conversion, nextType,
45044546
OptionalType::get(elementType), elementTypeLoc);
4505-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4547+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
45064548
elementLocator);
45074549
}
45084550

@@ -4528,7 +4570,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45284570

45294571
// Populate all of the information for a for-each loop.
45304572
forEachStmtInfo.elementType = elementType;
4531-
forEachStmtInfo.initType = initType;
4573+
forEachStmtInfo.initType = *initType;
45324574
target.setPattern(pattern);
45334575
target.getForEachStmtInfo() = forEachStmtInfo;
45344576
return target;
@@ -4708,7 +4750,7 @@ bool ConstraintSystem::generateConstraints(
47084750

47094751
// Generate constraints to bind all of the internal declarations
47104752
// and verify the pattern.
4711-
Type patternType = generateConstraints(
4753+
auto patternType = generateConstraints(
47124754
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
47134755
target.getPatternBindingOfUninitializedVar(),
47144756
target.getIndexOfUninitializedVar());
@@ -4737,25 +4779,13 @@ Expr *ConstraintSystem::generateConstraints(
47374779
return generateConstraintsFor(*this, expr, dc);
47384780
}
47394781

4740-
Type ConstraintSystem::generateConstraints(
4782+
Optional<Type> ConstraintSystem::generateConstraints(
47414783
Pattern *pattern, ConstraintLocatorBuilder locator,
47424784
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
47434785
unsigned patternIndex) {
47444786
ConstraintGenerator cg(*this, nullptr);
4745-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4746-
patternBinding, patternIndex);
4747-
assert(ty);
4748-
4749-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4750-
SmallVector<ExprPattern *, 4> exprPatterns;
4751-
pattern->forEachNode([&](Pattern *P) {
4752-
if (auto *EP = dyn_cast<ExprPattern>(P))
4753-
exprPatterns.push_back(EP);
4754-
});
4755-
if (!exprPatterns.empty())
4756-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4757-
4758-
return ty;
4787+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4788+
patternBinding, patternIndex);
47594789
}
47604790

47614791
bool ConstraintSystem::generateConstraints(StmtCondition condition,

0 commit comments

Comments
 (0)