Skip to content

Commit a8b28da

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent 06bc822 commit a8b28da

File tree

7 files changed

+103
-134
lines changed

7 files changed

+103
-134
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4345,27 +4345,22 @@ class ConstraintSystem {
43454345
/// \returns \c true if constraint generation failed, \c false otherwise
43464346
bool generateConstraints(SingleValueStmtExpr *E);
43474347

4348-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4349-
/// that solves each expression in turn.
4350-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4351-
ConstraintLocatorBuilder locator);
4352-
43534348
/// Generate constraints for the given (unchecked) expression.
43544349
///
43554350
/// \returns a possibly-sanitized expression, or null if an error occurred.
43564351
[[nodiscard]]
43574352
Expr *generateConstraints(Expr *E, DeclContext *dc,
43584353
bool isInputExpression = true);
43594354

4360-
/// Generate constraints for binding the given pattern to the
4361-
/// value of the given expression.
4355+
/// Generate constraints for a given pattern.
43624356
///
4363-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4357+
/// \returns The type of the pattern, or \c None if a failure occured.
43644358
[[nodiscard]]
4365-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4366-
bool bindPatternVarsOneWay,
4367-
PatternBindingDecl *patternBinding,
4368-
unsigned patternIndex);
4359+
Optional<Type> generateConstraints(Pattern *P,
4360+
ConstraintLocatorBuilder locator,
4361+
bool bindPatternVarsOneWay,
4362+
PatternBindingDecl *patternBinding,
4363+
unsigned patternIndex);
43694364

43704365
/// Generate constraints for a statement condition.
43714366
///

