Skip to content

Commit 3bf4c98

Browse files
committed
ensures behaviour of the first payload sending and racing with cancel
Signed-off-by: Oleh Dokuka <[email protected]>
1 parent d2b8f2a commit 3bf4c98

File tree

1 file changed

+121
-65
lines changed

1 file changed

+121
-65
lines changed

rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java

Lines changed: 121 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
import io.rsocket.util.MonoLifecycleHandler;
5555
import java.nio.channels.ClosedChannelException;
5656
import java.util.concurrent.CancellationException;
57-
import java.util.concurrent.atomic.AtomicBoolean;
57+
import java.util.concurrent.atomic.AtomicInteger;
5858
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
5959
import java.util.function.Consumer;
6060
import java.util.function.LongConsumer;
@@ -242,9 +242,14 @@ private Mono<Payload> handleRequestResponse(final Payload payload) {
242242
new MonoLifecycleHandler<Payload>() {
243243
@Override
244244
public void doOnSubscribe() {
245-
final ByteBuf requestFrame =
246-
RequestResponseFrameFlyweight.encodeReleasingPayload(
247-
allocator, streamId, payload);
245+
final ByteBuf requestFrame;
246+
try {
247+
requestFrame =
248+
RequestResponseFrameFlyweight.encodeReleasingPayload(
249+
allocator, streamId, payload);
250+
} catch (IllegalReferenceCountException e) {
251+
return;
252+
}
248253

249254
sendProcessor.onNext(requestFrame);
250255
}
@@ -260,6 +265,7 @@ public void doOnTerminal(
260265
removeStreamReceiver(streamId);
261266
}
262267
});
268+
263269
receivers.put(streamId, receiver);
264270

265271
return receiver.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
@@ -281,7 +287,7 @@ private Flux<Payload> handleRequestStream(final Payload payload) {
281287

282288
final UnboundedProcessor<ByteBuf> sendProcessor = this.sendProcessor;
283289
final UnicastProcessor<Payload> receiver = UnicastProcessor.create();
284-
final AtomicBoolean payloadReleasedFlag = new AtomicBoolean(false);
290+
final AtomicInteger wip = new AtomicInteger(0);
285291

286292
receivers.put(streamId, receiver);
287293

@@ -295,30 +301,56 @@ private Flux<Payload> handleRequestStream(final Payload payload) {
295301
public void accept(long n) {
296302
if (firstRequest && !receiver.isDisposed()) {
297303
firstRequest = false;
298-
if (!payloadReleasedFlag.getAndSet(true)) {
299-
sendProcessor.onNext(
300-
RequestStreamFrameFlyweight.encodeReleasingPayload(
301-
allocator, streamId, n, payload));
304+
if (wip.getAndIncrement() != 0) {
305+
// no need to do anything.
306+
// stream was canceled and fist payload has already been discarded
307+
return;
308+
}
309+
int missed = 1;
310+
boolean firstHasBeenSent = false;
311+
for (; ; ) {
312+
if (!firstHasBeenSent) {
313+
ByteBuf frame;
314+
try {
315+
frame =
316+
RequestStreamFrameFlyweight.encodeReleasingPayload(
317+
allocator, streamId, n, payload);
318+
} catch (IllegalReferenceCountException e) {
319+
return;
320+
}
321+
322+
sendProcessor.onNext(frame);
323+
firstHasBeenSent = true;
324+
} else {
325+
// if first frame was sent but we cycling again, it means that wip was
326+
// incremented at doOnCancel
327+
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
328+
return;
329+
}
330+
331+
missed = wip.addAndGet(-missed);
332+
if (missed == 0) {
333+
return;
334+
}
302335
}
303-
} else if (contains(streamId) && !receiver.isDisposed()) {
336+
} else {
304337
sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n));
305338
}
306339
}
307340
})
308-
.doOnError(
309-
t -> {
310-
if (contains(streamId) && !receiver.isDisposed()) {
311-
sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t));
312-
}
313-
})
314341
.doOnCancel(
315342
() -> {
316-
if (!payloadReleasedFlag.getAndSet(true)) {
317-
payload.release();
343+
if (wip.getAndIncrement() != 0) {
344+
return;
318345
}
319-
if (contains(streamId) && !receiver.isDisposed()) {
320-
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
346+
347+
// check if we need to release payload
348+
// only applicable if the cancel appears earlier than actual request
349+
if (payload.refCnt() > 0) {
350+
payload.release();
321351
}
352+
353+
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
322354
})
323355
.doFinally(s -> removeStreamReceiver(streamId))
324356
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
@@ -330,30 +362,32 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
330362
return Flux.error(err);
331363
}
332364

333-
return request.switchOnFirst(
334-
(s, flux) -> {
335-
Payload payload = s.get();
336-
if (payload != null) {
337-
if (!PayloadValidationUtils.isValid(mtu, payload)) {
338-
payload.release();
339-
final IllegalArgumentException t =
340-
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
341-
errorConsumer.accept(t);
342-
return Mono.error(t);
343-
}
344-
return handleChannel(payload, flux);
345-
} else {
346-
return flux;
347-
}
348-
},
349-
false);
365+
return request
366+
.switchOnFirst(
367+
(s, flux) -> {
368+
Payload payload = s.get();
369+
if (payload != null) {
370+
if (!PayloadValidationUtils.isValid(mtu, payload)) {
371+
payload.release();
372+
final IllegalArgumentException t =
373+
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
374+
errorConsumer.accept(t);
375+
return Mono.error(t);
376+
}
377+
return handleChannel(payload, flux);
378+
} else {
379+
return flux;
380+
}
381+
},
382+
false)
383+
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
350384
}
351385

