Skip to content

Commit e7e3c5b

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent b521f6f commit e7e3c5b

File tree

7 files changed

+101
-135
lines changed

7 files changed

+101
-135
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,27 +4304,22 @@ class ConstraintSystem {
43044304
/// \returns \c true if constraint generation failed, \c false otherwise
43054305
bool generateConstraints(SingleValueStmtExpr *E);
43064306

4307-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4308-
/// that solves each expression in turn.
4309-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4310-
ConstraintLocatorBuilder locator);
4311-
43124307
/// Generate constraints for the given (unchecked) expression.
43134308
///
43144309
/// \returns a possibly-sanitized expression, or null if an error occurred.
43154310
[[nodiscard]]
43164311
Expr *generateConstraints(Expr *E, DeclContext *dc,
43174312
bool isInputExpression = true);
43184313

4319-
/// Generate constraints for binding the given pattern to the
4320-
/// value of the given expression.
4314+
/// Generate constraints for a given pattern.
43214315
///
4322-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4316+
/// \returns The type of the pattern, or \c None if a failure occured.
43234317
[[nodiscard]]
4324-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4325-
bool bindPatternVarsOneWay,
4326-
PatternBindingDecl *patternBinding,
4327-
unsigned patternIndex);
4318+
Optional<Type> generateConstraints(Pattern *P,
4319+
ConstraintLocatorBuilder locator,
4320+
bool bindPatternVarsOneWay,
4321+
PatternBindingDecl *patternBinding,
4322+
unsigned patternIndex);
43284323

43294324
/// Generate constraints for a statement condition.
43304325
///

lib/Sema/CSGen.cpp

