Skip to content

Commit 8cd677e

Browse files
Kacper20nkcsgexi
authored andcommitted
[Refactoring] SR-6051 Expansion of switch statement missing cases (#12281)
1 parent e53ff51 commit 8cd677e

File tree

14 files changed

+263
-63
lines changed

14 files changed

+263
-63
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ CURSOR_REFACTORING(FillProtocolStub, "Add Missing Protocol Requirements", fillst
3434

3535
CURSOR_REFACTORING(ExpandDefault, "Expand Default", expand.default)
3636

37+
CURSOR_REFACTORING(ExpandSwitchCases, "Expand Switch Cases", expand.switch.cases)
38+
3739
CURSOR_REFACTORING(LocalizeString, "Localize String", localize.string)
3840

3941
CURSOR_REFACTORING(SimplifyNumberLiteral, "Simplify Long Number Literal", simplify.long.number.literal)

include/swift/IDE/Utils.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ struct ResolvedCursorInfo {
160160
ValueDecl *ValueD = nullptr;
161161
TypeDecl *CtorTyRef = nullptr;
162162
ExtensionDecl *ExtTyRef = nullptr;
163+
SourceFile *SF = nullptr;
163164
ModuleEntity Mod;
164165
SourceLoc Loc;
165166
bool IsRef = true;
@@ -174,6 +175,7 @@ struct ResolvedCursorInfo {
174175
ResolvedCursorInfo(ValueDecl *ValueD,
175176
TypeDecl *CtorTyRef,
176177
ExtensionDecl *ExtTyRef,
178+
SourceFile *SF,
177179
SourceLoc Loc,
178180
bool IsRef,
179181
Type Ty,
@@ -182,21 +184,26 @@ struct ResolvedCursorInfo {
182184
ValueD(ValueD),
183185
CtorTyRef(CtorTyRef),
184186
ExtTyRef(ExtTyRef),
187+
SF(SF),
185188
Loc(Loc),
186189
IsRef(IsRef),
187190
Ty(Ty),
188191
DC(ValueD->getDeclContext()),
189192
ContainerType(ContainerType) {}
190193
ResolvedCursorInfo(ModuleEntity Mod,
194+
SourceFile *SF,
191195
SourceLoc Loc) :
192196
Kind(CursorInfoKind::ModuleRef),
197+
SF(SF),
193198
Mod(Mod),
194-
Loc(Loc) { }
195-
ResolvedCursorInfo(Stmt *TrailingStmt) :
199+
Loc(Loc) {}
200+
ResolvedCursorInfo(Stmt *TrailingStmt, SourceFile *SF) :
196201
Kind(CursorInfoKind::StmtStart),
202+
SF(SF),
197203
TrailingStmt(TrailingStmt) {}
198-
ResolvedCursorInfo(Expr* TrailingExpr) :
204+
ResolvedCursorInfo(Expr* TrailingExpr, SourceFile *SF) :
199205
Kind(CursorInfoKind::ExprStart),
206+
SF(SF),
200207
TrailingExpr(TrailingExpr) {}
201208
bool isValid() const { return !isInvalid(); }
202209
bool isInvalid() const { return Kind == CursorInfoKind::Invalid; }

lib/IDE/Refactoring.cpp

Lines changed: 115 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,57 +1801,25 @@ collectAvailableRefactoringsAtCursor(SourceFile *SF, unsigned Line,
18011801
return collectAvailableRefactorings(SF, Tok, Scratch, /*Exclude rename*/false);
18021802
}
18031803

1804-
bool RefactoringActionExpandDefault::
1805-
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
1806-
auto Exit = [&](bool Applicable) {
1807-
if (!Applicable)
1808-
Diag.diagnose(SourceLoc(), diag::invalid_default_location);
1809-
return Applicable;
1810-
};
1811-
if (CursorInfo.Kind != CursorInfoKind::StmtStart)
1812-
return Exit(false);
1813-
if (auto *CS = dyn_cast<CaseStmt>(CursorInfo.TrailingStmt)) {
1814-
return Exit(CS->isDefault());
1815-
}
1816-
return Exit(false);
1817-
}
1818-
1819-
bool RefactoringActionExpandDefault::performChange() {
1820-
// Try to find the switch statement enclosing the default statement.
1821-
auto *CS = static_cast<CaseStmt*>(CursorInfo.TrailingStmt);
1822-
auto IsSwitch = [](ASTNode Node) {
1823-
return Node.is<Stmt*>() &&
1824-
Node.get<Stmt*>()->getKind() == StmtKind::Switch;
1825-
};
1826-
ContextFinder Finder(*TheFile, CS, IsSwitch);
1827-
Finder.resolve();
1828-
1829-
// If failed to find the switch statement, issue error.
1830-
if (Finder.getContexts().empty()) {
1831-
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
1832-
return true;
1833-
}
1834-
auto *SwitchS = static_cast<SwitchStmt*>(Finder.getContexts().back().
1835-
get<Stmt*>());
1836-
1837-
// To find the subject enum decl for this switch statement; if failing,
1838-
// issue errors.
1839-
EnumDecl *SubjectED = nullptr;
1804+
static EnumDecl* getEnumDeclFromSwitchStmt(SwitchStmt *SwitchS) {
18401805
if (auto SubjectTy = SwitchS->getSubjectExpr()->getType()) {
1841-
SubjectED = SubjectTy->getAnyNominal()->getAsEnumOrEnumExtensionContext();
1842-
}
1843-
if (!SubjectED) {
1844-
DiagEngine.diagnose(CS->getStartLoc(), diag::no_subject_enum);
1845-
return true;
1806+
return SubjectTy->getAnyNominal()->getAsEnumOrEnumExtensionContext();
18461807
}
1808+
return nullptr;
1809+
}
18471810

1811+
static bool performCasesExpansionInSwitchStmt(SwitchStmt *SwitchS,
1812+
DiagnosticEngine &DiagEngine,
1813+
SourceLoc ExpandedStmtLoc,
1814+
EditorConsumerInsertStream &OS
1815+
) {
18481816
// Assume enum elements are not handled in the switch statement.
1817+
auto EnumDecl = getEnumDeclFromSwitchStmt(SwitchS);
1818+
assert(EnumDecl);
18491819
llvm::DenseSet<EnumElementDecl*> UnhandledElements;
1850-
SubjectED->getAllElements(UnhandledElements);
1851-
bool FoundDefault = false;
1820+
EnumDecl->getAllElements(UnhandledElements);
18521821
for (auto Current : SwitchS->getCases()) {
1853-
if (Current == CS) {
1854-
FoundDefault = true;
1822+
if (Current->isDefault()) {
18551823
continue;
18561824
}
18571825
// For each handled enum element, remove it from the bucket.
@@ -1862,27 +1830,119 @@ bool RefactoringActionExpandDefault::performChange() {
18621830
}
18631831
}
18641832

1865-
// If we've not seen the default statement inside the switch statement, issue
1866-
// error.
1867-
if (!FoundDefault) {
1868-
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
1869-
return true;
1870-
}
1871-
18721833
// If all enum elements are handled in the switch statement, issue error.
18731834
if (UnhandledElements.empty()) {
1874-
DiagEngine.diagnose(CS->getStartLoc(), diag::no_remaining_cases);
1835+
DiagEngine.diagnose(ExpandedStmtLoc, diag::no_remaining_cases);
18751836
return true;
18761837
}
18771838

1878-
// Good to go, change the code!
1839+
printEnumElementsAsCases(UnhandledElements, OS);
1840+
return false;
1841+
}
1842+
1843+
// Finds SwitchStmt that contains given CaseStmt.
1844+
static SwitchStmt* findEnclosingSwitchStmt(CaseStmt *CS,
1845+
SourceFile *SF,
1846+
DiagnosticEngine &DiagEngine) {
1847+
auto IsSwitch = [](ASTNode Node) {
1848+
return Node.is<Stmt*>() &&
1849+
Node.get<Stmt*>()->getKind() == StmtKind::Switch;
1850+
};
1851+
ContextFinder Finder(*SF, CS, IsSwitch);
1852+
Finder.resolve();
1853+
1854+
// If failed to find the switch statement, issue error.
1855+
if (Finder.getContexts().empty()) {
1856+
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
1857+
return nullptr;
1858+
}
1859+
auto *SwitchS = static_cast<SwitchStmt*>(Finder.getContexts().back().
1860+
get<Stmt*>());
1861+
// Make sure that CaseStmt is included in switch that was found.
1862+
auto Cases = SwitchS->getCases();
1863+
auto Default = std::find(Cases.begin(), Cases.end(), CS);
1864+
if (Default == Cases.end()) {
1865+
DiagEngine.diagnose(CS->getStartLoc(), diag::no_parent_switch);
1866+
return nullptr;
1867+
}
1868+
return SwitchS;
1869+
}
1870+
1871+
bool RefactoringActionExpandDefault::
1872+
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &Diag) {
1873+
auto Exit = [&](bool Applicable) {
1874+
if (!Applicable)
1875+
Diag.diagnose(SourceLoc(), diag::invalid_default_location);
1876+
return Applicable;
1877+
};
1878+
if (CursorInfo.Kind != CursorInfoKind::StmtStart)
1879+
return Exit(false);
1880+
if (auto *CS = dyn_cast<CaseStmt>(CursorInfo.TrailingStmt)) {
1881+
auto EnclosingSwitchStmt = findEnclosingSwitchStmt(CS,
1882+
CursorInfo.SF,
1883+
Diag);
1884+
if (!EnclosingSwitchStmt)
1885+
return false;
1886+
auto EnumD = getEnumDeclFromSwitchStmt(EnclosingSwitchStmt);
1887+
auto IsApplicable = CS->isDefault() && EnumD != nullptr;
1888+
return IsApplicable;
1889+
}
1890+
return Exit(false);
1891+
}
1892+
1893+
bool RefactoringActionExpandDefault::performChange() {
1894+
// If we've not seen the default statement inside the switch statement, issue
1895+
// error.
1896+
auto *CS = static_cast<CaseStmt*>(CursorInfo.TrailingStmt);
1897+
auto *SwitchS = findEnclosingSwitchStmt(CS, TheFile, DiagEngine);
1898+
assert(SwitchS);
18791899
EditorConsumerInsertStream OS(EditConsumer, SM,
18801900
Lexer::getCharSourceRangeFromSourceRange(SM,
18811901
CS->getLabelItemsRange()));
1882-
printEnumElementsAsCases(UnhandledElements, OS);
1902+
return performCasesExpansionInSwitchStmt(SwitchS,
1903+
DiagEngine,
1904+
CS->getStartLoc(),
1905+
OS);
1906+
}
1907+
1908+
bool RefactoringActionExpandSwitchCases::
1909+
isApplicable(ResolvedCursorInfo CursorInfo, DiagnosticEngine &DiagEngine) {
1910+
if (!CursorInfo.TrailingStmt)
1911+
return false;
1912+
if (auto *Switch = dyn_cast<SwitchStmt>(CursorInfo.TrailingStmt)) {
1913+
return getEnumDeclFromSwitchStmt(Switch);
1914+
}
18831915
return false;
18841916
}
18851917

1918+
bool RefactoringActionExpandSwitchCases::performChange() {
1919+
auto *SwitchS = dyn_cast<SwitchStmt>(CursorInfo.TrailingStmt);
1920+
assert(SwitchS);
1921+
1922+
auto InsertRange = CharSourceRange();
1923+
auto Cases = SwitchS->getCases();
1924+
auto Default = std::find_if(Cases.begin(), Cases.end(), [](CaseStmt *Stmt) {
1925+
return Stmt->isDefault();
1926+
});
1927+
if (Default != Cases.end()) {
1928+
auto DefaultRange = (*Default)->getLabelItemsRange();
1929+
InsertRange = Lexer::getCharSourceRangeFromSourceRange(SM, DefaultRange);
1930+
} else {
1931+
auto RBraceLoc = SwitchS->getRBraceLoc();
1932+
InsertRange = CharSourceRange(SM, RBraceLoc, RBraceLoc);
1933+
}
1934+
EditorConsumerInsertStream OS(EditConsumer, SM, InsertRange);
1935+
if (SM.getLineNumber(SwitchS->getLBraceLoc()) ==
1936+
SM.getLineNumber(SwitchS->getRBraceLoc())) {
1937+
OS << "\n";
1938+
}
1939+
auto Result = performCasesExpansionInSwitchStmt(SwitchS,
1940+
DiagEngine,
1941+
SwitchS->getStartLoc(),
1942+
OS);
1943+
return Result;
1944+
}
1945+
18861946
static Expr *findLocalizeTarget(ResolvedCursorInfo CursorInfo) {
18871947
if (CursorInfo.Kind != CursorInfoKind::ExprStart)
18881948
return nullptr;

lib/IDE/SwiftSourceDocInfo.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,15 @@ bool CursorInfoResolver::tryResolve(ValueDecl *D, TypeDecl *CtorTyRef,
8080
return false;
8181

8282
if (Loc == LocToResolve) {
83-
CursorInfo = { D, CtorTyRef, ExtTyRef, Loc, IsRef, Ty, ContainerType };
83+
CursorInfo = { D, CtorTyRef, ExtTyRef, &SrcFile, Loc, IsRef, Ty, ContainerType };
8484
return true;
8585
}
8686
return false;
8787
}
8888

8989
bool CursorInfoResolver::tryResolve(ModuleEntity Mod, SourceLoc Loc) {
9090
if (Loc == LocToResolve) {
91-
CursorInfo = { Mod, Loc };
91+
CursorInfo = { Mod, &SrcFile, Loc };
9292
return true;
9393
}
9494
return false;
@@ -97,13 +97,13 @@ bool CursorInfoResolver::tryResolve(ModuleEntity Mod, SourceLoc Loc) {
9797
bool CursorInfoResolver::tryResolve(Stmt *St) {
9898
if (auto *LST = dyn_cast<LabeledStmt>(St)) {
9999
if (LST->getStartLoc() == LocToResolve) {
100-
CursorInfo = { St };
100+
CursorInfo = { St, &SrcFile };
101101
return true;
102102
}
103103
}
104104
if (auto *CS = dyn_cast<CaseStmt>(St)) {
105105
if (CS->getStartLoc() == LocToResolve) {
106-
CursorInfo = { St };
106+
CursorInfo = { St, &SrcFile };
107107
return true;
108108
}
109109
}
@@ -211,7 +211,7 @@ bool CursorInfoResolver::walkToExprPost(Expr *E) {
211211
return false;
212212
if (!TrailingExprStack.empty() && TrailingExprStack.back() == E) {
213213
// We return the outtermost expression in the token info.
214-
CursorInfo = { TrailingExprStack.front() };
214+
CursorInfo = { TrailingExprStack.front(), &SrcFile };
215215
return false;
216216
}
217217
return true;
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
enum E {
2+
case e1
3+
case e2
4+
case e3
5+
case e4
6+
}
7+
8+
func foo(e: E) -> Int {
9+
switch e {
10+
case .e1: <#code#>
11+
case .e2: <#code#>
12+
case .e3: <#code#>
13+
case .e4: <#code#>
14+
}
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
enum E {
2+
case e1
3+
case e2
4+
case e3
5+
case e4
6+
}
7+
8+
func foo(e: E) -> Int {
9+
switch e {
10+
case .e1: <#code#>
11+
case .e2: <#code#>
12+
case .e3: <#code#>
13+
case .e4: <#code#>
14+
}
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
enum E {
2+
case e1
3+
case e2
4+
case e3
5+
case e4
6+
}
7+
8+
func foo(e: E) -> Int {
9+
switch e {
10+
case .e1: return 5
11+
case .e2: <#code#>
12+
case .e3: <#code#>
13+
case .e4: <#code#>
14+
}
15+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
enum E {
2+
case e1
3+
case e2
4+
case e3
5+
case e4
6+
}
7+
8+
func foo(e: E) -> Int {
9+
switch e {
10+
case .e1: <#code#>
11+
case .e2: <#code#>
12+
case .e3: <#code#>
13+
case .e4: <#code#>
14+
return 3
15+
}
16+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
enum E {
2+
case e1
3+
case e2
4+
case e3
5+
case e4
6+
}
7+
8+
func foo(e: E) -> Int {
9+
switch e { }
10+
}
11+
// RUN: rm -rf %t.result && mkdir -p %t.result
12+
// RUN: %refactor -expand-switch-cases -source-filename %s -pos=9:8 >> %t.result/L10.swift
13+
// RUN: diff -u %S/Outputs/basic/L10.swift.expected %t.result/L10.swift
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
enum E {
2+
case e1
3+
case e2
4+
case e3
5+
case e4
6+
}
7+
8+
func foo(e: E) -> Int {
9+
switch e {}
10+
}
11+
// RUN: rm -rf %t.result && mkdir -p %t.result
12+
// RUN: %refactor -expand-switch-cases -source-filename %s -pos=9:8 >> %t.result/L10.swift
13+
// RUN: diff -u %S/Outputs/no_space_between_braces/L10.swift.expected %t.result/L10.swift

0 commit comments

Comments
 (0)