Skip to content

Commit 0990549

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent d6840d5 commit 0990549

File tree

7 files changed

+103
-134
lines changed

7 files changed

+103
-134
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4345,27 +4345,22 @@ class ConstraintSystem {
43454345
/// \returns \c true if constraint generation failed, \c false otherwise
43464346
bool generateConstraints(SingleValueStmtExpr *E);
43474347

4348-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349-
/// that solves each expression in turn.
4350-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351-
ConstraintLocatorBuilder locator);
4352-
43534348
/// Generate constraints for the given (unchecked) expression.
43544349
///
43554350
/// \returns a possibly-sanitized expression, or null if an error occurred.
43564351
[[nodiscard]]
43574352
Expr *generateConstraints(Expr *E, DeclContext *dc,
43584353
bool isInputExpression = true);
43594354

4360-
/// Generate constraints for binding the given pattern to the
4361-
/// value of the given expression.
4355+
/// Generate constraints for a given pattern.
43624356
///
4363-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4357+
/// \returns The type of the pattern, or \c None if a failure occured.
43644358
[[nodiscard]]
4365-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4366-
bool bindPatternVarsOneWay,
4367-
PatternBindingDecl *patternBinding,
4368-
unsigned patternIndex);
4359+
Optional<Type> generateConstraints(Pattern *P,
4360+
ConstraintLocatorBuilder locator,
4361+
bool bindPatternVarsOneWay,
4362+
PatternBindingDecl *patternBinding,
4363+
unsigned patternIndex);
43694364

43704365
/// Generate constraints for a statement condition.
43714366
///