352386
private Flux<? extends Payload> handleChannel(Payload initialPayload, Flux<Payload> inboundFlux) {
353387
final UnboundedProcessor<ByteBuf> sendProcessor = this.sendProcessor;
354-
final AtomicBoolean payloadReleasedFlag = new AtomicBoolean(false);
355388
final int streamId = streamIdSupplier.nextStreamId(receivers);
356389

390+
final AtomicInteger wip = new AtomicInteger(0);
357391
final UnicastProcessor<Payload> receiver = UnicastProcessor.create();
358392
final BaseSubscriber<Payload> upstreamSubscriber =
359393
new BaseSubscriber<Payload>() {
@@ -421,43 +455,65 @@ protected void hookFinally(SignalType type) {
421455
public void accept(long n) {
422456
if (firstRequest) {
423457
firstRequest = false;
424-
senders.put(streamId, upstreamSubscriber);
425-
receivers.put(streamId, receiver);
426-
427-
inboundFlux
428-
.limitRate(Queues.SMALL_BUFFER_SIZE)
429-
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER)
430-
.subscribe(upstreamSubscriber);
431-
if (!payloadReleasedFlag.getAndSet(true)) {
432-
ByteBuf frame =
433-
RequestChannelFrameFlyweight.encodeReleasingPayload(
434-
allocator, streamId, false, n, initialPayload);
435-
436-
sendProcessor.onNext(frame);
458+
if (wip.getAndIncrement() != 0) {
459+
// no need to do anything.
460+
// stream was canceled and fist payload has already been discarded
461+
return;
462+
}
463+
int missed = 1;
464+
boolean firstHasBeenSent = false;
465+
for (; ; ) {
466+
if (!firstHasBeenSent) {
467+
ByteBuf frame;
468+
try {
469+
frame =
470+
RequestChannelFrameFlyweight.encodeReleasingPayload(
471+
allocator, streamId, false, n, initialPayload);
472+
} catch (IllegalReferenceCountException e) {
473+
return;
474+
}
475+
476+
senders.put(streamId, upstreamSubscriber);
477+
receivers.put(streamId, receiver);
478+
479+
inboundFlux
480+
.limitRate(Queues.SMALL_BUFFER_SIZE)
481+
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER)
482+
.subscribe(upstreamSubscriber);
483+
484+
sendProcessor.onNext(frame);
485+
firstHasBeenSent = true;
486+
} else {
487+
// if first frame was sent but we cycling again, it means that wip was
488+
// incremented at doOnCancel
489+
senders.remove(streamId, upstreamSubscriber);
490+
receivers.remove(streamId, receiver);
491+
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
492+
return;
493+
}
494+
495+
missed = wip.addAndGet(-missed);
496+
if (missed == 0) {
497+
return;
498+
}
437499
}
438500
} else {
439501
sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n));
440502
}
441503
}
442504
})
443-
.doOnError(
444-
t -> {
445-
if (receivers.remove(streamId, receiver)) {
446-
upstreamSubscriber.cancel();
447-
}
448-
})
449-
.doOnComplete(() -> receivers.remove(streamId, receiver))
505+
.doOnError(t -> upstreamSubscriber.cancel())
450506
.doOnCancel(
451507
() -> {
452-
if (!payloadReleasedFlag.getAndSet(true)) {
453-
initialPayload.release();
454-
}
455-
if (receivers.remove(streamId, receiver)) {
456-
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
457-
upstreamSubscriber.cancel();
508+
upstreamSubscriber.cancel();
509+
if (wip.getAndIncrement() != 0) {
510+
return;
458511
}
512+
513+
// need to send frame only if RequestChannelFrame was sent
514+
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
459515
})
460-
.doOnDiscard(ReferenceCounted.class, DROPPED_ELEMENTS_CONSUMER);
516+
.doFinally(__ -> receivers.remove(streamId, receiver));
461517
}
462518

463519
private Mono<Void> handleMetadataPush(Payload payload) {

0 commit comments

Comments
 (0)