Skip to content

Commit 5afae5f

Browse files
committed
fixes responder to handle errors according to the spec
Signed-off-by: Oleh Dokuka <[email protected]>
1 parent 05fe717 commit 5afae5f

File tree

3 files changed

+134
-3
lines changed

3 files changed

+134
-3
lines changed

rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,26 @@ protected void hookOnComplete() {
497497
@Override
498498
protected void hookOnError(Throwable throwable) {
499499
if (sendingSubscriptions.remove(streamId, this)) {
500+
// specifically for requestChannel case so when Payload is invalid we will not be
501+
// sending CancelFrame and ErrorFrame
502+
// Note: CancelFrame is redundant and due to spec
503+
// (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel)
504+
// Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream
505+
// is
506+
// terminated on both Requester and Responder.
507+
// Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is
508+
// terminated on both the Requester and Responder.
509+
if (requestChannel != null && !requestChannel.isDisposed()) {
510+
if (channelProcessors.remove(streamId, requestChannel)) {
511+
try {
512+
requestChannel.dispose();
513+
} catch (Throwable e) {
514+
// ignore to ensure it does not blows up if it racing with async
515+
// cancel
516+
}
517+
}
518+
}
519+
500520
handleError(streamId, throwable);
501521
}
502522
}
@@ -535,6 +555,11 @@ public void accept(long l) {
535555
if (channelProcessors.remove(streamId, frames)) {
536556
if (signalType == SignalType.CANCEL) {
537557
sendProcessor.onNext(CancelFrameCodec.encode(allocator, streamId));
558+
} else if (signalType == SignalType.ON_ERROR) {
559+
Subscription subscription = sendingSubscriptions.remove(streamId);
560+
if (subscription != null) {
561+
subscription.cancel();
562+
}
538563
}
539564
}
540565
})

rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import io.rsocket.util.DefaultPayload;
6060
import io.rsocket.util.EmptyPayload;
6161
import java.util.Collection;
62+
import java.util.concurrent.CancellationException;
6263
import java.util.concurrent.ThreadLocalRandom;
6364
import java.util.concurrent.atomic.AtomicBoolean;
6465
import java.util.stream.Stream;
@@ -428,6 +429,24 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
428429

429430
ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE);
430431

432+
ByteBuf m1 = allocator.buffer();
433+
m1.writeCharSequence("m1", CharsetUtil.UTF_8);
434+
ByteBuf d1 = allocator.buffer();
435+
d1.writeCharSequence("d1", CharsetUtil.UTF_8);
436+
Payload np1 = ByteBufPayload.create(d1, m1);
437+
438+
ByteBuf m2 = allocator.buffer();
439+
m2.writeCharSequence("m2", CharsetUtil.UTF_8);
440+
ByteBuf d2 = allocator.buffer();
441+
d2.writeCharSequence("d2", CharsetUtil.UTF_8);
442+
Payload np2 = ByteBufPayload.create(d2, m2);
443+
444+
ByteBuf m3 = allocator.buffer();
445+
m3.writeCharSequence("m3", CharsetUtil.UTF_8);
446+
ByteBuf d3 = allocator.buffer();
447+
d3.writeCharSequence("d3", CharsetUtil.UTF_8);
448+
Payload np3 = ByteBufPayload.create(d3, m3);
449+
431450
FluxSink<Payload> sink = sinks[0];
432451
RaceTestUtils.race(
433452
() ->
@@ -436,14 +455,18 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
436455
() -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3),
437456
parallel),
438457
() -> {
439-
sink.next(ByteBufPayload.create("d1", "m1"));
440-
sink.next(ByteBufPayload.create("d2", "m2"));
441-
sink.next(ByteBufPayload.create("d3", "m3"));
458+
sink.next(np1);
459+
sink.next(np2);
460+
sink.next(np3);
442461
sink.error(new RuntimeException());
443462
},
444463
parallel);
445464

446465
Assertions.assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release);
466+
467+
assertSubscriber.assertTerminated()
468+
.assertError(CancellationException.class)
469+
.assertErrorMessage("Disposed");
447470
Assertions.assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release);
448471
rule.assertHasNoLeaks();
449472
}

rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import io.rsocket.util.EmptyPayload;
3434
import java.time.Duration;
3535
import java.util.List;
36+
import java.util.concurrent.CancellationException;
3637
import java.util.concurrent.atomic.AtomicReference;
3738
import org.assertj.core.api.Assertions;
3839
import org.junit.Rule;
@@ -287,6 +288,46 @@ public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2()
287288
requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber);
288289
}
289290

291+
@Test
292+
public void requestChannelCase_ErrorFromResponderShouldTerminatesStreamsOnBothSides() {
293+
TestPublisher<Payload> requesterPublisher = TestPublisher.create();
294+
AssertSubscriber<Payload> requesterSubscriber = new AssertSubscriber<>(0);
295+
296+
AssertSubscriber<Payload> responderSubscriber = new AssertSubscriber<>(0);
297+
TestPublisher<Payload> responderPublisher = TestPublisher.create();
298+
299+
initRequestChannelCase(
300+
requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber);
301+
302+
nextFromResponderPublisher(responderPublisher, requesterSubscriber);
303+
304+
nextFromRequesterPublisher(requesterPublisher, responderSubscriber);
305+
306+
// ensures both sides are terminated
307+
errorFromResponderPublisher(
308+
requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber);
309+
}
310+
311+
@Test
312+
public void requestChannelCase_ErrorFromRequesterShouldTerminatesStreamsOnBothSides() {
313+
TestPublisher<Payload> requesterPublisher = TestPublisher.create();
314+
AssertSubscriber<Payload> requesterSubscriber = new AssertSubscriber<>(0);
315+
316+
AssertSubscriber<Payload> responderSubscriber = new AssertSubscriber<>(0);
317+
TestPublisher<Payload> responderPublisher = TestPublisher.create();
318+
319+
initRequestChannelCase(
320+
requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber);
321+
322+
nextFromResponderPublisher(responderPublisher, requesterSubscriber);
323+
324+
nextFromRequesterPublisher(requesterPublisher, responderSubscriber);
325+
326+
// ensures both sides are terminated
327+
errorFromRequesterPublisher(
328+
requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber);
329+
}
330+
290331
void initRequestChannelCase(
291332
TestPublisher<Payload> requesterPublisher,
292333
AssertSubscriber<Payload> requesterSubscriber,
@@ -405,6 +446,48 @@ void cancelFromRequesterSubscriber(
405446
requesterPublisher.assertNoSubscribers();
406447
}
407448

449+
static final CustomRSocketException EXCEPTION = new CustomRSocketException(123456, "test");
450+
451+
void errorFromResponderPublisher(
452+
TestPublisher<Payload> requesterPublisher,
453+
AssertSubscriber<Payload> requesterSubscriber,
454+
TestPublisher<Payload> responderPublisher,
455+
AssertSubscriber<Payload> responderSubscriber) {
456+
// ensures that after sending cancel the whole requestChannel is terminated
457+
responderPublisher.error(EXCEPTION);
458+
// error should be propagated
459+
responderSubscriber.assertTerminated().assertError(CancellationException.class);
460+
requesterSubscriber
461+
.assertTerminated()
462+
.assertError(CustomRSocketException.class)
463+
.assertErrorMessage("test");
464+
// ensures that cancellation is propagated to the actual upstream
465+
requesterPublisher.assertWasCancelled();
466+
requesterPublisher.assertNoSubscribers();
467+
}
468+
469+
void errorFromRequesterPublisher(
470+
TestPublisher<Payload> requesterPublisher,
471+
AssertSubscriber<Payload> requesterSubscriber,
472+
TestPublisher<Payload> responderPublisher,
473+
AssertSubscriber<Payload> responderSubscriber) {
474+
// ensures that after sending cancel the whole requestChannel is terminated
475+
requesterPublisher.error(EXCEPTION);
476+
// error should be propagated
477+
responderSubscriber
478+
.assertTerminated()
479+
.assertError(CustomRSocketException.class)
480+
.assertErrorMessage("test");
481+
requesterSubscriber
482+
.assertTerminated()
483+
.assertError(CustomRSocketException.class)
484+
.assertErrorMessage("test");
485+
486+
// ensures that cancellation is propagated to the actual upstream
487+
responderPublisher.assertWasCancelled();
488+
responderPublisher.assertNoSubscribers();
489+
}
490+
408491
public static class SocketRule extends ExternalResource {
409492

410493
DirectProcessor<ByteBuf> serverProcessor;

0 commit comments

Comments
 (0)