Skip to content

Commit 9172b47

Browse files
mostroverkhovkbahr
authored andcommitted
fix issue when current and future streams are not terminated after connection dispose (#541)
Signed-off-by: Maksym Ostroverkhov <[email protected]> Signed-off-by: Kyle Bahr <[email protected]>
1 parent 3643636 commit 9172b47

File tree

3 files changed

+98
-25
lines changed

3 files changed

+98
-25
lines changed

rsocket-core/src/main/java/io/rsocket/RSocketClient.java

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@
2121
import io.rsocket.framing.FrameType;
2222
import io.rsocket.internal.LimitableRequestPublisher;
2323
import io.rsocket.internal.UnboundedProcessor;
24+
25+
import java.nio.channels.ClosedChannelException;
2426
import java.time.Duration;
2527
import java.util.concurrent.atomic.AtomicBoolean;
28+
import java.util.concurrent.atomic.AtomicReference;
2629
import java.util.function.Consumer;
2730
import java.util.function.Function;
2831
import org.jctools.maps.NonBlockingHashMapLong;
@@ -72,7 +75,7 @@ class RSocketClient implements RSocket {
7275
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
7376
this.sendProcessor = new UnboundedProcessor<>();
7477

75-
connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer);
78+
connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);
7679

7780
connection
7881
.send(sendProcessor)
@@ -92,7 +95,9 @@ class RSocketClient implements RSocket {
9295
keepAlive -> {
9396
String message =
9497
String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis());
95-
errorConsumer.accept(new ConnectionErrorException(message));
98+
ConnectionErrorException err = new ConnectionErrorException(message);
99+
lifecycle.terminate(err);
100+
errorConsumer.accept(err);
96101
connection.dispose();
97102
});
98103
keepAliveHandler.send().subscribe(sendProcessor::onNext);
@@ -157,12 +162,7 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
157162

158163
@Override
159164
public Mono<Void> metadataPush(Payload payload) {
160-
return Mono.fromRunnable(
161-
() -> {
162-
final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1);
163-
payload.release();
164-
sendProcessor.onNext(requestFrame);
165-
});
165+
return handleMetadataPush(payload);
166166
}
167167

