Skip to content

Commit 7a655e3

Browse files
committed
[Multi-expression closures] Add support for 'switch' and 'fallthrough' statements.
1 parent 5c16778 commit 7a655e3

File tree

4 files changed

+144
-10
lines changed

4 files changed

+144
-10
lines changed

lib/Sema/CSClosure.cpp

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,60 @@ class ClosureConstraintGenerator
209209
cs.setSolutionApplicationTarget(poundAssertStmt, target);
210210
}
211211

212+
void visitSwitchStmt(SwitchStmt *switchStmt) {
213+
// FIXME: Very similar to BuilderTransform's visitSwitchStmt. Unify them.
214+
215+
// Generate constraints for the subject expression, and capture its
216+
// type for use in matching the various patterns.
217+
Expr *subjectExpr = switchStmt->getSubjectExpr();
218+
ASTContext &ctx = cs.getASTContext();
219+
220+
// Form a one-way constraint to prevent backward propagation.
221+
subjectExpr = new (ctx) OneWayExpr(subjectExpr);
222+
223+
// FIXME: Add contextual type purpose for switch subjects?
224+
SolutionApplicationTarget target(
225+
subjectExpr, closure, CTP_Unused, Type(), /*isDiscarded=*/false);
226+
if (cs.generateConstraints(target, FreeTypeVariableBinding::Disallow)) {
227+
hadError = true;
228+
return;
229+
}
230+
231+
cs.setSolutionApplicationTarget(switchStmt, target);
232+
subjectExpr = target.getAsExpr();
233+
assert(subjectExpr && "Must have a subject expression here");
234+
235+
// Visit the raw cases.
236+
for (auto rawCase : switchStmt->getRawCases()) {
237+
if (auto decl = rawCase.dyn_cast<Decl *>())
238+
visitDecl(decl);
239+
else
240+
visitCaseStmt(cast<CaseStmt>(rawCase.get<Stmt *>()), subjectExpr);
241+
}
242+
}
243+
244+
void visitCaseStmt(CaseStmt *caseStmt, Expr *subjectExpr) {
245+
auto locator = cs.getConstraintLocator(
246+
subjectExpr, LocatorPathElt::ContextualType());
247+
Type subjectType = cs.getType(subjectExpr);
248+
249+
if (cs.generateConstraints(caseStmt, closure, subjectType, locator)) {
250+
hadError = true;
251+
return;
252+
}
253+
254+
// Visit the body.
255+
visit(caseStmt->getBody());
256+
}
257+
258+
void visitFallthroughStmt(FallthroughStmt *fallthroughStmt) { }
259+
212260
#define UNSUPPORTED_STMT(STMT) void visit##STMT##Stmt(STMT##Stmt *) { \
213261
llvm_unreachable("Unsupported statement kind " #STMT); \
214262
}
215263
UNSUPPORTED_STMT(Yield)
216264
UNSUPPORTED_STMT(DoCatch)
217-
UNSUPPORTED_STMT(Switch)
218265
UNSUPPORTED_STMT(Case)
219-
UNSUPPORTED_STMT(Fallthrough)
220266
UNSUPPORTED_STMT(Fail)
221267
#undef UNSUPPORTED_STMT
222268
};
@@ -503,14 +549,68 @@ class ClosureConstraintApplication
503549
return poundAssertStmt;
504550
}
505551

552+
ASTNode visitSwitchStmt(SwitchStmt *switchStmt) {
553+
ConstraintSystem &cs = solution.getConstraintSystem();
554+
555+
// Rewrite the switch subject.
556+
auto subjectTarget =
557+
rewriteTarget(*cs.getSolutionApplicationTarget(switchStmt));
558+
if (subjectTarget) {
559+
switchStmt->setSubjectExpr(subjectTarget->getAsExpr());
560+
} else {
561+
hadError = true;
562+
}
563+
564+
// Visit the raw cases.
565+
bool limitExhaustivityChecks = false;
566+
for (auto rawCase : switchStmt->getRawCases()) {
567+
if (auto decl = rawCase.dyn_cast<Decl *>()) {
568+
visitDecl(decl);
569+
continue;
570+
}
571+
572+
auto caseStmt = cast<CaseStmt>(rawCase.get<Stmt *>());
573+
visitCaseStmt(caseStmt);
574+
575+
// Check restrictions on '@unknown'.
576+
if (caseStmt->hasUnknownAttr()) {
577+
checkUnknownAttrRestrictions(
578+
cs.getASTContext(), caseStmt, limitExhaustivityChecks);
579+
}
580+
}
581+
582+
TypeChecker::checkSwitchExhaustiveness(
583+
switchStmt, closure, limitExhaustivityChecks);
584+
585+
return switchStmt;
586+
}
587+
588+
ASTNode visitCaseStmt(CaseStmt *caseStmt) {
589+
// Translate the patterns and guard expressions for each case label item.
590+
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
591+
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, closure);
592+
if (!rewriteTarget(caseLabelTarget)) {
593+
hadError = true;
594+
}
595+
}
596+
597+
// Translate the body.
598+
auto newBody = visit(caseStmt->getBody());
599+
caseStmt->setBody(newBody.get<Stmt *>());
600+
return caseStmt;
601+
}
602+
603+
ASTNode visitFallthroughStmt(FallthroughStmt *fallthroughStmt) {
604+
if (checkFallthroughStmt(closure, fallthroughStmt))
605+
hadError = true;
606+
return fallthroughStmt;
607+
}
608+
506609
#define UNSUPPORTED_STMT(STMT) ASTNode visit##STMT##Stmt(STMT##Stmt *) { \
507610
llvm_unreachable("Unsupported statement kind " #STMT); \
508611
}
509612
UNSUPPORTED_STMT(Yield)
510613
UNSUPPORTED_STMT(DoCatch)
511-
UNSUPPORTED_STMT(Switch)
512-
UNSUPPORTED_STMT(Case)
513-
UNSUPPORTED_STMT(Fallthrough)
514614
UNSUPPORTED_STMT(Fail)
515615
#undef UNSUPPORTED_STMT
516616

