Skip to content

Commit 7daf026

Browse files
authored
Use NIOThrowingAsyncSequenceProducer (#317)
1 parent a365a9b commit 7daf026

File tree

4 files changed

+93
-547
lines changed

4 files changed

+93
-547
lines changed

Sources/PostgresNIO/New/Messages/DataRow.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,6 @@ extension DataRow {
117117
}
118118
}
119119

120-
#if swift(>=5.6)
120+
#if swift(>=5.5)
121121
extension DataRow: Sendable {}
122122
#endif

Sources/PostgresNIO/New/PSQLRowStream.swift

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ import NIOCore
22
import Logging
33

44
final class PSQLRowStream {
5+
private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer<DataRow, Error, AdaptiveRowBuffer, PSQLRowStream>.Source
6+
57
enum RowSource {
68
case stream(PSQLRowsDataSource)
79
case noRows(Result<String, Error>)
@@ -23,7 +25,7 @@ final class PSQLRowStream {
2325
case consumed(Result<String, Error>)
2426

2527
#if canImport(_Concurrency)
26-
case asyncSequence(AsyncStreamConsumer, PSQLRowsDataSource)
28+
case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource)
2729
#endif
2830
}
2931

@@ -71,26 +73,35 @@ final class PSQLRowStream {
7173
preconditionFailure("Invalid state: \(self.downstreamState)")
7274
}
7375

74-
let consumer = AsyncStreamConsumer(
75-
lookupTable: self.lookupTable,
76-
columns: self.rowDescription
76+
let producer = NIOThrowingAsyncSequenceProducer.makeSequence(
77+
elementType: DataRow.self,
78+
failureType: Error.self,
79+
backPressureStrategy: AdaptiveRowBuffer(),
80+
delegate: self
7781
)
82+
83+
let source = producer.source
7884

7985
switch bufferState {
8086
case .streaming(let bufferedRows, let dataSource):
81-
consumer.startStreaming(bufferedRows, upstream: self)
82-
self.downstreamState = .asyncSequence(consumer, dataSource)
87+
let yieldResult = source.yield(contentsOf: bufferedRows)
88+
self.downstreamState = .asyncSequence(source, dataSource)
89+
90+
self.eventLoop.execute {
91+
self.executeActionBasedOnYieldResult(yieldResult, source: dataSource)
92+
}
8393

8494
case .finished(let buffer, let commandTag):
85-
consumer.startCompleted(buffer, commandTag: commandTag)
95+
_ = source.yield(contentsOf: buffer)
96+
source.finish()
8697
self.downstreamState = .consumed(.success(commandTag))
8798

8899
case .failure(let error):
89-
consumer.startFailed(error)
100+
source.finish(error)
90101
self.downstreamState = .consumed(.failure(error))
91102
}
92103

93-
return PostgresRowSequence(consumer)
104+
return PostgresRowSequence(producer.sequence, lookupTable: self.lookupTable, columns: self.rowDescription)
94105
}
95106

96107
func demand() {
@@ -128,10 +139,8 @@ final class PSQLRowStream {
128139

129140
private func cancel0() {
130141
switch self.downstreamState {
131-
case .asyncSequence(let consumer, let dataSource):
132-
let error = PSQLError.connectionClosed
133-
self.downstreamState = .consumed(.failure(error))
134-
consumer.receive(completion: .failure(error))
142+
case .asyncSequence(_, let dataSource):
143+
self.downstreamState = .consumed(.failure(CancellationError()))
135144
dataSource.cancel(for: self)
136145

137146
case .consumed:
@@ -305,8 +314,9 @@ final class PSQLRowStream {
305314
dataSource.request(for: self)
306315

307316
#if canImport(_Concurrency)
308-
case .asyncSequence(let consumer, _):
309-
consumer.receive(newRows)
317+
case .asyncSequence(let consumer, let source):
318+
let yieldResult = consumer.yield(contentsOf: newRows)
319+
self.executeActionBasedOnYieldResult(yieldResult, source: source)
310320
#endif
311321

312322
case .consumed(.success):
@@ -345,8 +355,8 @@ final class PSQLRowStream {
345355
promise.succeed(rows)
346356

347357
#if canImport(_Concurrency)
348-
case .asyncSequence(let consumer, _):
349-
consumer.receive(completion: .success(commandTag))
358+
case .asyncSequence(let source, _):
359+
source.finish()
350360
self.downstreamState = .consumed(.success(commandTag))
351361
#endif
352362

@@ -373,14 +383,30 @@ final class PSQLRowStream {
373383

374384
#if canImport(_Concurrency)
375385
case .asyncSequence(let consumer, _):
376-
consumer.receive(completion: .failure(error))
386+
consumer.finish(error)
377387
self.downstreamState = .consumed(.failure(error))
378388
#endif
379389

380390
case .consumed:
381391
break
382392
}
383393
}
394+
395+
private func executeActionBasedOnYieldResult(_ yieldResult: AsyncSequenceSource.YieldResult, source: PSQLRowsDataSource) {
396+
self.eventLoop.preconditionInEventLoop()
397+
switch yieldResult {
398+
case .dropped:
399+
// ignore
400+
break
401+
402+
case .produceMore:
403+
source.request(for: self)
404+
405+
case .stopProducing:
406+
// ignore
407+
break
408+
}
409+
}
384410

385411
var commandTag: String {
386412
guard case .consumed(.success(let commandTag)) = self.downstreamState else {
@@ -390,14 +416,24 @@ final class PSQLRowStream {
390416
}
391417
}
392418

419+
extension PSQLRowStream: NIOAsyncSequenceProducerDelegate {
420+
func produceMore() {
421+
self.demand()
422+
}
423+
424+
func didTerminate() {
425+
self.cancel()
426+
}
427+
}
428+
393429
protocol PSQLRowsDataSource {
394430

395431
func request(for stream: PSQLRowStream)
396432
func cancel(for stream: PSQLRowStream)
397433

398434
}
399435

400-
#if swift(>=5.6)
436+
#if swift(>=5.5)
401437
// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop.
402438
extension PSQLRowStream: @unchecked Sendable {}
403439
#endif

0 commit comments

Comments
 (0)