Skip to content

Fix current and future streams not terminating after connection dispose #541

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 28, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions rsocket-core/src/main/java/io/rsocket/RSocketClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import io.rsocket.framing.FrameType;
import io.rsocket.internal.LimitableRequestPublisher;
import io.rsocket.internal.UnboundedProcessor;

import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import org.jctools.maps.NonBlockingHashMapLong;
Expand Down Expand Up @@ -72,7 +75,7 @@ class RSocketClient implements RSocket {
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
this.sendProcessor = new UnboundedProcessor<>();

connection.onClose().doFinally(signalType -> cleanup()).subscribe(null, errorConsumer);
connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);

connection
.send(sendProcessor)
Expand All @@ -92,7 +95,9 @@ class RSocketClient implements RSocket {
keepAlive -> {
String message =
String.format("No keep-alive acks for %d ms", keepAlive.getTimeoutMillis());
errorConsumer.accept(new ConnectionErrorException(message));
ConnectionErrorException err = new ConnectionErrorException(message);
lifecycle.terminate(err);
errorConsumer.accept(err);
connection.dispose();
});
keepAliveHandler.send().subscribe(sendProcessor::onNext);
Expand Down Expand Up @@ -157,12 +162,7 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {

@Override
public Mono<Void> metadataPush(Payload payload) {
return Mono.fromRunnable(
() -> {
final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1);
payload.release();
sendProcessor.onNext(requestFrame);
});
return handleMetadataPush(payload);
}

@Override
Expand All @@ -187,7 +187,7 @@ public Mono<Void> onClose() {

private Mono<Void> handleFireAndForget(Payload payload) {
return lifecycle
.started()
.active()
.then(
Mono.fromRunnable(
() -> {
Expand All @@ -201,7 +201,7 @@ private Mono<Void> handleFireAndForget(Payload payload) {

private Flux<Payload> handleRequestStream(final Payload payload) {
return lifecycle
.started()
.active()
.thenMany(
Flux.defer(
() -> {
Expand Down Expand Up @@ -247,7 +247,7 @@ private Flux<Payload> handleRequestStream(final Payload payload) {

private Mono<Payload> handleRequestResponse(final Payload payload) {
return lifecycle
.started()
.active()
.then(
Mono.defer(
() -> {
Expand All @@ -274,7 +274,7 @@ private Mono<Payload> handleRequestResponse(final Payload payload) {

private Flux<Payload> handleChannel(Flux<Payload> request) {
return lifecycle
.started()
.active()
.thenMany(
Flux.defer(
() -> {
Expand Down Expand Up @@ -365,11 +365,25 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
}));
}

private Mono<Void> handleMetadataPush(Payload payload) {
return lifecycle
.active()
.then(Mono.fromRunnable(
() -> {
final Frame requestFrame = Frame.Request.from(0, FrameType.METADATA_PUSH, payload, 1);
payload.release();
sendProcessor.onNext(requestFrame);
}));
}

private boolean contains(int streamId) {
return receivers.containsKey(streamId);
}

protected void cleanup() {
protected void terminate() {

lifecycle.terminate(new ClosedChannelException());

if (keepAliveHandler != null) {
keepAliveHandler.dispose();
}
Expand Down Expand Up @@ -397,13 +411,8 @@ private synchronized void cleanUpLimitableRequestPublisher(
}

private synchronized void cleanUpSubscriber(UnicastProcessor<?> subscriber) {
Throwable err = lifecycle.terminationError();
try {
if (err != null) {
subscriber.onError(err);
} else {
subscriber.cancel();
}
subscriber.onError(lifecycle.terminationError());
} catch (Throwable t) {
errorConsumer.accept(t);
}
Expand Down Expand Up @@ -519,12 +528,12 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, Frame

private static class Lifecycle {

private volatile Throwable terminationError;
private final AtomicReference<Throwable> terminationError = new AtomicReference<>();

public Mono<Void> started() {
public Mono<Void> active() {
return Mono.create(
sink -> {
Throwable err = terminationError;
Throwable err = terminationError();
if (err == null) {
sink.success();
} else {
Expand All @@ -534,11 +543,11 @@ public Mono<Void> started() {
}

public void terminate(Throwable err) {
this.terminationError = err;
this.terminationError.compareAndSet(null, err);
}

public Throwable terminationError() {
return terminationError;
return terminationError.get();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.rsocket;

import io.rsocket.RSocketClientTest.ClientSocketRule;
import io.rsocket.util.EmptyPayload;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.Arrays;
import java.util.function.Function;

@RunWith(Parameterized.class)
public class RSocketClientTerminationTest {

@Rule
public final ClientSocketRule rule = new ClientSocketRule();
private Function<RSocket, ? extends Publisher<?>> interaction;

public RSocketClientTerminationTest(Function<RSocket, ? extends Publisher<?>> interaction) {
this.interaction = interaction;
}

@Test
public void testCurrentStreamIsTerminatedOnConnectionClose() {
RSocketClient rSocket = rule.socket;

Mono.delay(Duration.ofSeconds(1))
.doOnNext(v -> rule.connection.dispose())
.subscribe();

StepVerifier.create(interaction.apply(rSocket))
.expectError(ClosedChannelException.class)
.verify(Duration.ofSeconds(5));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specific to this PR: Are these tests "wall clock" slow? Should we tag someway to run in CI, but not in normal developer builds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this will be great improvement for local builds - will do that in follow-up. Also I was thinking about VirtualTimeScheduler, but eventually preferred approach with no magic in It as this way tests look less fragile

}

@Test
public void testSubsequentStreamIsTerminatedAfterConnectionClose() {
RSocketClient rSocket = rule.socket;

rule.connection.dispose();
StepVerifier.create(interaction.apply(rSocket))
.expectError(ClosedChannelException.class)
.verify(Duration.ofSeconds(5));
}

@Parameterized.Parameters
public static Iterable<Function<RSocket, ? extends Publisher<?>>> rsocketInteractions() {
EmptyPayload payload = EmptyPayload.INSTANCE;
Publisher<Payload> payloadStream = Flux.just(payload);

Function<RSocket, Mono<Payload>> resp = rSocket -> rSocket.requestResponse(payload);
Function<RSocket, Flux<Payload>> stream = rSocket -> rSocket.requestStream(payload);
Function<RSocket, Flux<Payload>> channel = rSocket -> rSocket.requestChannel(payloadStream);

return Arrays.asList(resp, stream, channel);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ protected RSocketClient newRSocket() {
throwable -> errors.add(throwable),
StreamIdSupplier.clientSupplier(),
Duration.ofMillis(100),
Duration.ofMillis(100),
Duration.ofMillis(10_000),
4);
}

Expand Down