Lines changed: 66 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2391,7 +2391,7 @@ namespace {
23912391
/// for the types of each variable declared within the pattern, along
23922392
/// with a one-way constraint binding that to the type to which the
23932393
/// variable will be ascribed or inferred.
2394-
Type getTypeForPattern(
2394+
Optional<Type> getTypeForPattern(
23952395
Pattern *pattern, ConstraintLocatorBuilder locator,
23962396
bool bindPatternVarsOneWay,
23972397
PatternBindingDecl *patternBinding = nullptr,
@@ -2415,14 +2415,21 @@ namespace {
24152415
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24162416
bindPatternVarsOneWay);
24172417

2418-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2418+
if (!underlyingType)
2419+
return None;
2420+
2421+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24192422
}
24202423
case PatternKind::Binding: {
24212424
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24222425
auto type = getTypeForPattern(subPattern, locator,
24232426
bindPatternVarsOneWay);
2427+
2428+
if (!type)
2429+
return None;
2430+
24242431
// Var doesn't affect the type.
2425-
return setType(type);
2432+
return setType(*type);
24262433
}
24272434
case PatternKind::Any: {
24282435
Type type;
@@ -2596,6 +2603,9 @@ namespace {
25962603

25972604
Type type = TypeChecker::typeCheckPattern(contextualPattern);
25982605

2606+
if (!type)
2607+
return None;
2608+
25992609
// Look through reference storage types.
26002610
type = type->getReferenceStorageReferent();
26012611

@@ -2607,16 +2617,19 @@ namespace {
26072617
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
26082618
// Determine the subpattern type. It will be convertible to the
26092619
// ascribed type.
2610-
Type subPatternType = getTypeForPattern(
2620+
auto subPatternType = getTypeForPattern(
26112621
subPattern,
26122622
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26132623
bindPatternVarsOneWay);
26142624

2625+
if (!subPatternType)
2626+
return None;
2627+
26152628
// NOTE: The order here is important! Pattern matching equality is
26162629
// not symmetric (we need to fix that either by using a different
26172630
// constraint, or actually making it symmetric).
26182631
CS.addConstraint(
2619-
ConstraintKind::Equal, openedType, subPatternType,
2632+
ConstraintKind::Equal, openedType, *subPatternType,
26202633
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26212634

26222635
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2636,12 +2649,15 @@ namespace {
26362649
auto &tupleElt = tuplePat->getElement(i);
26372650

26382651
auto *eltPattern = tupleElt.getPattern();
2639-
Type eltTy = getTypeForPattern(
2652+
auto eltTy = getTypeForPattern(
26402653
eltPattern,
26412654
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26422655
bindPatternVarsOneWay);
26432656

2644-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2657+
if (!eltTy)
2658+
return None;
2659+
2660+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
26452661
}
26462662

26472663
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2650,12 +2666,15 @@ namespace {
26502666
case PatternKind::OptionalSome: {
26512667
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
26522668
// The subpattern must have optional type.
2653-
Type subPatternType = getTypeForPattern(
2669+
auto subPatternType = getTypeForPattern(
26542670
subPattern,
26552671
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26562672
bindPatternVarsOneWay);
26572673

2658-
return setType(OptionalType::get(subPatternType));
2674+
if (!subPatternType)
2675+
return None;
2676+
2677+
return setType(OptionalType::get(*subPatternType));
26592678
}
26602679

26612680
case PatternKind::Is: {
@@ -2684,12 +2703,14 @@ namespace {
26842703
subPattern,
26852704
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26862705
bindPatternVarsOneWay);
2706+
if (!subPatternType)
2707+
return None;
26872708

26882709
// NOTE: The order here is important! Pattern matching equality is
26892710
// not symmetric (we need to fix that either by using a different
26902711
// constraint, or actually making it symmetric).
26912712
CS.addConstraint(
2692-
ConstraintKind::Equal, castType, subPatternType,
2713+
ConstraintKind::Equal, castType, *subPatternType,
26932714
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26942715
}
26952716
return setType(isType);
@@ -2753,6 +2774,9 @@ namespace {
27532774
TypeResolverContext::InExpression, patternMatchLoc);
27542775
}();
27552776

2777+
if (!parentType)
2778+
return None;
2779+
27562780
// Perform member lookup into the parent's metatype.
27572781
Type parentMetaType = MetatypeType::get(parentType);
27582782
CS.addValueMemberConstraint(parentMetaType, enumPattern->getName(),
@@ -2780,13 +2804,13 @@ namespace {
27802804
// When there is a subpattern, the member will have function type,
27812805
// and we're matching the type of that subpattern to the parameter
27822806
// types.
2783-
Type subPatternType = getTypeForPattern(
2807+
auto subPatternType = getTypeForPattern(
27842808
subPattern,
27852809
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27862810
bindPatternVarsOneWay);
27872811

27882812
SmallVector<AnyFunctionType::Param, 4> params;
2789-
decomposeTuple(subPatternType, params);
2813+
decomposeTuple(*subPatternType, params);
27902814

27912815
// Remove parameter labels; they aren't used when matching cases,
27922816
// but outright conflicts will be checked during coercion.
@@ -2819,10 +2843,24 @@ namespace {
28192843
}
28202844

28212845
case PatternKind::Expr: {
2822-
// We generate constraints for ExprPatterns in a separate pass. For
2823-
// now, just create a type variable.
2824-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2825-
TVO_CanBindToNoEscape));
2846+
auto *EP = cast<ExprPattern>(pattern);
2847+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2848+
TVO_CanBindToNoEscape);
2849+
2850+
auto target = SyntacticElementTarget::forExprPattern(EP);
2851+
2852+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2853+
/*leaveClosureBodyUnchecked=*/false)) {
2854+
return None;
2855+
}
2856+
CS.setType(EP->getMatchVar(), patternTy);
2857+
2858+
if (CS.generateConstraints(target))
2859+
return None;
2860+
2861+
CS.setTargetFor(EP, target);
2862+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2863+
return setType(patternTy);
28262864
}
28272865
}
28282866

@@ -4208,10 +4246,14 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
42084246

42094247
Type patternType;
42104248
if (auto pattern = target.getInitializationPattern()) {
4211-
patternType = cs.generateConstraints(
4249+
auto ty = cs.generateConstraints(
42124250
pattern, locator, target.shouldBindPatternVarsOneWay(),
42134251
target.getInitializationPatternBindingDecl(),
42144252
target.getInitializationPatternBindingIndex());
4253+
if (!ty)
4254+
return true;
4255+
4256+
patternType = *ty;
42154257
} else {
42164258
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
42174259
}
@@ -4370,7 +4412,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43704412
// Collect constraints from the element pattern.
43714413
auto elementLocator = cs.getConstraintLocator(
43724414
sequenceExpr, ConstraintLocator::SequenceElementType);
4373-
Type initType =
4415+
auto initType =
43744416
cs.generateConstraints(pattern, elementLocator,
43754417
target.shouldBindPatternVarsOneWay(), nullptr, 0);
43764418
if (!initType)
@@ -4389,7 +4431,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43894431
// resolving `optional object` constraint which is sometimes too eager.
43904432
cs.addConstraint(ConstraintKind::Conversion, nextType,
43914433
OptionalType::get(elementType), elementTypeLoc);
4392-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4434+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
43934435
elementLocator);
43944436
}
43954437

@@ -4415,7 +4457,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44154457

44164458
// Populate all of the information for a for-each loop.
44174459
forEachStmtInfo.elementType = elementType;
4418-
forEachStmtInfo.initType = initType;
4460+
forEachStmtInfo.initType = *initType;
44194461
target.setPattern(pattern);
44204462
target.getForEachStmtInfo() = forEachStmtInfo;
44214463
return target;
@@ -4595,7 +4637,7 @@ bool ConstraintSystem::generateConstraints(
45954637

45964638
// Generate constraints to bind all of the internal declarations
45974639
// and verify the pattern.
4598-
Type patternType = generateConstraints(
4640+
auto patternType = generateConstraints(
45994641
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
46004642
target.getPatternBindingOfUninitializedVar(),
46014643
target.getIndexOfUninitializedVar());
@@ -4624,25 +4666,13 @@ Expr *ConstraintSystem::generateConstraints(
46244666
return generateConstraintsFor(*this, expr, dc);
46254667
}
46264668

4627-
Type ConstraintSystem::generateConstraints(
4669+
Optional<Type> ConstraintSystem::generateConstraints(
46284670
Pattern *pattern, ConstraintLocatorBuilder locator,
46294671
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
46304672
unsigned patternIndex) {
46314673
ConstraintGenerator cg(*this, nullptr);
4632-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4633-
patternBinding, patternIndex);
4634-
assert(ty);
4635-
4636-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4637-
SmallVector<ExprPattern *, 4> exprPatterns;
4638-
pattern->forEachNode([&](Pattern *P) {
4639-
if (auto *EP = dyn_cast<ExprPattern>(P))
4640-
exprPatterns.push_back(EP);
4641-
});
4642-
if (!exprPatterns.empty())
4643-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4644-
4645-
return ty;
4674+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4675+
patternBinding, patternIndex);
46464676
}
46474677

46484678
bool ConstraintSystem::generateConstraints(StmtCondition condition,

lib/Sema/CSSimplify.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5924,6 +5924,16 @@ bool ConstraintSystem::repairFailures(
59245924
if (repairByConstructingRawRepresentableType(lhs, rhs))
59255925
break;
59265926

5927+
// If this is for an initialization, but we don't have a contextual type,
5928+
// then this is a pattern match of something that isn't a TypedPattern. As
5929+
// such, this is more like a regular conversion than a contextual type
5930+
// conversion, and type mismatches ought be diagnosed elsewhere (e.g for an
5931+
// ExprPattern, we should diagnose an argument mismatch).
5932+
if (purpose == CTP_Initialization &&
5933+
getContextualTypeLoc(anchor).isNull()) {
5934+
break;
5935+
}
5936+
59275937
conversionsOrFixes.push_back(IgnoreContextualType::create(
59285938
*this, lhs, rhs, getConstraintLocator(locator)));
59295939
break;

0 commit comments

Comments
 (0)