Skip to content

Fixes behavior of RequestChannel #736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 9, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 92 additions & 85 deletions rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import io.rsocket.frame.RequestResponseFrameFlyweight;
import io.rsocket.frame.RequestStreamFrameFlyweight;
import io.rsocket.frame.decoder.PayloadDecoder;
import io.rsocket.internal.RateLimitableRequestPublisher;
import io.rsocket.internal.SynchronizedIntObjectHashMap;
import io.rsocket.internal.UnboundedProcessor;
import io.rsocket.internal.UnicastMonoEmpty;
Expand All @@ -51,6 +50,7 @@
import io.rsocket.lease.RequesterLeaseHandler;
import io.rsocket.util.MonoLifecycleHandler;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import java.util.function.LongConsumer;
Expand All @@ -60,6 +60,7 @@
import org.reactivestreams.Processor;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -84,7 +85,7 @@ class RSocketRequester implements RSocket {
private final PayloadDecoder payloadDecoder;
private final Consumer<Throwable> errorConsumer;
private final StreamIdSupplier streamIdSupplier;
private final IntObjectMap<RateLimitableRequestPublisher> senders;
private final IntObjectMap<Subscription> senders;
private final IntObjectMap<Processor<Payload, Payload>> receivers;
private final UnboundedProcessor<ByteBuf> sendProcessor;
private final RequesterLeaseHandler leaseHandler;
Expand Down Expand Up @@ -258,6 +259,7 @@ private Flux<Payload> handleRequestStream(final Payload payload) {

final UnboundedProcessor<ByteBuf> sendProcessor = this.sendProcessor;
final UnicastProcessor<Payload> receiver = UnicastProcessor.create();
final AtomicBoolean payloadReleasedFlag = new AtomicBoolean(false);

receivers.put(streamId, receiver);

Expand All @@ -279,7 +281,9 @@ public void accept(long n) {
n,
payload.sliceMetadata().retain(),
payload.sliceData().retain()));
payload.release();
if (!payloadReleasedFlag.getAndSet(true)) {
payload.release();
}
} else if (contains(streamId) && !receiver.isDisposed()) {
sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n));
}
Expand All @@ -293,6 +297,9 @@ public void accept(long n) {
})
.doOnCancel(
() -> {
if (!payloadReleasedFlag.getAndSet(true)) {
payload.release();
}
if (contains(streamId) && !receiver.isDisposed()) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
}
Expand All @@ -306,10 +313,60 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
return Flux.error(err);
}

return request.switchOnFirst(
(s, flux) -> {
Payload payload = s.get();
if (payload != null) {
return handleChannel(payload, flux.skip(1));
} else {
return flux;
}
},
false);
}

