Skip to content

Commit 278aea7

Browse files
committed
[CompilerPlugin] Rewrite MessageConnection
Back to read(2)/write(2). Some cleanups.
1 parent 27fc293 commit 278aea7

File tree

1 file changed

+81
-78
lines changed

1 file changed

+81
-78
lines changed

Sources/SwiftCompilerPlugin/CompilerPlugin.swift

Lines changed: 81 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,13 @@ extension CompilerPlugin {
100100
}
101101
}
102102

103+
struct CompilerPluginError: Error, CustomStringConvertible {
104+
var description: String
105+
init(message: String) {
106+
self.description = message
107+
}
108+
}
109+
103110
struct MacroProviderAdapter<Plugin: CompilerPlugin>: PluginProvider {
104111
let plugin: Plugin
105112
init(plugin: Plugin) {
@@ -110,57 +117,61 @@ struct MacroProviderAdapter<Plugin: CompilerPlugin>: PluginProvider {
110117
}
111118
}
112119

120+
#if canImport(ucrt)
121+
private let dup = _dup(_:)
122+
private let fileno = _fileno(_:)
123+
private let dup2 = _dup2(_:_:)
124+
private let close = _close(_:)
125+
private let read = _read(_:_:_:)
126+
private let write = _write(_:_:_:)
127+
#endif
128+
113129
extension CompilerPlugin {
114130

115131
/// Main entry point of the plugin — sets up a communication channel with
116132
/// the plugin host and runs the main message loop.
117133
public static func main() throws {
118-
let _stdin = _ss_stdin()
119-
let _stdout = _ss_stdout()
120-
let _stderr = _ss_stderr()
134+
let stdin = _ss_stdin()
135+
let stdout = _ss_stdout()
136+
let stderr = _ss_stderr()
121137

122138
// Duplicate the `stdin` file descriptor, which we will then use for
123139
// receiving messages from the plugin host.
124-
let inputFD = dup(fileno(_stdin))
140+
let inputFD = dup(fileno(stdin))
125141
guard inputFD >= 0 else {
126142
internalError("Could not duplicate `stdin`: \(describe(errno: _ss_errno())).")
127143
}
128144

129145
// Having duplicated the original standard-input descriptor, we close
130146
// `stdin` so that attempts by the plugin to read console input (which
131147
// are usually a mistake) return errors instead of blocking.
132-
guard close(fileno(_stdin)) >= 0 else {
148+
guard close(fileno(stdin)) >= 0 else {
133149
internalError("Could not close `stdin`: \(describe(errno: _ss_errno())).")
134150
}
135151

136152
// Duplicate the `stdout` file descriptor, which we will then use for
137153
// sending messages to the plugin host.
138-
let outputFD = dup(fileno(_stdout))
154+
let outputFD = dup(fileno(stdout))
139155
guard outputFD >= 0 else {
140156
internalError("Could not dup `stdout`: \(describe(errno: _ss_errno())).")
141157
}
142158

143159
// Having duplicated the original standard-output descriptor, redirect
144160
// `stdout` to `stderr` so that all free-form text output goes there.
145-
guard dup2(fileno(_stderr), fileno(_stdout)) >= 0 else {
161+
guard dup2(fileno(stderr), fileno(stdout)) >= 0 else {
146162
internalError("Could not dup2 `stdout` to `stderr`: \(describe(errno: _ss_errno())).")
147163
}
148164

149-
// Turn off full buffering so printed text appears as soon as possible.
150-
// Windows is much less forgiving than other platforms. If line
151-
// buffering is enabled, we must provide a buffer and the size of the
152-
// buffer. As a result, on Windows, we completely disable all
153-
// buffering, which means that partial writes are possible.
154-
#if os(Windows)
155-
setvbuf(_stdout, nil, _IONBF, 0)
156-
#else
157-
setvbuf(_stdout, nil, _IOLBF, 0)
165+
#if canImport(ucrt)
166+
// Set I/O to binary mode. Avoid CRLF translation, and Ctrl+Z (0x1A) as EOF.
167+
_ = _setmode(inputFD, _O_BINARY)
168+
_ = _setmode(outputFD, _O_BINARY)
158169
#endif
159170

160171
// Open a message channel for communicating with the plugin host.
161172
let connection = PluginHostConnection(
162-
inputStream: fdopen(inputFD, "r"),
163-
outputStream: fdopen(outputFD, "w")
173+
inputStream: inputFD,
174+
outputStream: outputFD
164175
)
165176

166177
// Handle messages from the host until the input stream is closed,
@@ -181,12 +192,11 @@ extension CompilerPlugin {
181192
fputs("Internal Error: \(message)\n", _ss_stderr())
182193
exit(1)
183194
}
184-
185195
}
186196

187197
internal struct PluginHostConnection: MessageConnection {
188-
fileprivate let inputStream: _ss_ptr_FILE
189-
fileprivate let outputStream: _ss_ptr_FILE
198+
fileprivate let inputStream: CInt
199+
fileprivate let outputStream: CInt
190200

191201
func sendMessage<TX: Encodable>(_ message: TX) throws {
192202
// Encode the message as JSON.
@@ -195,83 +205,76 @@ internal struct PluginHostConnection: MessageConnection {
195205
// Write the header (a 64-bit length field in little endian byte order).
196206
let count = payload.count
197207
var header = UInt64(count).littleEndian
208+
try withUnsafeBytes(of: &header) { try _write(outputStream, contentsOf: $0) }
198209

199-
try withUnsafeBytes(of: &header) { buffer in
200-
precondition(buffer.count == 8)
201-
try _write(outputStream, contentsOf: buffer)
202-
}
203-
204-
try payload.withUnsafeBytes { buffer in
205-
try _write(outputStream, contentsOf: buffer)
206-
}
207-
208-
fflush(outputStream)
210+
// Write the JSON payload.
211+
try payload.withUnsafeBytes { try _write(outputStream, contentsOf: $0) }
209212
}
210213

211214
func waitForNextMessage<RX: Decodable>(_ ty: RX.Type) throws -> RX? {
212215
// Read the header (a 64-bit length field in little endian byte order).
213-
let count = try _reading(inputStream, count: 8) { buffer in
214-
return buffer.count == 8 ? UInt64(littleEndian: buffer.loadUnaligned(as: UInt64.self)) : 0
215-
}
216-
guard count >= 2 else {
217-
if count == 0 {
218-
// input stream is closed.
219-
return nil
220-
}
221-
throw PluginMessageError.invalidPayloadSize
216+
var header: UInt64 = 0
217+
do {
218+
try withUnsafeMutableBytes(of: &header) { try _read(inputStream, into: $0) }
219+
} catch IOError.readReachedEndOfInput {
220+
// Connection closed.
221+
return nil
222222
}
223223

224224
// Read the JSON payload.
225-
return try _reading(inputStream, count: Int(count)) { buffer -> RX in
226-
if buffer.count != Int(count) {
227-
throw PluginMessageError.truncatedPayload
228-
}
229-
// Decode and return the message.
230-
return try JSON.decode(RX.self, from: buffer.bindMemory(to: UInt8.self))
231-
}
232-
}
225+
let count = Int(UInt64(littleEndian: header))
226+
let data = UnsafeMutableRawBufferPointer.allocate(byteCount: count, alignment: 1)
227+
defer { data.deallocate() }
228+
try _read(inputStream, into: data)
233229

234-
enum PluginMessageError: Swift.Error {
235-
case invalidPayloadSize
236-
case truncatedPayload
230+
// Decode and return the message.
231+
return try JSON.decode(ty, from: UnsafeBufferPointer(data.bindMemory(to: UInt8.self)))
237232
}
238233
}
239234

240-
// Private function to construct an error message from an `errno` code.
241-
private func describe(errno: CInt) -> String {
242-
if let cStr = strerror(errno) { return String(cString: cStr) }
243-
return String(describing: errno)
235+
/// Write the buffer to the file descriptor. Throws an error on failure.
236+
private func _write(_ fd: CInt, contentsOf buffer: UnsafeRawBufferPointer) throws {
237+
guard var ptr = buffer.baseAddress else { return }
238+
let endPtr = ptr.advanced(by: buffer.count)
239+
while ptr != endPtr {
240+
switch write(fd, ptr, numericCast(endPtr - ptr)) {
241+
case -1: throw IOError.writeFailed(_ss_errno())
242+
case 0: throw IOError.writeFailed(0) /* unreachable */
243+
case let n: ptr += Int(n)
244+
}
245+
}
244246
}
245247

246-
private func _write(_ stream: _ss_ptr_FILE, contentsOf buffer: UnsafeRawBufferPointer) throws {
247-
let result = fwrite(buffer.baseAddress, 1, buffer.count, stream)
248-
if result < buffer.count {
249-
throw CompilerPluginError(message: "fwrite(3) failed: \(describe(errno: _ss_errno()))")
248+
/// Fill the buffer to the file descriptor. Throws an error on failure.
249+
/// If the file descriptor reached the end-of-file, throws IOError.readReachedEndOfInput
250+
private func _read(_ fd: CInt, into buffer: UnsafeMutableRawBufferPointer) throws {
251+
guard var ptr = buffer.baseAddress else { return }
252+
let endPtr = ptr.advanced(by: buffer.count)
253+
while ptr != endPtr {
254+
switch read(fd, ptr, numericCast(endPtr - ptr)) {
255+
case -1: throw IOError.readFailed(_ss_errno())
256+
case 0: throw IOError.readReachedEndOfInput
257+
case let n: ptr += Int(n)
258+
}
250259
}
251260
}
252261

253-
private func _reading<T>(_ stream: _ss_ptr_FILE, count: Int, _ fn: (UnsafeRawBufferPointer) throws -> T) throws -> T {
254-
guard count > 0 else {
255-
return try fn(UnsafeRawBufferPointer(start: nil, count: 0))
256-
}
257-
let buffer = UnsafeMutableRawBufferPointer.allocate(byteCount: count, alignment: 1)
258-
defer { buffer.deallocate() }
262+
private enum IOError: Error, CustomStringConvertible {
263+
case readReachedEndOfInput
264+
case readFailed(CInt)
265+
case writeFailed(CInt)
259266

260-
let result = fread(buffer.baseAddress, 1, count, stream)
261-
if result < count {
262-
if ferror(stream) == 0 {
263-
// Input is closed.
264-
return try fn(UnsafeRawBufferPointer(start: nil, count: 0))
265-
} else {
266-
throw CompilerPluginError(message: "fread(3) failed: \(describe(errno: _ss_errno()))")
267+
var description: String {
268+
switch self {
269+
case .readReachedEndOfInput: "read(2) reached end-of-file"
270+
case .readFailed(let errno): "read(2) failed: \(describe(errno: errno))"
271+
case .writeFailed(let errno): "write(2) failed: \(describe(errno: errno))"
267272
}
268273
}
269-
return try fn(UnsafeRawBufferPointer(buffer))
270274
}
271275

272-
struct CompilerPluginError: Error, CustomStringConvertible {
273-
var description: String
274-
init(message: String) {
275-
self.description = message
276-
}
276+
// Private function to construct an error message from an `errno` code.
277+
private func describe(errno: CInt) -> String {
278+
if let cStr = strerror(errno) { return String(cString: cStr) }
279+
return String(describing: errno)
277280
}

0 commit comments

Comments
 (0)