Skip to content

Commit 85d8479

Browse files
committed
Sema: Support optional promotion of tuple patterns in closure contexts
1 parent afb7d42 commit 85d8479

10 files changed

+173
-102
lines changed

include/swift/Sema/SyntacticElementTarget.h

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class SyntacticElementTarget {
5555
closure,
5656
function,
5757
stmtCondition,
58-
caseLabelItem,
58+
caseStmt,
5959
patternBinding,
6060
uninitializedVar,
6161
forEachStmt,
@@ -139,9 +139,14 @@ class SyntacticElementTarget {
139139
} stmtCondition;
140140

141141
struct {
142-
CaseLabelItem *caseLabelItem;
142+
CaseStmt *caseStmt;
143+
144+
/// The type to which the patterns in case label items should
145+
/// be converted.
146+
Type convertType;
147+
143148
DeclContext *dc;
144-
} caseLabelItem;
149+
} caseStmt;
145150

146151
struct {
147152
PatternBindingDecl *binding;
@@ -218,10 +223,12 @@ class SyntacticElementTarget {
218223
function.body = body;
219224
}
220225

221-
SyntacticElementTarget(CaseLabelItem *caseLabelItem, DeclContext *dc) {
222-
kind = Kind::caseLabelItem;
223-
this->caseLabelItem.caseLabelItem = caseLabelItem;
224-
this->caseLabelItem.dc = dc;
226+
SyntacticElementTarget(CaseStmt *caseStmt, Type convertType,
227+
DeclContext *dc) {
228+
kind = Kind::caseStmt;
229+
this->caseStmt.caseStmt = caseStmt;
230+
this->caseStmt.convertType = convertType;
231+
this->caseStmt.dc = dc;
225232
}
226233

227234
SyntacticElementTarget(PatternBindingDecl *patternBinding) {
@@ -324,8 +331,8 @@ class SyntacticElementTarget {
324331
case Kind::stmtCondition:
325332
return ASTNode();
326333

327-
case Kind::caseLabelItem:
328-
return *getAsCaseLabelItem();
334+
case Kind::caseStmt:
335+
return getAsCaseStmt();
329336

330337
case Kind::patternBinding:
331338
return getAsPatternBinding();
@@ -343,7 +350,7 @@ class SyntacticElementTarget {
343350
case Kind::closure:
344351
case Kind::function:
345352
case Kind::stmtCondition:
346-
case Kind::caseLabelItem:
353+
case Kind::caseStmt:
347354
case Kind::patternBinding:
348355
case Kind::uninitializedVar:
349356
case Kind::forEachStmt:
@@ -366,8 +373,8 @@ class SyntacticElementTarget {
366373
case Kind::stmtCondition:
367374
return stmtCondition.dc;
368375

369-
case Kind::caseLabelItem:
370-
return caseLabelItem.dc;
376+
case Kind::caseStmt:
377+
return caseStmt.dc;
371378

372379
case Kind::patternBinding:
373380
return patternBinding->getDeclContext();
@@ -598,7 +605,7 @@ class SyntacticElementTarget {
598605
case Kind::expression:
599606
case Kind::closure:
600607
case Kind::stmtCondition:
601-
case Kind::caseLabelItem:
608+
case Kind::caseStmt:
602609
case Kind::patternBinding:
603610
case Kind::uninitializedVar:
604611
case Kind::forEachStmt:
@@ -615,7 +622,7 @@ class SyntacticElementTarget {
615622
case Kind::expression:
616623
case Kind::closure:
617624
case Kind::function:
618-
case Kind::caseLabelItem:
625+
case Kind::caseStmt:
619626
case Kind::patternBinding:
620627
case Kind::uninitializedVar:
621628
case Kind::forEachStmt:
@@ -627,7 +634,13 @@ class SyntacticElementTarget {
627634
llvm_unreachable("invalid statement kind");
628635
}
629636

630-
Optional<CaseLabelItem *> getAsCaseLabelItem() const {
637+
/// Get the type to which patterns in the \c CaseStmt should be converted.
638+
Type getCaseStmtContextualType() const {
639+
assert(kind == Kind::caseStmt);
640+
return caseStmt.convertType;
641+
}
642+
643+
CaseStmt *getAsCaseStmt() const {
631644
switch (kind) {
632645
case Kind::expression:
633646
case Kind::closure:
@@ -636,10 +649,10 @@ class SyntacticElementTarget {
636649
case Kind::patternBinding:
637650
case Kind::uninitializedVar:
638651
case Kind::forEachStmt:
639-
return None;
652+
return nullptr;
640653

641-
case Kind::caseLabelItem:
642-
return caseLabelItem.caseLabelItem;
654+
case Kind::caseStmt:
655+
return caseStmt.caseStmt;
643656
}
644657
llvm_unreachable("invalid case label type");
645658
}
@@ -650,7 +663,7 @@ class SyntacticElementTarget {
650663
case Kind::closure:
651664
case Kind::function:
652665
case Kind::stmtCondition:
653-
case Kind::caseLabelItem:
666+
case Kind::caseStmt:
654667
case Kind::uninitializedVar:
655668
case Kind::forEachStmt:
656669
return nullptr;
@@ -667,7 +680,7 @@ class SyntacticElementTarget {
667680
case Kind::closure:
668681
case Kind::function:
669682
case Kind::stmtCondition:
670-
case Kind::caseLabelItem:
683+
case Kind::caseStmt:
671684
case Kind::patternBinding:
672685
case Kind::forEachStmt:
673686
return nullptr;
@@ -684,7 +697,7 @@ class SyntacticElementTarget {
684697
case Kind::closure:
685698
case Kind::function:
686699
case Kind::stmtCondition:
687-
case Kind::caseLabelItem:
700+
case Kind::caseStmt:
688701
case Kind::patternBinding:
689702
case Kind::forEachStmt:
690703
return nullptr;
@@ -701,7 +714,7 @@ class SyntacticElementTarget {
701714
case Kind::closure:
702715
case Kind::function:
703716
case Kind::stmtCondition:
704-
case Kind::caseLabelItem:
717+
case Kind::caseStmt:
705718
case Kind::patternBinding:
706719
case Kind::uninitializedVar:
707720
return nullptr;
@@ -718,7 +731,7 @@ class SyntacticElementTarget {
718731
case Kind::closure:
719732
case Kind::function:
720733
case Kind::stmtCondition:
721-
case Kind::caseLabelItem:
734+
case Kind::caseStmt:
722735
case Kind::patternBinding:
723736
case Kind::forEachStmt:
724737
return nullptr;
@@ -735,7 +748,7 @@ class SyntacticElementTarget {
735748
case Kind::closure:
736749
case Kind::function:
737750
case Kind::stmtCondition:
738-
case Kind::caseLabelItem:
751+
case Kind::caseStmt:
739752
case Kind::patternBinding:
740753
case Kind::forEachStmt:
741754
return nullptr;
@@ -752,7 +765,7 @@ class SyntacticElementTarget {
752765
case Kind::closure:
753766
case Kind::function:
754767
case Kind::stmtCondition:
755-
case Kind::caseLabelItem:
768+
case Kind::caseStmt:
756769
case Kind::patternBinding:
757770
case Kind::forEachStmt:
758771
return 0;
@@ -801,8 +814,8 @@ class SyntacticElementTarget {
801814
return SourceRange(stmtCondition.stmtCondition.front().getStartLoc(),
802815
stmtCondition.stmtCondition.back().getEndLoc());
803816

804-
case Kind::caseLabelItem:
805-
return caseLabelItem.caseLabelItem->getSourceRange();
817+
case Kind::caseStmt:
818+
return caseStmt.caseStmt->getLabelItemsRange();
806819

807820
case Kind::patternBinding:
808821
return patternBinding->getSourceRange();
@@ -845,8 +858,8 @@ class SyntacticElementTarget {
845858
case Kind::stmtCondition:
846859
return stmtCondition.stmtCondition.front().getStartLoc();
847860

848-
case Kind::caseLabelItem:
849-
return caseLabelItem.caseLabelItem->getStartLoc();
861+
case Kind::caseStmt:
862+
return caseStmt.caseStmt->getStartLoc();
850863

851864
case Kind::patternBinding:
852865
return patternBinding->getLoc();

lib/Sema/BuilderTransform.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,9 @@ class BuilderClosureVisitor
766766
subjectExpr, LocatorPathElt::ContextualType(CTP_CaseStmt));
767767
Type subjectType = cs->getType(subjectExpr);
768768

769+
cs->setTargetFor(caseStmt,
770+
SyntacticElementTarget(caseStmt, subjectType, dc));
771+
769772
if (cs->generateConstraints(caseStmt, dc, subjectType, locator)) {
770773
hadError = true;
771774
return nullptr;
@@ -2102,11 +2105,9 @@ class BuilderClosureRewriter
21022105
NullablePtr<Stmt> visitCaseStmt(CaseStmt *caseStmt,
21032106
ResultBuilderTarget target) {
21042107
// Translate the patterns and guard expressions for each case label item.
2105-
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
2106-
SyntacticElementTarget caseLabelTarget(&caseLabelItem, dc);
2107-
if (!rewriteTarget(caseLabelTarget))
2108-
return nullptr;
2109-
}
2108+
const auto &cs = solution.getConstraintSystem();
2109+
if (!rewriteTarget(*cs.getTargetFor(caseStmt)))
2110+
return nullptr;
21102111

21112112
// Setup the types of our case body var decls.
21122113
for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) {

lib/Sema/CSApply.cpp

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -9218,60 +9218,65 @@ ExprWalker::rewriteTarget(SyntacticElementTarget target) {
92189218
}
92199219

92209220
return target;
9221-
} else if (auto caseLabelItem = target.getAsCaseLabelItem()) {
9221+
} else if (auto *caseStmt = target.getAsCaseStmt()) {
92229222
ConstraintSystem &cs = solution.getConstraintSystem();
9223-
auto info = *cs.getCaseLabelItemInfo(*caseLabelItem);
92249223

9225-
// Figure out the pattern type.
9226-
Type patternType = solution.simplifyType(solution.getType(info.pattern));
9227-
patternType = patternType->reconstituteSugar(/*recursive=*/false);
9224+
// Figure out the contextual type for patterns.
9225+
const Type contextualTy =
9226+
solution.simplifyType(target.getCaseStmtContextualType())
9227+
->reconstituteSugar(/*recursive=*/false)
9228+
->getRValueType();
92289229

9229-
// Check whether this enum element is resolved via ~= application.
9230-
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9231-
if (auto target = cs.getTargetFor(enumElement)) {
9232-
auto *EP = target->getExprPattern();
9233-
auto enumType = solution.getResolvedType(EP);
9230+
// Rewrite each case label item.
9231+
for (auto &item : caseStmt->getMutableCaseLabelItems()) {
9232+
auto info = *cs.getCaseLabelItemInfo(&item);
92349233

9235-
auto *matchCall = target->getAsExpr();
9234+
// Check whether this enum element is resolved via ~= application.
9235+
if (auto *enumElement = dyn_cast<EnumElementPattern>(info.pattern)) {
9236+
if (auto target = cs.getTargetFor(enumElement)) {
9237+
auto *EP = target->getExprPattern();
9238+
auto enumType = solution.getResolvedType(EP);
92369239

9237-
auto *result = matchCall->walk(*this);
9238-
if (!result)
9239-
return None;
9240+
auto *matchCall = target->getAsExpr();
92409241

9241-
{
9242-
auto *matchVar = EP->getMatchVar();
9243-
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9244-
}
9242+
auto *result = matchCall->walk(*this);
9243+
if (!result)
9244+
return None;
92459245

9246-
EP->setMatchExpr(result);
9247-
EP->setType(enumType);
9246+
{
9247+
auto *matchVar = EP->getMatchVar();
9248+
matchVar->setInterfaceType(enumType->mapTypeOutOfContext());
9249+
}
92489250

9249-
(*caseLabelItem)->setPattern(EP, /*resolved=*/true);
9250-
return target;
9251-
}
9252-
}
9251+
EP->setMatchExpr(result);
9252+
EP->setType(enumType);
92539253

9254-
// Coerce the pattern to its appropriate type.
9255-
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9256-
patternOptions |= TypeResolutionFlags::OverrideType;
9257-
auto contextualPattern =
9258-
ContextualPattern::forRawPattern(info.pattern,
9259-
target.getDeclContext());
9260-
if (auto coercedPattern = TypeChecker::coercePatternToType(
9261-
contextualPattern, patternType, patternOptions)) {
9262-
(*caseLabelItem)->setPattern(coercedPattern, /*resolved=*/true);
9263-
} else {
9264-
return None;
9265-
}
9254+
item.setPattern(EP, /*resolved=*/true);
9255+
return target;
9256+
}
9257+
}
92669258

9267-
// If there is a guard expression, coerce that.
9268-
if (auto *guardExpr = info.guardExpr) {
9269-
auto target = *cs.getTargetFor(guardExpr);
9270-
auto resultTarget = rewriteTarget(target);
9271-
if (!resultTarget)
9259+
// Coerce the pattern to its appropriate type.
9260+
TypeResolutionOptions patternOptions(TypeResolverContext::InExpression);
9261+
patternOptions |= TypeResolutionFlags::OverrideType;
9262+
auto contextualPattern = ContextualPattern::forRawPattern(
9263+
info.pattern, target.getDeclContext());
9264+
if (auto coercedPattern = TypeChecker::coercePatternToType(
9265+
contextualPattern, contextualTy, patternOptions)) {
9266+
item.setPattern(coercedPattern, /*resolved=*/true);
9267+
} else {
92729268
return None;
9269+
}
9270+
9271+
// If there is a guard expression, coerce that.
9272+
if (auto *guardExpr = info.guardExpr) {
9273+
auto target = *cs.getTargetFor(guardExpr);
9274+
auto resultTarget = rewriteTarget(target);
9275+
if (!resultTarget)
9276+
return None;
92739277

9274-
(*caseLabelItem)->setGuardExpr(resultTarget->getAsExpr());
9278+
item.setGuardExpr(resultTarget->getAsExpr());
9279+
}
92759280
}
92769281

92779282
return target;

lib/Sema/CSGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4607,7 +4607,7 @@ bool ConstraintSystem::generateConstraints(
46074607
llvm_unreachable("Handled above");
46084608

46094609
case SyntacticElementTarget::Kind::closure:
4610-
case SyntacticElementTarget::Kind::caseLabelItem:
4610+
case SyntacticElementTarget::Kind::caseStmt:
46114611
case SyntacticElementTarget::Kind::function:
46124612
case SyntacticElementTarget::Kind::stmtCondition:
46134613
llvm_unreachable("Handled separately");

0 commit comments

Comments
 (0)