Skip to content

Commit 780a510

Browse files
authored
Refactor PSQLRowStream to make async/await easier (#201)
### Motivation `PSQLRowStream`'s current implementation is interesting. It should be better tested and easier to follow for async/await support later. ### Changes - Make `PSQLRowStream`'s implementation more sensible - Add unit tests for `PSQLRowStream` ### Result Adding async/await support becomes easier.
1 parent 81ca909 commit 780a510

File tree

3 files changed

+429
-104
lines changed

3 files changed

+429
-104
lines changed

Sources/PostgresNIO/New/PSQLRow.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ struct PSQLRow {
1616
}
1717
}
1818

19+
extension PSQLRow: Equatable {
20+
static func ==(lhs: Self, rhs: Self) -> Bool {
21+
lhs.data == rhs.data && lhs.columns == rhs.columns
22+
}
23+
}
24+
1925
extension PSQLRow {
2026
/// Access the data in the provided column and decode it into the target type.
2127
///

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 100 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import NIOCore
22
import Logging
33

44
final class PSQLRowStream {
5-
65
enum RowSource {
76
case stream(PSQLRowsDataSource)
87
case noRows(Result<String, Error>)
@@ -11,23 +10,21 @@ final class PSQLRowStream {
1110
let eventLoop: EventLoop
1211
let logger: Logger
1312

14-
private enum UpstreamState {
13+
private enum BufferState {
1514
case streaming(buffer: CircularBuffer<DataRow>, dataSource: PSQLRowsDataSource)
1615
case finished(buffer: CircularBuffer<DataRow>, commandTag: String)
1716
case failure(Error)
18-
case consumed(Result<String, Error>)
19-
case modifying
2017
}
2118

2219
private enum DownstreamState {
23-
case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise<Void>)
24-
case waitingForAll(EventLoopPromise<[PSQLRow]>)
25-
case consuming
20+
case waitingForConsumer(BufferState)
21+
case iteratingRows(onRow: (PSQLRow) throws -> (), EventLoopPromise<Void>, PSQLRowsDataSource)
22+
case waitingForAll([PSQLRow], EventLoopPromise<[PSQLRow]>, PSQLRowsDataSource)
23+
case consumed(Result<String, Error>)
2624
}
2725

2826
internal let rowDescription: [RowDescription.Column]
2927
private let lookupTable: [String: Int]
30-
private var upstreamState: UpstreamState
3128
private var downstreamState: DownstreamState
3229
private let jsonDecoder: PSQLJSONDecoder
3330

@@ -36,30 +33,33 @@ final class PSQLRowStream {
3633
eventLoop: EventLoop,
3734
rowSource: RowSource)
3835
{
39-
let buffer = CircularBuffer<DataRow>()
40-
41-
self.downstreamState = .consuming
36+
let bufferState: BufferState
4237
switch rowSource {
4338
case .stream(let dataSource):
44-
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
39+
bufferState = .streaming(buffer: .init(), dataSource: dataSource)
4540
case .noRows(.success(let commandTag)):
46-
self.upstreamState = .finished(buffer: .init(), commandTag: commandTag)
41+
bufferState = .finished(buffer: .init(), commandTag: commandTag)
4742
case .noRows(.failure(let error)):
48-
self.upstreamState = .failure(error)
43+
bufferState = .failure(error)
4944
}
5045

46+
self.downstreamState = .waitingForConsumer(bufferState)
47+
5148
self.eventLoop = eventLoop
5249
self.logger = queryContext.logger
5350
self.jsonDecoder = queryContext.jsonDecoder
5451

5552
self.rowDescription = rowDescription
53+
5654
var lookup = [String: Int]()
5755
lookup.reserveCapacity(rowDescription.count)
5856
rowDescription.enumerated().forEach { (index, column) in
5957
lookup[column.name] = index
6058
}
6159
self.lookupTable = lookup
6260
}
61+
62+
// MARK: Consume in array
6363

6464
func all() -> EventLoopFuture<[PSQLRow]> {
6565
if self.eventLoop.inEventLoop {
@@ -74,40 +74,37 @@ final class PSQLRowStream {
7474
private func all0() -> EventLoopFuture<[PSQLRow]> {
7575
self.eventLoop.preconditionInEventLoop()
7676

77-
guard case .consuming = self.downstreamState else {
78-
preconditionFailure("Invalid state")
77+
guard case .waitingForConsumer(let bufferState) = self.downstreamState else {
78+
preconditionFailure("Invalid state: \(self.downstreamState)")
7979
}
8080

81-
switch self.upstreamState {
82-
case .streaming(_, let dataSource):
83-
dataSource.request(for: self)
81+
switch bufferState {
82+
case .streaming(let bufferedRows, let dataSource):
8483
let promise = self.eventLoop.makePromise(of: [PSQLRow].self)
85-
self.downstreamState = .waitingForAll(promise)
84+
let rows = bufferedRows.map { data in
85+
PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
86+
}
87+
self.downstreamState = .waitingForAll(rows, promise, dataSource)
88+
// immediately request more
89+
dataSource.request(for: self)
8690
return promise.futureResult
8791

8892
case .finished(let buffer, let commandTag):
89-
self.upstreamState = .modifying
90-
9193
let rows = buffer.map {
9294
PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
9395
}
9496

95-
self.downstreamState = .consuming
96-
self.upstreamState = .consumed(.success(commandTag))
97+
self.downstreamState = .consumed(.success(commandTag))
9798
return self.eventLoop.makeSucceededFuture(rows)
9899

99-
case .consumed:
100-
preconditionFailure("We already signaled, that the stream has completed, why are we asked again?")
101-
102-
case .modifying:
103-
preconditionFailure("Invalid state")
104-
105100
case .failure(let error):
106-
self.upstreamState = .consumed(.failure(error))
101+
self.downstreamState = .consumed(.failure(error))
107102
return self.eventLoop.makeFailedFuture(error)
108103
}
109104
}
110105

106+
// MARK: Consume on EventLoop
107+
111108
func onRow(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture<Void> {
112109
if self.eventLoop.inEventLoop {
113110
return self.onRow0(onRow)
@@ -121,7 +118,11 @@ final class PSQLRowStream {
121118
private func onRow0(_ onRow: @escaping (PSQLRow) throws -> ()) -> EventLoopFuture<Void> {
122119
self.eventLoop.preconditionInEventLoop()
123120

124-
switch self.upstreamState {
121+
guard case .waitingForConsumer(let bufferState) = self.downstreamState else {
122+
preconditionFailure("Invalid state: \(self.downstreamState)")
123+
}
124+
125+
switch bufferState {
125126
case .streaming(var buffer, let dataSource):
126127
let promise = self.eventLoop.makePromise(of: Void.self)
127128
do {
@@ -136,12 +137,11 @@ final class PSQLRowStream {
136137
}
137138

138139
buffer.removeAll()
139-
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
140-
self.downstreamState = .iteratingRows(onRow: onRow, promise)
140+
self.downstreamState = .iteratingRows(onRow: onRow, promise, dataSource)
141141
// immediately request more
142142
dataSource.request(for: self)
143143
} catch {
144-
self.upstreamState = .failure(error)
144+
self.downstreamState = .consumed(.failure(error))
145145
dataSource.cancel(for: self)
146146
promise.fail(error)
147147
}
@@ -160,22 +160,15 @@ final class PSQLRowStream {
160160
try onRow(row)
161161
}
162162

163-
self.upstreamState = .consumed(.success(commandTag))
164-
self.downstreamState = .consuming
163+
self.downstreamState = .consumed(.success(commandTag))
165164
return self.eventLoop.makeSucceededVoidFuture()
166165
} catch {
167-
self.upstreamState = .consumed(.failure(error))
166+
self.downstreamState = .consumed(.failure(error))
168167
return self.eventLoop.makeFailedFuture(error)
169168
}
170169

171-
case .consumed:
172-
preconditionFailure("We already signaled, that the stream has completed, why are we asked again?")
173-
174-
case .modifying:
175-
preconditionFailure("Invalid state")
176-
177170
case .failure(let error):
178-
self.upstreamState = .consumed(.failure(error))
171+
self.downstreamState = .consumed(.failure(error))
179172
return self.eventLoop.makeFailedFuture(error)
180173
}
181174
}
@@ -193,13 +186,15 @@ final class PSQLRowStream {
193186
"row_count": "\(newRows.count)"
194187
])
195188

196-
guard case .streaming(var buffer, let dataSource) = self.upstreamState else {
197-
preconditionFailure("Invalid state")
198-
}
199-
200189
switch self.downstreamState {
201-
case .iteratingRows(let onRow, let promise):
202-
precondition(buffer.isEmpty)
190+
case .waitingForConsumer(.streaming(buffer: var buffer, dataSource: let dataSource)):
191+
buffer.append(contentsOf: newRows)
192+
self.downstreamState = .waitingForConsumer(.streaming(buffer: buffer, dataSource: dataSource))
193+
194+
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
195+
preconditionFailure("How can new rows be received, if an end was already signalled?")
196+
197+
case .iteratingRows(let onRow, let promise, let dataSource):
203198
do {
204199
for data in newRows {
205200
let row = PSQLRow(
@@ -214,82 +209,83 @@ final class PSQLRowStream {
214209
dataSource.request(for: self)
215210
} catch {
216211
dataSource.cancel(for: self)
217-
self.upstreamState = .failure(error)
212+
self.downstreamState = .consumed(.failure(error))
218213
promise.fail(error)
219214
return
220215
}
221-
case .waitingForAll:
222-
self.upstreamState = .modifying
223-
buffer.append(contentsOf: newRows)
224-
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
225-
216+
217+
case .waitingForAll(var rows, let promise, let dataSource):
218+
newRows.forEach { data in
219+
let row = PSQLRow(data: data, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
220+
rows.append(row)
221+
}
222+
self.downstreamState = .waitingForAll(rows, promise, dataSource)
226223
// immediately request more
227224
dataSource.request(for: self)
228225

229-
case .consuming:
230-
// this might happen, if the query has finished while the user is consuming data
231-
// we don't need to ask for more since the user is consuming anyway
232-
self.upstreamState = .modifying
233-
buffer.append(contentsOf: newRows)
234-
self.upstreamState = .streaming(buffer: buffer, dataSource: dataSource)
226+
case .consumed(.success):
227+
preconditionFailure("How can we receive further rows, if we are supposed to be done")
228+
229+
case .consumed(.failure):
230+
break
235231
}
236232
}
237233

238234
internal func receive(completion result: Result<String, Error>) {
239235
self.eventLoop.preconditionInEventLoop()
240236

241-
guard case .streaming(let oldBuffer, _) = self.upstreamState else {
242-
preconditionFailure("Invalid state")
237+
switch result {
238+
case .success(let commandTag):
239+
self.receiveEnd(commandTag)
240+
case .failure(let error):
241+
self.receiveError(error)
243242
}
243+
}
244244

245+
private func receiveEnd(_ commandTag: String) {
245246
switch self.downstreamState {
246-
case .iteratingRows(_, let promise):
247-
precondition(oldBuffer.isEmpty)
248-
self.downstreamState = .consuming
249-
self.upstreamState = .consumed(result)
250-
switch result {
251-
case .success:
252-
promise.succeed(())
253-
case .failure(let error):
254-
promise.fail(error)
255-
}
247+
case .waitingForConsumer(.streaming(buffer: let buffer, _)):
248+
self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag))
256249

250+
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
251+
preconditionFailure("How can we get another end, if an end was already signalled?")
257252

258-
case .consuming:
259-
switch result {
260-
case .success(let commandTag):
261-
self.upstreamState = .finished(buffer: oldBuffer, commandTag: commandTag)
262-
case .failure(let error):
263-
self.upstreamState = .failure(error)
264-
}
265-
266-
case .waitingForAll(let promise):
267-
switch result {
268-
case .failure(let error):
269-
self.upstreamState = .consumed(.failure(error))
270-
promise.fail(error)
271-
case .success(let commandTag):
272-
let rows = oldBuffer.map {
273-
PSQLRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription, jsonDecoder: self.jsonDecoder)
274-
}
275-
self.upstreamState = .consumed(.success(commandTag))
276-
promise.succeed(rows)
277-
}
253+
case .iteratingRows(_, let promise, _):
254+
self.downstreamState = .consumed(.success(commandTag))
255+
promise.succeed(())
256+
257+
case .waitingForAll(let rows, let promise, _):
258+
self.downstreamState = .consumed(.success(commandTag))
259+
promise.succeed(rows)
260+
261+
case .consumed:
262+
break
278263
}
279264
}
280-
281-
func cancel() {
282-
guard case .streaming(_, let dataSource) = self.upstreamState else {
283-
// We don't need to cancel any upstream resource. All needed data is already
284-
// included in this
285-
return
286-
}
287265

288-
dataSource.cancel(for: self)
266+
private func receiveError(_ error: Error) {
267+
switch self.downstreamState {
268+
case .waitingForConsumer(.streaming):
269+
self.downstreamState = .waitingForConsumer(.failure(error))
270+
271+
case .waitingForConsumer(.finished), .waitingForConsumer(.failure):
272+
preconditionFailure("How can we get another end, if an end was already signalled?")
273+
274+
case .iteratingRows(_, let promise, _):
275+
self.downstreamState = .consumed(.failure(error))
276+
promise.fail(error)
277+
278+
case .waitingForAll(_, let promise, _):
279+
self.downstreamState = .consumed(.failure(error))
280+
promise.fail(error)
281+
282+
case .consumed:
283+
break
284+
}
289285
}
290286

291287
var commandTag: String {
292-
guard case .consumed(.success(let commandTag)) = self.upstreamState else {
288+
guard case .consumed(.success(let commandTag)) = self.downstreamState else {
293289
preconditionFailure("commandTag may only be called if all rows have been consumed")
294290
}
295291
return commandTag

0 commit comments

Comments
 (0)