private Flux<? extends Payload> handleChannel(Payload initialPayload, Flux<Payload> inboundFlux) {
final UnboundedProcessor<ByteBuf> sendProcessor = this.sendProcessor;
final UnicastProcessor<Payload> receiver = UnicastProcessor.create();
final AtomicBoolean payloadReleasedFlag = new AtomicBoolean(false);
final int streamId = streamIdSupplier.nextStreamId(receivers);

final UnicastProcessor<Payload> receiver = UnicastProcessor.create();
final BaseSubscriber<Payload> upstreamSubscriber =
new BaseSubscriber<Payload>() {

@Override
protected void hookOnSubscribe(Subscription subscription) {
// noops
}

@Override
protected void hookOnNext(Payload payload) {
final ByteBuf frame =
PayloadFrameFlyweight.encode(allocator, streamId, false, false, true, payload);

sendProcessor.onNext(frame);
payload.release();
}

@Override
protected void hookOnComplete() {
ByteBuf frame = PayloadFrameFlyweight.encodeComplete(allocator, streamId);
sendProcessor.onNext(frame);
}

@Override
protected void hookOnError(Throwable t) {
ByteBuf frame = ErrorFrameFlyweight.encode(allocator, streamId, t);
sendProcessor.onNext(frame);
receiver.onError(t);
}

@Override
protected void hookFinally(SignalType type) {
senders.remove(streamId, this);
}
};

return receiver
.doOnRequest(
new LongConsumer() {
Expand All @@ -320,85 +377,47 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
public void accept(long n) {
if (firstRequest) {
firstRequest = false;
request
.transform(
f -> {
RateLimitableRequestPublisher<Payload> wrapped =
RateLimitableRequestPublisher.wrap(f, Queues.SMALL_BUFFER_SIZE);
// Need to set this to one for first the frame
wrapped.request(1);
senders.put(streamId, wrapped);
receivers.put(streamId, receiver);

return wrapped;
})
.subscribe(
new BaseSubscriber<Payload>() {

boolean firstPayload = true;

@Override
protected void hookOnNext(Payload payload) {
final ByteBuf frame;

if (firstPayload) {
firstPayload = false;
frame =
RequestChannelFrameFlyweight.encode(
allocator,
streamId,
false,
false,
n,
payload.sliceMetadata().retain(),
payload.sliceData().retain());
} else {
frame =
PayloadFrameFlyweight.encode(
allocator, streamId, false, false, true, payload);
}

sendProcessor.onNext(frame);
payload.release();
}

@Override
protected void hookOnComplete() {
if (contains(streamId) && !receiver.isDisposed()) {
sendProcessor.onNext(
PayloadFrameFlyweight.encodeComplete(allocator, streamId));
}
if (firstPayload) {
receiver.onComplete();
}
}

@Override
protected void hookOnError(Throwable t) {
errorConsumer.accept(t);
receiver.dispose();
}
});
} else {
if (contains(streamId) && !receiver.isDisposed()) {
sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n));
senders.put(streamId, upstreamSubscriber);
receivers.put(streamId, receiver);

inboundFlux.limitRate(Queues.SMALL_BUFFER_SIZE).subscribe(upstreamSubscriber);
if (!payloadReleasedFlag.getAndSet(true)) {
ByteBuf frame =
RequestChannelFrameFlyweight.encode(
allocator,
streamId,
false,
false,
n,
initialPayload.sliceMetadata().retain(),
initialPayload.sliceData().retain());

sendProcessor.onNext(frame);

initialPayload.release();
}
} else {
sendProcessor.onNext(RequestNFrameFlyweight.encode(allocator, streamId, n));
}
}
})
.doOnError(
t -> {
if (contains(streamId) && !receiver.isDisposed()) {
sendProcessor.onNext(ErrorFrameFlyweight.encode(allocator, streamId, t));
if (receivers.remove(streamId, receiver)) {
upstreamSubscriber.cancel();
}
})
.doOnComplete(() -> receivers.remove(streamId, receiver))
.doOnCancel(
() -> {
if (contains(streamId) && !receiver.isDisposed()) {
if (!payloadReleasedFlag.getAndSet(true)) {
initialPayload.release();
}
if (receivers.remove(streamId, receiver)) {
sendProcessor.onNext(CancelFrameFlyweight.encode(allocator, streamId));
upstreamSubscriber.cancel();
}
})
.doFinally(s -> removeStreamReceiverAndSender(streamId));
});
}

private Mono<Void> handleMetadataPush(Payload payload) {
Expand Down Expand Up @@ -487,7 +506,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) {
break;
case CANCEL:
{
RateLimitableRequestPublisher sender = senders.remove(streamId);
Subscription sender = senders.remove(streamId);
if (sender != null) {
sender.cancel();
}
Expand All @@ -498,7 +517,7 @@ private void handleFrame(int streamId, FrameType type, ByteBuf frame) {
break;
case REQUEST_N:
{
RateLimitableRequestPublisher sender = senders.get(streamId);
Subscription sender = senders.get(streamId);
if (sender != null) {
int n = RequestNFrameFlyweight.requestN(frame);
sender.request(n >= Integer.MAX_VALUE ? Long.MAX_VALUE : n);
Expand Down Expand Up @@ -606,18 +625,6 @@ private void removeStreamReceiver(int streamId) {
}
}

private void removeStreamReceiverAndSender(int streamId) {
/*on termination senders & receivers are explicitly cleared to avoid removing from map while iterating over one
of its views*/
if (terminationError == null) {
receivers.remove(streamId);
RateLimitableRequestPublisher<?> sender = senders.remove(streamId);
if (sender != null) {
sender.cancel();
}
}
}

private void handleSendProcessorError(Throwable t) {
connection.dispose();
}
Expand Down
Loading