Skip to content

Commit 594b635

Browse files
committed
[Multi-expression closures] Add support for do-catch statements.
1 parent 7a655e3 commit 594b635

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

lib/Sema/CSClosure.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,32 @@ class ClosureConstraintGenerator
233233
assert(subjectExpr && "Must have a subject expression here");
234234

235235
// Visit the raw cases.
236+
auto subjectLocator = cs.getConstraintLocator(
237+
subjectExpr, LocatorPathElt::ContextualType());
238+
Type subjectType = cs.getType(subjectExpr);
236239
for (auto rawCase : switchStmt->getRawCases()) {
237-
if (auto decl = rawCase.dyn_cast<Decl *>())
240+
if (auto decl = rawCase.dyn_cast<Decl *>()) {
238241
visitDecl(decl);
239-
else
240-
visitCaseStmt(cast<CaseStmt>(rawCase.get<Stmt *>()), subjectExpr);
242+
} else {
243+
visitCaseStmt(
244+
cast<CaseStmt>(rawCase.get<Stmt *>()), subjectType, subjectLocator);
245+
}
241246
}
242247
}
243248

244-
void visitCaseStmt(CaseStmt *caseStmt, Expr *subjectExpr) {
245-
auto locator = cs.getConstraintLocator(
246-
subjectExpr, LocatorPathElt::ContextualType());
247-
Type subjectType = cs.getType(subjectExpr);
249+
void visitDoCatchStmt(DoCatchStmt *doCatchStmt) {
250+
visit(doCatchStmt->getBody());
251+
252+
// Visit the "catch" blocks.
253+
Type exceptionType = cs.getASTContext().getExceptionType();
254+
for (auto catchStmt : doCatchStmt->getCatches()) {
255+
auto locator = cs.getConstraintLocator(catchStmt);
256+
visitCaseStmt(catchStmt, exceptionType, locator);
257+
}
258+
}
248259

260+
void visitCaseStmt(
261+
CaseStmt *caseStmt, Type subjectType, ConstraintLocator *locator) {
249262
if (cs.generateConstraints(caseStmt, closure, subjectType, locator)) {
250263
hadError = true;
251264
return;
@@ -261,7 +274,6 @@ class ClosureConstraintGenerator
261274
llvm_unreachable("Unsupported statement kind " #STMT); \
262275
}
263276
UNSUPPORTED_STMT(Yield)
264-
UNSUPPORTED_STMT(DoCatch)
265277
UNSUPPORTED_STMT(Case)
266278
UNSUPPORTED_STMT(Fail)
267279
#undef UNSUPPORTED_STMT
@@ -585,6 +597,19 @@ class ClosureConstraintApplication
585597
return switchStmt;
586598
}
587599

600+
ASTNode visitDoCatchStmt(DoCatchStmt *doCatchStmt) {
601+
// Translate the body.
602+
auto newBody = visit(doCatchStmt->getBody());
603+
doCatchStmt->setBody(newBody.get<Stmt *>());
604+
605+
// Visit the catch blocks.
606+
for (auto catchStmt : doCatchStmt->getCatches()) {
607+
visitCaseStmt(catchStmt);
608+
}
609+
610+
return doCatchStmt;
611+
}
612+
588613
ASTNode visitCaseStmt(CaseStmt *caseStmt) {
589614
// Translate the patterns and guard expressions for each case label item.
590615
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
@@ -610,7 +635,6 @@ class ClosureConstraintApplication
610635
llvm_unreachable("Unsupported statement kind " #STMT); \
611636
}
612637
UNSUPPORTED_STMT(Yield)
613-
UNSUPPORTED_STMT(DoCatch)
614638
UNSUPPORTED_STMT(Fail)
615639
#undef UNSUPPORTED_STMT
616640

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,6 +1318,10 @@ namespace {
13181318
std::pair<bool, Stmt *> walkToStmtPre(Stmt *stmt) override {
13191319
return { true, stmt };
13201320
}
1321+
1322+
std::pair<bool, Pattern *> walkToPatternPre(Pattern *pattern) override {
1323+
return { false, pattern };
1324+
}
13211325
};
13221326
} // end anonymous namespace
13231327

lib/Sema/TypeCheckStmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1324,8 +1324,8 @@ class StmtChecker : public StmtVisitor<StmtChecker, Stmt*> {
13241324
// There is nothing more to do.
13251325
return S;
13261326
}
1327-
13281327
};
1328+
13291329
} // end anonymous namespace
13301330

13311331
static bool isDiscardableType(Type type) {

test/expr/closure/multi_statement.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ func maybeGetValue<T>(_ value: T) -> T? {
1010

1111
enum MyError: Error {
1212
case featureIsTooCool
13+
14+
func doIt() { }
1315
}
1416

1517
enum State {
@@ -20,6 +22,8 @@ enum State {
2022

2123
func random(_: Int) -> Bool { return false }
2224

25+
func mightThrow() throws -> Bool { throw MyError.featureIsTooCool }
26+
2327
func mapWithMoreStatements(ints: [Int], state: State) throws {
2428
let _ = try ints.map { i in
2529
guard var actualValue = maybeGetValue(i) else {
@@ -76,6 +80,13 @@ func mapWithMoreStatements(ints: [Int], state: State) throws {
7680

7781
#assert(true)
7882

83+
do {
84+
print(try mightThrow())
85+
} catch let e as MyError {
86+
e.doIt()
87+
} catch {
88+
print(error)
89+
}
7990
return String(value)
8091
}
8192
}

0 commit comments

Comments
 (0)