lib/Sema/CSGen.cpp

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2391,7 +2391,7 @@ namespace {
23912391
/// for the types of each variable declared within the pattern, along
23922392
/// with a one-way constraint binding that to the type to which the
23932393
/// variable will be ascribed or inferred.
2394-
Type getTypeForPattern(
2394+
Optional<Type> getTypeForPattern(
23952395
Pattern *pattern, ConstraintLocatorBuilder locator,
23962396
bool bindPatternVarsOneWay,
23972397
PatternBindingDecl *patternBinding = nullptr,
@@ -2415,14 +2415,21 @@ namespace {
24152415
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24162416
bindPatternVarsOneWay);
24172417

2418-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2418+
if (!underlyingType)
2419+
return None;
2420+
2421+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24192422
}
24202423
case PatternKind::Binding: {
24212424
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24222425
auto type = getTypeForPattern(subPattern, locator,
24232426
bindPatternVarsOneWay);
2427+
2428+
if (!type)
2429+
return None;
2430+
24242431
// Var doesn't affect the type.
2425-
return setType(type);
2432+
return setType(*type);
24262433
}
24272434
case PatternKind::Any: {
24282435
Type type;
@@ -2602,6 +2609,9 @@ namespace {
26022609

26032610
Type type = TypeChecker::typeCheckPattern(contextualPattern);
26042611

2612+
if (!type)
2613+
return None;
2614+
26052615
// Look through reference storage types.
26062616
type = type->getReferenceStorageReferent();
26072617

@@ -2613,16 +2623,19 @@ namespace {
26132623
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26142624
// Determine the subpattern type. It will be convertible to the
26152625
// ascribed type.
2616-
Type subPatternType = getTypeForPattern(
2626+
auto subPatternType = getTypeForPattern(
26172627
subPattern,
26182628
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26192629
bindPatternVarsOneWay);
26202630

2631+
if (!subPatternType)
2632+
return None;
2633+
26212634
// NOTE: The order here is important! Pattern matching equality is
26222635
// not symmetric (we need to fix that either by using a different
26232636
// constraint, or actually making it symmetric).
26242637
CS.addConstraint(
2625-
ConstraintKind::Equal, openedType, subPatternType,
2638+
ConstraintKind::Equal, openedType, *subPatternType,
26262639
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26272640

26282641
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2642,12 +2655,15 @@ namespace {
26422655
auto &tupleElt = tuplePat->getElement(i);
26432656

26442657
auto *eltPattern = tupleElt.getPattern();
2645-
Type eltTy = getTypeForPattern(
2658+
auto eltTy = getTypeForPattern(
26462659
eltPattern,
26472660
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26482661
bindPatternVarsOneWay);
26492662

2650-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2663+
if (!eltTy)
2664+
return None;
2665+
2666+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
26512667
}
26522668

26532669
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2656,12 +2672,15 @@ namespace {
26562672
case PatternKind::OptionalSome: {
26572673
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
26582674
// The subpattern must have optional type.
2659-
Type subPatternType = getTypeForPattern(
2675+
auto subPatternType = getTypeForPattern(
26602676
subPattern,
26612677
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26622678
bindPatternVarsOneWay);
26632679

2664-
return setType(OptionalType::get(subPatternType));
2680+
if (!subPatternType)
2681+
return None;
2682+
2683+
return setType(OptionalType::get(*subPatternType));
26652684
}
26662685

26672686
case PatternKind::Is: {
@@ -2691,12 +2710,14 @@ namespace {
26912710
subPattern,
26922711
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26932712
bindPatternVarsOneWay);
2713+
if (!subPatternType)
2714+
return None;
26942715

26952716
// NOTE: The order here is important! Pattern matching equality is
26962717
// not symmetric (we need to fix that either by using a different
26972718
// constraint, or actually making it symmetric).
26982719
CS.addConstraint(
2699-
ConstraintKind::Equal, castType, subPatternType,
2720+
ConstraintKind::Equal, castType, *subPatternType,
27002721
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
27012722
}
27022723
return setType(isType);
@@ -2760,6 +2781,9 @@ namespace {
27602781
TypeResolverContext::InExpression, patternMatchLoc);
27612782
}();
27622783

2784+
if (!parentType)
2785+
return None;
2786+
27632787
// Perform member lookup into the parent's metatype.
27642788
Type parentMetaType = MetatypeType::get(parentType);
27652789
CS.addValueMemberConstraint(parentMetaType, enumPattern->getName(),
@@ -2787,13 +2811,13 @@ namespace {
27872811
// When there is a subpattern, the member will have function type,
27882812
// and we're matching the type of that subpattern to the parameter
27892813
// types.
2790-
Type subPatternType = getTypeForPattern(
2814+
auto subPatternType = getTypeForPattern(
27912815
subPattern,
27922816
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27932817
bindPatternVarsOneWay);
27942818

27952819
SmallVector<AnyFunctionType::Param, 4> params;
2796-
decomposeTuple(subPatternType, params);
2820+
decomposeTuple(*subPatternType, params);
27972821

27982822
// Remove parameter labels; they aren't used when matching cases,
27992823
// but outright conflicts will be checked during coercion.
@@ -2826,10 +2850,24 @@ namespace {
28262850
}
28272851

28282852
case PatternKind::Expr: {
2829-
// We generate constraints for ExprPatterns in a separate pass. For
2830-
// now, just create a type variable.
2831-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2832-
TVO_CanBindToNoEscape));
2853+
auto *EP = cast<ExprPattern>(pattern);
2854+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2855+
TVO_CanBindToNoEscape);
2856+
2857+
auto target = SyntacticElementTarget::forExprPattern(EP);
2858+
2859+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2860+
/*leaveClosureBodyUnchecked=*/false)) {
2861+
return None;
2862+
}
2863+
CS.setType(EP->getMatchVar(), patternTy);
2864+
2865+
if (CS.generateConstraints(target))
2866+
return None;
2867+
2868+
CS.setTargetFor(EP, target);
2869+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2870+
return setType(patternTy);
28332871
}
28342872
}
28352873

@@ -4222,10 +4260,14 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
42224260

42234261
Type patternType;
42244262
if (auto pattern = target.getInitializationPattern()) {
4225-
patternType = cs.generateConstraints(
4263+
auto ty = cs.generateConstraints(
42264264
pattern, locator, target.shouldBindPatternVarsOneWay(),
42274265
target.getInitializationPatternBindingDecl(),
42284266
target.getInitializationPatternBindingIndex());
4267+
if (!ty)
4268+
return true;
4269+
4270+
patternType = *ty;
42294271
} else {
42304272
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
42314273
}
@@ -4384,7 +4426,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43844426
// Collect constraints from the element pattern.
43854427
auto elementLocator = cs.getConstraintLocator(
43864428
sequenceExpr, ConstraintLocator::SequenceElementType);
4387-
Type initType =
4429+
auto initType =
43884430
cs.generateConstraints(pattern, elementLocator,
43894431
target.shouldBindPatternVarsOneWay(), nullptr, 0);
43904432
if (!initType)
@@ -4403,7 +4445,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44034445
// resolving `optional object` constraint which is sometimes too eager.
44044446
cs.addConstraint(ConstraintKind::Conversion, nextType,
44054447
OptionalType::get(elementType), elementTypeLoc);
4406-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4448+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
44074449
elementLocator);
44084450
}
44094451

@@ -4429,7 +4471,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44294471

44304472
// Populate all of the information for a for-each loop.
44314473
forEachStmtInfo.elementType = elementType;
4432-
forEachStmtInfo.initType = initType;
4474+
forEachStmtInfo.initType = *initType;
44334475
target.setPattern(pattern);
44344476
target.getForEachStmtInfo() = forEachStmtInfo;
44354477
return target;
@@ -4609,7 +4651,7 @@ bool ConstraintSystem::generateConstraints(
46094651

46104652
// Generate constraints to bind all of the internal declarations
46114653
// and verify the pattern.
4612-
Type patternType = generateConstraints(
4654+
auto patternType = generateConstraints(
46134655
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
46144656
target.getPatternBindingOfUninitializedVar(),
46154657
target.getIndexOfUninitializedVar());
@@ -4638,25 +4680,13 @@ Expr *ConstraintSystem::generateConstraints(
46384680
return generateConstraintsFor(*this, expr, dc);
46394681
}
46404682

4641-
Type ConstraintSystem::generateConstraints(
4683+
Optional<Type> ConstraintSystem::generateConstraints(
46424684
Pattern *pattern, ConstraintLocatorBuilder locator,
46434685
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
46444686
unsigned patternIndex) {
46454687
ConstraintGenerator cg(*this, nullptr);
4646-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4647-
patternBinding, patternIndex);
4648-
assert(ty);
4649-
4650-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4651-
SmallVector<ExprPattern *, 4> exprPatterns;
4652-
pattern->forEachNode([&](Pattern *P) {
4653-
if (auto *EP = dyn_cast<ExprPattern>(P))
4654-
exprPatterns.push_back(EP);
4655-
});
4656-
if (!exprPatterns.empty())
4657-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4658-
4659-
return ty;
4688+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4689+
patternBinding, patternIndex);
46604690
}
46614691

46624692
bool ConstraintSystem::generateConstraints(StmtCondition condition,

lib/Sema/CSSimplify.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5993,6 +5993,16 @@ bool ConstraintSystem::repairFailures(
59935993
if (repairByConstructingRawRepresentableType(lhs, rhs))
59945994
break;
59955995

5996+
// If this is for an initialization, but we don't have a contextual type,
5997+
// then this is a pattern match of something that isn't a TypedPattern. As
5998+
// such, this is more like a regular conversion than a contextual type
5999+
// conversion, and type mismatches ought be diagnosed elsewhere (e.g for an
6000+
// ExprPattern, we should diagnose an argument mismatch).
6001+
if (purpose == CTP_Initialization &&
6002+
getContextualTypeLoc(anchor).isNull()) {
6003+
break;
6004+
}
6005+
59966006
conversionsOrFixes.push_back(IgnoreContextualType::create(
59976007
*this, lhs, rhs, getConstraintLocator(locator)));
59986008
break;

0 commit comments

Comments
 (0)