lib/Sema/CSGen.cpp

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,7 +2422,7 @@ namespace {
24222422
/// for the types of each variable declared within the pattern, along
24232423
/// with a one-way constraint binding that to the type to which the
24242424
/// variable will be ascribed or inferred.
2425-
Type getTypeForPattern(
2425+
Optional<Type> getTypeForPattern(
24262426
Pattern *pattern, ConstraintLocatorBuilder locator,
24272427
bool bindPatternVarsOneWay,
24282428
PatternBindingDecl *patternBinding = nullptr,
@@ -2446,14 +2446,21 @@ namespace {
24462446
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24472447
bindPatternVarsOneWay);
24482448

2449-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2449+
if (!underlyingType)
2450+
return None;
2451+
2452+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24502453
}
24512454
case PatternKind::Binding: {
24522455
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24532456
auto type = getTypeForPattern(subPattern, locator,
24542457
bindPatternVarsOneWay);
2458+
2459+
if (!type)
2460+
return None;
2461+
24552462
// Var doesn't affect the type.
2456-
return setType(type);
2463+
return setType(*type);
24572464
}
24582465
case PatternKind::Any: {
24592466
Type type;
@@ -2633,6 +2640,9 @@ namespace {
26332640

26342641
Type type = TypeChecker::typeCheckPattern(contextualPattern);
26352642

2643+
if (!type)
2644+
return None;
2645+
26362646
// Look through reference storage types.
26372647
type = type->getReferenceStorageReferent();
26382648

@@ -2644,16 +2654,19 @@ namespace {
26442654
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26452655
// Determine the subpattern type. It will be convertible to the
26462656
// ascribed type.
2647-
Type subPatternType = getTypeForPattern(
2657+
auto subPatternType = getTypeForPattern(
26482658
subPattern,
26492659
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26502660
bindPatternVarsOneWay);
26512661

2662+
if (!subPatternType)
2663+
return None;
2664+
26522665
// NOTE: The order here is important! Pattern matching equality is
26532666
// not symmetric (we need to fix that either by using a different
26542667
// constraint, or actually making it symmetric).
26552668
CS.addConstraint(
2656-
ConstraintKind::Equal, openedType, subPatternType,
2669+
ConstraintKind::Equal, openedType, *subPatternType,
26572670
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26582671

26592672
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2673,12 +2686,15 @@ namespace {
26732686
auto &tupleElt = tuplePat->getElement(i);
26742687

26752688
auto *eltPattern = tupleElt.getPattern();
2676-
Type eltTy = getTypeForPattern(
2689+
auto eltTy = getTypeForPattern(
26772690
eltPattern,
26782691
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26792692
bindPatternVarsOneWay);
26802693

2681-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2694+
if (!eltTy)
2695+
return None;
2696+
2697+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
26822698
}
26832699

26842700
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2687,12 +2703,15 @@ namespace {
26872703
case PatternKind::OptionalSome: {
26882704
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
26892705
// The subpattern must have optional type.
2690-
Type subPatternType = getTypeForPattern(
2706+
auto subPatternType = getTypeForPattern(
26912707
subPattern,
26922708
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26932709
bindPatternVarsOneWay);
26942710

2695-
return setType(OptionalType::get(subPatternType));
2711+
if (!subPatternType)
2712+
return None;
2713+
2714+
return setType(OptionalType::get(*subPatternType));
26962715
}
26972716

26982717
case PatternKind::Is: {
@@ -2722,12 +2741,14 @@ namespace {
27222741
subPattern,
27232742
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27242743
bindPatternVarsOneWay);
2744+
if (!subPatternType)
2745+
return None;
27252746

27262747
// NOTE: The order here is important! Pattern matching equality is
27272748
// not symmetric (we need to fix that either by using a different
27282749
// constraint, or actually making it symmetric).
27292750
CS.addConstraint(
2730-
ConstraintKind::Equal, castType, subPatternType,
2751+
ConstraintKind::Equal, castType, *subPatternType,
27312752
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
27322753
}
27332754
return setType(isType);
@@ -2791,6 +2812,9 @@ namespace {
27912812
TypeResolverContext::InExpression, patternMatchLoc);
27922813
}();
27932814

2815+
if (!parentType)
2816+
return None;
2817+
27942818
// Perform member lookup into the parent's metatype.
27952819
Type parentMetaType = MetatypeType::get(parentType);
27962820
CS.addValueMemberConstraint(parentMetaType, enumPattern->getName(),
@@ -2818,13 +2842,13 @@ namespace {
28182842
// When there is a subpattern, the member will have function type,
28192843
// and we're matching the type of that subpattern to the parameter
28202844
// types.
2821-
Type subPatternType = getTypeForPattern(
2845+
auto subPatternType = getTypeForPattern(
28222846
subPattern,
28232847
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
28242848
bindPatternVarsOneWay);
28252849

28262850
SmallVector<AnyFunctionType::Param, 4> params;
2827-
decomposeTuple(subPatternType, params);
2851+
decomposeTuple(*subPatternType, params);
28282852

28292853
// Remove parameter labels; they aren't used when matching cases,
28302854
// but outright conflicts will be checked during coercion.
@@ -2857,10 +2881,24 @@ namespace {
28572881
}
28582882

28592883
case PatternKind::Expr: {
2860-
// We generate constraints for ExprPatterns in a separate pass. For
2861-
// now, just create a type variable.
2862-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2863-
TVO_CanBindToNoEscape));
2884+
auto *EP = cast<ExprPattern>(pattern);
2885+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2886+
TVO_CanBindToNoEscape);
2887+
2888+
auto target = SyntacticElementTarget::forExprPattern(EP);
2889+
2890+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2891+
/*leaveClosureBodyUnchecked=*/false)) {
2892+
return None;
2893+
}
2894+
CS.setType(EP->getMatchVar(), patternTy);
2895+
2896+
if (CS.generateConstraints(target))
2897+
return None;
2898+
2899+
CS.setTargetFor(EP, target);
2900+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2901+
return setType(patternTy);
28642902
}
28652903
}
28662904

@@ -4299,10 +4337,14 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
42994337

43004338
Type patternType;
43014339
if (auto pattern = target.getInitializationPattern()) {
4302-
patternType = cs.generateConstraints(
4340+
auto ty = cs.generateConstraints(
43034341
pattern, locator, target.shouldBindPatternVarsOneWay(),
43044342
target.getInitializationPatternBindingDecl(),
43054343
target.getInitializationPatternBindingIndex());
4344+
if (!ty)
4345+
return true;
4346+
4347+
patternType = *ty;
43064348
} else {
43074349
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
43084350
}
@@ -4461,7 +4503,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44614503
// Collect constraints from the element pattern.
44624504
auto elementLocator = cs.getConstraintLocator(
44634505
sequenceExpr, ConstraintLocator::SequenceElementType);
4464-
Type initType =
4506+
auto initType =
44654507
cs.generateConstraints(pattern, elementLocator,
44664508
target.shouldBindPatternVarsOneWay(), nullptr, 0);
44674509
if (!initType)
@@ -4480,7 +4522,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44804522
// resolving `optional object` constraint which is sometimes too eager.
44814523
cs.addConstraint(ConstraintKind::Conversion, nextType,
44824524
OptionalType::get(elementType), elementTypeLoc);
4483-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4525+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
44844526
elementLocator);
44854527
}
44864528

@@ -4506,7 +4548,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
45064548

45074549
// Populate all of the information for a for-each loop.
45084550
forEachStmtInfo.elementType = elementType;
4509-
forEachStmtInfo.initType = initType;
4551+
forEachStmtInfo.initType = *initType;
45104552
target.setPattern(pattern);
45114553
target.getForEachStmtInfo() = forEachStmtInfo;
45124554
return target;
@@ -4686,7 +4728,7 @@ bool ConstraintSystem::generateConstraints(
46864728

46874729
// Generate constraints to bind all of the internal declarations
46884730
// and verify the pattern.
4689-
Type patternType = generateConstraints(
4731+
auto patternType = generateConstraints(
46904732
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
46914733
target.getPatternBindingOfUninitializedVar(),
46924734
target.getIndexOfUninitializedVar());
@@ -4715,25 +4757,13 @@ Expr *ConstraintSystem::generateConstraints(
47154757
return generateConstraintsFor(*this, expr, dc);
47164758
}
47174759

4718-
Type ConstraintSystem::generateConstraints(
4760+
Optional<Type> ConstraintSystem::generateConstraints(
47194761
Pattern *pattern, ConstraintLocatorBuilder locator,
47204762
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
47214763
unsigned patternIndex) {
47224764
ConstraintGenerator cg(*this, nullptr);
4723-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4724-
patternBinding, patternIndex);
4725-
assert(ty);
4726-
4727-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4728-
SmallVector<ExprPattern *, 4> exprPatterns;
4729-
pattern->forEachNode([&](Pattern *P) {
4730-
if (auto *EP = dyn_cast<ExprPattern>(P))
4731-
exprPatterns.push_back(EP);
4732-
});
4733-
if (!exprPatterns.empty())
4734-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4735-
4736-
return ty;
4765+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4766+
patternBinding, patternIndex);
47374767
}
47384768

47394769
bool ConstraintSystem::generateConstraints(StmtCondition condition,

lib/Sema/CSSimplify.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5993,6 +5993,16 @@ bool ConstraintSystem::repairFailures(
59935993
if (repairByConstructingRawRepresentableType(lhs, rhs))
59945994
break;
59955995

5996+
// If this is for an initialization, but we don't have a contextual type,
5997+
// then this is a pattern match of something that isn't a TypedPattern. As
5998+
// such, this is more like a regular conversion than a contextual type
5999+
// conversion, and type mismatches ought be diagnosed elsewhere (e.g for an
6000+
// ExprPattern, we should diagnose an argument mismatch).
6001+
if (purpose == CTP_Initialization &&
6002+
getContextualTypeLoc(anchor).isNull()) {
6003+
break;
6004+
}
6005+
59966006
conversionsOrFixes.push_back(IgnoreContextualType::create(
59976007
*this, lhs, rhs, getConstraintLocator(locator)));
59986008
break;

0 commit comments

Comments
 (0)