Skip to content

Commit f3f6be7

Browse files
committed
Sema: Support optional promotion of tuple patterns in closure contexts
1 parent 148f86d commit f3f6be7

9 files changed

+172
-102
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,7 +1960,7 @@ class SolutionApplicationTarget {
19601960
closure,
19611961
function,
19621962
stmtCondition,
1963-
caseLabelItem,
1963+
caseStmt,
19641964
patternBinding,
19651965
uninitializedVar,
19661966
forEachStmt,
@@ -2044,9 +2044,14 @@ class SolutionApplicationTarget {
20442044
} stmtCondition;
20452045

20462046
struct {
2047-
CaseLabelItem *caseLabelItem;
2047+
CaseStmt *caseStmt;
2048+
2049+
/// The type to which the patterns in case label items should
2050+
/// be converted.
2051+
Type convertType;
2052+
20482053
DeclContext *dc;
2049-
} caseLabelItem;
2054+
} caseStmt;
20502055

20512056
struct {
20522057
PatternBindingDecl *binding;
@@ -2124,10 +2129,12 @@ class SolutionApplicationTarget {
21242129
function.body = body;
21252130
}
21262131

2127-
SolutionApplicationTarget(CaseLabelItem *caseLabelItem, DeclContext *dc) {
2128-
kind = Kind::caseLabelItem;
2129-
this->caseLabelItem.caseLabelItem = caseLabelItem;
2130-
this->caseLabelItem.dc = dc;
2132+
SolutionApplicationTarget(CaseStmt *caseStmt, Type convertType,
2133+
DeclContext *dc) {
2134+
kind = Kind::caseStmt;
2135+
this->caseStmt.caseStmt = caseStmt;
2136+
this->caseStmt.convertType = convertType;
2137+
this->caseStmt.dc = dc;
21312138
}
21322139

21332140
SolutionApplicationTarget(PatternBindingDecl *patternBinding) {
@@ -2229,8 +2236,8 @@ class SolutionApplicationTarget {
22292236
case Kind::stmtCondition:
22302237
return ASTNode();
22312238

2232-
case Kind::caseLabelItem:
2233-
return *getAsCaseLabelItem();
2239+
case Kind::caseStmt:
2240+
return getAsCaseStmt();
22342241

22352242
case Kind::patternBinding:
22362243
return getAsPatternBinding();
@@ -2248,7 +2255,7 @@ class SolutionApplicationTarget {
22482255
case Kind::closure:
22492256
case Kind::function:
22502257
case Kind::stmtCondition:
2251-
case Kind::caseLabelItem:
2258+
case Kind::caseStmt:
22522259
case Kind::patternBinding:
22532260
case Kind::uninitializedVar:
22542261
case Kind::forEachStmt:
@@ -2271,8 +2278,8 @@ class SolutionApplicationTarget {
22712278
case Kind::stmtCondition:
22722279
return stmtCondition.dc;
22732280

2274-
case Kind::caseLabelItem:
2275-
return caseLabelItem.dc;
2281+
case Kind::caseStmt:
2282+
return caseStmt.dc;
22762283

22772284
case Kind::patternBinding:
22782285
return patternBinding->getDeclContext();
@@ -2507,7 +2514,7 @@ class SolutionApplicationTarget {
25072514
case Kind::expression:
25082515
case Kind::closure:
25092516
case Kind::stmtCondition:
2510-
case Kind::caseLabelItem:
2517+
case Kind::caseStmt:
25112518
case Kind::patternBinding:
25122519
case Kind::uninitializedVar:
25132520
case Kind::forEachStmt:
@@ -2524,7 +2531,7 @@ class SolutionApplicationTarget {
25242531
case Kind::expression:
25252532
case Kind::closure:
25262533
case Kind::function:
2527-
case Kind::caseLabelItem:
2534+
case Kind::caseStmt:
25282535
case Kind::patternBinding:
25292536
case Kind::uninitializedVar:
25302537
case Kind::forEachStmt:
@@ -2536,7 +2543,13 @@ class SolutionApplicationTarget {
25362543
llvm_unreachable("invalid statement kind");
25372544
}
25382545

2539-
Optional<CaseLabelItem *> getAsCaseLabelItem() const {
2546+
/// Get the type to which patterns in the \c CaseStmt should be converted.
2547+
Type getCaseStmtContextualType() const {
2548+
assert(kind == Kind::caseStmt);
2549+
return caseStmt.convertType;
2550+
}
2551+
2552+
CaseStmt *getAsCaseStmt() const {
25402553
switch (kind) {
25412554
case Kind::expression:
25422555
case Kind::closure:
@@ -2545,10 +2558,10 @@ class SolutionApplicationTarget {
25452558
case Kind::patternBinding:
25462559
case Kind::uninitializedVar:
25472560
case Kind::forEachStmt:
2548-
return None;
2561+
return nullptr;
25492562

2550-
case Kind::caseLabelItem:
2551-
return caseLabelItem.caseLabelItem;
2563+
case Kind::caseStmt:
2564+
return caseStmt.caseStmt;
25522565
}
25532566
llvm_unreachable("invalid case label type");
25542567
}
@@ -2559,7 +2572,7 @@ class SolutionApplicationTarget {
25592572
case Kind::closure:
25602573
case Kind::function:
25612574
case Kind::stmtCondition:
2562-
case Kind::caseLabelItem:
2575+
case Kind::caseStmt:
25632576
case Kind::uninitializedVar:
25642577
case Kind::forEachStmt:
25652578
return nullptr;
@@ -2576,7 +2589,7 @@ class SolutionApplicationTarget {
25762589
case Kind::closure:
25772590
case Kind::function:
25782591
case Kind::stmtCondition:
2579-
case Kind::caseLabelItem:
2592+
case Kind::caseStmt:
25802593
case Kind::patternBinding:
25812594
case Kind::forEachStmt:
25822595
return nullptr;
@@ -2593,7 +2606,7 @@ class SolutionApplicationTarget {
25932606
case Kind::closure:
25942607
case Kind::function:
25952608
case Kind::stmtCondition:
2596-
case Kind::caseLabelItem:
2609+
case Kind::caseStmt:
25972610
case Kind::patternBinding:
25982611
case Kind::forEachStmt:
25992612
return nullptr;
@@ -2610,7 +2623,7 @@ class SolutionApplicationTarget {
26102623
case Kind::closure:
26112624
case Kind::function:
26122625
case Kind::stmtCondition:
2613-
case Kind::caseLabelItem:
2626+
case Kind::caseStmt:
26142627
case Kind::patternBinding:
26152628
case Kind::uninitializedVar:
26162629
return nullptr;
@@ -2627,7 +2640,7 @@ class SolutionApplicationTarget {
26272640
case Kind::closure:
26282641
case Kind::function:
26292642
case Kind::stmtCondition:
2630-
case Kind::caseLabelItem:
2643+
case Kind::caseStmt:
26312644
case Kind::patternBinding:
26322645
case Kind::forEachStmt:
26332646
return nullptr;
@@ -2644,7 +2657,7 @@ class SolutionApplicationTarget {
26442657
case Kind::closure:
26452658
case Kind::function:
26462659
case Kind::stmtCondition:
2647-
case Kind::caseLabelItem:
2660+
case Kind::caseStmt:
26482661
case Kind::patternBinding:
26492662
case Kind::forEachStmt:
26502663
return nullptr;
@@ -2661,7 +2674,7 @@ class SolutionApplicationTarget {
26612674
case Kind::closure:
26622675
case Kind::function:
26632676
case Kind::stmtCondition:
2664-
case Kind::caseLabelItem:
2677+
case Kind::caseStmt:
26652678
case Kind::patternBinding:
26662679
case Kind::forEachStmt:
26672680
return 0;
@@ -2698,8 +2711,8 @@ class SolutionApplicationTarget {
26982711
return SourceRange(stmtCondition.stmtCondition.front().getStartLoc(),
26992712
stmtCondition.stmtCondition.back().getEndLoc());
27002713

2701-
case Kind::caseLabelItem:
2702-
return caseLabelItem.caseLabelItem->getSourceRange();
2714+
case Kind::caseStmt:
2715+
return caseStmt.caseStmt->getLabelItemsRange();
27032716

27042717
case Kind::patternBinding:
27052718
return patternBinding->getSourceRange();
@@ -2742,8 +2755,8 @@ class SolutionApplicationTarget {
27422755
case Kind::stmtCondition:
27432756
return stmtCondition.stmtCondition.front().getStartLoc();
27442757

2745-
case Kind::caseLabelItem:
2746-
return caseLabelItem.caseLabelItem->getStartLoc();
2758+
case Kind::caseStmt:
2759+
return caseStmt.caseStmt->getStartLoc();
27472760

27482761
case Kind::patternBinding:
27492762
return patternBinding->getLoc();

lib/Sema/BuilderTransform.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,9 @@ class BuilderClosureVisitor
765765
subjectExpr, LocatorPathElt::ContextualType(CTP_CaseStmt));
766766
Type subjectType = cs->getType(subjectExpr);
767767

768+
cs->setSolutionApplicationTarget(
769+
caseStmt, SolutionApplicationTarget(caseStmt, subjectType, dc));
770+
768771
if (cs->generateConstraints(caseStmt, dc, subjectType, locator)) {
769772
hadError = true;
770773
return nullptr;
@@ -2106,11 +2109,9 @@ class BuilderClosureRewriter
21062109
NullablePtr<Stmt> visitCaseStmt(CaseStmt *caseStmt,
21072110
ResultBuilderTarget target) {
21082111
// Translate the patterns and guard expressions for each case label item.
2109-
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
2110-
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, dc);
2111-
if (!rewriteTarget(caseLabelTarget))
2112-
return nullptr;
2113-
}
2112+
const auto &cs = solution.getConstraintSystem();
2113+
if (!rewriteTarget(*cs.getSolutionApplicationTarget(caseStmt)))
2114+
return nullptr;
21142115

21152116
// Setup the types of our case body var decls.
21162117
for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) {

lib/Sema/CSApply.cpp

Lines changed: 56 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -9177,60 +9177,65 @@ ExprWalker::rewriteTarget(SolutionApplicationTarget target) {
91779177
}
91789178

91799179
return target;
9180-
} else if (auto caseLabelItem = target.getAsCaseLabelItem()) {
9180+
} else if (auto *caseStmt = target.getAsCaseStmt()) {
91819181
ConstraintSystem &cs = solution.getConstraintSystem();
9182-
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
91839182

9184-
// Figure out the pattern type.
9185-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9186-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9183+
// Figure out the contextual type for patterns.
9184+
const Type contextualTy =
9185+
solution.simplifyType(target.getCaseStmtContextualType())
9186+
->reconstituteSugar(/*recursive=*/false)
9187+
->getRValueType();
91879188

9188-
// Check whether this enum element is resolved via ~= application.
9189-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9190-
if (auto target = cs.getSolutionApplicationTarget(enumElement)) {
9191-
auto *EP = target->getExprPattern();
9192-
auto enumType = solution.getResolvedType(EP);
9189+
// Rewrite each case label item.
9190+
for (auto &item : caseStmt->getMutableCaseLabelItems()) {
9191+
auto info = *cs.getCaseLabelItemInfo(&item);
91939192

9194-
auto *matchCall = target->getAsExpr();
9193+
// Check whether this enum element is resolved via ~= application.
9194+
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9195+
if (auto target = cs.getSolutionApplicationTarget(enumElement)) {
9196+
auto *EP = target->getExprPattern();
9197+
auto enumType = solution.getResolvedType(EP);
91959198

9196-
auto *result = matchCall->walk(*this);
9197-
if (!result)
9198-
return None;
9199+
auto *matchCall = target->getAsExpr();
91999200

9200-
{
9201-
auto *matchVar = EP->getMatchVar();
9202-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9203-
}
9201+
auto *result = matchCall->walk(*this);
9202+
if (!result)
9203+
return None;
92049204

9205-
EP->setMatchExpr(result);
9206-
EP->setType(enumType);
9205+
{
9206+
auto *matchVar = EP->getMatchVar();
9207+
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9208+
}
92079209

9208-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9209-
return target;
9210-
}
9211-
}
9210+
EP->setMatchExpr(result);
9211+
EP->setType(enumType);
92129212

9213-
// Coerce the pattern to its appropriate type.
9214-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9215-
patternOptions |= TypeResolutionFlags::OverrideType;
9216-
auto contextualPattern =
9217-
ContextualPattern::forRawPattern(info.pattern,
9218-
target.getDeclContext());
9219-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9220-
contextualPattern, patternType, patternOptions)) {
9221-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9222-
} else {
9223-
return None;
9224-
}
9213+
item.setPattern(EP, /*resolved=*/true);
9214+
return target;
9215+
}
9216+
}
92259217

9226-
// If there is a guard expression, coerce that.
9227-
if (auto *guardExpr = info.guardExpr) {
9228-
auto target = *cs.getSolutionApplicationTarget(guardExpr);
9229-
auto resultTarget = rewriteTarget(target);
9230-
if (!resultTarget)
9218+
// Coerce the pattern to its appropriate type.
9219+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9220+
patternOptions |= TypeResolutionFlags::OverrideType;
9221+
auto contextualPattern = ContextualPattern::forRawPattern(
9222+
info.pattern, target.getDeclContext());
9223+
if (auto coercedPattern = TypeChecker::coercePatternToType(
9224+
contextualPattern, contextualTy, patternOptions)) {
9225+
item.setPattern(coercedPattern, /*resolved=*/true);
9226+
} else {
92319227
return None;
9228+
}
9229+
9230+
// If there is a guard expression, coerce that.
9231+
if (auto *guardExpr = info.guardExpr) {
9232+
auto target = *cs.getSolutionApplicationTarget(guardExpr);
9233+
auto resultTarget = rewriteTarget(target);
9234+
if (!resultTarget)
9235+
return None;
92329236

9233-
(*caseLabelItem)->setGuardExpr(resultTarget->getAsExpr());
9237+
item.setGuardExpr(resultTarget->getAsExpr());
9238+
}
92349239
}
92359240

92369241
return target;
@@ -9654,15 +9659,15 @@ SolutionApplicationTarget SolutionApplicationTarget::walk(ASTWalker &walker) {
96549659
}
96559660
return *this;
96569661

9657-
case Kind::caseLabelItem:
9658-
if (auto newPattern =
9659-
caseLabelItem.caseLabelItem->getPattern()->walk(walker)) {
9660-
caseLabelItem.caseLabelItem->setPattern(
9661-
newPattern, caseLabelItem.caseLabelItem->isPatternResolved());
9662-
}
9663-
if (auto guardExpr = caseLabelItem.caseLabelItem->getGuardExpr()) {
9664-
if (auto newGuardExpr = guardExpr->walk(walker))
9665-
caseLabelItem.caseLabelItem->setGuardExpr(newGuardExpr);
9662+
case Kind::caseStmt:
9663+
for (auto &item : caseStmt.caseStmt->getMutableCaseLabelItems()) {
9664+
if (auto *newPattern = item.getPattern()->walk(walker)) {
9665+
item.setPattern(newPattern, item.isPatternResolved());
9666+
}
9667+
if (auto *guardExpr = item.getGuardExpr()) {
9668+
if (auto *newGuardExpr = guardExpr->walk(walker))
9669+
item.setGuardExpr(newGuardExpr);
9670+
}
96669671
}
96679672

96689673
return *this;

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4568,7 +4568,7 @@ bool ConstraintSystem::generateConstraints(
45684568
llvm_unreachable("Handled above");
45694569

45704570
case SolutionApplicationTarget::Kind::closure:
4571-
case SolutionApplicationTarget::Kind::caseLabelItem:
4571+
case SolutionApplicationTarget::Kind::caseStmt:
45724572
case SolutionApplicationTarget::Kind::function:
45734573
case SolutionApplicationTarget::Kind::stmtCondition:
45744574
llvm_unreachable("Handled separately");

0 commit comments

Comments
 (0)