37
37
import io .rsocket .frame .RequestResponseFrameFlyweight ;
38
38
import io .rsocket .frame .RequestStreamFrameFlyweight ;
39
39
import io .rsocket .frame .decoder .PayloadDecoder ;
40
- import io .rsocket .internal .RateLimitableRequestPublisher ;
40
+ import io .rsocket .internal .FluxSwitchOnFirst ;
41
+ import io .rsocket .internal .RateLimitableRequestSubscriber ;
41
42
import io .rsocket .internal .SynchronizedIntObjectHashMap ;
42
43
import io .rsocket .internal .UnboundedProcessor ;
43
44
import io .rsocket .internal .UnicastMonoEmpty ;
48
49
import io .rsocket .lease .RequesterLeaseHandler ;
49
50
import io .rsocket .util .MonoLifecycleHandler ;
50
51
import java .nio .channels .ClosedChannelException ;
52
+ import java .util .concurrent .atomic .AtomicBoolean ;
51
53
import java .util .concurrent .atomic .AtomicReferenceFieldUpdater ;
52
54
import java .util .function .Consumer ;
53
55
import java .util .function .LongConsumer ;
57
59
import org .reactivestreams .Processor ;
58
60
import org .reactivestreams .Publisher ;
59
61
import org .reactivestreams .Subscriber ;
60
- import reactor . core . publisher . BaseSubscriber ;
62
+ import org . reactivestreams . Subscription ;
61
63
import reactor .core .publisher .Flux ;
62
64
import reactor .core .publisher .Mono ;
63
65
import reactor .core .publisher .SignalType ;
@@ -81,7 +83,7 @@ class RSocketRequester implements RSocket {
81
83
private final PayloadDecoder payloadDecoder ;
82
84
private final Consumer <Throwable > errorConsumer ;
83
85
private final StreamIdSupplier streamIdSupplier ;
84
- private final IntObjectMap <RateLimitableRequestPublisher > senders ;
86
+ private final IntObjectMap <RateLimitableRequestSubscriber > senders ;
85
87
private final IntObjectMap <Processor <Payload , Payload >> receivers ;
86
88
private final UnboundedProcessor <ByteBuf > sendProcessor ;
87
89
private final RequesterLeaseHandler leaseHandler ;
@@ -255,10 +257,12 @@ private Flux<Payload> handleRequestStream(final Payload payload) {
255
257
256
258
final UnboundedProcessor <ByteBuf > sendProcessor = this .sendProcessor ;
257
259
final UnicastProcessor <Payload > receiver = UnicastProcessor .create ();
260
+ final AtomicBoolean payloadReleasedFlag = new AtomicBoolean (false );
258
261
259
262
receivers .put (streamId , receiver );
260
263
261
264
return receiver
265
+ .log ()
262
266
.doOnRequest (
263
267
new LongConsumer () {
264
268
@@ -276,7 +280,9 @@ public void accept(long n) {
276
280
n ,
277
281
payload .sliceMetadata ().retain (),
278
282
payload .sliceData ().retain ()));
279
- payload .release ();
283
+ if (!payloadReleasedFlag .getAndSet (true )) {
284
+ payload .release ();
285
+ }
280
286
} else if (contains (streamId ) && !receiver .isDisposed ()) {
281
287
sendProcessor .onNext (RequestNFrameFlyweight .encode (allocator , streamId , n ));
282
288
}
@@ -290,6 +296,9 @@ public void accept(long n) {
290
296
})
291
297
.doOnCancel (
292
298
() -> {
299
+ if (!payloadReleasedFlag .getAndSet (true )) {
300
+ payload .release ();
301
+ }
293
302
if (contains (streamId ) && !receiver .isDisposed ()) {
294
303
sendProcessor .onNext (CancelFrameFlyweight .encode (allocator , streamId ));
295
304
}
@@ -303,10 +312,58 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
303
312
return Flux .error (err );
304
313
}
305
314
315
+ return request .transform (
316
+ f ->
317
+ new FluxSwitchOnFirst <>(
318
+ f ,
319
+ (s , flux ) -> {
320
+ Payload payload = s .get ();
321
+ if (payload != null ) {
322
+ return handleChannel (flux , payload );
323
+ } else {
324
+ return flux ;
325
+ }
326
+ },
327
+ false ));
328
+ }
329
+
330
+ private Flux <? extends Payload > handleChannel (Flux <Payload > inboundFlux , Payload initialPayload ) {
306
331
final UnboundedProcessor <ByteBuf > sendProcessor = this .sendProcessor ;
307
- final UnicastProcessor < Payload > receiver = UnicastProcessor . create ( );
332
+ final AtomicBoolean payloadReleasedFlag = new AtomicBoolean ( false );
308
333
final int streamId = streamIdSupplier .nextStreamId (receivers );
309
334
335
+ final UnicastProcessor <Payload > receiver = UnicastProcessor .create ();
336
+ final RateLimitableRequestSubscriber <Payload > upstreamSubscriber =
337
+ new RateLimitableRequestSubscriber <Payload >(Queues .SMALL_BUFFER_SIZE ) {
338
+
339
+ @ Override
340
+ protected void hookOnNext (Payload payload ) {
341
+ final ByteBuf frame =
342
+ PayloadFrameFlyweight .encode (allocator , streamId , false , false , true , payload );
343
+
344
+ sendProcessor .onNext (frame );
345
+ payload .release ();
346
+ }
347
+
348
+ @ Override
349
+ protected void hookOnComplete () {
350
+ ByteBuf frame = PayloadFrameFlyweight .encodeComplete (allocator , streamId );
351
+ sendProcessor .onNext (frame );
352
+ }
353
+
354
+ @ Override
355
+ protected void hookOnError (Throwable t ) {
356
+ ByteBuf frame = ErrorFrameFlyweight .encode (allocator , streamId , t );
357
+ sendProcessor .onNext (frame );
358
+ receiver .dispose ();
359
+ }
360
+
361
+ @ Override
362
+ protected void hookFinally (SignalType type ) {
363
+ senders .remove (streamId , this );
364
+ }
365
+ };
366
+
310
367
return receiver
311
368
.doOnRequest (
312
369
new LongConsumer () {
@@ -317,85 +374,49 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
317
374
public void accept (long n ) {
318
375
if (firstRequest ) {
319
376
firstRequest = false ;
320
- request
321
- .transform (
322
- f -> {
323
- RateLimitableRequestPublisher <Payload > wrapped =
324
- RateLimitableRequestPublisher .wrap (f , Queues .SMALL_BUFFER_SIZE );
325
- // Need to set this to one for first the frame
326
- wrapped .request (1 );
327
- senders .put (streamId , wrapped );
328
- receivers .put (streamId , receiver );
329
-
330
- return wrapped ;
331
- })
332
- .subscribe (
333
- new BaseSubscriber <Payload >() {
334
-
335
- boolean firstPayload = true ;
336
-
337
- @ Override
338
- protected void hookOnNext (Payload payload ) {
339
- final ByteBuf frame ;
340
-
341
- if (firstPayload ) {
342
- firstPayload = false ;
343
- frame =
344
- RequestChannelFrameFlyweight .encode (
345
- allocator ,
346
- streamId ,
347
- false ,
348
- false ,
349
- n ,
350
- payload .sliceMetadata ().retain (),
351
- payload .sliceData ().retain ());
352
- } else {
353
- frame =
354
- PayloadFrameFlyweight .encode (
355
- allocator , streamId , false , false , true , payload );
356
- }
357
-
358
- sendProcessor .onNext (frame );
359
- payload .release ();
360
- }
361
-
362
- @ Override
363
- protected void hookOnComplete () {
364
- if (contains (streamId ) && !receiver .isDisposed ()) {
365
- sendProcessor .onNext (
366
- PayloadFrameFlyweight .encodeComplete (allocator , streamId ));
367
- }
368
- if (firstPayload ) {
369
- receiver .onComplete ();
370
- }
371
- }
372
-
373
- @ Override
374
- protected void hookOnError (Throwable t ) {
375
- errorConsumer .accept (t );
376
- receiver .dispose ();
377
- }
378
- });
379
- } else {
380
- if (contains (streamId ) && !receiver .isDisposed ()) {
381
- sendProcessor .onNext (RequestNFrameFlyweight .encode (allocator , streamId , n ));
377
+ senders .put (streamId , upstreamSubscriber );
378
+ receivers .put (streamId , receiver );
379
+
380
+ inboundFlux .subscribe (upstreamSubscriber );
381
+
382
+ ByteBuf frame =
383
+ RequestChannelFrameFlyweight .encode (
384
+ allocator ,
385
+ streamId ,
386
+ false ,
387
+ false ,
388
+ n ,
389
+ initialPayload .sliceMetadata ().retain (),
390
+ initialPayload .sliceData ().retain ());
391
+
392
+ sendProcessor .onNext (frame );
393
+
394
+ if (!payloadReleasedFlag .getAndSet (true )) {
395
+ initialPayload .release ();
382
396
}
397
+ } else {
398
+ sendProcessor .onNext (RequestNFrameFlyweight .encode (allocator , streamId , n ));
383
399
}
384
400
}
385
401
})
386
402
.doOnError (
387
403
t -> {
388
- if (contains (streamId ) && ! receiver . isDisposed ( )) {
389
- sendProcessor . onNext ( ErrorFrameFlyweight . encode ( allocator , streamId , t ) );
404
+ if (receivers . remove (streamId , receiver )) {
405
+ upstreamSubscriber . cancel ( );
390
406
}
391
407
})
392
408
.doOnCancel (
393
409
() -> {
394
- if (contains (streamId ) && !receiver .isDisposed ()) {
410
+ if (!payloadReleasedFlag .getAndSet (true )) {
411
+ initialPayload .release ();
412
+ }
413
+ if (contains (streamId )) {
395
414
sendProcessor .onNext (CancelFrameFlyweight .encode (allocator , streamId ));
415
+ if (receivers .remove (streamId , receiver )) {
416
+ upstreamSubscriber .cancel ();
417
+ }
396
418
}
397
- })
398
- .doFinally (s -> removeStreamReceiverAndSender (streamId ));
419
+ });
399
420
}
400
421
401
422
private Mono <Void > handleMetadataPush (Payload payload ) {
@@ -484,7 +505,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) {
484
505
break ;
485
506
case CANCEL :
486
507
{
487
- RateLimitableRequestPublisher sender = senders .remove (streamId );
508
+ Subscription sender = senders .remove (streamId );
488
509
if (sender != null ) {
489
510
sender .cancel ();
490
511
}
@@ -495,7 +516,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) {
495
516
break ;
496
517
case REQUEST_N :
497
518
{
498
- RateLimitableRequestPublisher sender = senders .get (streamId );
519
+ Subscription sender = senders .get (streamId );
499
520
if (sender != null ) {
500
521
int n = RequestNFrameFlyweight .requestN (frame );
501
522
sender .request (n >= Integer .MAX_VALUE ? Long .MAX_VALUE : n );
@@ -603,18 +624,6 @@ private void removeStreamReceiver(int streamId) {
603
624
}
604
625
}
605
626
606
- private void removeStreamReceiverAndSender (int streamId ) {
607
- /*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one
608
- of its views*/
609
- if (terminationError == null ) {
610
- receivers .remove (streamId );
611
- RateLimitableRequestPublisher <?> sender = senders .remove (streamId );
612
- if (sender != null ) {
613
- sender .cancel ();
614
- }
615
- }
616
- }
617
-
618
627
private void handleSendProcessorError (Throwable t ) {
619
628
connection .dispose ();
620
629
}
0 commit comments