168168
@Override
@@ -187,7 +187,7 @@ public Mono<Void> onClose() {
187187

188188
private Mono<Void> handleFireAndForget(Payload payload) {
189189
return lifecycle
190-
.started()
190+
.active()
191191
.then(
192192
Mono.fromRunnable(
193193
() -> {
@@ -201,7 +201,7 @@ private Mono<Void> handleFireAndForget(Payload payload) {
201201

202202
private Flux<Payload> handleRequestStream(final Payload payload) {
203203
return lifecycle
204-
.started()
204+
.active()
205205
.thenMany(
206206
Flux.defer(
207207
() -> {
@@ -247,7 +247,7 @@ private Flux<Payload> handleRequestStream(final Payload payload) {
247247

248248
private Mono<Payload> handleRequestResponse(final Payload payload) {
249249
return lifecycle
250-
.started()
250+
.active()
251251
.then(
252252
Mono.defer(
253253
() -> {
@@ -274,7 +274,7 @@ private Mono<Payload> handleRequestResponse(final Payload payload) {
274274

275275
private Flux<Payload> handleChannel(Flux<Payload> request) {
276276
return lifecycle
277-
.started()
277+
.active()
278278
.thenMany(
279279
Flux.defer(
280280
() -> {
@@ -365,11 +365,25 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
365365
}));
366366
}
367367

368+
private Mono<Void> handleMetadataPush(Payload payload) {
369+
return lifecycle
370+
.active()
371+
.then(Mono.fromRunnable(
372+
() -> {
373+
final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1);
374+
payload.release();
375+
sendProcessor.onNext(requestFrame);
376+
}));
377+
}
378+
368379
private boolean contains(int streamId) {
369380
return receivers.containsKey(streamId);
370381
}
371382

372-
protected void cleanup() {
383+
protected void terminate() {
384+
385+
lifecycle.terminate(new ClosedChannelException());
386+
373387
if (keepAliveHandler != null) {
374388
keepAliveHandler.dispose();
375389
}
@@ -397,13 +411,8 @@ private synchronized void cleanUpLimitableRequestPublisher(
397411
}
398412

399413
private synchronized void cleanUpSubscriber(UnicastProcessor<?> subscriber) {
400-
Throwable err = lifecycle.terminationError();
401414
try {
402-
if (err != null) {
403-
subscriber.onError(err);
404-
} else {
405-
subscriber.cancel();
406-
}
415+
subscriber.onError(lifecycle.terminationError());
407416
} catch (Throwable t) {
408417
errorConsumer.accept(t);
409418
}
@@ -519,12 +528,12 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, Frame
519528

520529
private static class Lifecycle {
521530

522-
private volatile Throwable terminationError;
531+
private final AtomicReference<Throwable> terminationError = new AtomicReference<>();
523532

524-
public Mono<Void> started() {
533+
public Mono<Void> active() {
525534
return Mono.create(
526535
sink -> {
527-
Throwable err = terminationError;
536+
Throwable err = terminationError();
528537
if (err == null) {
529538
sink.success();
530539
} else {
@@ -534,11 +543,11 @@ public Mono<Void> started() {
534543
}
535544

536545
public void terminate(Throwable err) {
537-
this.terminationError = err;
546+
this.terminationError.compareAndSet(null, err);
538547
}
539548

540549
public Throwable terminationError() {
541-
return terminationError;
550+
return terminationError.get();
542551
}
543552
}
544553
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package io.rsocket;
2+
3+
import io.rsocket.RSocketClientTest.ClientSocketRule;
4+
import io.rsocket.util.EmptyPayload;
5+
import org.junit.Rule;
6+
import org.junit.Test;
7+
import org.junit.runner.RunWith;
8+
import org.junit.runners.Parameterized;
9+
import org.reactivestreams.Publisher;
10+
import reactor.core.publisher.Flux;
11+
import reactor.core.publisher.Mono;
12+
import reactor.test.StepVerifier;
13+
14+
import java.nio.channels.ClosedChannelException;
15+
import java.time.Duration;
16+
import java.util.Arrays;
17+
import java.util.function.Function;
18+
19+
@RunWith(Parameterized.class)
20+
public class RSocketClientTerminationTest {
21+
22+
@Rule
23+
public final ClientSocketRule rule = new ClientSocketRule();
24+
private Function<RSocket, ? extends Publisher<?>> interaction;
25+
26+
public RSocketClientTerminationTest(Function<RSocket, ? extends Publisher<?>> interaction) {
27+
this.interaction = interaction;
28+
}
29+
30+
@Test
31+
public void testCurrentStreamIsTerminatedOnConnectionClose() {
32+
RSocketClient rSocket = rule.socket;
33+
34+
Mono.delay(Duration.ofSeconds(1))
35+
.doOnNext(v -> rule.connection.dispose())
36+
.subscribe();
37+
38+
StepVerifier.create(interaction.apply(rSocket))
39+
.expectError(ClosedChannelException.class)
40+
.verify(Duration.ofSeconds(5));
41+
}
42+
43+
@Test
44+
public void testSubsequentStreamIsTerminatedAfterConnectionClose() {
45+
RSocketClient rSocket = rule.socket;
46+
47+
rule.connection.dispose();
48+
StepVerifier.create(interaction.apply(rSocket))
49+
.expectError(ClosedChannelException.class)
50+
.verify(Duration.ofSeconds(5));
51+
}
52+
53+
@Parameterized.Parameters
54+
public static Iterable<Function<RSocket, ? extends Publisher<?>>> rsocketInteractions() {
55+
EmptyPayload payload = EmptyPayload.INSTANCE;
56+
Publisher<Payload> payloadStream = Flux.just(payload);
57+
58+
Function<RSocket, Mono<Payload>> resp = rSocket -> rSocket.requestResponse(payload);
59+
Function<RSocket, Flux<Payload>> stream = rSocket -> rSocket.requestStream(payload);
60+
Function<RSocket, Flux<Payload>> channel = rSocket -> rSocket.requestChannel(payloadStream);
61+
62+
return Arrays.asList(resp, stream, channel);
63+
}
64+
}

rsocket-core/src/test/java/io/rsocket/RSocketClientTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ protected RSocketClient newRSocket() {
215215
throwable -> errors.add(throwable),
216216
StreamIdSupplier.clientSupplier(),
217217
Duration.ofMillis(100),
218-
Duration.ofMillis(100),
218+
Duration.ofMillis(10_000),
219219
4);
220220
}
221221

0 commit comments

Comments
 (0)