Skip to content

Commit 23d8e6a

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent c037e0e commit 23d8e6a

File tree

7 files changed

+105
-136
lines changed

7 files changed

+105
-136
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,27 +4304,22 @@ class ConstraintSystem {
43044304
/// \returns \c true if constraint generation failed, \c false otherwise
43054305
bool generateConstraints(SingleValueStmtExpr *E);
43064306

4307-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4308-
/// that solves each expression in turn.
4309-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4310-
ConstraintLocatorBuilder locator);
4311-
43124307
/// Generate constraints for the given (unchecked) expression.
43134308
///
43144309
/// \returns a possibly-sanitized expression, or null if an error occurred.
43154310
[[nodiscard]]
43164311
Expr *generateConstraints(Expr *E, DeclContext *dc,
43174312
bool isInputExpression = true);
43184313

4319-
/// Generate constraints for binding the given pattern to the
4320-
/// value of the given expression.
4314+
/// Generate constraints for a given pattern.
43214315
///
4322-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4316+
/// \returns The type of the pattern, or \c None if a failure occured.
43234317
[[nodiscard]]
4324-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4325-
bool bindPatternVarsOneWay,
4326-
PatternBindingDecl *patternBinding,
4327-
unsigned patternIndex);
4318+
Optional<Type> generateConstraints(Pattern *P,
4319+
ConstraintLocatorBuilder locator,
4320+
bool bindPatternVarsOneWay,
4321+
PatternBindingDecl *patternBinding,
4322+
unsigned patternIndex);
43284323

43294324
/// Generate constraints for a statement condition.
43304325
///

lib/Sema/CSGen.cpp

