Skip to content

Commit b005df3

Browse files
committed
[Macros] Add an overload of findSyntaxNodeInSourceFile that accepts
a predicate. The new overload is used during macro expansion to find the nearest syntax node that is either a `DeclSyntax` or a `ClosureExprSyntax`. This avoids an issue where calling the overload that accepts a specific type may find an outer, unrelated declaration or closure.
1 parent e2aed1f commit b005df3

File tree

2 files changed

+48
-21
lines changed

2 files changed

+48
-21
lines changed

lib/ASTGen/Sources/ASTGen/SourceFile.swift

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,11 @@ public func emitParserDiagnostics(
249249
}
250250
}
251251

252-
/// Retrieve a syntax node in the given source file, with the given type.
253-
public func findSyntaxNodeInSourceFile<Node: SyntaxProtocol>(
254-
sourceFilePtr: UnsafeRawPointer,
255-
sourceLocationPtr: UnsafePointer<UInt8>?,
256-
type: Node.Type,
257-
wantOutermost: Bool = false
258-
) -> Node? {
252+
/// Find a token in the given source file at the given location.
253+
func findToken(
254+
in sourceFilePtr: UnsafeRawPointer,
255+
at sourceLocationPtr: UnsafePointer<UInt8>?
256+
) -> TokenSyntax? {
259257
guard let sourceLocationPtr = sourceLocationPtr else {
260258
return nil
261259
}
@@ -277,6 +275,20 @@ public func findSyntaxNodeInSourceFile<Node: SyntaxProtocol>(
277275
return nil
278276
}
279277

278+
return token
279+
}
280+
281+
/// Retrieve a syntax node in the given source file, with the given type.
282+
public func findSyntaxNodeInSourceFile<Node: SyntaxProtocol>(
283+
sourceFilePtr: UnsafeRawPointer,
284+
sourceLocationPtr: UnsafePointer<UInt8>?,
285+
type: Node.Type,
286+
wantOutermost: Bool = false
287+
) -> Node? {
288+
guard let token = findToken(in: sourceFilePtr, at: sourceLocationPtr) else {
289+
return nil
290+
}
291+
280292
var currentSyntax = Syntax(token)
281293
var resultSyntax: Node? = nil
282294
while let parentSyntax = currentSyntax.parent {
@@ -309,6 +321,28 @@ public func findSyntaxNodeInSourceFile<Node: SyntaxProtocol>(
309321
return resultSyntax
310322
}
311323

324+
/// Retrieve a syntax node in the given source file that satisfies the
325+
/// given predicate.
326+
public func findSyntaxNodeInSourceFile(
327+
sourceFilePtr: UnsafeRawPointer,
328+
sourceLocationPtr: UnsafePointer<UInt8>?,
329+
where predicate: (Syntax) -> Bool
330+
) -> Syntax? {
331+
guard let token = findToken(in: sourceFilePtr, at: sourceLocationPtr) else {
332+
return nil
333+
}
334+
335+
var currentSyntax = Syntax(token)
336+
while let parentSyntax = currentSyntax.parent {
337+
currentSyntax = parentSyntax
338+
if predicate(currentSyntax) {
339+
return currentSyntax
340+
}
341+
}
342+
343+
return nil
344+
}
345+
312346
@_cdecl("swift_ASTGen_virtualFiles")
313347
@usableFromInline
314348
func getVirtualFiles(

lib/ASTGen/Sources/MacroEvaluation/Macros.swift

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -583,22 +583,15 @@ func expandAttachedMacro(
583583
return 1
584584
}
585585

586-
func findNode<T: SyntaxProtocol>(type: T.Type) -> T? {
587-
findSyntaxNodeInSourceFile(
588-
sourceFilePtr: declarationSourceFilePtr,
589-
sourceLocationPtr: declarationSourceLocPointer,
590-
type: T.self
591-
)
592-
}
593-
594586
// Dig out the node for the closure or declaration to which the custom
595587
// attribute is attached.
596-
let node: Syntax
597-
if let closureNode = findNode(type: ClosureExprSyntax.self) {
598-
node = Syntax(closureNode)
599-
} else if let declNode = findNode(type: DeclSyntax.self) {
600-
node = Syntax(declNode)
601-
} else {
588+
let node = findSyntaxNodeInSourceFile(
589+
sourceFilePtr: declarationSourceFilePtr,
590+
sourceLocationPtr: declarationSourceLocPointer,
591+
where: { $0.is(DeclSyntax.self) || $0.is(ClosureExprSyntax.self) }
592+
)
593+
594+
guard let node else {
602595
return 1
603596
}
604597

0 commit comments

Comments
 (0)