Skip to content

Commit 7a50e1c

Browse files
committed
Change SyntaxRewriter to disallow type node type conversions for all types that are not ExprSyntax, StmtSyntax …
Intuitively, changing the kind of an expression in a rewriter makes sense but for all the other syntax kinds I can’t see any use case for changing their types. Thus, add some type safety to `SyntaxRewriter` by requiring that all `visit` methods for nodes that don’t have a base kind like `ExprSyntax` return a node of the same type. rdar://101355004
1 parent 587d3ad commit 7a50e1c

File tree

12 files changed

+2466
-2433
lines changed

12 files changed

+2466
-2433
lines changed

CodeGeneration/Sources/generate-swiftbasicformat/BasicFormatFile.swift

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,27 @@ let basicFormatFile = SourceFile {
4040
}
4141
}
4242

43+
private func createChildVisitCall(childType: SyntaxBuildableType, rewrittenExpr: ExprBuildable) -> ExprBuildable {
44+
let visitCall: FunctionCallExpr
45+
if childType.isOptional {
46+
visitCall = FunctionCallExpr("\(rewrittenExpr).map(self.visit)")
47+
} else {
48+
visitCall = FunctionCallExpr("self.visit(\(rewrittenExpr))")
49+
}
50+
if childType.baseType?.baseName != "Syntax", childType.baseType?.isSyntaxCollection != true, childType.baseType != nil {
51+
let optionalChained = childType.optionalChained(expr: visitCall).createExprBuildable()
52+
return FunctionCallExpr("\(optionalChained).cast(\(childType.syntaxBaseName).self)")
53+
} else {
54+
return visitCall
55+
}
56+
}
57+
4358
private func makeLayoutNodeRewriteFunc(node: Node) -> FunctionDecl {
4459
let rewriteResultType: String
45-
if node.isSyntaxCollection {
46-
rewriteResultType = "Syntax"
60+
if node.type.baseType?.syntaxKind == "Syntax" && node.type.syntaxKind != "Missing" {
61+
rewriteResultType = node.type.syntaxBaseName
4762
} else {
48-
rewriteResultType = node.type.baseType?.syntaxBaseName ?? "Syntax"
63+
rewriteResultType = node.type.baseType?.syntaxBaseName ?? node.type.syntaxBaseName
4964
}
5065
return FunctionDecl(
5166
leadingTrivia: .newline,
@@ -69,11 +84,7 @@ private func makeLayoutNodeRewriteFunc(node: Node) -> FunctionDecl {
6984
SequenceExpr("indentationLevel += 1")
7085
}
7186
let variableLetVar = child.requiresLeadingNewline ? "var" : "let"
72-
if child.isOptional {
73-
VariableDecl("\(variableLetVar) \(child.swiftName) = node.\(child.swiftName).map(self.visit)?.cast(\(child.type.syntaxBaseName).self)")
74-
} else {
75-
VariableDecl("\(variableLetVar) \(child.swiftName) = self.visit(node.\(child.swiftName)).cast(\(child.type.syntaxBaseName).self)")
76-
}
87+
VariableDecl("\(variableLetVar) \(child.swiftName) = \(createChildVisitCall(childType: child.type, rewrittenExpr: MemberAccessExpr("node.\(child.swiftName)")))")
7788
if child.requiresLeadingNewline {
7889
IfStmt(
7990
"""
@@ -95,17 +106,16 @@ private func makeLayoutNodeRewriteFunc(node: Node) -> FunctionDecl {
95106
)
96107
}
97108
}
98-
ReturnStmt("return \(rewriteResultType)(\(reconstructed))")
109+
if rewriteResultType != node.type.syntaxBaseName {
110+
ReturnStmt("return \(rewriteResultType)(\(reconstructed))")
111+
} else {
112+
ReturnStmt("return \(reconstructed)")
113+
}
99114
}
100115
}
101116

102117
private func makeSyntaxCollectionRewriteFunc(node: Node) -> FunctionDecl {
103-
let rewriteResultType: String
104-
if node.isSyntaxCollection {
105-
rewriteResultType = "Syntax"
106-
} else {
107-
rewriteResultType = node.type.baseType?.syntaxBaseName ?? "Syntax"
108-
}
118+
let rewriteResultType = node.type.syntaxBaseName
109119
return FunctionDecl(
110120
leadingTrivia: .newline,
111121
modifiers: [Token.open, Token(tokenSyntax: TokenSyntax.contextualKeyword("override", trailingTrivia: .space))],
@@ -126,8 +136,8 @@ private func makeSyntaxCollectionRewriteFunc(node: Node) -> FunctionDecl {
126136
let formattedChildrenVarLet = node.elementsSeparatedByNewline ? "var" : "let"
127137
VariableDecl(
128138
"""
129-
\(formattedChildrenVarLet) formattedChildren = node.children(viewMode: .all).map {
130-
self.visit($0).cast(\(node.collectionElementType.syntaxBaseName).self)
139+
\(formattedChildrenVarLet) formattedChildren = node.map {
140+
\(createChildVisitCall(childType: node.collectionElementType, rewrittenExpr: IdentifierExpr(identifier: .dollarIdentifier("$0"))))
131141
}
132142
"""
133143
)
@@ -144,7 +154,7 @@ private func makeSyntaxCollectionRewriteFunc(node: Node) -> FunctionDecl {
144154
"""
145155
)
146156
}
147-
ReturnStmt("return Syntax(\(node.type.syntaxBaseName)(formattedChildren))")
157+
ReturnStmt("return \(node.type.syntaxBaseName)(formattedChildren)")
148158
}
149159
}
150160

Sources/SwiftBasicFormat/Utils.swift

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
import SwiftSyntax
22

3-
extension SyntaxProtocol {
4-
func cast<S: SyntaxProtocol>(_ syntaxType: S.Type) -> S {
5-
return Syntax(self).as(S.self)!
6-
}
7-
}
8-
93
extension Trivia {
104
var containsNewline: Bool {
115
for piece in self {

Sources/SwiftBasicFormat/generated/BasicFormat.swift

Lines changed: 1942 additions & 1942 deletions
Large diffs are not rendered by default.

Sources/SwiftSyntax/Syntax.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ extension Syntax {
6767
public func `as`<S: SyntaxProtocol>(_ syntaxType: S.Type) -> S? {
6868
return S.init(self)
6969
}
70+
71+
func cast<S: SyntaxProtocol>(_ syntaxType: S.Type) -> S {
72+
return self.as(S.self)!
73+
}
7074
}
7175

7276
extension Syntax: CustomReflectable {

Sources/SwiftSyntax/SyntaxBaseNodes.swift.gyb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ public struct ${node.name}: ${node.name}Protocol, SyntaxHashable {
107107
return S.init(_syntaxNode)
108108
}
109109

110+
public func cast<S: ${node.name}Protocol>(_ syntaxType: S.Type) -> S {
111+
return self.as(S.self)!
112+
}
113+
110114
/// Syntax nodes always conform to `${node.name}Protocol`. This API is just
111115
/// added for consistency.
112116
/// Note that this will incur an existential conversion.

Sources/SwiftSyntax/SyntaxRewriter.swift.gyb

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,15 @@ open class SyntaxRewriter {
3333
/// Visit a `${node.name}`.
3434
/// - Parameter node: the node that is being visited
3535
/// - Returns: the rewritten node
36+
% if node.base_type == 'Syntax' and node.name != 'MissingSyntax':
37+
open func visit(_ node: ${node.name}) -> ${node.name} {
38+
return Syntax(visitChildren(node)).cast(${node.name}.self)
39+
}
40+
% else:
3641
open func visit(_ node: ${node.name}) -> ${node.base_type} {
3742
return ${node.base_type}(visitChildren(node))
3843
}
44+
% end
3945

4046
% end
4147
% end
@@ -84,8 +90,8 @@ open class SyntaxRewriter {
8490
/// Visit any ${base_kind}Syntax node.
8591
/// - Parameter node: the node that is being visited
8692
/// - Returns: the rewritten node
87-
public func visit(_ node: ${base_kind}Syntax) -> Syntax {
88-
return visit(node.data)
93+
public func visit(_ node: ${base_kind}Syntax) -> ${base_kind}Syntax {
94+
return visit(node.data).cast(${base_kind}Syntax.self)
8995
}
9096

9197
% end
@@ -106,11 +112,7 @@ open class SyntaxRewriter {
106112
visitPre(node._syntaxNode)
107113
defer { visitPost(node._syntaxNode) }
108114
if let newNode = visitAny(node._syntaxNode) { return newNode }
109-
% if node.base_type == 'Syntax':
110-
return visit(node)
111-
% else:
112115
return Syntax(visit(node))
113-
% end
114116
% end
115117
}
116118

Sources/SwiftSyntax/gyb_generated/SyntaxBaseNodes.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ public struct DeclSyntax: DeclSyntaxProtocol, SyntaxHashable {
8888
return S.init(_syntaxNode)
8989
}
9090

91+
public func cast<S: DeclSyntaxProtocol>(_ syntaxType: S.Type) -> S {
92+
return self.as(S.self)!
93+
}
94+
9195
/// Syntax nodes always conform to `DeclSyntaxProtocol`. This API is just
9296
/// added for consistency.
9397
/// Note that this will incur an existential conversion.
@@ -191,6 +195,10 @@ public struct ExprSyntax: ExprSyntaxProtocol, SyntaxHashable {
191195
return S.init(_syntaxNode)
192196
}
193197

198+
public func cast<S: ExprSyntaxProtocol>(_ syntaxType: S.Type) -> S {
199+
return self.as(S.self)!
200+
}
201+
194202
/// Syntax nodes always conform to `ExprSyntaxProtocol`. This API is just
195203
/// added for consistency.
196204
/// Note that this will incur an existential conversion.
@@ -294,6 +302,10 @@ public struct StmtSyntax: StmtSyntaxProtocol, SyntaxHashable {
294302
return S.init(_syntaxNode)
295303
}
296304

305+
public func cast<S: StmtSyntaxProtocol>(_ syntaxType: S.Type) -> S {
306+
return self.as(S.self)!
307+
}
308+
297309
/// Syntax nodes always conform to `StmtSyntaxProtocol`. This API is just
298310
/// added for consistency.
299311
/// Note that this will incur an existential conversion.
@@ -397,6 +409,10 @@ public struct TypeSyntax: TypeSyntaxProtocol, SyntaxHashable {
397409
return S.init(_syntaxNode)
398410
}
399411

412+
public func cast<S: TypeSyntaxProtocol>(_ syntaxType: S.Type) -> S {
413+
return self.as(S.self)!
414+
}
415+
400416
/// Syntax nodes always conform to `TypeSyntaxProtocol`. This API is just
401417
/// added for consistency.
402418
/// Note that this will incur an existential conversion.
@@ -500,6 +516,10 @@ public struct PatternSyntax: PatternSyntaxProtocol, SyntaxHashable {
500516
return S.init(_syntaxNode)
501517
}
502518

519+
public func cast<S: PatternSyntaxProtocol>(_ syntaxType: S.Type) -> S {
520+
return self.as(S.self)!
521+
}
522+
503523
/// Syntax nodes always conform to `PatternSyntaxProtocol`. This API is just
504524
/// added for consistency.
505525
/// Note that this will incur an existential conversion.

0 commit comments

Comments
 (0)