Lines changed: 69 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2390,7 +2390,7 @@ namespace {
23902390
/// for the types of each variable declared within the pattern, along
23912391
/// with a one-way constraint binding that to the type to which the
23922392
/// variable will be ascribed or inferred.
2393-
Type getTypeForPattern(
2393+
Optional<Type> getTypeForPattern(
23942394
Pattern *pattern, ConstraintLocatorBuilder locator,
23952395
bool bindPatternVarsOneWay,
23962396
PatternBindingDecl *patternBinding = nullptr,
@@ -2414,14 +2414,21 @@ namespace {
24142414
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24152415
bindPatternVarsOneWay);
24162416

2417-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2417+
if (!underlyingType)
2418+
return None;
2419+
2420+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24182421
}
24192422
case PatternKind::Binding: {
24202423
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24212424
auto type = getTypeForPattern(subPattern, locator,
24222425
bindPatternVarsOneWay);
2426+
2427+
if (!type)
2428+
return None;
2429+
24232430
// Var doesn't affect the type.
2424-
return setType(type);
2431+
return setType(*type);
24252432
}
24262433
case PatternKind::Any: {
24272434
Type type;
@@ -2595,6 +2602,9 @@ namespace {
25952602

25962603
Type type = TypeChecker::typeCheckPattern(contextualPattern);
25972604

2605+
if (!type)
2606+
return None;
2607+
25982608
// Look through reference storage types.
25992609
type = type->getReferenceStorageReferent();
26002610

@@ -2606,16 +2616,19 @@ namespace {
26062616
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26072617
// Determine the subpattern type. It will be convertible to the
26082618
// ascribed type.
2609-
Type subPatternType = getTypeForPattern(
2619+
auto subPatternType = getTypeForPattern(
26102620
subPattern,
26112621
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26122622
bindPatternVarsOneWay);
26132623

2624+
if (!subPatternType)
2625+
return None;
2626+
26142627
// NOTE: The order here is important! Pattern matching equality is
26152628
// not symmetric (we need to fix that either by using a different
26162629
// constraint, or actually making it symmetric).
26172630
CS.addConstraint(
2618-
ConstraintKind::Equal, openedType, subPatternType,
2631+
ConstraintKind::Equal, openedType, *subPatternType,
26192632
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26202633

26212634
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2635,12 +2648,15 @@ namespace {
26352648
auto &tupleElt = tuplePat->getElement(i);
26362649

26372650
auto *eltPattern = tupleElt.getPattern();
2638-
Type eltTy = getTypeForPattern(
2651+
auto eltTy = getTypeForPattern(
26392652
eltPattern,
26402653
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26412654
bindPatternVarsOneWay);
26422655

2643-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2656+
if (!eltTy)
2657+
return None;
2658+
2659+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
26442660
}
26452661

26462662
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2649,12 +2665,15 @@ namespace {
26492665
case PatternKind::OptionalSome: {
26502666
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
26512667
// The subpattern must have optional type.
2652-
Type subPatternType = getTypeForPattern(
2668+
auto subPatternType = getTypeForPattern(
26532669
subPattern,
26542670
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26552671
bindPatternVarsOneWay);
26562672

2657-
return setType(OptionalType::get(subPatternType));
2673+
if (!subPatternType)
2674+
return None;
2675+
2676+
return setType(OptionalType::get(*subPatternType));
26582677
}
26592678

26602679
case PatternKind::Is: {
@@ -2683,12 +2702,14 @@ namespace {
26832702
subPattern,
26842703
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26852704
bindPatternVarsOneWay);
2705+
if (!subPatternType)
2706+
return None;
26862707

26872708
// NOTE: The order here is important! Pattern matching equality is
26882709
// not symmetric (we need to fix that either by using a different
26892710
// constraint, or actually making it symmetric).
26902711
CS.addConstraint(
2691-
ConstraintKind::Equal, castType, subPatternType,
2712+
ConstraintKind::Equal, castType, *subPatternType,
26922713
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26932714
}
26942715
return setType(isType);
@@ -2749,6 +2770,9 @@ namespace {
27492770
TypeResolverContext::InExpression, patternMatchLoc);
27502771
}();
27512772

2773+
if (!parentType)
2774+
return None;
2775+
27522776
// Perform member lookup into the parent's metatype.
27532777
Type parentMetaType = MetatypeType::get(parentType);
27542778
CS.addValueMemberConstraint(
@@ -2778,13 +2802,16 @@ namespace {
27782802
// When there is a subpattern, the member will have function type,
27792803
// and we're matching the type of that subpattern to the parameter
27802804
// types.
2781-
Type subPatternType = getTypeForPattern(
2805+
auto subPatternType = getTypeForPattern(
27822806
subPattern,
27832807
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27842808
bindPatternVarsOneWay);
27852809

2810+
if (!subPatternType)
2811+
return None;
2812+
27862813
SmallVector<AnyFunctionType::Param, 4> params;
2787-
decomposeTuple(subPatternType, params);
2814+
decomposeTuple(*subPatternType, params);
27882815

27892816
// Remove parameter labels; they aren't used when matching cases,
27902817
// but outright conflicts will be checked during coercion.
@@ -2819,10 +2846,24 @@ namespace {
28192846
}
28202847

28212848
case PatternKind::Expr: {
2822-
// We generate constraints for ExprPatterns in a separate pass. For
2823-
// now, just create a type variable.
2824-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2825-
TVO_CanBindToNoEscape));
2849+
auto *EP = cast<ExprPattern>(pattern);
2850+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2851+
TVO_CanBindToNoEscape);
2852+
2853+
auto target = SyntacticElementTarget::forExprPattern(EP);
2854+
2855+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2856+
/*leaveClosureBodyUnchecked=*/false)) {
2857+
return None;
2858+
}
2859+
CS.setType(EP->getMatchVar(), patternTy);
2860+
2861+
if (CS.generateConstraints(target))
2862+
return None;
2863+
2864+
CS.setTargetFor(EP, target);
2865+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2866+
return setType(patternTy);
28262867
}
28272868
}
28282869

@@ -4208,10 +4249,14 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
42084249

42094250
Type patternType;
42104251
if (auto pattern = target.getInitializationPattern()) {
4211-
patternType = cs.generateConstraints(
4252+
auto ty = cs.generateConstraints(
42124253
pattern, locator, target.shouldBindPatternVarsOneWay(),
42134254
target.getInitializationPatternBindingDecl(),
42144255
target.getInitializationPatternBindingIndex());
4256+
if (!ty)
4257+
return true;
4258+
4259+
patternType = *ty;
42154260
} else {
42164261
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
42174262
}
@@ -4370,7 +4415,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43704415
// Collect constraints from the element pattern.
43714416
auto elementLocator = cs.getConstraintLocator(
43724417
sequenceExpr, ConstraintLocator::SequenceElementType);
4373-
Type initType =
4418+
auto initType =
43744419
cs.generateConstraints(pattern, elementLocator,
43754420
target.shouldBindPatternVarsOneWay(), nullptr, 0);
43764421
if (!initType)
@@ -4389,7 +4434,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43894434
// resolving `optional object` constraint which is sometimes too eager.
43904435
cs.addConstraint(ConstraintKind::Conversion, nextType,
43914436
OptionalType::get(elementType), elementTypeLoc);
4392-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4437+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
43934438
elementLocator);
43944439
}
43954440

@@ -4415,7 +4460,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44154460

44164461
// Populate all of the information for a for-each loop.
44174462
forEachStmtInfo.elementType = elementType;
4418-
forEachStmtInfo.initType = initType;
4463+
forEachStmtInfo.initType = *initType;
44194464
target.setPattern(pattern);
44204465
target.getForEachStmtInfo() = forEachStmtInfo;
44214466
return target;
@@ -4595,7 +4640,7 @@ bool ConstraintSystem::generateConstraints(
45954640

45964641
// Generate constraints to bind all of the internal declarations
45974642
// and verify the pattern.
4598-
Type patternType = generateConstraints(
4643+
auto patternType = generateConstraints(
45994644
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
46004645
target.getPatternBindingOfUninitializedVar(),
46014646
target.getIndexOfUninitializedVar());
@@ -4624,25 +4669,13 @@ Expr *ConstraintSystem::generateConstraints(
46244669
return generateConstraintsFor(*this, expr, dc);
46254670
}
46264671

4627-
Type ConstraintSystem::generateConstraints(
4672+
Optional<Type> ConstraintSystem::generateConstraints(
46284673
Pattern *pattern, ConstraintLocatorBuilder locator,
46294674
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
46304675
unsigned patternIndex) {
46314676
ConstraintGenerator cg(*this, nullptr);
4632-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4633-
patternBinding, patternIndex);
4634-
assert(ty);
4635-
4636-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4637-
SmallVector<ExprPattern *, 4> exprPatterns;
4638-
pattern->forEachNode([&](Pattern *P) {
4639-
if (auto *EP = dyn_cast<ExprPattern>(P))
4640-
exprPatterns.push_back(EP);
4641-
});
4642-
if (!exprPatterns.empty())
4643-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4644-
4645-
return ty;
4677+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4678+
patternBinding, patternIndex);
46464679
}
46474680

46484681
bool ConstraintSystem::generateConstraints(StmtCondition condition,

lib/Sema/CSSimplify.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5914,6 +5914,16 @@ bool ConstraintSystem::repairFailures(
59145914
if (repairByConstructingRawRepresentableType(lhs, rhs))
59155915
break;
59165916

5917+
// If this is for an initialization, but we don't have a contextual type,
5918+
// then this is a pattern match of something that isn't a TypedPattern. As
5919+
// such, this is more like a regular conversion than a contextual type
5920+
// conversion, and type mismatches ought be diagnosed elsewhere (e.g for an
5921+
// ExprPattern, we should diagnose an argument mismatch).
5922+
if (purpose == CTP_Initialization &&
5923+
getContextualTypeLoc(anchor).isNull()) {
5924+
break;
5925+
}
5926+
59175927
conversionsOrFixes.push_back(IgnoreContextualType::create(
59185928
*this, lhs, rhs, getConstraintLocator(locator)));
59195929
break;

0 commit comments

Comments
 (0)