40
40
import io .rsocket .frame .RequestResponseFrameFlyweight ;
41
41
import io .rsocket .frame .RequestStreamFrameFlyweight ;
42
42
import io .rsocket .frame .decoder .PayloadDecoder ;
43
- import io .rsocket .internal .RateLimitableRequestPublisher ;
43
+ import io .rsocket .internal .FluxSwitchOnFirst ;
44
+ import io .rsocket .internal .RateLimitableRequestSubscriber ;
44
45
import io .rsocket .internal .SynchronizedIntObjectHashMap ;
45
46
import io .rsocket .internal .UnboundedProcessor ;
46
47
import io .rsocket .internal .UnicastMonoEmpty ;
51
52
import io .rsocket .lease .RequesterLeaseHandler ;
52
53
import io .rsocket .util .MonoLifecycleHandler ;
53
54
import java .nio .channels .ClosedChannelException ;
55
+ import java .util .concurrent .atomic .AtomicBoolean ;
54
56
import java .util .concurrent .atomic .AtomicReferenceFieldUpdater ;
55
57
import java .util .function .Consumer ;
56
58
import java .util .function .LongConsumer ;
60
62
import org .reactivestreams .Processor ;
61
63
import org .reactivestreams .Publisher ;
62
64
import org .reactivestreams .Subscriber ;
63
- import reactor . core . publisher . BaseSubscriber ;
65
+ import org . reactivestreams . Subscription ;
64
66
import reactor .core .publisher .Flux ;
65
67
import reactor .core .publisher .Mono ;
66
68
import reactor .core .publisher .SignalType ;
@@ -84,7 +86,7 @@ class RSocketRequester implements RSocket {
84
86
private final PayloadDecoder payloadDecoder ;
85
87
private final Consumer <Throwable > errorConsumer ;
86
88
private final StreamIdSupplier streamIdSupplier ;
87
- private final IntObjectMap <RateLimitableRequestPublisher > senders ;
89
+ private final IntObjectMap <RateLimitableRequestSubscriber > senders ;
88
90
private final IntObjectMap <Processor <Payload , Payload >> receivers ;
89
91
private final UnboundedProcessor <ByteBuf > sendProcessor ;
90
92
private final RequesterLeaseHandler leaseHandler ;
@@ -258,10 +260,12 @@ private Flux<Payload> handleRequestStream(final Payload payload) {
258
260
259
261
final UnboundedProcessor <ByteBuf > sendProcessor = this .sendProcessor ;
260
262
final UnicastProcessor <Payload > receiver = UnicastProcessor .create ();
263
+ final AtomicBoolean payloadReleasedFlag = new AtomicBoolean (false );
261
264
262
265
receivers .put (streamId , receiver );
263
266
264
267
return receiver
268
+ .log ()
265
269
.doOnRequest (
266
270
new LongConsumer () {
267
271
@@ -279,7 +283,9 @@ public void accept(long n) {
279
283
n ,
280
284
payload .sliceMetadata ().retain (),
281
285
payload .sliceData ().retain ()));
282
- payload .release ();
286
+ if (!payloadReleasedFlag .getAndSet (true )) {
287
+ payload .release ();
288
+ }
283
289
} else if (contains (streamId ) && !receiver .isDisposed ()) {
284
290
sendProcessor .onNext (RequestNFrameFlyweight .encode (allocator , streamId , n ));
285
291
}
@@ -293,6 +299,9 @@ public void accept(long n) {
293
299
})
294
300
.doOnCancel (
295
301
() -> {
302
+ if (!payloadReleasedFlag .getAndSet (true )) {
303
+ payload .release ();
304
+ }
296
305
if (contains (streamId ) && !receiver .isDisposed ()) {
297
306
sendProcessor .onNext (CancelFrameFlyweight .encode (allocator , streamId ));
298
307
}
@@ -306,10 +315,58 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
306
315
return Flux .error (err );
307
316
}
308
317
318
+ return request .transform (
319
+ f ->
320
+ new FluxSwitchOnFirst <>(
321
+ f ,
322
+ (s , flux ) -> {
323
+ Payload payload = s .get ();
324
+ if (payload != null ) {
325
+ return handleChannel (flux , payload );
326
+ } else {
327
+ return flux ;
328
+ }
329
+ },
330
+ false ));
331
+ }
332
+
333
+ private Flux <? extends Payload > handleChannel (Flux <Payload > inboundFlux , Payload initialPayload ) {
309
334
final UnboundedProcessor <ByteBuf > sendProcessor = this .sendProcessor ;
310
- final UnicastProcessor < Payload > receiver = UnicastProcessor . create ( );
335
+ final AtomicBoolean payloadReleasedFlag = new AtomicBoolean ( false );
311
336
final int streamId = streamIdSupplier .nextStreamId (receivers );
312
337
338
+ final UnicastProcessor <Payload > receiver = UnicastProcessor .create ();
339
+ final RateLimitableRequestSubscriber <Payload > upstreamSubscriber =
340
+ new RateLimitableRequestSubscriber <Payload >(Queues .SMALL_BUFFER_SIZE ) {
341
+
342
+ @ Override
343
+ protected void hookOnNext (Payload payload ) {
344
+ final ByteBuf frame =
345
+ PayloadFrameFlyweight .encode (allocator , streamId , false , false , true , payload );
346
+
347
+ sendProcessor .onNext (frame );
348
+ payload .release ();
349
+ }
350
+
351
+ @ Override
352
+ protected void hookOnComplete () {
353
+ ByteBuf frame = PayloadFrameFlyweight .encodeComplete (allocator , streamId );
354
+ sendProcessor .onNext (frame );
355
+ }
356
+
357
+ @ Override
358
+ protected void hookOnError (Throwable t ) {
359
+ ByteBuf frame = ErrorFrameFlyweight .encode (allocator , streamId , t );
360
+ sendProcessor .onNext (frame );
361
+ receiver .dispose ();
362
+ }
363
+
364
+ @ Override
365
+ protected void hookFinally (SignalType type ) {
366
+ senders .remove (streamId , this );
367
+ }
368
+ };
369
+
313
370
return receiver
314
371
.doOnRequest (
315
372
new LongConsumer () {
@@ -320,85 +377,49 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
320
377
public void accept (long n ) {
321
378
if (firstRequest ) {
322
379
firstRequest = false ;
323
- request
324
- .transform (
325
- f -> {
326
- RateLimitableRequestPublisher <Payload > wrapped =
327
- RateLimitableRequestPublisher .wrap (f , Queues .SMALL_BUFFER_SIZE );
328
- // Need to set this to one for first the frame
329
- wrapped .request (1 );
330
- senders .put (streamId , wrapped );
331
- receivers .put (streamId , receiver );
332
-
333
- return wrapped ;
334
- })
335
- .subscribe (
336
- new BaseSubscriber <Payload >() {
337
-
338
- boolean firstPayload = true ;
339
-
340
- @ Override
341
- protected void hookOnNext (Payload payload ) {
342
- final ByteBuf frame ;
343
-
344
- if (firstPayload ) {
345
- firstPayload = false ;
346
- frame =
347
- RequestChannelFrameFlyweight .encode (
348
- allocator ,
349
- streamId ,
350
- false ,
351
- false ,
352
- n ,
353
- payload .sliceMetadata ().retain (),
354
- payload .sliceData ().retain ());
355
- } else {
356
- frame =
357
- PayloadFrameFlyweight .encode (
358
- allocator , streamId , false , false , true , payload );
359
- }
360
-
361
- sendProcessor .onNext (frame );
362
- payload .release ();
363
- }
364
-
365
- @ Override
366
- protected void hookOnComplete () {
367
- if (contains (streamId ) && !receiver .isDisposed ()) {
368
- sendProcessor .onNext (
369
- PayloadFrameFlyweight .encodeComplete (allocator , streamId ));
370
- }
371
- if (firstPayload ) {
372
- receiver .onComplete ();
373
- }
374
- }
375
-
376
- @ Override
377
- protected void hookOnError (Throwable t ) {
378
- errorConsumer .accept (t );
379
- receiver .dispose ();
380
- }
381
- });
382
- } else {
383
- if (contains (streamId ) && !receiver .isDisposed ()) {
384
- sendProcessor .onNext (RequestNFrameFlyweight .encode (allocator , streamId , n ));
380
+ senders .put (streamId , upstreamSubscriber );
381
+ receivers .put (streamId , receiver );
382
+
383
+ inboundFlux .subscribe (upstreamSubscriber );
384
+
385
+ ByteBuf frame =
386
+ RequestChannelFrameFlyweight .encode (
387
+ allocator ,
388
+ streamId ,
389
+ false ,
390
+ false ,
391
+ n ,
392
+ initialPayload .sliceMetadata ().retain (),
393
+ initialPayload .sliceData ().retain ());
394
+
395
+ sendProcessor .onNext (frame );
396
+
397
+ if (!payloadReleasedFlag .getAndSet (true )) {
398
+ initialPayload .release ();
385
399
}
400
+ } else {
401
+ sendProcessor .onNext (RequestNFrameFlyweight .encode (allocator , streamId , n ));
386
402
}
387
403
}
388
404
})
389
405
.doOnError (
390
406
t -> {
391
- if (contains (streamId ) && ! receiver . isDisposed ( )) {
392
- sendProcessor . onNext ( ErrorFrameFlyweight . encode ( allocator , streamId , t ) );
407
+ if (receivers . remove (streamId , receiver )) {
408
+ upstreamSubscriber . cancel ( );
393
409
}
394
410
})
395
411
.doOnCancel (
396
412
() -> {
397
- if (contains (streamId ) && !receiver .isDisposed ()) {
413
+ if (!payloadReleasedFlag .getAndSet (true )) {
414
+ initialPayload .release ();
415
+ }
416
+ if (contains (streamId )) {
398
417
sendProcessor .onNext (CancelFrameFlyweight .encode (allocator , streamId ));
418
+ if (receivers .remove (streamId , receiver )) {
419
+ upstreamSubscriber .cancel ();
420
+ }
399
421
}
400
- })
401
- .doFinally (s -> removeStreamReceiverAndSender (streamId ));
422
+ });
402
423
}
403
424
404
425
private Mono <Void > handleMetadataPush (Payload payload ) {
@@ -487,7 +508,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) {
487
508
break ;
488
509
case CANCEL :
489
510
{
490
- RateLimitableRequestPublisher sender = senders .remove (streamId );
511
+ Subscription sender = senders .remove (streamId );
491
512
if (sender != null ) {
492
513
sender .cancel ();
493
514
}
@@ -498,7 +519,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) {
498
519
break ;
499
520
case REQUEST_N :
500
521
{
501
- RateLimitableRequestPublisher sender = senders .get (streamId );
522
+ Subscription sender = senders .get (streamId );
502
523
if (sender != null ) {
503
524
int n = RequestNFrameFlyweight .requestN (frame );
504
525
sender .request (n >= Integer .MAX_VALUE ? Long .MAX_VALUE : n );
@@ -606,18 +627,6 @@ private void removeStreamReceiver(int streamId) {
606
627
}
607
628
}
608
629
609
- private void removeStreamReceiverAndSender (int streamId ) {
610
- /*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one
611
- of its views*/
612
- if (terminationError == null ) {
613
- receivers .remove (streamId );
614
- RateLimitableRequestPublisher <?> sender = senders .remove (streamId );
615
- if (sender != null ) {
616
- sender .cancel ();
617
- }
618
- }
619
- }
620
-
621
630
private void handleSendProcessorError (Throwable t ) {
622
631
connection .dispose ();
623
632
}
0 commit comments