Skip to content

RSocketRequester: fix concurrent modification of senders & receivers … #706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 98 additions & 99 deletions rsocket-core/src/main/java/io/rsocket/RSocketRequester.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
import java.util.function.Supplier;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.reactivestreams.Processor;
Expand All @@ -56,6 +57,11 @@ class RSocketRequester implements RSocket {
private static final AtomicReferenceFieldUpdater<RSocketRequester, Throwable> TERMINATION_ERROR =
AtomicReferenceFieldUpdater.newUpdater(
RSocketRequester.class, Throwable.class, "terminationError");
private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException();

static {
CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]);
}

private final DuplexConnection connection;
private final PayloadDecoder payloadDecoder;
Expand Down Expand Up @@ -91,69 +97,25 @@ class RSocketRequester implements RSocket {
// DO NOT Change the order here. The Send processor must be subscribed to before receiving
this.sendProcessor = new UnboundedProcessor<>();

connection.onClose().doFinally(signalType -> terminate()).subscribe(null, errorConsumer);
connection
.send(sendProcessor)
.doFinally(this::handleSendProcessorCancel)
.subscribe(null, this::handleSendProcessorError);
.onClose()
.doFinally(signalType -> tryTerminateOnConnectionClose())
.subscribe(null, errorConsumer);
connection.send(sendProcessor).subscribe(null, this::handleSendProcessorError);

connection.receive().subscribe(this::handleIncomingFrames, errorConsumer);

if (keepAliveTickPeriod != 0 && keepAliveHandler != null) {
KeepAliveSupport keepAliveSupport =
new ClientKeepAliveSupport(allocator, keepAliveTickPeriod, keepAliveAckTimeout);
this.keepAliveFramesAcceptor =
keepAliveHandler.start(keepAliveSupport, sendProcessor::onNext, this::terminate);
keepAliveHandler.start(
keepAliveSupport, sendProcessor::onNext, this::tryTerminateOnKeepAlive);
} else {
keepAliveFramesAcceptor = null;
}
}

private void terminate(KeepAlive keepAlive) {
String message =
String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis());
ConnectionErrorException err = new ConnectionErrorException(message);
setTerminationError(err);
errorConsumer.accept(err);
connection.dispose();
}

private void handleSendProcessorError(Throwable t) {
Throwable terminationError = this.terminationError;
Throwable err = terminationError != null ? terminationError : t;
receivers
.values()
.forEach(
subscriber -> {
try {
subscriber.onError(err);
} catch (Throwable e) {
errorConsumer.accept(e);
}
});

senders.values().forEach(RateLimitableRequestPublisher::cancel);
}

private void handleSendProcessorCancel(SignalType t) {
if (SignalType.ON_ERROR == t) {
return;
}

receivers
.values()
.forEach(
subscriber -> {
try {
subscriber.onError(new Throwable("closed connection"));
} catch (Throwable e) {
errorConsumer.accept(e);
}
});

senders.values().forEach(RateLimitableRequestPublisher::cancel);
}

@Override
public Mono<Void> fireAndForget(Payload payload) {
return handleFireAndForget(payload);
Expand Down Expand Up @@ -263,8 +225,7 @@ public void acceptOnce(@Nonnull Subscription subscription) {
if (s == SignalType.CANCEL) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
}

receivers.remove(streamId);
removeStreamReceiver(streamId);
});
}

Expand Down Expand Up @@ -318,7 +279,7 @@ public void accept(long n) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
}
})
.doFinally(s -> receivers.remove(streamId));
.doFinally(s -> removeStreamReceiver(streamId));
}

private Flux<Payload> handleChannel(Flux<Payload> request) {
Expand Down Expand Up @@ -419,14 +380,7 @@ protected void hookOnError(Throwable t) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
}
})
.doFinally(
s -> {
receivers.remove(streamId);
RateLimitableRequestPublisher sender = senders.remove(streamId);
if (sender != null) {
sender.cancel();
}
});
.doFinally(s -> removeStreamReceiverAndSender(streamId));
}

