Skip to content

Change SyntaxRewriter to disallow type node type conversions for all types that are not ExprSyntax, StmtSyntax … #1003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,27 @@ let basicFormatFile = SourceFile {
}
}

private func createChildVisitCall(childType: SyntaxBuildableType, rewrittenExpr: ExprBuildable) -> ExprBuildable {
let visitCall: FunctionCallExpr
if childType.isOptional {
visitCall = FunctionCallExpr("\(rewrittenExpr).map(self.visit)")
} else {
visitCall = FunctionCallExpr("self.visit(\(rewrittenExpr))")
}
if childType.baseType?.baseName != "Syntax", childType.baseType?.isSyntaxCollection != true, childType.baseType != nil {
let optionalChained = childType.optionalChained(expr: visitCall).createExprBuildable()
return FunctionCallExpr("\(optionalChained).cast(\(childType.syntaxBaseName).self)")
} else {
return visitCall
}
}

private func makeLayoutNodeRewriteFunc(node: Node) -> FunctionDecl {
let rewriteResultType: String
if node.isSyntaxCollection {
rewriteResultType = "Syntax"
if node.type.baseType?.syntaxKind == "Syntax" && node.type.syntaxKind != "Missing" {
rewriteResultType = node.type.syntaxBaseName
} else {
rewriteResultType = node.type.baseType?.syntaxBaseName ?? "Syntax"
rewriteResultType = node.type.baseType?.syntaxBaseName ?? node.type.syntaxBaseName
}
return FunctionDecl(
leadingTrivia: .newline,
Expand All @@ -69,11 +84,7 @@ private func makeLayoutNodeRewriteFunc(node: Node) -> FunctionDecl {
SequenceExpr("indentationLevel += 1")
}
let variableLetVar = child.requiresLeadingNewline ? "var" : "let"
if child.isOptional {
VariableDecl("\(variableLetVar) \(child.swiftName) = node.\(child.swiftName).map(self.visit)?.cast(\(child.type.syntaxBaseName).self)")
} else {
VariableDecl("\(variableLetVar) \(child.swiftName) = self.visit(node.\(child.swiftName)).cast(\(child.type.syntaxBaseName).self)")
}
VariableDecl("\(variableLetVar) \(child.swiftName) = \(createChildVisitCall(childType: child.type, rewrittenExpr: MemberAccessExpr("node.\(child.swiftName)")))")
if child.requiresLeadingNewline {
IfStmt(
"""
Expand All @@ -95,17 +106,16 @@ private func makeLayoutNodeRewriteFunc(node: Node) -> FunctionDecl {
)
}
}
ReturnStmt("return \(rewriteResultType)(\(reconstructed))")
if rewriteResultType != node.type.syntaxBaseName {
ReturnStmt("return \(rewriteResultType)(\(reconstructed))")
} else {
ReturnStmt("return \(reconstructed)")
}
}
}

private func makeSyntaxCollectionRewriteFunc(node: Node) -> FunctionDecl {
let rewriteResultType: String
if node.isSyntaxCollection {
rewriteResultType = "Syntax"
} else {
rewriteResultType = node.type.baseType?.syntaxBaseName ?? "Syntax"
}
let rewriteResultType = node.type.syntaxBaseName
return FunctionDecl(
leadingTrivia: .newline,
modifiers: [Token.open, Token(tokenSyntax: TokenSyntax.contextualKeyword("override", trailingTrivia: .space))],
Expand All @@ -126,8 +136,8 @@ private func makeSyntaxCollectionRewriteFunc(node: Node) -> FunctionDecl {
let formattedChildrenVarLet = node.elementsSeparatedByNewline ? "var" : "let"
VariableDecl(
"""
\(formattedChildrenVarLet) formattedChildren = node.children(viewMode: .all).map {
self.visit($0).cast(\(node.collectionElementType.syntaxBaseName).self)
\(formattedChildrenVarLet) formattedChildren = node.map {
\(createChildVisitCall(childType: node.collectionElementType, rewrittenExpr: IdentifierExpr(identifier: .dollarIdentifier("$0"))))
}
"""
)
Expand All @@ -144,7 +154,7 @@ private func makeSyntaxCollectionRewriteFunc(node: Node) -> FunctionDecl {
"""
)
}
ReturnStmt("return Syntax(\(node.type.syntaxBaseName)(formattedChildren))")
ReturnStmt("return \(node.type.syntaxBaseName)(formattedChildren)")
}
}

Expand Down
6 changes: 0 additions & 6 deletions Sources/SwiftBasicFormat/Utils.swift
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
import SwiftSyntax

extension SyntaxProtocol {
func cast<S: SyntaxProtocol>(_ syntaxType: S.Type) -> S {
return Syntax(self).as(S.self)!
}
}

extension Trivia {
var containsNewline: Bool {
for piece in self {
Expand Down
3,884 changes: 1,942 additions & 1,942 deletions Sources/SwiftBasicFormat/generated/BasicFormat.swift

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Sources/SwiftSyntax/Syntax.swift
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ extension Syntax {
public func `as`<S: SyntaxProtocol>(_ syntaxType: S.Type) -> S? {
return S.init(self)
}

func cast<S: SyntaxProtocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}
}

extension Syntax: CustomReflectable {
Expand Down
4 changes: 4 additions & 0 deletions Sources/SwiftSyntax/SyntaxBaseNodes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ public struct ${node.name}: ${node.name}Protocol, SyntaxHashable {
return S.init(_syntaxNode)
}

public func cast<S: ${node.name}Protocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}

/// Syntax nodes always conform to `${node.name}Protocol`. This API is just
/// added for consistency.
/// Note that this will incur an existential conversion.
Expand Down
14 changes: 8 additions & 6 deletions Sources/SwiftSyntax/SyntaxRewriter.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@ open class SyntaxRewriter {
/// Visit a `${node.name}`.
/// - Parameter node: the node that is being visited
/// - Returns: the rewritten node
% if node.base_type == 'Syntax' and node.name != 'MissingSyntax':
open func visit(_ node: ${node.name}) -> ${node.name} {
return Syntax(visitChildren(node)).cast(${node.name}.self)
}
% else:
open func visit(_ node: ${node.name}) -> ${node.base_type} {
return ${node.base_type}(visitChildren(node))
}
% end

% end
% end
Expand Down Expand Up @@ -84,8 +90,8 @@ open class SyntaxRewriter {
/// Visit any ${base_kind}Syntax node.
/// - Parameter node: the node that is being visited
/// - Returns: the rewritten node
public func visit(_ node: ${base_kind}Syntax) -> Syntax {
return visit(node.data)
public func visit(_ node: ${base_kind}Syntax) -> ${base_kind}Syntax {
return visit(node.data).cast(${base_kind}Syntax.self)
}

% end
Expand All @@ -106,11 +112,7 @@ open class SyntaxRewriter {
visitPre(node._syntaxNode)
defer { visitPost(node._syntaxNode) }
if let newNode = visitAny(node._syntaxNode) { return newNode }
% if node.base_type == 'Syntax':
return visit(node)
% else:
return Syntax(visit(node))
% end
% end
}

Expand Down
20 changes: 20 additions & 0 deletions Sources/SwiftSyntax/gyb_generated/SyntaxBaseNodes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ public struct DeclSyntax: DeclSyntaxProtocol, SyntaxHashable {
return S.init(_syntaxNode)
}

public func cast<S: DeclSyntaxProtocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}

/// Syntax nodes always conform to `DeclSyntaxProtocol`. This API is just
/// added for consistency.
/// Note that this will incur an existential conversion.
Expand Down Expand Up @@ -191,6 +195,10 @@ public struct ExprSyntax: ExprSyntaxProtocol, SyntaxHashable {
return S.init(_syntaxNode)
}

public func cast<S: ExprSyntaxProtocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}

/// Syntax nodes always conform to `ExprSyntaxProtocol`. This API is just
/// added for consistency.
/// Note that this will incur an existential conversion.
Expand Down Expand Up @@ -294,6 +302,10 @@ public struct StmtSyntax: StmtSyntaxProtocol, SyntaxHashable {
return S.init(_syntaxNode)
}

public func cast<S: StmtSyntaxProtocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}

/// Syntax nodes always conform to `StmtSyntaxProtocol`. This API is just
/// added for consistency.
/// Note that this will incur an existential conversion.
Expand Down Expand Up @@ -397,6 +409,10 @@ public struct TypeSyntax: TypeSyntaxProtocol, SyntaxHashable {
return S.init(_syntaxNode)
}

public func cast<S: TypeSyntaxProtocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}

/// Syntax nodes always conform to `TypeSyntaxProtocol`. This API is just
/// added for consistency.
/// Note that this will incur an existential conversion.
Expand Down Expand Up @@ -500,6 +516,10 @@ public struct PatternSyntax: PatternSyntaxProtocol, SyntaxHashable {
return S.init(_syntaxNode)
}

public func cast<S: PatternSyntaxProtocol>(_ syntaxType: S.Type) -> S {
return self.as(S.self)!
}

/// Syntax nodes always conform to `PatternSyntaxProtocol`. This API is just
/// added for consistency.
/// Note that this will incur an existential conversion.
Expand Down
Loading