Skip to content

Commit 691278b

Browse files
committed
[CS] Allow bidirectional inference for ExprPatterns
Rather than using a conjunction, just form the constraints for the match expression right away.
1 parent f592b5c commit 691278b

File tree

7 files changed

+109
-138
lines changed

7 files changed

+109
-138
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4329,27 +4329,22 @@ class ConstraintSystem {
43294329
/// \returns \c true if constraint generation failed, \c false otherwise
43304330
bool generateConstraints(SingleValueStmtExpr *E);
43314331

4332-
/// Generate constraints for an array of ExprPatterns, forming a conjunction
4333-
/// that solves each expression in turn.
4334-
void generateConstraints(ArrayRef<ExprPattern *> exprPatterns,
4335-
ConstraintLocatorBuilder locator);
4336-
43374332
/// Generate constraints for the given (unchecked) expression.
43384333
///
43394334
/// \returns a possibly-sanitized expression, or null if an error occurred.
43404335
[[nodiscard]]
43414336
Expr *generateConstraints(Expr *E, DeclContext *dc,
43424337
bool isInputExpression = true);
43434338

4344-
/// Generate constraints for binding the given pattern to the
4345-
/// value of the given expression.
4339+
/// Generate constraints for a given pattern.
43464340
///
4347-
/// \returns a possibly-sanitized initializer, or null if an error occurred.
4341+
/// \returns The type of the pattern, or \c None if a failure occured.
43484342
[[nodiscard]]
4349-
Type generateConstraints(Pattern *P, ConstraintLocatorBuilder locator,
4350-
bool bindPatternVarsOneWay,
4351-
PatternBindingDecl *patternBinding,
4352-
unsigned patternIndex);
4343+
Optional<Type> generateConstraints(Pattern *P,
4344+
ConstraintLocatorBuilder locator,
4345+
bool bindPatternVarsOneWay,
4346+
PatternBindingDecl *patternBinding,
4347+
unsigned patternIndex);
43534348

43544349
/// Generate constraints for a statement condition.
43554350
///

lib/Sema/CSGen.cpp

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2382,7 +2382,7 @@ namespace {
23822382
/// for the types of each variable declared within the pattern, along
23832383
/// with a one-way constraint binding that to the type to which the
23842384
/// variable will be ascribed or inferred.
2385-
Type getTypeForPattern(
2385+
Optional<Type> getTypeForPattern(
23862386
Pattern *pattern, ConstraintLocatorBuilder locator,
23872387
bool bindPatternVarsOneWay,
23882388
PatternBindingDecl *patternBinding = nullptr,
@@ -2406,14 +2406,21 @@ namespace {
24062406
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
24072407
bindPatternVarsOneWay);
24082408

2409-
return setType(ParenType::get(CS.getASTContext(), underlyingType));
2409+
if (!underlyingType)
2410+
return None;
2411+
2412+
return setType(ParenType::get(CS.getASTContext(), *underlyingType));
24102413
}
24112414
case PatternKind::Binding: {
24122415
auto *subPattern = cast<BindingPattern>(pattern)->getSubPattern();
24132416
auto type = getTypeForPattern(subPattern, locator,
24142417
bindPatternVarsOneWay);
2418+
2419+
if (!type)
2420+
return None;
2421+
24152422
// Var doesn't affect the type.
2416-
return setType(type);
2423+
return setType(*type);
24172424
}
24182425
case PatternKind::Any: {
24192426
Type type;
@@ -2587,6 +2594,9 @@ namespace {
25872594

25882595
Type type = TypeChecker::typeCheckPattern(contextualPattern);
25892596

2597+
if (!type)
2598+
return None;
2599+
25902600
// Look through reference storage types.
25912601
type = type->getReferenceStorageReferent();
25922602

@@ -2598,16 +2608,19 @@ namespace {
25982608
auto *subPattern = cast<TypedPattern>(pattern)->getSubPattern();
25992609
// Determine the subpattern type. It will be convertible to the
26002610
// ascribed type.
2601-
Type subPatternType = getTypeForPattern(
2611+
auto subPatternType = getTypeForPattern(
26022612
subPattern,
26032613
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26042614
bindPatternVarsOneWay);
26052615

2616+
if (!subPatternType)
2617+
return None;
2618+
26062619
// NOTE: The order here is important! Pattern matching equality is
26072620
// not symmetric (we need to fix that either by using a different
26082621
// constraint, or actually making it symmetric).
26092622
CS.addConstraint(
2610-
ConstraintKind::Equal, openedType, subPatternType,
2623+
ConstraintKind::Equal, openedType, *subPatternType,
26112624
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26122625

26132626
// FIXME [OPAQUE SUPPORT]: the distinction between where we want opaque
@@ -2627,12 +2640,15 @@ namespace {
26272640
auto &tupleElt = tuplePat->getElement(i);
26282641

26292642
auto *eltPattern = tupleElt.getPattern();
2630-
Type eltTy = getTypeForPattern(
2643+
auto eltTy = getTypeForPattern(
26312644
eltPattern,
26322645
locator.withPathElement(LocatorPathElt::PatternMatch(eltPattern)),
26332646
bindPatternVarsOneWay);
26342647

2635-
tupleTypeElts.push_back(TupleTypeElt(eltTy, tupleElt.getLabel()));
2648+
if (!eltTy)
2649+
return None;
2650+
2651+
tupleTypeElts.push_back(TupleTypeElt(*eltTy, tupleElt.getLabel()));
26362652
}
26372653

26382654
return setType(TupleType::get(tupleTypeElts, CS.getASTContext()));
@@ -2641,12 +2657,15 @@ namespace {
26412657
case PatternKind::OptionalSome: {
26422658
auto *subPattern = cast<OptionalSomePattern>(pattern)->getSubPattern();
26432659
// The subpattern must have optional type.
2644-
Type subPatternType = getTypeForPattern(
2660+
auto subPatternType = getTypeForPattern(
26452661
subPattern,
26462662
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26472663
bindPatternVarsOneWay);
26482664

2649-
return setType(OptionalType::get(subPatternType));
2665+
if (!subPatternType)
2666+
return None;
2667+
2668+
return setType(OptionalType::get(*subPatternType));
26502669
}
26512670

26522671
case PatternKind::Is: {
@@ -2675,12 +2694,14 @@ namespace {
26752694
subPattern,
26762695
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
26772696
bindPatternVarsOneWay);
2697+
if (!subPatternType)
2698+
return None;
26782699

26792700
// NOTE: The order here is important! Pattern matching equality is
26802701
// not symmetric (we need to fix that either by using a different
26812702
// constraint, or actually making it symmetric).
26822703
CS.addConstraint(
2683-
ConstraintKind::Equal, castType, subPatternType,
2704+
ConstraintKind::Equal, castType, *subPatternType,
26842705
locator.withPathElement(LocatorPathElt::PatternMatch(pattern)));
26852706
}
26862707
return setType(isType);
@@ -2741,6 +2762,9 @@ namespace {
27412762
TypeResolverContext::InExpression, patternMatchLoc);
27422763
}();
27432764

2765+
if (!parentType)
2766+
return None;
2767+
27442768
// Perform member lookup into the parent's metatype.
27452769
Type parentMetaType = MetatypeType::get(parentType);
27462770
CS.addValueMemberConstraint(
@@ -2770,13 +2794,16 @@ namespace {
27702794
// When there is a subpattern, the member will have function type,
27712795
// and we're matching the type of that subpattern to the parameter
27722796
// types.
2773-
Type subPatternType = getTypeForPattern(
2797+
auto subPatternType = getTypeForPattern(
27742798
subPattern,
27752799
locator.withPathElement(LocatorPathElt::PatternMatch(subPattern)),
27762800
bindPatternVarsOneWay);
27772801

2802+
if (!subPatternType)
2803+
return None;
2804+
27782805
SmallVector<AnyFunctionType::Param, 4> params;
2779-
decomposeTuple(subPatternType, params);
2806+
decomposeTuple(*subPatternType, params);
27802807

27812808
// Remove parameter labels; they aren't used when matching cases,
27822809
// but outright conflicts will be checked during coercion.
@@ -2811,10 +2838,24 @@ namespace {
28112838
}
28122839

28132840
case PatternKind::Expr: {
2814-
// We generate constraints for ExprPatterns in a separate pass. For
2815-
// now, just create a type variable.
2816-
return setType(CS.createTypeVariable(CS.getConstraintLocator(locator),
2817-
TVO_CanBindToNoEscape));
2841+
auto *EP = cast<ExprPattern>(pattern);
2842+
Type patternTy = CS.createTypeVariable(CS.getConstraintLocator(locator),
2843+
TVO_CanBindToNoEscape);
2844+
2845+
auto target = SyntacticElementTarget::forExprPattern(EP);
2846+
2847+
if (CS.preCheckTarget(target, /*replaceInvalidRefWithErrors=*/true,
2848+
/*leaveClosureBodyUnchecked=*/false)) {
2849+
return None;
2850+
}
2851+
CS.setType(EP->getMatchVar(), patternTy);
2852+
2853+
if (CS.generateConstraints(target))
2854+
return None;
2855+
2856+
CS.setTargetFor(EP, target);
2857+
CS.setExprPatternFor(EP->getSubExpr(), EP);
2858+
return setType(patternTy);
28182859
}
28192860
}
28202861

@@ -4197,10 +4238,14 @@ static bool generateInitPatternConstraints(ConstraintSystem &cs,
41974238

41984239
Type patternType;
41994240
if (auto pattern = target.getInitializationPattern()) {
4200-
patternType = cs.generateConstraints(
4241+
auto ty = cs.generateConstraints(
42014242
pattern, locator, target.shouldBindPatternVarsOneWay(),
42024243
target.getInitializationPatternBindingDecl(),
42034244
target.getInitializationPatternBindingIndex());
4245+
if (!ty)
4246+
return true;
4247+
4248+
patternType = *ty;
42044249
} else {
42054250
patternType = cs.createTypeVariable(locator, TVO_CanBindToNoEscape);
42064251
}
@@ -4359,7 +4404,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43594404
// Collect constraints from the element pattern.
43604405
auto elementLocator = cs.getConstraintLocator(
43614406
sequenceExpr, ConstraintLocator::SequenceElementType);
4362-
Type initType =
4407+
auto initType =
43634408
cs.generateConstraints(pattern, elementLocator,
43644409
target.shouldBindPatternVarsOneWay(), nullptr, 0);
43654410
if (!initType)
@@ -4378,7 +4423,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
43784423
// resolving `optional object` constraint which is sometimes too eager.
43794424
cs.addConstraint(ConstraintKind::Conversion, nextType,
43804425
OptionalType::get(elementType), elementTypeLoc);
4381-
cs.addConstraint(ConstraintKind::Conversion, elementType, initType,
4426+
cs.addConstraint(ConstraintKind::Conversion, elementType, *initType,
43824427
elementLocator);
43834428
}
43844429

@@ -4404,7 +4449,7 @@ generateForEachStmtConstraints(ConstraintSystem &cs,
44044449

44054450
// Populate all of the information for a for-each loop.
44064451
forEachStmtInfo.elementType = elementType;
4407-
forEachStmtInfo.initType = initType;
4452+
forEachStmtInfo.initType = *initType;
44084453
target.setPattern(pattern);
44094454
target.getForEachStmtInfo() = forEachStmtInfo;
44104455
return target;
@@ -4584,7 +4629,7 @@ bool ConstraintSystem::generateConstraints(
45844629

45854630
// Generate constraints to bind all of the internal declarations
45864631
// and verify the pattern.
4587-
Type patternType = generateConstraints(
4632+
auto patternType = generateConstraints(
45884633
pattern, locator, /*shouldBindPatternVarsOneWay*/ true,
45894634
target.getPatternBindingOfUninitializedVar(),
45904635
target.getIndexOfUninitializedVar());
@@ -4613,25 +4658,13 @@ Expr *ConstraintSystem::generateConstraints(
46134658
return generateConstraintsFor(*this, expr, dc);
46144659
}
46154660

4616-
Type ConstraintSystem::generateConstraints(
4661+
Optional<Type> ConstraintSystem::generateConstraints(
46174662
Pattern *pattern, ConstraintLocatorBuilder locator,
46184663
bool bindPatternVarsOneWay, PatternBindingDecl *patternBinding,
46194664
unsigned patternIndex) {
46204665
ConstraintGenerator cg(*this, nullptr);
4621-
auto ty = cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4622-
patternBinding, patternIndex);
4623-
assert(ty);
4624-
4625-
// Gather the ExprPatterns, and form a conjunction for their expressions.
4626-
SmallVector<ExprPattern *, 4> exprPatterns;
4627-
pattern->forEachNode([&](Pattern *P) {
4628-
if (auto *EP = dyn_cast<ExprPattern>(P))
4629-
exprPatterns.push_back(EP);
4630-
});
4631-
if (!exprPatterns.empty())
4632-
generateConstraints(exprPatterns, getConstraintLocator(pattern));
4633-
4634-
return ty;
4666+
return cg.getTypeForPattern(pattern, locator, bindPatternVarsOneWay,
4667+
patternBinding, patternIndex);
46354668
}
46364669

46374670
bool ConstraintSystem::generateConstraints(StmtCondition condition,
@@ -4713,14 +4746,16 @@ bool ConstraintSystem::generateConstraints(
47134746
// Generate constraints for the pattern, including one-way bindings for
47144747
// any variables that show up in this pattern, because those variables
47154748
// can be referenced in the guard expressions and the body.
4716-
Type patternType = generateConstraints(
4749+
auto patternType = generateConstraints(
47174750
pattern, locator, /* bindPatternVarsOneWay=*/true,
47184751
/*patternBinding=*/nullptr, /*patternBindingIndex=*/0);
4752+
if (!patternType)
4753+
return true;
47194754

47204755
// Convert the subject type to the pattern, which establishes the
47214756
// bindings.
47224757
addConstraint(
4723-
ConstraintKind::Conversion, subjectType, patternType, locator);
4758+
ConstraintKind::Conversion, subjectType, *patternType, locator);
47244759

47254760
// Generate constraints for the guard expression, if there is one.
47264761
Expr *guardExpr = caseLabelItem.getGuardExpr();

lib/Sema/CSSimplify.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5937,6 +5937,16 @@ bool ConstraintSystem::repairFailures(
59375937
if (repairByConstructingRawRepresentableType(lhs, rhs))
59385938
break;
59395939

5940+
// If this is for an initialization, but we don't have a contextual type,
5941+
// then this is a pattern match of something that isn't a TypedPattern. As
5942+
// such, this is more like a regular conversion than a contextual type
5943+
// conversion, and type mismatches ought be diagnosed elsewhere (e.g for an
5944+
// ExprPattern, we should diagnose an argument mismatch).
5945+
if (purpose == CTP_Initialization &&
5946+
getContextualTypeLoc(anchor).isNull()) {
5947+
break;
5948+
}
5949+
59405950
conversionsOrFixes.push_back(IgnoreContextualType::create(
59415951
*this, lhs, rhs, getConstraintLocator(locator)));
59425952
break;

0 commit comments

Comments
 (0)