Skip to content

Commit 69caeae

Browse files
authored
Merge pull request #37189 from hamishknight/break-it-down-for-me
2 parents 1344900 + 2d75074 commit 69caeae

File tree

4 files changed

+169
-36
lines changed

4 files changed

+169
-36
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4547,6 +4547,7 @@ struct CallbackClassifier {
45474547
/// names from `Body`. Errors are added through `DiagEngine`, possibly
45484548
/// resulting in partially filled out blocks.
45494549
static void classifyInto(ClassifiedBlocks &Blocks,
4550+
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
45504551
DiagnosticEngine &DiagEngine,
45514552
ArrayRef<const ParamDecl *> SuccessParams,
45524553
const ParamDecl *ErrParam, HandlerType ResultType,
@@ -4558,25 +4559,30 @@ struct CallbackClassifier {
45584559
if (ErrParam)
45594560
ParamsSet.insert(ErrParam);
45604561

4561-
CallbackClassifier Classifier(Blocks, DiagEngine, ParamsSet, ErrParam,
4562+
CallbackClassifier Classifier(Blocks, HandledSwitches, DiagEngine,
4563+
ParamsSet, ErrParam,
45624564
ResultType == HandlerType::RESULT);
45634565
Classifier.classifyNodes(Body);
45644566
}
45654567

45664568
private:
45674569
ClassifiedBlocks &Blocks;
4570+
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
45684571
DiagnosticEngine &DiagEngine;
45694572
ClassifiedBlock *CurrentBlock;
45704573
llvm::DenseSet<const Decl *> ParamsSet;
45714574
const ParamDecl *ErrParam;
45724575
bool IsResultParam;
45734576

4574-
CallbackClassifier(ClassifiedBlocks &Blocks, DiagnosticEngine &DiagEngine,
4577+
CallbackClassifier(ClassifiedBlocks &Blocks,
4578+
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
4579+
DiagnosticEngine &DiagEngine,
45754580
llvm::DenseSet<const Decl *> ParamsSet,
45764581
const ParamDecl *ErrParam, bool IsResultParam)
4577-
: Blocks(Blocks), DiagEngine(DiagEngine),
4578-
CurrentBlock(&Blocks.SuccessBlock), ParamsSet(ParamsSet),
4579-
ErrParam(ErrParam), IsResultParam(IsResultParam) {}
4582+
: Blocks(Blocks), HandledSwitches(HandledSwitches),
4583+
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock),
4584+
ParamsSet(ParamsSet), ErrParam(ErrParam), IsResultParam(IsResultParam) {
4585+
}
45804586

45814587
void classifyNodes(ArrayRef<ASTNode> Nodes) {
45824588
for (auto I = Nodes.begin(), E = Nodes.end(); I < E; ++I) {
@@ -4710,6 +4716,7 @@ struct CallbackClassifier {
47104716
void classifySwitch(SwitchStmt *SS) {
47114717
if (!IsResultParam || singleSwitchSubject(SS) != ErrParam) {
47124718
CurrentBlock->addNode(SS);
4719+
return;
47134720
}
47144721

47154722
for (auto *CS : SS->getCases()) {
@@ -4747,6 +4754,8 @@ struct CallbackClassifier {
47474754
if (DiagEngine.hadAnyError())
47484755
return;
47494756
}
4757+
// Mark this switch statement as having been transformed.
4758+
HandledSwitches.insert(SS);
47504759
}
47514760
};
47524761

@@ -4805,6 +4814,9 @@ class AsyncConverter : private SourceEntityWalker {
48054814
// references to it
48064815
llvm::DenseMap<const Decl *, std::string> Names;
48074816

4817+
/// The switch statements that have been re-written by this transform.
4818+
llvm::DenseSet<SwitchStmt *> HandledSwitches;
4819+
48084820
// These are per-node (ie. are saved and restored on each convertNode call)
48094821
SourceLoc LastAddedLoc;
48104822
int NestedExprCount = 0;
@@ -4875,6 +4887,9 @@ class AsyncConverter : private SourceEntityWalker {
48754887
NestedExprCount++;
48764888
return true;
48774889
}
4890+
// Note we don't walk into any nested local function decls. If we start
4891+
// doing so in the future, be sure to update the logic that deals with
4892+
// converting unhandled returns into placeholders in walkToStmtPre.
48784893
return false;
48794894
}
48804895

@@ -4892,18 +4907,16 @@ class AsyncConverter : private SourceEntityWalker {
48924907
bool AddPlaceholder = Placeholders.count(D);
48934908
StringRef Name = newNameFor(D, false);
48944909
if (AddPlaceholder || !Name.empty())
4895-
return addCustom(DRE->getStartLoc(),
4896-
Lexer::getLocForEndOfToken(SM, DRE->getEndLoc()),
4897-
[&]() {
4898-
if (AddPlaceholder)
4899-
OS << PLACEHOLDER_START;
4900-
if (!Name.empty())
4901-
OS << Name;
4902-
else
4903-
D->getName().print(OS);
4904-
if (AddPlaceholder)
4905-
OS << PLACEHOLDER_END;
4906-
});
4910+
return addCustom(DRE->getSourceRange(), [&]() {
4911+
if (AddPlaceholder)
4912+
OS << PLACEHOLDER_START;
4913+
if (!Name.empty())
4914+
OS << Name;
4915+
else
4916+
D->getName().print(OS);
4917+
if (AddPlaceholder)
4918+
OS << PLACEHOLDER_END;
4919+
});
49074920
}
49084921
} else if (isa<ForceValueExpr>(E) || isa<BindOptionalExpr>(E)) {
49094922
// Remove a force unwrap or optional chain of a returned success value,
@@ -4917,26 +4930,57 @@ class AsyncConverter : private SourceEntityWalker {
49174930
// completely valid.
49184931
if (auto *D = E->getReferencedDecl().getDecl()) {
49194932
if (Unwraps.count(D))
4920-
return addCustom(E->getStartLoc(), E->getEndLoc().getAdvancedLoc(1),
4933+
return addCustom(E->getSourceRange(),
49214934
[&]() { OS << newNameFor(D, true); });
49224935
}
49234936
} else if (NestedExprCount == 0) {
49244937
if (CallExpr *CE = TopHandler.getAsHandlerCall(E))
4925-
return addCustom(CE->getStartLoc(), CE->getEndLoc().getAdvancedLoc(1),
4926-
[&]() { addHandlerCall(CE); });
4938+
return addCustom(CE->getSourceRange(), [&]() { addHandlerCall(CE); });
49274939

49284940
if (auto *CE = dyn_cast<CallExpr>(E)) {
49294941
auto HandlerDesc = AsyncHandlerDesc::find(
49304942
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
49314943
if (HandlerDesc.isValid())
4932-
return addCustom(CE->getStartLoc(), CE->getEndLoc().getAdvancedLoc(1),
4944+
return addCustom(CE->getSourceRange(),
49334945
[&]() { addAsyncAlternativeCall(CE, HandlerDesc); });
49344946
}
49354947
}
49364948

49374949
NestedExprCount++;
49384950
return true;
49394951
}
4952+
4953+
bool replaceRangeWithPlaceholder(SourceRange range) {
4954+
return addCustom(range, [&]() {
4955+
OS << PLACEHOLDER_START;
4956+
addRange(range, /*toEndOfToken*/ true);
4957+
OS << PLACEHOLDER_END;
4958+
});
4959+
}
4960+
4961+
bool walkToStmtPre(Stmt *S) override {
4962+
// Some break and return statements need to be turned into placeholders,
4963+
// as they may no longer perform the control flow that the user is
4964+
// expecting.
4965+
if (!S->isImplicit()) {
4966+
// For a break, if it's jumping out of a switch statement that we've
4967+
// re-written as a part of the transform, turn it into a placeholder, as
4968+
// it would have been lifted out of the switch statement.
4969+
if (auto *BS = dyn_cast<BreakStmt>(S)) {
4970+
if (auto *SS = dyn_cast<SwitchStmt>(BS->getTarget())) {
4971+
if (HandledSwitches.contains(SS))
4972+
replaceRangeWithPlaceholder(S->getSourceRange());
4973+
}
4974+
}
4975+
4976+
// For a return, if it's not nested inside another closure or function,
4977+
// turn it into a placeholder, as it will be lifted out of the callback.
4978+
if (isa<ReturnStmt>(S) && NestedExprCount == 0)
4979+
replaceRangeWithPlaceholder(S->getSourceRange());
4980+
}
4981+
return true;
4982+
}
4983+
49404984
#undef PLACEHOLDER_START
49414985
#undef PLACEHOLDER_END
49424986

@@ -4945,11 +4989,10 @@ class AsyncConverter : private SourceEntityWalker {
49454989
return true;
49464990
}
49474991

4948-
bool addCustom(SourceLoc End, SourceLoc NextAddedLoc,
4949-
std::function<void()> Custom = {}) {
4950-
addRange(LastAddedLoc, End);
4992+
bool addCustom(SourceRange Range, std::function<void()> Custom = {}) {
4993+
addRange(LastAddedLoc, Range.Start);
49514994
Custom();
4952-
LastAddedLoc = NextAddedLoc;
4995+
LastAddedLoc = Lexer::getLocForEndOfToken(SM, Range.End);
49534996
return false;
49544997
}
49554998

@@ -5140,9 +5183,9 @@ class AsyncConverter : private SourceEntityWalker {
51405183
if (!HandlerDesc.HasError) {
51415184
Blocks.SuccessBlock.addAllNodes(CallbackBody);
51425185
} else if (!CallbackBody.empty()) {
5143-
CallbackClassifier::classifyInto(Blocks, DiagEngine, SuccessParams,
5144-
ErrParam, HandlerDesc.Type,
5145-
CallbackBody);
5186+
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
5187+
SuccessParams, ErrParam,
5188+
HandlerDesc.Type, CallbackBody);
51465189
if (DiagEngine.hadAnyError()) {
51475190
// Can only fallback when the results are params, in which case only
51485191
// the names are used (defaulted to the names of the params if none)

test/refactoring/ConvertAsync/convert_function.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func asyncParams(arg: String, _ completion: (String?, Error?) -> Void) {
8787
// ASYNC-SIMPLE: func {{[a-zA-Z_]+}}(arg: String) async throws -> String {
8888
// ASYNC-SIMPLE-NEXT: let str = try await simpleErr(arg: arg)
8989
// ASYNC-SIMPLE-NEXT: print("simpleErr")
90-
// ASYNC-SIMPLE-NEXT: return str
90+
// ASYNC-SIMPLE-NEXT: {{^}}return str{{$}}
9191
// ASYNC-SIMPLE-NEXT: print("after")
9292
// ASYNC-SIMPLE-NEXT: }
9393

@@ -120,7 +120,7 @@ func asyncResNewErr(arg: String, _ completion: (Result<String, Error>) -> Void)
120120
// ASYNC-ERR-NEXT: do {
121121
// ASYNC-ERR-NEXT: let str = try await simpleErr(arg: arg)
122122
// ASYNC-ERR-NEXT: print("simpleErr")
123-
// ASYNC-ERR-NEXT: return str
123+
// ASYNC-ERR-NEXT: {{^}}return str{{$}}
124124
// ASYNC-ERR-NEXT: print("after")
125125
// ASYNC-ERR-NEXT: } catch let err {
126126
// ASYNC-ERR-NEXT: throw CustomError.Bad
@@ -142,11 +142,11 @@ func asyncUnhandledCompletion(_ completion: (String) -> Void) {
142142
// ASYNC-UNHANDLED: func asyncUnhandledCompletion() async -> String {
143143
// ASYNC-UNHANDLED-NEXT: let str = await simple()
144144
// ASYNC-UNHANDLED-NEXT: let success = run {
145-
// ASYNC-UNHANDLED-NEXT: <#completion#>(str)
146-
// ASYNC-UNHANDLED-NEXT: return true
145+
// ASYNC-UNHANDLED-NEXT: <#completion#>(str)
146+
// ASYNC-UNHANDLED-NEXT: {{^}} return true{{$}}
147147
// ASYNC-UNHANDLED-NEXT: }
148148
// ASYNC-UNHANDLED-NEXT: if !success {
149-
// ASYNC-UNHANDLED-NEXT: return "bad"
149+
// ASYNC-UNHANDLED-NEXT: {{^}} return "bad"{{$}}
150150
// ASYNC-UNHANDLED-NEXT: }
151151
// ASYNC-UNHANDLED-NEXT: }
152152

test/refactoring/ConvertAsync/convert_params_single.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ withError { res, err in
357357
// NESTEDRET-NEXT: let str = try await withError()
358358
// NESTEDRET-NEXT: print("before")
359359
// NESTEDRET-NEXT: if test(str) {
360-
// NESTEDRET-NEXT: return
360+
// NESTEDRET-NEXT: <#return#>
361361
// NESTEDRET-NEXT: }
362362
// NESTEDRET-NEXT: print("got result \(str)")
363363
// NESTEDRET-NEXT: print("after")

test/refactoring/ConvertAsync/convert_result.swift

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
func simple(_ completion: (Result<String, Error>) -> Void) { }
2+
func simpleWithArg(_ arg: Int, _ completion: (Result<String, Error>) -> Void) { }
23
func noError(_ completion: (Result<String, Never>) -> Void) { }
34
func test(_ str: String) -> Bool { return false }
45

@@ -279,7 +280,7 @@ simple { res in
279280
// NESTEDRET-NEXT: let str = try await simple()
280281
// NESTEDRET-NEXT: print("before")
281282
// NESTEDRET-NEXT: if test(str) {
282-
// NESTEDRET-NEXT: return
283+
// NESTEDRET-NEXT: <#return#>
283284
// NESTEDRET-NEXT: }
284285
// NESTEDRET-NEXT: print("result \(str)")
285286
// NESTEDRET-NEXT: print("after")
@@ -303,7 +304,7 @@ simple { res in
303304
// NESTEDBREAK-NEXT: let str = try await simple()
304305
// NESTEDBREAK-NEXT: print("before")
305306
// NESTEDBREAK-NEXT: if test(str) {
306-
// NESTEDBREAK-NEXT: break
307+
// NESTEDBREAK-NEXT: <#break#>
307308
// NESTEDBREAK-NEXT: }
308309
// NESTEDBREAK-NEXT: print("result \(str)")
309310
// NESTEDBREAK-NEXT: print("after")
@@ -322,3 +323,92 @@ voidAndErrorResult { res in
322323
}
323324
// VOID-AND-ERROR-RESULT-CALL: {{^}}try await voidAndErrorResult()
324325
// VOID-AND-ERROR-RESULT-CALL: {{^}}print(<#res#>)
326+
327+
// Make sure we ignore an unrelated switch.
328+
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=IGNORE-UNRELATED %s
329+
simple { res in
330+
print("before")
331+
switch Bool.random() {
332+
case true:
333+
break
334+
case false:
335+
break
336+
}
337+
print("after")
338+
}
339+
// IGNORE-UNRELATED: let res = try await simple()
340+
// IGNORE-UNRELATED-NEXT: print("before")
341+
// IGNORE-UNRELATED-NEXT: switch Bool.random() {
342+
// IGNORE-UNRELATED-NEXT: case true:
343+
// IGNORE-UNRELATED-NEXT: {{^}} break{{$}}
344+
// IGNORE-UNRELATED-NEXT: case false:
345+
// IGNORE-UNRELATED-NEXT: {{^}} break{{$}}
346+
// IGNORE-UNRELATED-NEXT: }
347+
// IGNORE-UNRELATED-NEXT: print("after")
348+
349+
// RUN: %refactor -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=BREAK-RET-PLACEHOLDER %s
350+
simpleWithArg({ return 0 }()) { res in
351+
switch res {
352+
case .success:
353+
if .random() { break }
354+
x: if .random() { break x }
355+
case .failure:
356+
break
357+
}
358+
359+
func foo<T>(_ x: T) {
360+
if .random() { return }
361+
}
362+
foo(res)
363+
364+
let fn = {
365+
if .random() { return }
366+
return
367+
}
368+
fn()
369+
370+
_ = { return }()
371+
372+
switch Bool.random() {
373+
case true:
374+
break
375+
case false:
376+
if .random() { break }
377+
y: if .random() { break y }
378+
return
379+
}
380+
381+
x: if .random() {
382+
break x
383+
}
384+
if .random() { return }
385+
}
386+
387+
// Make sure we replace lifted break/returns with placeholders, but keep nested
388+
// break/returns in e.g closures or labelled control flow in place.
389+
390+
// BREAK-RET-PLACEHOLDER: let res = try await simpleWithArg({ return 0 }())
391+
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { <#break#> }
392+
// BREAK-RET-PLACEHOLDER-NEXT: x: if .random() { break x }
393+
// BREAK-RET-PLACEHOLDER-NEXT: func foo<T>(_ x: T) {
394+
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { return }
395+
// BREAK-RET-PLACEHOLDER-NEXT: }
396+
// BREAK-RET-PLACEHOLDER-NEXT: foo(<#res#>)
397+
// BREAK-RET-PLACEHOLDER-NEXT: let fn = {
398+
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { return }
399+
// BREAK-RET-PLACEHOLDER-NEXT: {{^}} return{{$}}
400+
// BREAK-RET-PLACEHOLDER-NEXT: }
401+
// BREAK-RET-PLACEHOLDER-NEXT: fn()
402+
// BREAK-RET-PLACEHOLDER-NEXT: _ = { return }()
403+
// BREAK-RET-PLACEHOLDER-NEXT: switch Bool.random() {
404+
// BREAK-RET-PLACEHOLDER-NEXT: case true:
405+
// BREAK-RET-PLACEHOLDER-NEXT: {{^}} break{{$}}
406+
// BREAK-RET-PLACEHOLDER-NEXT: case false:
407+
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { break }
408+
// BREAK-RET-PLACEHOLDER-NEXT: y: if .random() { break y }
409+
// BREAK-RET-PLACEHOLDER-NEXT: <#return#>
410+
// BREAK-RET-PLACEHOLDER-NEXT: }
411+
// BREAK-RET-PLACEHOLDER-NEXT: x: if .random() {
412+
// BREAK-RET-PLACEHOLDER-NEXT: {{^}} break x{{$}}
413+
// BREAK-RET-PLACEHOLDER-NEXT: }
414+
// BREAK-RET-PLACEHOLDER-NEXT: if .random() { <#return#> }

0 commit comments

Comments
 (0)