@@ -2,6 +2,8 @@ import NIOCore
2
2
import Logging
3
3
4
4
final class PSQLRowStream {
5
+ private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer < DataRow , Error , AdaptiveRowBuffer , PSQLRowStream > . Source
6
+
5
7
enum RowSource {
6
8
case stream( PSQLRowsDataSource )
7
9
case noRows( Result < String , Error > )
@@ -23,7 +25,7 @@ final class PSQLRowStream {
23
25
case consumed( Result < String , Error > )
24
26
25
27
#if canImport(_Concurrency)
26
- case asyncSequence( AsyncStreamConsumer , PSQLRowsDataSource )
28
+ case asyncSequence( AsyncSequenceSource , PSQLRowsDataSource )
27
29
#endif
28
30
}
29
31
@@ -71,26 +73,35 @@ final class PSQLRowStream {
71
73
preconditionFailure ( " Invalid state: \( self . downstreamState) " )
72
74
}
73
75
74
- let consumer = AsyncStreamConsumer (
75
- lookupTable: self . lookupTable,
76
- columns: self . rowDescription
76
+ let producer = NIOThrowingAsyncSequenceProducer . makeSequence (
77
+ elementType: DataRow . self,
78
+ failureType: Error . self,
79
+ backPressureStrategy: AdaptiveRowBuffer ( ) ,
80
+ delegate: self
77
81
)
82
+
83
+ let source = producer. source
78
84
79
85
switch bufferState {
80
86
case . streaming( let bufferedRows, let dataSource) :
81
- consumer. startStreaming ( bufferedRows, upstream: self )
82
- self . downstreamState = . asyncSequence( consumer, dataSource)
87
+ let yieldResult = source. yield ( contentsOf: bufferedRows)
88
+ self . downstreamState = . asyncSequence( source, dataSource)
89
+
90
+ self . eventLoop. execute {
91
+ self . executeActionBasedOnYieldResult ( yieldResult, source: dataSource)
92
+ }
83
93
84
94
case . finished( let buffer, let commandTag) :
85
- consumer. startCompleted ( buffer, commandTag: commandTag)
95
+ _ = source. yield ( contentsOf: buffer)
96
+ source. finish ( )
86
97
self . downstreamState = . consumed( . success( commandTag) )
87
98
88
99
case . failure( let error) :
89
- consumer . startFailed ( error)
100
+ source . finish ( error)
90
101
self . downstreamState = . consumed( . failure( error) )
91
102
}
92
103
93
- return PostgresRowSequence ( consumer )
104
+ return PostgresRowSequence ( producer . sequence , lookupTable : self . lookupTable , columns : self . rowDescription )
94
105
}
95
106
96
107
func demand( ) {
@@ -128,10 +139,8 @@ final class PSQLRowStream {
128
139
129
140
private func cancel0( ) {
130
141
switch self . downstreamState {
131
- case . asyncSequence( let consumer, let dataSource) :
132
- let error = PSQLError . connectionClosed
133
- self . downstreamState = . consumed( . failure( error) )
134
- consumer. receive ( completion: . failure( error) )
142
+ case . asyncSequence( _, let dataSource) :
143
+ self . downstreamState = . consumed( . failure( CancellationError ( ) ) )
135
144
dataSource. cancel ( for: self )
136
145
137
146
case . consumed:
@@ -305,8 +314,9 @@ final class PSQLRowStream {
305
314
dataSource. request ( for: self )
306
315
307
316
#if canImport(_Concurrency)
308
- case . asyncSequence( let consumer, _) :
309
- consumer. receive ( newRows)
317
+ case . asyncSequence( let consumer, let source) :
318
+ let yieldResult = consumer. yield ( contentsOf: newRows)
319
+ self . executeActionBasedOnYieldResult ( yieldResult, source: source)
310
320
#endif
311
321
312
322
case . consumed( . success) :
@@ -345,8 +355,8 @@ final class PSQLRowStream {
345
355
promise. succeed ( rows)
346
356
347
357
#if canImport(_Concurrency)
348
- case . asyncSequence( let consumer , _) :
349
- consumer . receive ( completion : . success ( commandTag ) )
358
+ case . asyncSequence( let source , _) :
359
+ source . finish ( )
350
360
self . downstreamState = . consumed( . success( commandTag) )
351
361
#endif
352
362
@@ -373,14 +383,30 @@ final class PSQLRowStream {
373
383
374
384
#if canImport(_Concurrency)
375
385
case . asyncSequence( let consumer, _) :
376
- consumer. receive ( completion : . failure ( error) )
386
+ consumer. finish ( error)
377
387
self . downstreamState = . consumed( . failure( error) )
378
388
#endif
379
389
380
390
case . consumed:
381
391
break
382
392
}
383
393
}
394
+
395
+ private func executeActionBasedOnYieldResult( _ yieldResult: AsyncSequenceSource . YieldResult , source: PSQLRowsDataSource ) {
396
+ self . eventLoop. preconditionInEventLoop ( )
397
+ switch yieldResult {
398
+ case . dropped:
399
+ // ignore
400
+ break
401
+
402
+ case . produceMore:
403
+ source. request ( for: self )
404
+
405
+ case . stopProducing:
406
+ // ignore
407
+ break
408
+ }
409
+ }
384
410
385
411
var commandTag : String {
386
412
guard case . consumed( . success( let commandTag) ) = self . downstreamState else {
@@ -390,14 +416,24 @@ final class PSQLRowStream {
390
416
}
391
417
}
392
418
419
+ extension PSQLRowStream : NIOAsyncSequenceProducerDelegate {
420
+ func produceMore( ) {
421
+ self . demand ( )
422
+ }
423
+
424
+ func didTerminate( ) {
425
+ self . cancel ( )
426
+ }
427
+ }
428
+
393
429
protocol PSQLRowsDataSource {
394
430
395
431
func request( for stream: PSQLRowStream )
396
432
func cancel( for stream: PSQLRowStream )
397
433
398
434
}
399
435
400
- #if swift(>=5.6 )
436
+ #if swift(>=5.5 )
401
437
// Thread safety is guaranteed in the RowStream through dispatching onto the NIO EventLoop.
402
438
extension PSQLRowStream : @unchecked Sendable { }
403
439
#endif
0 commit comments