12
12
13
13
import SwiftSyntax
14
14
import SwiftParser
15
+ import SwiftDiagnostics
15
16
#if canImport(_CompilerPluginSupport)
16
17
import _CompilerPluginSupport
17
18
#endif
@@ -26,6 +27,16 @@ public protocol ExpressionMacro: Macro {
26
27
}
27
28
28
29
#if canImport(_CompilerPluginSupport)
30
+ extension _CompilerPluginSupport . _DiagnosticSeverity {
31
+ fileprivate init ( _ other: SwiftDiagnostics . DiagnosticSeverity ) {
32
+ switch other {
33
+ case . note: self = . note
34
+ case . warning: self = . warning
35
+ case . error: self = . error
36
+ }
37
+ }
38
+ }
39
+
29
40
extension ExpressionMacro {
30
41
public static func _kind( ) -> _CompilerPluginKind {
31
42
. expressionMacro
@@ -40,7 +51,9 @@ extension ExpressionMacro {
40
51
sourceFileTextCount: Int ,
41
52
localSourceText: UnsafePointer < UInt8 > ,
42
53
localSourceTextCount: Int
43
- ) -> ( UnsafePointer < UInt8 > ? , count: Int ) {
54
+ ) -> ( code: UnsafePointer < UInt8 > ? , codeLength: Int ,
55
+ diagnostics: UnsafePointer < _Diagnostic > ? ,
56
+ diagnosticCount: Int ) {
44
57
let targetModuleNameBuffer = UnsafeBufferPointer (
45
58
start: filePath, count: targetModuleNameCount)
46
59
let targetModuleName = String (
@@ -66,13 +79,25 @@ extension ExpressionMacro {
66
79
// Evaluate the macro.
67
80
let evalResult = apply ( mee, in: context)
68
81
82
+ let rawDiags = UnsafeMutablePointer< _Diagnostic> . allocate(
83
+ capacity: evalResult. diagnostics. count)
84
+ for (i, diag) in evalResult. diagnostics. enumerated ( ) {
85
+ rawDiags. advanced ( by: i) . initialize ( to: _makeDiagnostic (
86
+ message: diag. message,
87
+ position: diag. position. utf8Offset,
88
+ severity: . init( diag. diagMessage. severity) ) )
89
+ }
90
+
69
91
var resultString = " \( evalResult. rewritten) "
70
92
return resultString. withUTF8 { buffer in
71
93
let result = UnsafeMutableBufferPointer< UInt8> . allocate(
72
94
capacity: buffer. count + 1 )
73
95
_ = result. initialize ( from: buffer)
74
96
result [ buffer. count] = 0
75
- return ( UnsafePointer ( result. baseAddress) , buffer. count)
97
+ return (
98
+ code: UnsafePointer ( result. baseAddress) , codeLength: buffer. count,
99
+ diagnostics: UnsafePointer ? ( rawDiags) ,
100
+ diagnosticCount: evalResult. diagnostics. count)
76
101
}
77
102
}
78
103
}
0 commit comments