private Mono<Void> handleMetadataPush(Payload payload) {
Expand Down Expand Up @@ -472,40 +426,6 @@ private boolean contains(int streamId) {
return receivers.containsKey(streamId);
}

private void terminate() {
setTerminationError(new ClosedChannelException());
leaseHandler.dispose();
try {
receivers.values().forEach(this::cleanUpSubscriber);
senders.values().forEach(this::cleanUpLimitableRequestPublisher);
} finally {
senders.clear();
receivers.clear();
sendProcessor.dispose();
}
}

private void setTerminationError(Throwable error) {
TERMINATION_ERROR.compareAndSet(this, null, error);
}

private synchronized void cleanUpLimitableRequestPublisher(
RateLimitableRequestPublisher<?> limitableRequestPublisher) {
try {
limitableRequestPublisher.cancel();
} catch (Throwable t) {
errorConsumer.accept(t);
}
}

private synchronized void cleanUpSubscriber(Processor subscriber) {
try {
subscriber.onError(terminationError);
} catch (Throwable t) {
errorConsumer.accept(t);
}
}

private void handleIncomingFrames(ByteBuf frame) {
try {
int streamId = FrameHeaderFlyweight.streamId(frame);
Expand All @@ -525,10 +445,7 @@ private void handleIncomingFrames(ByteBuf frame) {
private void handleStreamZero(FrameType type, ByteBuf frame) {
switch (type) {
case ERROR:
RuntimeException error = Exceptions.from(frame);
setTerminationError(error);
errorConsumer.accept(error);
connection.dispose();
tryTerminateOnZeroError(frame);
break;
case LEASE:
leaseHandler.receive(frame);
Expand Down Expand Up @@ -614,4 +531,86 @@ private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBu
// receiving a frame after a given stream has been cancelled/completed,
// so ignore (cancellation is async so there is a race condition)
}

private void tryTerminateOnKeepAlive(KeepAlive keepAlive) {
tryTerminate(
() ->
new ConnectionErrorException(
String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis())));
}

private void tryTerminateOnConnectionClose() {
tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION);
}

private void tryTerminateOnZeroError(ByteBuf errorFrame) {
tryTerminate(() -> Exceptions.from(errorFrame));
}

private void tryTerminate(Supplier<Exception> errorSupplier) {
if (terminationError == null) {
Exception e = errorSupplier.get();
if (TERMINATION_ERROR.compareAndSet(this, null, e)) {
terminate(e);
}
}
}

private void terminate(Exception e) {
connection.dispose();
leaseHandler.dispose();

synchronized (receivers) {
receivers
.values()
.forEach(
receiver -> {
try {
receiver.onError(e);
} catch (Throwable t) {
errorConsumer.accept(t);
}
});
}
synchronized (senders) {
senders
.values()
.forEach(
sender -> {
try {
sender.cancel();
} catch (Throwable t) {
errorConsumer.accept(t);
}
});
}
senders.clear();
receivers.clear();
sendProcessor.dispose();
errorConsumer.accept(e);
}

private void removeStreamReceiver(int streamId) {
/*on termination receivers are explicitly cleared to avoid removing from map while iterating over one
of its views*/
if (terminationError == null) {
receivers.remove(streamId);
}
}

private void removeStreamReceiverAndSender(int streamId) {
/*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one
of its views*/
if (terminationError == null) {
receivers.remove(streamId);
RateLimitableRequestPublisher<?> sender = senders.remove(streamId);
if (sender != null) {
sender.cancel();
}
}
}

private void handleSendProcessorError(Throwable t) {
connection.dispose();
}
}
21 changes: 10 additions & 11 deletions rsocket-core/src/test/java/io/rsocket/SetupRejectionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import org.junit.Ignore;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Mono;
import reactor.core.publisher.UnicastProcessor;
import reactor.test.StepVerifier;

@Ignore
public class SetupRejectionTest {

@Test
Expand Down Expand Up @@ -64,15 +62,16 @@ void requesterStreamsTerminatedOnZeroErrorFrame() {

String errorMsg = "error";

Mono.delay(Duration.ofMillis(100))
.doOnTerminate(
() ->
conn.addToReceivedBuffer(
ErrorFrameFlyweight.encode(
ByteBufAllocator.DEFAULT, 0, new RejectedSetupException(errorMsg))))
.subscribe();

StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test")))
StepVerifier.create(
rSocket
.requestResponse(DefaultPayload.create("test"))
.doOnRequest(
ignored ->
conn.addToReceivedBuffer(
ErrorFrameFlyweight.encode(
ByteBufAllocator.DEFAULT,
0,
new RejectedSetupException(errorMsg)))))
.expectErrorMatches(
err -> err instanceof RejectedSetupException && errorMsg.equals(err.getMessage()))
.verify(Duration.ofSeconds(5));
Expand Down