Skip to content

[5.5] [CS] Better handle null in BuilderClosureRewriter #38843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 68 additions & 52 deletions lib/Sema/BuilderTransform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,8 @@ struct ResultBuilderTarget {
/// Handles the rewrite of the body of a closure to which a result builder
/// has been applied.
class BuilderClosureRewriter
: public StmtVisitor<BuilderClosureRewriter, Stmt *, ResultBuilderTarget> {
: public StmtVisitor<BuilderClosureRewriter, NullablePtr<Stmt>,
ResultBuilderTarget> {
ASTContext &ctx;
const Solution &solution;
DeclContext *dc;
Expand Down Expand Up @@ -1124,8 +1125,9 @@ class BuilderClosureRewriter
solution(solution), dc(dc), builderTransform(builderTransform),
rewriteTarget(rewriteTarget) { }

Stmt *visitBraceStmt(BraceStmt *braceStmt, ResultBuilderTarget target,
Optional<ResultBuilderTarget> innerTarget = None) {
NullablePtr<Stmt>
visitBraceStmt(BraceStmt *braceStmt, ResultBuilderTarget target,
Optional<ResultBuilderTarget> innerTarget = None) {
std::vector<ASTNode> newElements;

// If there is an "inner" target corresponding to this brace, declare
Expand Down Expand Up @@ -1165,7 +1167,7 @@ class BuilderClosureRewriter
// "throw" statements produce no value. Transform them directly.
if (auto throwStmt = dyn_cast<ThrowStmt>(stmt)) {
if (auto newStmt = visitThrowStmt(throwStmt)) {
newElements.push_back(stmt);
newElements.push_back(newStmt.get());
}
continue;
}
Expand All @@ -1176,7 +1178,7 @@ class BuilderClosureRewriter

declareTemporaryVariable(captured.first, newElements);

Stmt *finalStmt = visit(
auto finalStmt = visit(
stmt,
ResultBuilderTarget{ResultBuilderTarget::TemporaryVar,
std::move(captured)});
Expand All @@ -1186,7 +1188,7 @@ class BuilderClosureRewriter
if (!finalStmt)
return nullptr;

newElements.push_back(finalStmt);
newElements.push_back(finalStmt.get());
continue;
}

Expand Down Expand Up @@ -1236,7 +1238,7 @@ class BuilderClosureRewriter
braceStmt->getRBraceLoc());
}

Stmt *visitIfStmt(IfStmt *ifStmt, ResultBuilderTarget target) {
NullablePtr<Stmt> visitIfStmt(IfStmt *ifStmt, ResultBuilderTarget target) {
// Rewrite the condition.
if (auto condition = rewriteTarget(
SolutionApplicationTarget(ifStmt->getCond(), dc)))
Expand All @@ -1252,7 +1254,10 @@ class BuilderClosureRewriter
temporaryVar, {target.captured.second[0]}),
ResultBuilderTarget::forAssign(
capturedThen.first, {capturedThen.second.front()}));
ifStmt->setThenStmt(newThen);
if (!newThen)
return nullptr;

ifStmt->setThenStmt(newThen.get());

// Look for a #available condition. If there is one, we need to check
// that the resulting type of the "then" doesn't refer to any types that
Expand Down Expand Up @@ -1315,23 +1320,28 @@ class BuilderClosureRewriter
dyn_cast_or_null<BraceStmt>(ifStmt->getElseStmt())) {
// Translate the "else" branch when it's a stmt-brace.
auto capturedElse = takeCapturedStmt(elseBraceStmt);
Stmt *newElse = visitBraceStmt(
auto newElse = visitBraceStmt(
elseBraceStmt,
ResultBuilderTarget::forAssign(
temporaryVar, {target.captured.second[1]}),
ResultBuilderTarget::forAssign(
capturedElse.first, {capturedElse.second.front()}));
ifStmt->setElseStmt(newElse);
if (!newElse)
return nullptr;

ifStmt->setElseStmt(newElse.get());
} else if (auto elseIfStmt = cast_or_null<IfStmt>(ifStmt->getElseStmt())){
// Translate the "else" branch when it's an else-if.
auto capturedElse = takeCapturedStmt(elseIfStmt);
std::vector<ASTNode> newElseElements;
declareTemporaryVariable(capturedElse.first, newElseElements);
newElseElements.push_back(
visitIfStmt(
elseIfStmt,
ResultBuilderTarget::forAssign(
capturedElse.first, capturedElse.second)));
auto newElseElt =
visitIfStmt(elseIfStmt, ResultBuilderTarget::forAssign(
capturedElse.first, capturedElse.second));
if (!newElseElt)
return nullptr;

newElseElements.push_back(newElseElt.get());
newElseElements.push_back(
initializeTarget(
ResultBuilderTarget::forAssign(
Expand All @@ -1355,23 +1365,25 @@ class BuilderClosureRewriter
return ifStmt;
}

Stmt *visitDoStmt(DoStmt *doStmt, ResultBuilderTarget target) {
NullablePtr<Stmt> visitDoStmt(DoStmt *doStmt, ResultBuilderTarget target) {
// Each statement turns into a (potential) temporary variable
// binding followed by the statement itself.
auto body = cast<BraceStmt>(doStmt->getBody());
auto captured = takeCapturedStmt(body);

auto newInnerBody = cast<BraceStmt>(
visitBraceStmt(
body,
target,
ResultBuilderTarget::forAssign(
captured.first, {captured.second.front()})));
doStmt->setBody(newInnerBody);
auto newInnerBody =
visitBraceStmt(body, target,
ResultBuilderTarget::forAssign(
captured.first, {captured.second.front()}));
if (!newInnerBody)
return nullptr;

doStmt->setBody(cast<BraceStmt>(newInnerBody.get()));
return doStmt;
}

Stmt *visitSwitchStmt(SwitchStmt *switchStmt, ResultBuilderTarget target) {
NullablePtr<Stmt> visitSwitchStmt(SwitchStmt *switchStmt,
ResultBuilderTarget target) {
// Translate the subject expression.
ConstraintSystem &cs = solution.getConstraintSystem();
auto subjectTarget =
Expand Down Expand Up @@ -1416,7 +1428,8 @@ class BuilderClosureRewriter
return switchStmt;
}

Stmt *visitCaseStmt(CaseStmt *caseStmt, ResultBuilderTarget target) {
NullablePtr<Stmt> visitCaseStmt(CaseStmt *caseStmt,
ResultBuilderTarget target) {
// Translate the patterns and guard expressions for each case label item.
for (auto &caseLabelItem : caseStmt->getMutableCaseLabelItems()) {
SolutionApplicationTarget caseLabelTarget(&caseLabelItem, dc);
Expand All @@ -1427,19 +1440,19 @@ class BuilderClosureRewriter
// Transform the body of the case.
auto body = cast<BraceStmt>(caseStmt->getBody());
auto captured = takeCapturedStmt(body);
auto newInnerBody = cast<BraceStmt>(
visitBraceStmt(
body,
target,
ResultBuilderTarget::forAssign(
captured.first, {captured.second.front()})));
caseStmt->setBody(newInnerBody);
auto newInnerBody =
visitBraceStmt(body, target,
ResultBuilderTarget::forAssign(
captured.first, {captured.second.front()}));
if (!newInnerBody)
return nullptr;

caseStmt->setBody(cast<BraceStmt>(newInnerBody.get()));
return caseStmt;
}

Stmt *visitForEachStmt(
ForEachStmt *forEachStmt, ResultBuilderTarget target) {
NullablePtr<Stmt> visitForEachStmt(ForEachStmt *forEachStmt,
ResultBuilderTarget target) {
// Translate the for-each loop header.
ConstraintSystem &cs = solution.getConstraintSystem();
auto forEachTarget =
Expand Down Expand Up @@ -1469,13 +1482,14 @@ class BuilderClosureRewriter
// will append the result of executing the loop body to the array.
auto body = forEachStmt->getBody();
auto capturedBody = takeCapturedStmt(body);
auto newBody = cast<BraceStmt>(
visitBraceStmt(
body,
ResultBuilderTarget::forExpression(arrayAppendCall),
ResultBuilderTarget::forAssign(
capturedBody.first, {capturedBody.second.front()})));
forEachStmt->setBody(newBody);
auto newBody = visitBraceStmt(
body, ResultBuilderTarget::forExpression(arrayAppendCall),
ResultBuilderTarget::forAssign(capturedBody.first,
{capturedBody.second.front()}));
if (!newBody)
return nullptr;

forEachStmt->setBody(cast<BraceStmt>(newBody.get()));
outerBodySteps.push_back(forEachStmt);

// Step 3. Perform the buildArray() call to turn the array of results
Expand All @@ -1487,11 +1501,11 @@ class BuilderClosureRewriter

// Form a brace statement to put together the three main steps for the
// for-each loop translation outlined above.
return BraceStmt::create(
ctx, forEachStmt->getStartLoc(), outerBodySteps, newBody->getEndLoc());
return BraceStmt::create(ctx, forEachStmt->getStartLoc(), outerBodySteps,
newBody.get()->getEndLoc());
}

Stmt *visitThrowStmt(ThrowStmt *throwStmt) {
NullablePtr<Stmt> visitThrowStmt(ThrowStmt *throwStmt) {
// Rewrite the error.
auto target = *solution.getConstraintSystem()
.getSolutionApplicationTarget(throwStmt);
Expand All @@ -1503,12 +1517,14 @@ class BuilderClosureRewriter
return throwStmt;
}

Stmt *visitThrowStmt(ThrowStmt *throwStmt, ResultBuilderTarget target) {
NullablePtr<Stmt> visitThrowStmt(ThrowStmt *throwStmt,
ResultBuilderTarget target) {
llvm_unreachable("Throw statements produce no value");
}

#define UNHANDLED_RESULT_BUILDER_STMT(STMT) \
Stmt *visit##STMT##Stmt(STMT##Stmt *stmt, ResultBuilderTarget target) { \
NullablePtr<Stmt> \
visit##STMT##Stmt(STMT##Stmt *stmt, ResultBuilderTarget target) { \
llvm_unreachable("Function builders do not allow statement of kind " \
#STMT); \
}
Expand Down Expand Up @@ -1540,12 +1556,12 @@ BraceStmt *swift::applyResultBuilderTransform(
rewriteTarget) {
BuilderClosureRewriter rewriter(solution, dc, applied, rewriteTarget);
auto captured = rewriter.takeCapturedStmt(body);
return cast_or_null<BraceStmt>(
rewriter.visitBraceStmt(
body,
ResultBuilderTarget::forReturn(applied.returnExpr),
ResultBuilderTarget::forAssign(
captured.first, captured.second)));
auto result = rewriter.visitBraceStmt(
body, ResultBuilderTarget::forReturn(applied.returnExpr),
ResultBuilderTarget::forAssign(captured.first, captured.second));
if (!result)
return nullptr;
return cast<BraceStmt>(result.get());
}

Optional<BraceStmt *> TypeChecker::applyResultBuilderBodyTransform(
Expand Down
66 changes: 66 additions & 0 deletions test/Constraints/result_builder_invalid_stmts.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: %target-typecheck-verify-swift
// rdar://81228221

@resultBuilder
struct Builder {
static func buildBlock(_ components: Int...) -> Int { 0 }
static func buildEither(first component: Int) -> Int { 0 }
static func buildEither(second component: Int) -> Int { 0 }
static func buildOptional(_ component: Int?) -> Int { 0 }
static func buildArray(_ components: [Int]) -> Int { 0 }
}

@Builder
func foo(_ x: String) -> Int {
if .random() {
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
}
}
}

@Builder
func bar(_ x: String) -> Int {
switch 0 {
case 0:
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
}
default:
0
}
}

@Builder
func baz(_ x: String) -> Int {
do {
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
}
}
}

@Builder
func qux(_ x: String) -> Int {
for _ in 0 ... 0 {
switch x {
case 1: // expected-error {{expression pattern of type 'Int' cannot match values of type 'String'}}
// expected-note@-1 {{overloads for '~=' exist with these partially matching parameter lists}}
0
default:
1
}
}
}