Skip to content

Commit 043bc46

Browse files
authored
Merge pull request #38706 from hamishknight/a-game-of-draughts
2 parents ad7dbc5 + f69ca41 commit 043bc46

File tree

2 files changed

+134
-52
lines changed

2 files changed

+134
-52
lines changed

lib/Sema/BuilderTransform.cpp

Lines changed: 68 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,8 @@ struct ResultBuilderTarget {
977977
/// Handles the rewrite of the body of a closure to which a result builder
978978
/// has been applied.
979979
class BuilderClosureRewriter
980-
: public StmtVisitor<BuilderClosureRewriter, Stmt *, ResultBuilderTarget> {
980+
: public StmtVisitor<BuilderClosureRewriter, NullablePtr<Stmt>,
981+
ResultBuilderTarget> {
981982
ASTContext &ctx;
982983
const Solution &solution;
983984
DeclContext *dc;
@@ -1138,8 +1139,9 @@ class BuilderClosureRewriter
11381139
solution(solution), dc(dc), builderTransform(builderTransform),
11391140
rewriteTarget(rewriteTarget) { }
11401141

1141-
Stmt *visitBraceStmt(BraceStmt *braceStmt, ResultBuilderTarget target,
1142-
Optional<ResultBuilderTarget> innerTarget = None) {
1142+
NullablePtr<Stmt>
1143+
visitBraceStmt(BraceStmt *braceStmt, ResultBuilderTarget target,
1144+
Optional<ResultBuilderTarget> innerTarget = None) {
11431145
std::vector<ASTNode> newElements;
11441146

11451147
// If there is an "inner" target corresponding to this brace, declare
@@ -1179,7 +1181,7 @@ class BuilderClosureRewriter
11791181
// "throw" statements produce no value. Transform them directly.
11801182
if (auto throwStmt = dyn_cast<ThrowStmt>(stmt)) {
11811183
if (auto newStmt = visitThrowStmt(throwStmt)) {
1182-
newElements.push_back(stmt);
1184+
newElements.push_back(newStmt.get());
11831185
}
11841186
continue;
11851187
}
@@ -1190,7 +1192,7 @@ class BuilderClosureRewriter
11901192

11911193
declareTemporaryVariable(captured.first, newElements);
11921194

1193-
Stmt *finalStmt = visit(
1195+
auto finalStmt = visit(
11941196
stmt,
11951197
ResultBuilderTarget{ResultBuilderTarget::TemporaryVar,
11961198
std::move(captured)});
@@ -1200,7 +1202,7 @@ class BuilderClosureRewriter
12001202
if (!finalStmt)
12011203
return nullptr;
12021204

1203-
newElements.push_back(finalStmt);
1205+
newElements.push_back(finalStmt.get());
12041206
continue;
12051207
}
12061208

@@ -1250,7 +1252,7 @@ class BuilderClosureRewriter
12501252
braceStmt->getRBraceLoc());
12511253
}
12521254

1253-
Stmt *visitIfStmt(IfStmt *ifStmt, ResultBuilderTarget target) {
1255+
NullablePtr<Stmt> visitIfStmt(IfStmt *ifStmt, ResultBuilderTarget target) {
12541256
// Rewrite the condition.
12551257
if (auto condition = rewriteTarget(
12561258
SolutionApplicationTarget(ifStmt->getCond(), dc)))
@@ -1266,7 +1268,10 @@ class BuilderClosureRewriter
12661268
temporaryVar, {target.captured.second[0]}),
12671269
ResultBuilderTarget::forAssign(
12681270
capturedThen.first, {capturedThen.second.front()}));
1269-
ifStmt->setThenStmt(newThen);
1271+
if (!newThen)
1272+
return nullptr;
1273+
1274+
ifStmt->setThenStmt(newThen.get());
12701275

12711276
// Look for a #available condition. If there is one, we need to check
12721277
// that the resulting type of the "then" doesn't refer to any types that
@@ -1338,23 +1343,28 @@ class BuilderClosureRewriter
13381343
dyn_cast_or_null<BraceStmt>(ifStmt->getElseStmt())) {
13391344
// Translate the "else" branch when it's a stmt-brace.
13401345
auto capturedElse = takeCapturedStmt(elseBraceStmt);
1341-
Stmt *newElse = visitBraceStmt(
1346+
auto newElse = visitBraceStmt(
13421347
elseBraceStmt,
13431348
ResultBuilderTarget::forAssign(
13441349
temporaryVar, {target.captured.second[1]}),
13451350
ResultBuilderTarget::forAssign(
13461351
capturedElse.first, {capturedElse.second.front()}));
1347-
ifStmt->setElseStmt(newElse);
1352+
if (!newElse)
1353+
return nullptr;
1354+
1355+
ifStmt->setElseStmt(newElse.get());
13481356
} else if (auto elseIfStmt = cast_or_null<IfStmt>(ifStmt->getElseStmt())){
13491357
// Translate the "else" branch when it's an else-if.
13501358
auto capturedElse = takeCapturedStmt(elseIfStmt);
13511359
std::vector<ASTNode> newElseElements;
13521360
declareTemporaryVariable(capturedElse.first, newElseElements);
1353-
newElseElements.push_back(
1354-
visitIfStmt(
1355-
elseIfStmt,
1356-
ResultBuilderTarget::forAssign(
1357-
capturedElse.first, capturedElse.second)));
1361+
auto newElseElt =
1362+
visitIfStmt(elseIfStmt, ResultBuilderTarget::forAssign(
1363+
capturedElse.first, capturedElse.second));
1364+
if (!newElseElt)
1365+
return nullptr;
1366+
1367+
newElseElements.push_back(newElseElt.get());
13581368
newElseElements.push_back(
13591369
initializeTarget(
13601370
ResultBuilderTarget::forAssign(
@@ -1378,23 +1388,25 @@ class BuilderClosureRewriter
13781388
return ifStmt;
13791389
}
13801390

1381-
Stmt *visitDoStmt(DoStmt *doStmt, ResultBuilderTarget target) {
1391+
NullablePtr<Stmt> visitDoStmt(DoStmt *doStmt, ResultBuilderTarget target) {
13821392
// Each statement turns into a (potential) temporary variable
13831393
// binding followed by the statement itself.
13841394
auto body = cast<BraceStmt>(doStmt->getBody());
13851395
auto captured = takeCapturedStmt(body);
13861396

1387-
auto newInnerBody = cast<BraceStmt>(
1388-
visitBraceStmt(
1389-
body,
1390-
target,
1391-
ResultBuilderTarget::forAssign(
1392-
captured.first, {captured.second.front()})));
1393-
doStmt->setBody(newInnerBody);
1397+
auto newInnerBody =
1398+
visitBraceStmt(body, target,
1399+
ResultBuilderTarget::forAssign(
1400+
captured.first, {captured.second.front()}));
1401+
if (!newInnerBody)
1402+
return nullptr;
1403+
1404+
doStmt->setBody(cast<BraceStmt>(newInnerBody.get()));
13941405
return doStmt;
13951406
}
13961407

1397-
Stmt *visitSwitchStmt(SwitchStmt *switchStmt, ResultBuilderTarget target) {
1408+
NullablePtr<Stmt> visitSwitchStmt(SwitchStmt *switchStmt,
1409+
ResultBuilderTarget target) {
13981410
// Translate the subject expression.
13991411
ConstraintSystem &cs = solution.getConstraintSystem();
14001412
auto subjectTarget =
@@ -1439,7 +1451,8 @@ class BuilderClosureRewriter
14391451
return switchStmt;
14401452
}
14411453

1442-
Stmt *visitCaseStmt(CaseStmt *caseStmt, ResultBuilderTarget target) {
1454+
NullablePtr<Stmt> visitCaseStmt(CaseStmt *caseStmt,
1455+
ResultBuilderTarget target) {
14431456
// Translate the patterns and guard expressions for each case label item.
14441457
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
14451458
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, dc);
@@ -1450,19 +1463,19 @@ class BuilderClosureRewriter
14501463
// Transform the body of the case.
14511464
auto body = cast<BraceStmt>(caseStmt->getBody());
14521465
auto captured = takeCapturedStmt(body);
1453-
auto newInnerBody = cast<BraceStmt>(
1454-
visitBraceStmt(
1455-
body,
1456-
target,
1457-
ResultBuilderTarget::forAssign(
1458-
captured.first, {captured.second.front()})));
1459-
caseStmt->setBody(newInnerBody);
1466+
auto newInnerBody =
1467+
visitBraceStmt(body, target,
1468+
ResultBuilderTarget::forAssign(
1469+
captured.first, {captured.second.front()}));
1470+
if (!newInnerBody)
1471+
return nullptr;
14601472

1473+
caseStmt->setBody(cast<BraceStmt>(newInnerBody.get()));
14611474
return caseStmt;
14621475
}
14631476

1464-
Stmt *visitForEachStmt(
1465-
ForEachStmt *forEachStmt, ResultBuilderTarget target) {
1477+
NullablePtr<Stmt> visitForEachStmt(ForEachStmt *forEachStmt,
1478+
ResultBuilderTarget target) {
14661479
// Translate the for-each loop header.
14671480
ConstraintSystem &cs = solution.getConstraintSystem();
14681481
auto forEachTarget =
@@ -1492,13 +1505,14 @@ class BuilderClosureRewriter
14921505
// will append the result of executing the loop body to the array.
14931506
auto body = forEachStmt->getBody();
14941507
auto capturedBody = takeCapturedStmt(body);
1495-
auto newBody = cast<BraceStmt>(
1496-
visitBraceStmt(
1497-
body,
1498-
ResultBuilderTarget::forExpression(arrayAppendCall),
1499-
ResultBuilderTarget::forAssign(
1500-
capturedBody.first, {capturedBody.second.front()})));
1501-
forEachStmt->setBody(newBody);
1508+
auto newBody = visitBraceStmt(
1509+
body, ResultBuilderTarget::forExpression(arrayAppendCall),
1510+
ResultBuilderTarget::forAssign(capturedBody.first,
1511+
{capturedBody.second.front()}));
1512+
if (!newBody)
1513+
return nullptr;
1514+
1515+
forEachStmt->setBody(cast<BraceStmt>(newBody.get()));
15021516
outerBodySteps.push_back(forEachStmt);
15031517

15041518
// Step 3. Perform the buildArray() call to turn the array of results
@@ -1510,11 +1524,11 @@ class BuilderClosureRewriter
15101524

15111525
// Form a brace statement to put together the three main steps for the
15121526
// for-each loop translation outlined above.
1513-
return BraceStmt::create(
1514-
ctx, forEachStmt->getStartLoc(), outerBodySteps, newBody->getEndLoc());
1527+
return BraceStmt::create(ctx, forEachStmt->getStartLoc(), outerBodySteps,
1528+
newBody.get()->getEndLoc());
15151529
}
15161530

1517-
Stmt *visitThrowStmt(ThrowStmt *throwStmt) {
1531+
NullablePtr<Stmt> visitThrowStmt(ThrowStmt *throwStmt) {
15181532
// Rewrite the error.
15191533
auto target = *solution.getConstraintSystem()
15201534
.getSolutionApplicationTarget(throwStmt);
@@ -1526,12 +1540,14 @@ class BuilderClosureRewriter
15261540
return throwStmt;
15271541
}
15281542

1529-
Stmt *visitThrowStmt(ThrowStmt *throwStmt, ResultBuilderTarget target) {
1543+
NullablePtr<Stmt> visitThrowStmt(ThrowStmt *throwStmt,
1544+
ResultBuilderTarget target) {
15301545
llvm_unreachable("Throw statements produce no value");
15311546
}
15321547

15331548
#define UNHANDLED_RESULT_BUILDER_STMT(STMT) \
1534-
Stmt *visit##STMT##Stmt(STMT##Stmt *stmt, ResultBuilderTarget target) { \
1549+
NullablePtr<Stmt> \
1550+
visit##STMT##Stmt(STMT##Stmt *stmt, ResultBuilderTarget target) { \
15351551
llvm_unreachable("Function builders do not allow statement of kind " \
15361552
#STMT); \
15371553
}
@@ -1563,12 +1579,12 @@ BraceStmt *swift::applyResultBuilderTransform(
15631579
rewriteTarget) {
15641580
BuilderClosureRewriter rewriter(solution, dc, applied, rewriteTarget);
15651581
auto captured = rewriter.takeCapturedStmt(body);
1566-
return cast_or_null<BraceStmt>(
1567-
rewriter.visitBraceStmt(
1568-
body,
1569-
ResultBuilderTarget::forReturn(applied.returnExpr),
1570-
ResultBuilderTarget::forAssign(
1571-
captured.first, captured.second)));
1582+
auto result = rewriter.visitBraceStmt(
1583+
body, ResultBuilderTarget::forReturn(applied.returnExpr),
1584+
ResultBuilderTarget::forAssign(captured.first, captured.second));
1585+
if (!result)
1586+
return nullptr;
1587+
return cast<BraceStmt>(result.get());
15721588
}
15731589

15741590
Optional<BraceStmt *> TypeChecker::applyResultBuilderBodyTransform(
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: %target-typecheck-verify-swift
2+
// rdar://81228221
3+
4+
@resultBuilder
5+
struct Builder {
6+
static func buildBlock(_ components: Int...) -> Int { 0 }
7+
static func buildEither(first component: Int) -> Int { 0 }
8+
static func buildEither(second component: Int) -> Int { 0 }
9+
static func buildOptional(_ component: Int?) -> Int { 0 }
10+
static func buildArray(_ components: [Int]) -> Int { 0 }
11+
}
12+
13+
@Builder
14+
func foo(_ x: String) -> Int {
15+
if .random() {
16+
switch x {
17+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
18+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
19+
0
20+
default:
21+
1
22+
}
23+
}
24+
}
25+
26+
@Builder
27+
func bar(_ x: String) -> Int {
28+
switch 0 {
29+
case 0:
30+
switch x {
31+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
32+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
33+
0
34+
default:
35+
1
36+
}
37+
default:
38+
0
39+
}
40+
}
41+
42+
@Builder
43+
func baz(_ x: String) -> Int {
44+
do {
45+
switch x {
46+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
47+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
48+
0
49+
default:
50+
1
51+
}
52+
}
53+
}
54+
55+
@Builder
56+
func qux(_ x: String) -> Int {
57+
for _ in 0 ... 0 {
58+
switch x {
59+
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
60+
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
61+
0
62+
default:
63+
1
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)