Skip to content

Commit 0b34078

Browse files
authored
Merge pull request #12268 from rintaro/refactoring-trailingclosure
[refactoring] Implement "Convert to Trailing Closure" refactoring action
2 parents a70e857 + a57199c commit 0b34078

File tree

14 files changed

+376
-51
lines changed

14 files changed

+376
-51
lines changed

include/swift/AST/Expr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2712,6 +2712,12 @@ class ImplicitConversionExpr : public Expr {
27122712
Expr *getSubExpr() const { return SubExpr; }
27132713
void setSubExpr(Expr *e) { SubExpr = e; }
27142714

2715+
Expr *getSyntacticSubExpr() const {
2716+
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(SubExpr))
2717+
return ICE->getSyntacticSubExpr();
2718+
return SubExpr;
2719+
}
2720+
27152721
static bool classof(const Expr *E) {
27162722
return E->getKind() >= ExprKind::First_ImplicitConversionExpr &&
27172723
E->getKind() <= ExprKind::Last_ImplicitConversionExpr;

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ CURSOR_REFACTORING(CollapseNestedIfExpr, "Collapse Nested If Expression", collap
4444

4545
CURSOR_REFACTORING(ConvertToDoCatch, "Convert To Do/Catch", convert.do.catch)
4646

47+
CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosure)
48+
4749
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
4850

4951
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

include/swift/IDE/Utils.h

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,12 @@ enum class CursorInfoKind {
157157

158158
struct ResolvedCursorInfo {
159159
CursorInfoKind Kind = CursorInfoKind::Invalid;
160+
SourceFile *SF;
161+
SourceLoc Loc;
160162
ValueDecl *ValueD = nullptr;
161163
TypeDecl *CtorTyRef = nullptr;
162164
ExtensionDecl *ExtTyRef = nullptr;
163-
SourceFile *SF = nullptr;
164165
ModuleEntity Mod;
165-
SourceLoc Loc;
166166
bool IsRef = true;
167167
bool IsKeywordArgument = false;
168168
Type Ty;
@@ -172,39 +172,36 @@ struct ResolvedCursorInfo {
172172
Expr *TrailingExpr = nullptr;
173173

174174
ResolvedCursorInfo() = default;
175-
ResolvedCursorInfo(ValueDecl *ValueD,
176-
TypeDecl *CtorTyRef,
177-
ExtensionDecl *ExtTyRef,
178-
SourceFile *SF,
179-
SourceLoc Loc,
180-
bool IsRef,
181-
Type Ty,
182-
Type ContainerType) :
183-
Kind(CursorInfoKind::ValueRef),
184-
ValueD(ValueD),
185-
CtorTyRef(CtorTyRef),
186-
ExtTyRef(ExtTyRef),
187-
SF(SF),
188-
Loc(Loc),
189-
IsRef(IsRef),
190-
Ty(Ty),
191-
DC(ValueD->getDeclContext()),
192-
ContainerType(ContainerType) {}
193-
ResolvedCursorInfo(ModuleEntity Mod,
194-
SourceFile *SF,
195-
SourceLoc Loc) :
196-
Kind(CursorInfoKind::ModuleRef),
197-
SF(SF),
198-
Mod(Mod),
199-
Loc(Loc) {}
200-
ResolvedCursorInfo(Stmt *TrailingStmt, SourceFile *SF) :
201-
Kind(CursorInfoKind::StmtStart),
202-
SF(SF),
203-
TrailingStmt(TrailingStmt) {}
204-
ResolvedCursorInfo(Expr* TrailingExpr, SourceFile *SF) :
205-
Kind(CursorInfoKind::ExprStart),
206-
SF(SF),
207-
TrailingExpr(TrailingExpr) {}
175+
ResolvedCursorInfo(SourceFile *SF) : SF(SF) {}
176+
177+
void setValueRef(ValueDecl *ValueD,
178+
TypeDecl *CtorTyRef,
179+
ExtensionDecl *ExtTyRef,
180+
bool IsRef,
181+
Type Ty,
182+
Type ContainerType) {
183+
Kind = CursorInfoKind::ValueRef;
184+
this->ValueD = ValueD;
185+
this->CtorTyRef = CtorTyRef;
186+
this->ExtTyRef = ExtTyRef;
187+
this->IsRef = IsRef;
188+
this->Ty = Ty;
189+
this->DC = ValueD->getDeclContext();
190+
this->ContainerType = ContainerType;
191+
}
192+
void setModuleRef(ModuleEntity Mod) {
193+
Kind = CursorInfoKind::ModuleRef;
194+
this->Mod = Mod;
195+
}
196+
void setTrailingStmt(Stmt *TrailingStmt) {
197+
Kind = CursorInfoKind::StmtStart;
198+
this->TrailingStmt = TrailingStmt;
199+
}
200+
void setTrailingExpr(Expr* TrailingExpr) {
201+
Kind = CursorInfoKind::ExprStart;
202+
this->TrailingExpr = TrailingExpr;
203+
}
204+
208205
bool isValid() const { return !isInvalid(); }
209206
bool isInvalid() const { return Kind == CursorInfoKind::Invalid; }
210207
};
@@ -217,7 +214,8 @@ class CursorInfoResolver : public SourceEntityWalker {
217214
llvm::SmallVector<Expr*, 4> TrailingExprStack;
218215

219216
public:
220-
explicit CursorInfoResolver(SourceFile &SrcFile) : SrcFile(SrcFile) { }
217+
explicit CursorInfoResolver(SourceFile &SrcFile) :
218+
SrcFile(SrcFile), CursorInfo(&SrcFile) {}
221219
ResolvedCursorInfo resolve(SourceLoc Loc);
222220
SourceManager &getSourceMgr() const;
223221
private:

lib/IDE/Refactoring.cpp

Lines changed: 119 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,28 @@ class ContextFinder : public SourceEntityWalker {
4040
SourceFile &SF;
4141
ASTContext &Ctx;
4242
SourceManager &SM;
43-
ASTNode Target;
43+
SourceRange Target;
4444
llvm::function_ref<bool(ASTNode)> IsContext;
4545
SmallVector<ASTNode, 4> AllContexts;
4646
bool contains(ASTNode Enclosing) {
47-
auto Result = SM.rangeContains(Enclosing.getSourceRange(),
48-
Target.getSourceRange());
47+
auto Result = SM.rangeContains(Enclosing.getSourceRange(), Target);
4948
if (Result && IsContext(Enclosing))
5049
AllContexts.push_back(Enclosing);
5150
return Result;
5251
}
5352
public:
54-
ContextFinder(SourceFile &SF, ASTNode Target,
53+
ContextFinder(SourceFile &SF, ASTNode TargetNode,
5554
llvm::function_ref<bool(ASTNode)> IsContext =
56-
[](ASTNode N) { return true; }) : SF(SF),
57-
Ctx(SF.getASTContext()), SM(Ctx.SourceMgr), Target(Target),
58-
IsContext(IsContext) {}
55+
[](ASTNode N) { return true; }) :
56+
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
57+
Target(TargetNode.getSourceRange()), IsContext(IsContext) {}
58+
ContextFinder(SourceFile &SF, SourceLoc TargetLoc,
59+
llvm::function_ref<bool(ASTNode)> IsContext =
60+
[](ASTNode N) { return true; }) :
61+
SF(SF), Ctx(SF.getASTContext()), SM(Ctx.SourceMgr),
62+
Target(TargetLoc), IsContext(IsContext) {
63+
assert(TargetLoc.isValid() && "Invalid loc to find");
64+
}
5965
bool walkToDeclPre(Decl *D, CharSourceRange Range) override { return contains(D); }
6066
bool walkToStmtPre(Stmt *S) override { return contains(S); }
6167
bool walkToExprPre(Expr *E) override { return contains(E); }
@@ -1722,7 +1728,8 @@ class FillProtocolStubContext {
17221728

17231729
FillProtocolStubContext FillProtocolStubContext::
17241730
getContextFromCursorInfo(ResolvedCursorInfo CursorInfo) {
1725-
assert(CursorInfo.isValid());
1731+
if(!CursorInfo.isValid())
1732+
return FillProtocolStubContext();
17261733
if (!CursorInfo.IsRef) {
17271734
// If the type name is on the declared nominal, e.g. "class A {}"
17281735
if (auto ND = dyn_cast<NominalTypeDecl>(CursorInfo.ValueD)) {
@@ -2127,6 +2134,110 @@ bool RefactoringActionSimplifyNumberLiteral::performChange() {
21272134
return true;
21282135
}
21292136

2137+
static CallExpr *findTrailingClosureTarget(SourceManager &SM,
2138+
ResolvedCursorInfo CursorInfo) {
2139+
if (CursorInfo.Kind == CursorInfoKind::StmtStart)
2140+
// StmtStart postion can't be a part of CallExpr.
2141+
return nullptr;
2142+
2143+
// Find inner most CallExpr
2144+
ContextFinder
2145+
Finder(*CursorInfo.SF, CursorInfo.Loc,
2146+
[](ASTNode N) {
2147+
return N.isStmt(StmtKind::Brace) || N.isExpr(ExprKind::Call);
2148+
});
2149+
Finder.resolve();
2150+
if (Finder.getContexts().empty()
2151+
|| !Finder.getContexts().back().is<Expr*>())
2152+
return nullptr;
2153+
CallExpr *CE = cast<CallExpr>(Finder.getContexts().back().get<Expr*>());
2154+
2155+
// The last arugment is a closure?
2156+
Expr *Args = CE->getArg();
2157+
if (!Args)
2158+
return nullptr;
2159+
Expr *LastArg;
2160+
if (auto *TSE = dyn_cast<TupleShuffleExpr>(Args))
2161+
Args = TSE->getSubExpr();
2162+
if (auto *PE = dyn_cast<ParenExpr>(Args)) {
2163+
LastArg = PE->getSubExpr();
2164+
} else {
2165+
auto *TE = cast<TupleExpr>(Args);
2166+
if (TE->getNumElements() == 0)
2167+
return nullptr;
2168+
LastArg = TE->getElements().back();
2169+
}
2170+
2171+
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(LastArg))
2172+
LastArg = ICE->getSyntacticSubExpr();
2173+
2174+
if (isa<ClosureExpr>(LastArg) || isa<CaptureListExpr>(LastArg))
2175+
return CE;
2176+
return nullptr;
2177+
}
2178+
2179+
bool RefactoringActionTrailingClosure::
2180+
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
2181+
SourceManager &SM = CursorInfo.SF->getASTContext().SourceMgr;
2182+
return findTrailingClosureTarget(SM, CursorInfo);
2183+
}
2184+
2185+
bool RefactoringActionTrailingClosure::performChange() {
2186+
auto *CE = findTrailingClosureTarget(SM, CursorInfo);
2187+
if (!CE)
2188+
return true;
2189+
Expr *Args = CE->getArg();
2190+
if (auto *TSE = dyn_cast<TupleShuffleExpr>(Args))
2191+
Args = TSE;
2192+
2193+
Expr *ClosureArg = nullptr;
2194+
Expr *PrevArg = nullptr;
2195+
SourceLoc LPLoc, RPLoc;
2196+
2197+
if (auto *PE = dyn_cast<ParenExpr>(Args)) {
2198+
ClosureArg = PE->getSubExpr();
2199+
LPLoc = PE->getLParenLoc();
2200+
RPLoc = PE->getRParenLoc();
2201+
} else {
2202+
auto *TE = cast<TupleExpr>(Args);
2203+
auto NumArgs = TE->getNumElements();
2204+
if (NumArgs == 0)
2205+
return true;
2206+
LPLoc = TE->getLParenLoc();
2207+
RPLoc = TE->getRParenLoc();
2208+
ClosureArg = TE->getElement(NumArgs - 1);
2209+
if (NumArgs > 1)
2210+
PrevArg = TE->getElement(NumArgs - 2);
2211+
}
2212+
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(ClosureArg))
2213+
ClosureArg = ICE->getSyntacticSubExpr();
2214+
2215+
if (LPLoc.isInvalid() || RPLoc.isInvalid())
2216+
return true;
2217+
2218+
// Replace:
2219+
// * Open paren with ' ' if the closure is sole argument.
2220+
// * Comma with ') ' otherwise.
2221+
if (PrevArg) {
2222+
CharSourceRange PreRange(
2223+
SM,
2224+
Lexer::getLocForEndOfToken(SM, PrevArg->getEndLoc()),
2225+
ClosureArg->getStartLoc());
2226+
EditConsumer.accept(SM, PreRange, ") ");
2227+
} else {
2228+
CharSourceRange PreRange(
2229+
SM, LPLoc, ClosureArg->getStartLoc());
2230+
EditConsumer.accept(SM, PreRange, " ");
2231+
}
2232+
// Remove original closing paren.
2233+
CharSourceRange PostRange(
2234+
SM,
2235+
Lexer::getLocForEndOfToken(SM, ClosureArg->getEndLoc()),
2236+
Lexer::getLocForEndOfToken(SM, RPLoc));
2237+
EditConsumer.remove(SM, PostRange);
2238+
return false;
2239+
}
2240+
21302241
static bool rangeStartMayNeedRename(ResolvedRangeInfo Info) {
21312242
switch(Info.Kind) {
21322243
case RangeKind::SingleExpression: {

lib/IDE/SwiftSourceDocInfo.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ bool CursorInfoResolver::tryResolve(ValueDecl *D, TypeDecl *CtorTyRef,
8080
return false;
8181

8282
if (Loc == LocToResolve) {
83-
CursorInfo = { D, CtorTyRef, ExtTyRef, &SrcFile, Loc, IsRef, Ty, ContainerType };
83+
CursorInfo.setValueRef(D, CtorTyRef, ExtTyRef, IsRef, Ty, ContainerType);
8484
return true;
8585
}
8686
return false;
8787
}
8888

8989
bool CursorInfoResolver::tryResolve(ModuleEntity Mod, SourceLoc Loc) {
9090
if (Loc == LocToResolve) {
91-
CursorInfo = { Mod, &SrcFile, Loc };
91+
CursorInfo.setModuleRef(Mod);
9292
return true;
9393
}
9494
return false;
@@ -97,13 +97,13 @@ bool CursorInfoResolver::tryResolve(ModuleEntity Mod, SourceLoc Loc) {
9797
bool CursorInfoResolver::tryResolve(Stmt *St) {
9898
if (auto *LST = dyn_cast<LabeledStmt>(St)) {
9999
if (LST->getStartLoc() == LocToResolve) {
100-
CursorInfo = { St, &SrcFile };
100+
CursorInfo.setTrailingStmt(St);
101101
return true;
102102
}
103103
}
104104
if (auto *CS = dyn_cast<CaseStmt>(St)) {
105105
if (CS->getStartLoc() == LocToResolve) {
106-
CursorInfo = { St, &SrcFile };
106+
CursorInfo.setTrailingStmt(St);
107107
return true;
108108
}
109109
}
@@ -120,7 +120,7 @@ bool CursorInfoResolver::visitSubscriptReference(ValueDecl *D, CharSourceRange R
120120
ResolvedCursorInfo CursorInfoResolver::resolve(SourceLoc Loc) {
121121
assert(Loc.isValid());
122122
LocToResolve = Loc;
123-
CursorInfo = ResolvedCursorInfo();
123+
CursorInfo.Loc = Loc;
124124
walk(SrcFile);
125125
return CursorInfo;
126126
}
@@ -211,7 +211,7 @@ bool CursorInfoResolver::walkToExprPost(Expr *E) {
211211
return false;
212212
if (!TrailingExprStack.empty() && TrailingExprStack.back() == E) {
213213
// We return the outtermost expression in the token info.
214-
CursorInfo = { TrailingExprStack.front(), &SrcFile };
214+
CursorInfo.setTrailingExpr(TrailingExprStack.front());
215215
return false;
216216
}
217217
return true;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
struct Foo {
2+
static func foo(a: () -> Int) {}
3+
func qux(x: Int, y: () -> Int ) {}
4+
}
5+
6+
func testTrailingClosure() -> String {
7+
Foo.foo(a: { 1 })
8+
Foo.bar(a: { print(3); return 1 })
9+
Foo().qux(x: 1, y: { 1 })
10+
let _ = Foo().quux(x: 1, y: { 1 })
11+
12+
[1,2,3]
13+
.filter({ $0 % 2 == 0 })
14+
.map({ $0 + 1 })
15+
}
16+
17+
// RUN: %refactor -source-filename %s -pos=7:3 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
18+
// RUN: %refactor -source-filename %s -pos=7:6 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
19+
// RUN: %refactor -source-filename %s -pos=7:7 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
20+
// RUN: %refactor -source-filename %s -pos=7:10 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
21+
// RUN: %refactor -source-filename %s -pos=7:11 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
22+
// RUN: %refactor -source-filename %s -pos=7:12 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
23+
// RUN: %refactor -source-filename %s -pos=7:14 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
24+
// RUN: %refactor -source-filename %s -pos=7:16 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
25+
// RUN: %refactor -source-filename %s -pos=7:18 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
26+
// RUN: %refactor -source-filename %s -pos=7:19 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
27+
28+
// RUN: %refactor -source-filename %s -pos=8:3 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
29+
// RUN: %refactor -source-filename %s -pos=8:11 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
30+
31+
// RUN: %refactor -source-filename %s -pos=9:3 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
32+
// RUN: %refactor -source-filename %s -pos=9:8 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
33+
34+
// RUN: %refactor -source-filename %s -pos=10:3 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
35+
// RUN: %refactor -source-filename %s -pos=10:9 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
36+
// RUN: %refactor -source-filename %s -pos=10:17 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
37+
38+
// RUN: %refactor -source-filename %s -pos=12:4 | %FileCheck %s -check-prefix=CHECK-NO-TRAILING-CLOSURE
39+
// RUN: %refactor -source-filename %s -pos=13:5 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
40+
// RUN: %refactor -source-filename %s -pos=14:5 | %FileCheck %s -check-prefix=CHECK-TRAILING-CLOSURE
41+
42+
// CHECK-TRAILING-CLOSURE: Convert To Trailing Closure
43+
44+
// CHECK-NO-TRAILING-CLOSURE: Action begins
45+
// CHECK-NO-TRAILING-CLOSURE-NOT: Convert To Trailing Closure
46+
// CHECK-NO-TRAILING-CLOSURE: Action ends
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
struct Foo {
2+
static func foo(a: () -> Int) {}
3+
func qux(x: Int, y: () -> Int ) {}
4+
}
5+
6+
func testTrailingClosure() -> String {
7+
Foo.foo(a: { 1 })
8+
Foo.bar(a: { print(3); return 1 })
9+
Foo().qux(x: 1, y: { 1 })
10+
let _ = Foo().quux(x: 1) { 1 }
11+
12+
[1,2,3]
13+
.filter({ $0 % 2 == 0 })
14+
.map({ $0 + 1 })
15+
}
16+
17+
18+
19+
20+
21+

0 commit comments

Comments
 (0)