lib/Sema/TypeCheckStmt.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -616,17 +616,21 @@ static void checkFallthroughPatternBindingsAndTypes(
616616
/// \returns true if an error occurred.
617617
static bool checkFallthroughStmt(
618618
DeclContext *dc, FallthroughStmt *stmt,
619-
CaseStmt *oldFallthroughSource, CaseStmt *oldFallthroughDest) {
619+
CaseStmt *oldFallthroughSource, CaseStmt *oldFallthroughDest,
620+
bool allowOldInfo) {
620621
CaseStmt *fallthroughSource;
621622
CaseStmt *fallthroughDest;
622623
ASTContext &ctx = dc->getASTContext();
623624
if (ctx.LangOpts.EnableASTScopeLookup) {
624625
auto sourceFile = dc->getParentSourceFile();
625626
std::tie(fallthroughSource, fallthroughDest) =
626627
ASTScope::lookupFallthroughSourceAndDest(sourceFile, stmt->getLoc());
627-
assert(fallthroughSource == oldFallthroughSource);
628-
assert(fallthroughDest == oldFallthroughDest);
628+
assert(!allowOldInfo || fallthroughSource == oldFallthroughSource);
629+
assert(!allowOldInfo || fallthroughDest == oldFallthroughDest);
629630
} else {
631+
if (!allowOldInfo)
632+
return false;
633+
630634
fallthroughSource = oldFallthroughSource;
631635
fallthroughDest = oldFallthroughDest;
632636
}
@@ -647,6 +651,11 @@ static bool checkFallthroughStmt(
647651
return false;
648652
}
649653

654+
bool swift::checkFallthroughStmt(DeclContext *dc, FallthroughStmt *stmt) {
655+
return ::checkFallthroughStmt(
656+
dc, stmt, nullptr, nullptr, /*allowOldInfo=*/false);
657+
}
658+
650659
namespace {
651660
class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
652661
public:
@@ -1079,7 +1088,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
10791088
}
10801089

10811090
Stmt *visitFallthroughStmt(FallthroughStmt *S) {
1082-
if (checkFallthroughStmt(DC, S, FallthroughSource, FallthroughDest))
1091+
if (::checkFallthroughStmt(DC, S, FallthroughSource, FallthroughDest,
1092+
/*allowOldInfo=*/true))
10831093
return nullptr;
10841094

10851095
return S;

lib/Sema/TypeChecker.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1375,6 +1375,11 @@ LabeledStmt *findBreakOrContinueStmtTarget(
13751375
bool isContinue, DeclContext *dc,
13761376
ArrayRef<LabeledStmt *> oldActiveLabeledStmts);
13771377

1378+
/// Check the correctness of a 'fallthrough' statement.
1379+
///
1380+
/// \returns true if an error occurred.
1381+
bool checkFallthroughStmt(DeclContext *dc, FallthroughStmt *stmt);
1382+
13781383
/// Check for restrictions on the use of the @unknown attribute on a
13791384
/// case statement.
13801385
void checkUnknownAttrRestrictions(

test/expr/closure/multi_statement.swift

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@ enum MyError: Error {
1212
case featureIsTooCool
1313
}
1414

15+
enum State {
16+
case suspended
17+
case partial(Int, Int)
18+
case finished
19+
}
20+
1521
func random(_: Int) -> Bool { return false }
1622

17-
func mapWithMoreStatements(ints: [Int]) throws {
23+
func mapWithMoreStatements(ints: [Int], state: State) throws {
1824
let _ = try ints.map { i in
1925
guard var actualValue = maybeGetValue(i) else {
2026
return String(0)
@@ -51,6 +57,19 @@ func mapWithMoreStatements(ints: [Int]) throws {
5157
continue
5258
}
5359

60+
switch (state, j) {
61+
case (.suspended, 0):
62+
print("something")
63+
fallthrough
64+
case (.finished, 0):
65+
print("something else")
66+
67+
case (.partial(let current, let end), let j):
68+
print("\(current) of \(end): \(j)")
69+
70+
default:
71+
print("so, here we are")
72+
}
5473
print("even")
5574
throw MyError.featureIsTooCool
5675
}

0 commit comments

Comments
 (0)