Skip to content

Commit 48efcec

Browse files
authored
Merge pull request #38843 from hamishknight/a-game-of-draughts-5.5
[5.5] [CS] Better handle null in BuilderClosureRewriter
2 parents e097686 + ac27721 commit 48efcec

File tree

2 files changed

+134
-52
lines changed

2 files changed

+134
-52
lines changed

lib/Sema/BuilderTransform.cpp

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,8 @@ struct ResultBuilderTarget {
963963
/// Handles the rewrite of the body of a closure to which a result builder
964964
/// has been applied.
965965
class BuilderClosureRewriter
966-
: public StmtVisitor<BuilderClosureRewriter, Stmt *, ResultBuilderTarget> {
966+
: public StmtVisitor<BuilderClosureRewriter, NullablePtr<Stmt>,
967+
ResultBuilderTarget> {
967968
ASTContext &ctx;
968969
const Solution &solution;
969970
DeclContext *dc;
@@ -1124,8 +1125,9 @@ class BuilderClosureRewriter
11241125
solution(solution), dc(dc), builderTransform(builderTransform),
11251126
rewriteTarget(rewriteTarget) { }
11261127

1127-
Stmt *visitBraceStmt(BraceStmt *braceStmt, ResultBuilderTarget target,
1128-
Optional<ResultBuilderTarget> innerTarget = None) {
1128+
NullablePtr<Stmt>
1129+
visitBraceStmt(BraceStmt *braceStmt, ResultBuilderTarget target,
1130+
Optional<ResultBuilderTarget> innerTarget = None) {
11291131
std::vector<ASTNode> newElements;
11301132

11311133
// If there is an "inner" target corresponding to this brace, declare
@@ -1165,7 +1167,7 @@ class BuilderClosureRewriter
11651167
// "throw" statements produce no value. Transform them directly.
11661168
if (auto throwStmt = dyn_cast<ThrowStmt>(stmt)) {
11671169
if (auto newStmt = visitThrowStmt(throwStmt)) {
1168-
newElements.push_back(stmt);
1170+
newElements.push_back(newStmt.get());
11691171
}
11701172
continue;
11711173
}
@@ -1176,7 +1178,7 @@ class BuilderClosureRewriter
11761178

11771179
declareTemporaryVariable(captured.first, newElements);
11781180

1179-
Stmt *finalStmt = visit(
1181+
auto finalStmt = visit(
11801182
stmt,
11811183
ResultBuilderTarget{ResultBuilderTarget::TemporaryVar,
11821184
std::move(captured)});
@@ -1186,7 +1188,7 @@ class BuilderClosureRewriter
11861188
if (!finalStmt)
11871189
return nullptr;
11881190

1189-
newElements.push_back(finalStmt);
1191+
newElements.push_back(finalStmt.get());
11901192
continue;
11911193
}
11921194

@@ -1236,7 +1238,7 @@ class BuilderClosureRewriter
12361238
braceStmt->getRBraceLoc());
12371239
}
12381240

1239-
Stmt *visitIfStmt(IfStmt *ifStmt, ResultBuilderTarget target) {
1241+
NullablePtr<Stmt> visitIfStmt(IfStmt *ifStmt, ResultBuilderTarget target) {
12401242
// Rewrite the condition.
12411243
if (auto condition = rewriteTarget(
12421244
SolutionApplicationTarget(ifStmt->getCond(), dc)))
@@ -1252,7 +1254,10 @@ class BuilderClosureRewriter
12521254
temporaryVar, {target.captured.second[0]}),
12531255
ResultBuilderTarget::forAssign(
12541256
capturedThen.first, {capturedThen.second.front()}));
1255-
ifStmt->setThenStmt(newThen);
1257+
if (!newThen)
1258+
return nullptr;
1259+
1260+
ifStmt->setThenStmt(newThen.get());
12561261

12571262
// Look for a #available condition. If there is one, we need to check
12581263
// that the resulting type of the "then" doesn't refer to any types that
@@ -1315,23 +1320,28 @@ class BuilderClosureRewriter
13151320
dyn_cast_or_null<BraceStmt>(ifStmt->getElseStmt())) {
13161321
// Translate the "else" branch when it's a stmt-brace.
13171322
auto capturedElse = takeCapturedStmt(elseBraceStmt);
1318-
Stmt *newElse = visitBraceStmt(
1323+
auto newElse = visitBraceStmt(
13191324
elseBraceStmt,
13201325
ResultBuilderTarget::forAssign(
13211326
temporaryVar, {target.captured.second[1]}),
13221327
ResultBuilderTarget::forAssign(
13231328
capturedElse.first, {capturedElse.second.front()}));
1324-
ifStmt->setElseStmt(newElse);
1329+
if (!newElse)
1330+
return nullptr;
1331+
1332+
ifStmt->setElseStmt(newElse.get());
13251333
} else if (auto elseIfStmt = cast_or_null<IfStmt>(ifStmt->getElseStmt())){
13261334
// Translate the "else" branch when it's an else-if.
13271335
auto capturedElse = takeCapturedStmt(elseIfStmt);
13281336
std::vector<ASTNode> newElseElements;
13291337
declareTemporaryVariable(capturedElse.first, newElseElements);
1330-
newElseElements.push_back(
1331-
visitIfStmt(
1332-
elseIfStmt,
1333-
ResultBuilderTarget::forAssign(
1334-
capturedElse.first, capturedElse.second)));
1338+
auto newElseElt =
1339+
visitIfStmt(elseIfStmt, ResultBuilderTarget::forAssign(
1340+
capturedElse.first, capturedElse.second));
1341+
if (!newElseElt)
1342+
return nullptr;
1343+
1344+
newElseElements.push_back(newElseElt.get());
13351345
newElseElements.push_back(
13361346
initializeTarget(
13371347
ResultBuilderTarget::forAssign(
@@ -1355,23 +1365,25 @@ class BuilderClosureRewriter
13551365
return ifStmt;
13561366
}
13571367

1358-
Stmt *visitDoStmt(DoStmt *doStmt, ResultBuilderTarget target) {
1368+
NullablePtr<Stmt> visitDoStmt(DoStmt *doStmt, ResultBuilderTarget target) {
13591369
// Each statement turns into a (potential) temporary variable
13601370
// binding followed by the statement itself.
13611371
auto body = cast<BraceStmt>(doStmt->getBody());
13621372
auto captured = takeCapturedStmt(body);
13631373

1364-
auto newInnerBody = cast<BraceStmt>(
1365-
visitBraceStmt(
1366-
body,
1367-
target,
1368-
ResultBuilderTarget::forAssign(
1369-
captured.first, {captured.second.front()})));
1370-
doStmt->setBody(newInnerBody);
1374+
auto newInnerBody =
1375+
visitBraceStmt(body, target,
1376+
ResultBuilderTarget::forAssign(
1377+
captured.first, {captured.second.front()}));
1378+
if (!newInnerBody)
1379+
return nullptr;
1380+
1381+
doStmt->setBody(cast<BraceStmt>(newInnerBody.get()));
13711382
return doStmt;
13721383
}
13731384

1374-
Stmt *visitSwitchStmt(SwitchStmt *switchStmt, ResultBuilderTarget target) {
1385+
NullablePtr<Stmt> visitSwitchStmt(SwitchStmt *switchStmt,
1386+
ResultBuilderTarget target) {
13751387
// Translate the subject expression.
13761388
ConstraintSystem &cs = solution.getConstraintSystem();
13771389
auto subjectTarget =
@@ -1416,7 +1428,8 @@ class BuilderClosureRewriter
14161428
return switchStmt;
14171429
}
14181430

1419-
Stmt *visitCaseStmt(CaseStmt *caseStmt, ResultBuilderTarget target) {
1431+
NullablePtr<Stmt> visitCaseStmt(CaseStmt *caseStmt,
1432+
ResultBuilderTarget target) {
14201433
// Translate the patterns and guard expressions for each case label item.
14211434
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
14221435
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, dc);
@@ -1427,19 +1440,19 @@ class BuilderClosureRewriter
14271440
// Transform the body of the case.
14281441
auto body = cast<BraceStmt>(caseStmt->getBody());
14291442
auto captured = takeCapturedStmt(body);
1430-
auto newInnerBody = cast<BraceStmt>(
1431-
visitBraceStmt(
1432-
body,
1433-
target,
1434-
ResultBuilderTarget::forAssign(
1435-
captured.first, {captured.second.front()})));
1436-
caseStmt->setBody(newInnerBody);
1443+
auto newInnerBody =
1444+
visitBraceStmt(body, target,
1445+
ResultBuilderTarget::forAssign(
1446+
captured.first, {captured.second.front()}));
1447+
if (!newInnerBody)
1448+
return nullptr;
14371449

1450+
caseStmt->setBody(cast<BraceStmt>(newInnerBody.get()));
14381451
return caseStmt;
14391452
}
14401453

1441-
Stmt *visitForEachStmt(
1442-
ForEachStmt *forEachStmt, ResultBuilderTarget target) {
1454+
NullablePtr<Stmt> visitForEachStmt(ForEachStmt *forEachStmt,
1455+
ResultBuilderTarget target) {
14431456
// Translate the for-each loop header.
14441457
ConstraintSystem &cs = solution.getConstraintSystem();
14451458
auto forEachTarget =
@@ -1469,13 +1482,14 @@ class BuilderClosureRewriter
14691482
// will append the result of executing the loop body to the array.
14701483
auto body = forEachStmt->getBody();
14711484
auto capturedBody = takeCapturedStmt(body);
1472-
auto newBody = cast<BraceStmt>(
1473-
visitBraceStmt(
1474-
body,
1475-
ResultBuilderTarget::forExpression(arrayAppendCall),
1476-
ResultBuilderTarget::forAssign(
1477-
capturedBody.first, {capturedBody.second.front()})));
1478-
forEachStmt->setBody(newBody);
1485+
auto newBody = visitBraceStmt(
1486+
body, ResultBuilderTarget::forExpression(arrayAppendCall),
1487+
ResultBuilderTarget::forAssign(capturedBody.first,
1488+
{capturedBody.second.front()}));
1489+
if (!newBody)
1490+
return nullptr;
1491+
1492+
forEachStmt->setBody(cast<BraceStmt>(newBody.get()));
14791493
outerBodySteps.push_back(forEachStmt);
14801494

14811495
// Step 3. Perform the buildArray() call to turn the array of results
@@ -1487,11 +1501,11 @@ class BuilderClosureRewriter
14871501

14881502
// Form a brace statement to put together the three main steps for the
14891503
// for-each loop translation outlined above.
1490-
return BraceStmt::create(
1491-
ctx, forEachStmt->getStartLoc(), outerBodySteps, newBody->getEndLoc());
1504+
return BraceStmt::create(ctx, forEachStmt->getStartLoc(), outerBodySteps,
1505+
newBody.get()->getEndLoc());
14921506
}
14931507

1494-
Stmt *visitThrowStmt(ThrowStmt *throwStmt) {
1508+
NullablePtr<Stmt> visitThrowStmt(ThrowStmt *throwStmt) {
14951509
// Rewrite the error.
14961510
auto target = *solution.getConstraintSystem()
14971511
.getSolutionApplicationTarget(throwStmt);
@@ -1503,12 +1517,14 @@ class BuilderClosureRewriter
15031517
return throwStmt;
15041518
}
15051519

1506-
Stmt *visitThrowStmt(ThrowStmt *throwStmt, ResultBuilderTarget target) {
1520+
NullablePtr<Stmt> visitThrowStmt(ThrowStmt *throwStmt,
1521+
ResultBuilderTarget target) {
15071522
llvm_unreachable("Throw statements produce no value");
15081523
}
15091524

15101525
#define UNHANDLED_RESULT_BUILDER_STMT(STMT) \
1511-
Stmt *visit##STMT##Stmt(STMT##Stmt *stmt, ResultBuilderTarget target) { \
1526+
NullablePtr<Stmt> \
1527+
visit##STMT##Stmt(STMT##Stmt *stmt, ResultBuilderTarget target) { \
15121528
llvm_unreachable("Function builders do not allow statement of kind " \
15131529
#STMT); \
15141530
}
@@ -1540,12 +1556,12 @@ BraceStmt *swift::applyResultBuilderTransform(
15401556
rewriteTarget) {
15411557
BuilderClosureRewriter rewriter(solution, dc, applied, rewriteTarget);
15421558
auto captured = rewriter.takeCapturedStmt(body);
1543-
return cast_or_null<BraceStmt>(
1544-
rewriter.visitBraceStmt(
1545-
body,
1546-
ResultBuilderTarget::forReturn(applied.returnExpr),
1547-
ResultBuilderTarget::forAssign(
1548-
captured.first, captured.second)));
1559+
auto result = rewriter.visitBraceStmt(
1560+
body, ResultBuilderTarget::forReturn(applied.returnExpr),
1561+
ResultBuilderTarget::forAssign(captured.first, captured.second));
1562+
if (!result)
1563+
return nullptr;
1564+
return cast<BraceStmt>(result.get());
15491565
}
15501566

15511567
Optional<BraceStmt *> TypeChecker::applyResultBuilderBodyTransform(
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: %target-typecheck-verify-swift
2+
// rdar://81228221
3+
4+
@resultBuilder
5+
struct Builder {
6+
static func buildBlock(_ components: Int...) -> Int { 0 }
7+
static func buildEither(first component: Int) -> Int { 0 }
8+
static func buildEither(second component: Int) -> Int { 0 }
9+
static func buildOptional(_ component: Int?) -> Int { 0 }
10+
static func buildArray(_ components: [Int]) -> Int { 0 }
11+
}
12+
13+
@Builder
14+
func foo(_ x: String) -> Int {
15+
if .random() {
16+
switch x {
17+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
18+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
19+
0
20+
default:
21+
1
22+
}
23+
}
24+
}
25+
26+
@Builder
27+
func bar(_ x: String) -> Int {
28+
switch 0 {
29+
case 0:
30+
switch x {
31+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
32+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
33+
0
34+
default:
35+
1
36+
}
37+
default:
38+
0
39+
}
40+
}
41+
42+
@Builder
43+
func baz(_ x: String) -> Int {
44+
do {
45+
switch x {
46+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
47+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
48+
0
49+
default:
50+
1
51+
}
52+
}
53+
}
54+
55+
@Builder
56+
func qux(_ x: String) -> Int {
57+
for _ in 0 ... 0 {
58+
switch x {
59+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
60+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
61+
0
62+
default:
63+
1
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)