Skip to content

[Refactoring] Replace lifted breaks/returns with placeholder for async transform #37189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 71 additions & 28 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4517,6 +4517,7 @@ struct CallbackClassifier {
/// names from `Body`. Errors are added through `DiagEngine`, possibly
/// resulting in partially filled out blocks.
static void classifyInto(ClassifiedBlocks &Blocks,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine,
ArrayRef<const ParamDecl *> SuccessParams,
const ParamDecl *ErrParam, HandlerType ResultType,
Expand All @@ -4528,25 +4529,30 @@ struct CallbackClassifier {
if (ErrParam)
ParamsSet.insert(ErrParam);

CallbackClassifier Classifier(Blocks, DiagEngine, ParamsSet, ErrParam,
CallbackClassifier Classifier(Blocks, HandledSwitches, DiagEngine,
ParamsSet, ErrParam,
ResultType == HandlerType::RESULT);
Classifier.classifyNodes(Body);
}

private:
ClassifiedBlocks &Blocks;
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
DiagnosticEngine &DiagEngine;
ClassifiedBlock *CurrentBlock;
llvm::DenseSet<const Decl *> ParamsSet;
const ParamDecl *ErrParam;
bool IsResultParam;

CallbackClassifier(ClassifiedBlocks &Blocks, DiagnosticEngine &DiagEngine,
CallbackClassifier(ClassifiedBlocks &Blocks,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine,
llvm::DenseSet<const Decl *> ParamsSet,
const ParamDecl *ErrParam, bool IsResultParam)
: Blocks(Blocks), DiagEngine(DiagEngine),
CurrentBlock(&Blocks.SuccessBlock), ParamsSet(ParamsSet),
ErrParam(ErrParam), IsResultParam(IsResultParam) {}
: Blocks(Blocks), HandledSwitches(HandledSwitches),
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock),
ParamsSet(ParamsSet), ErrParam(ErrParam), IsResultParam(IsResultParam) {
}

void classifyNodes(ArrayRef<ASTNode> Nodes) {
for (auto I = Nodes.begin(), E = Nodes.end(); I < E; ++I) {
Expand Down Expand Up @@ -4680,6 +4686,7 @@ struct CallbackClassifier {
void classifySwitch(SwitchStmt *SS) {
if (!IsResultParam || singleSwitchSubject(SS) != ErrParam) {
CurrentBlock->addNode(SS);
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof. Thanks for catching this and adding the test.

}

for (auto *CS : SS->getCases()) {
Expand Down Expand Up @@ -4717,6 +4724,8 @@ struct CallbackClassifier {
if (DiagEngine.hadAnyError())
return;
}
// Mark this switch statement as having been transformed.
HandledSwitches.insert(SS);
}
};

Expand Down Expand Up @@ -4775,6 +4784,9 @@ class AsyncConverter : private SourceEntityWalker {
// references to it
llvm::DenseMap<const Decl *, std::string> Names;

/// The switch statements that have been re-written by this transform.
llvm::DenseSet<SwitchStmt *> HandledSwitches;

// These are per-node (ie. are saved and restored on each convertNode call)
SourceLoc LastAddedLoc;
int NestedExprCount = 0;
Expand Down Expand Up @@ -4845,6 +4857,9 @@ class AsyncConverter : private SourceEntityWalker {
NestedExprCount++;
return true;
}
// Note we don't walk into any nested local function decls. If we start
// doing so in the future, be sure to update the logic that deals with
// converting unhandled returns into placeholders in walkToStmtPre.
return false;
}

Expand All @@ -4862,18 +4877,16 @@ class AsyncConverter : private SourceEntityWalker {
bool AddPlaceholder = Placeholders.count(D);
StringRef Name = newNameFor(D, false);
if (AddPlaceholder || !Name.empty())
return addCustom(DRE->getStartLoc(),
Lexer::getLocForEndOfToken(SM, DRE->getEndLoc()),
[&]() {
if (AddPlaceholder)
OS << PLACEHOLDER_START;
if (!Name.empty())
OS << Name;
else
D->getName().print(OS);
if (AddPlaceholder)
OS << PLACEHOLDER_END;
});
return addCustom(DRE->getSourceRange(), [&]() {
if (AddPlaceholder)
OS << PLACEHOLDER_START;
if (!Name.empty())
OS << Name;
else
D->getName().print(OS);
if (AddPlaceholder)
OS << PLACEHOLDER_END;
});
}
} else if (isa<ForceValueExpr>(E) || isa<BindOptionalExpr>(E)) {
// Remove a force unwrap or optional chain of a returned success value,
Expand All @@ -4887,26 +4900,57 @@ class AsyncConverter : private SourceEntityWalker {
// completely valid.
if (auto *D = E->getReferencedDecl().getDecl()) {
if (Unwraps.count(D))
return addCustom(E->getStartLoc(), E->getEndLoc().getAdvancedLoc(1),
return addCustom(E->getSourceRange(),
[&]() { OS << newNameFor(D, true); });
}
} else if (NestedExprCount == 0) {
if (CallExpr *CE = TopHandler.getAsHandlerCall(E))
return addCustom(CE->getStartLoc(), CE->getEndLoc().getAdvancedLoc(1),
[&]() { addHandlerCall(CE); });
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });

if (auto *CE = dyn_cast<CallExpr>(E)) {
auto HandlerDesc = AsyncHandlerDesc::find(
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
if (HandlerDesc.isValid())
return addCustom(CE->getStartLoc(), CE->getEndLoc().getAdvancedLoc(1),
return addCustom(CE->getSourceRange(),
[&]() { addAsyncAlternativeCall(CE, HandlerDesc); });
}
}

NestedExprCount++;
return true;
}

bool replaceRangeWithPlaceholder(SourceRange range) {
return addCustom(range, [&]() {
OS << PLACEHOLDER_START;
addRange(range, /*toEndOfToken*/ true);
OS << PLACEHOLDER_END;
});
}

bool walkToStmtPre(Stmt *S) override {
// Some break and return statements need to be turned into placeholders,
// as they may no longer perform the control flow that the user is
// expecting.
if (!S->isImplicit()) {
// For a break, if it's jumping out of a switch statement that we've
// re-written as a part of the transform, turn it into a placeholder, as
// it would have been lifted out of the switch statement.
if (auto *BS = dyn_cast<BreakStmt>(S)) {
if (auto *SS = dyn_cast<SwitchStmt>(BS->getTarget())) {
if (HandledSwitches.contains(SS))
replaceRangeWithPlaceholder(S->getSourceRange());
}
}

// For a return, if it's not nested inside another closure or function,
// turn it into a placeholder, as it will be lifted out of the callback.
if (isa<ReturnStmt>(S) && NestedExprCount == 0)
replaceRangeWithPlaceholder(S->getSourceRange());
}
return true;
}

#undef PLACEHOLDER_START
#undef PLACEHOLDER_END

Expand All @@ -4915,11 +4959,10 @@ class AsyncConverter : private SourceEntityWalker {
return true;
}

bool addCustom(SourceLoc End, SourceLoc NextAddedLoc,
std::function<void()> Custom = {}) {
addRange(LastAddedLoc, End);
bool addCustom(SourceRange Range, std::function<void()> Custom = {}) {
addRange(LastAddedLoc, Range.Start);
Custom();
LastAddedLoc = NextAddedLoc;
LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End);
return false;
}

Expand Down Expand Up @@ -5110,9 +5153,9 @@ class AsyncConverter : private SourceEntityWalker {
if (!HandlerDesc.HasError) {
Blocks.SuccessBlock.addAllNodes(CallbackBody);
} else if (!CallbackBody.empty()) {
CallbackClassifier::classifyInto(Blocks, DiagEngine, SuccessParams,
ErrParam, HandlerDesc.Type,
CallbackBody);
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
SuccessParams, ErrParam,
HandlerDesc.Type, CallbackBody);
if (DiagEngine.hadAnyError()) {
// Can only fallback when the results are params, in which case only
// the names are used (defaulted to the names of the params if none)
Expand Down
10 changes: 5 additions & 5 deletions test/refactoring/ConvertAsync/convert_function.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func asyncParams(arg: String, _ completion: (String?, Error?) -> Void) {
// ASYNC-SIMPLE: func {{[a-zA-Z_]+}}(arg: String) async throws -> String {
// ASYNC-SIMPLE-NEXT: let str = try await simpleErr(arg: arg)
// ASYNC-SIMPLE-NEXT: print("simpleErr")
// ASYNC-SIMPLE-NEXT: return str
// ASYNC-SIMPLE-NEXT: {{^}}return str{{$}}
// ASYNC-SIMPLE-NEXT: print("after")
// ASYNC-SIMPLE-NEXT: }

Expand Down Expand Up @@ -120,7 +120,7 @@ func asyncResNewErr(arg: String, _ completion: (Result<String, Error>) -> Void)
// ASYNC-ERR-NEXT: do {
// ASYNC-ERR-NEXT: let str = try await simpleErr(arg: arg)
// ASYNC-ERR-NEXT: print("simpleErr")
// ASYNC-ERR-NEXT: return str
// ASYNC-ERR-NEXT: {{^}}return str{{$}}
// ASYNC-ERR-NEXT: print("after")
// ASYNC-ERR-NEXT: } catch let err {
// ASYNC-ERR-NEXT: throw CustomError.Bad
Expand All @@ -142,11 +142,11 @@ func asyncUnhandledCompletion(_ completion: (String) -> Void) {
// ASYNC-UNHANDLED: func asyncUnhandledCompletion() async -> String {
// ASYNC-UNHANDLED-NEXT: let str = await simple()
// ASYNC-UNHANDLED-NEXT: let success = run {
// ASYNC-UNHANDLED-NEXT: <#completion#>(str)
// ASYNC-UNHANDLED-NEXT: return true
// ASYNC-UNHANDLED-NEXT: <#completion#>(str)
// ASYNC-UNHANDLED-NEXT: {{^}} return true{{$}}
// ASYNC-UNHANDLED-NEXT: }
// ASYNC-UNHANDLED-NEXT: if !success {
// ASYNC-UNHANDLED-NEXT: return "bad"
// ASYNC-UNHANDLED-NEXT: {{^}} return "bad"{{$}}
// ASYNC-UNHANDLED-NEXT: }
// ASYNC-UNHANDLED-NEXT: }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ withError { res, err in
// NESTEDRET-NEXT: let str = try await withError()
// NESTEDRET-NEXT: print("before")
// NESTEDRET-NEXT: if test(str) {
// NESTEDRET-NEXT: return
// NESTEDRET-NEXT: <#return#>
// NESTEDRET-NEXT: }
// NESTEDRET-NEXT: print("got result \(str)")
// NESTEDRET-NEXT: print("after")
Expand Down
94 changes: 92 additions & 2 deletions test/refactoring/ConvertAsync/convert_result.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
func simple(_ completion: (Result<String, Error>) -> Void) { }
func simpleWithArg(_ arg: Int, _ completion: (Result<String, Error>) -> Void) { }
func noError(_ completion: (Result<String, Never>) -> Void) { }
func test(_ str: String) -> Bool { return false }

Expand Down Expand Up @@ -279,7 +280,7 @@ simple { res in
// NESTEDRET-NEXT: let str = try await simple()
// NESTEDRET-NEXT: print("before")
// NESTEDRET-NEXT: if test(str) {
// NESTEDRET-NEXT: return
// NESTEDRET-NEXT: <#return#>
// NESTEDRET-NEXT: }
// NESTEDRET-NEXT: print("result \(str)")
// NESTEDRET-NEXT: print("after")
Expand All @@ -303,7 +304,7 @@ simple { res in
// NESTEDBREAK-NEXT: let str = try await simple()
// NESTEDBREAK-NEXT: print("before")
// NESTEDBREAK-NEXT: if test(str) {
// NESTEDBREAK-NEXT: break
// NESTEDBREAK-NEXT: <#break#>
// NESTEDBREAK-NEXT: }
// NESTEDBREAK-NEXT: print("result \(str)")
// NESTEDBREAK-NEXT: print("after")
Expand All @@ -322,3 +323,92 @@ voidAndErrorResult { res in
}
// VOID-AND-ERROR-RESULT-CALL: {{^}}try await voidAndErrorResult()
// VOID-AND-ERROR-RESULT-CALL: {{^}}print(<#res#>)

// Make sure we ignore an unrelated switch.
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=IGNORE-UNRELATED %s
simple { res in
print("before")
switch Bool.random() {
case true:
break
case false:
break
}
print("after")
}
// IGNORE-UNRELATED: let res = try await simple()
// IGNORE-UNRELATED-NEXT: print("before")
// IGNORE-UNRELATED-NEXT: switch Bool.random() {
// IGNORE-UNRELATED-NEXT: case true:
// IGNORE-UNRELATED-NEXT: {{^}} break{{$}}
// IGNORE-UNRELATED-NEXT: case false:
// IGNORE-UNRELATED-NEXT: {{^}} break{{$}}
// IGNORE-UNRELATED-NEXT: }
// IGNORE-UNRELATED-NEXT: print("after")

// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=BREAK-RET-PLACEHOLDER %s
simpleWithArg({ return 0 }()) { res in
switch res {
case .success:
if .random() { break }
x: if .random() { break x }
case .failure:
break
}

func foo<T>(_ x: T) {
if .random() { return }
}
foo(res)

let fn = {
if .random() { return }
return
}
fn()

_ = { return }()

switch Bool.random() {
case true:
break
case false:
if .random() { break }
y: if .random() { break y }
return
}

x: if .random() {
break x
}
if .random() { return }
}

// Make sure we replace lifted break/returns with placeholders, but keep nested
// break/returns in e.g closures or labelled control flow in place.

// BREAK-RET-PLACEHOLDER: let res = try await simpleWithArg({ return 0 }())
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { <#break#> }
// BREAK-RET-PLACEHOLDER-NEXT: x: if .random() { break x }
// BREAK-RET-PLACEHOLDER-NEXT: func foo<T>(_ x: T) {
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { return }
// BREAK-RET-PLACEHOLDER-NEXT: }
// BREAK-RET-PLACEHOLDER-NEXT: foo(<#res#>)
// BREAK-RET-PLACEHOLDER-NEXT: let fn = {
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { return }
// BREAK-RET-PLACEHOLDER-NEXT: {{^}} return{{$}}
// BREAK-RET-PLACEHOLDER-NEXT: }
// BREAK-RET-PLACEHOLDER-NEXT: fn()
// BREAK-RET-PLACEHOLDER-NEXT: _ = { return }()
// BREAK-RET-PLACEHOLDER-NEXT: switch Bool.random() {
// BREAK-RET-PLACEHOLDER-NEXT: case true:
// BREAK-RET-PLACEHOLDER-NEXT: {{^}} break{{$}}
// BREAK-RET-PLACEHOLDER-NEXT: case false:
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { break }
// BREAK-RET-PLACEHOLDER-NEXT: y: if .random() { break y }
// BREAK-RET-PLACEHOLDER-NEXT: <#return#>
// BREAK-RET-PLACEHOLDER-NEXT: }
// BREAK-RET-PLACEHOLDER-NEXT: x: if .random() {
// BREAK-RET-PLACEHOLDER-NEXT: {{^}} break x{{$}}
// BREAK-RET-PLACEHOLDER-NEXT: }
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { <#return#> }