Skip to content

Commit 006dbbe

Browse files
committed
provides extra encoding safety
Signed-off-by: Oleh Dokuka <[email protected]>
1 parent 12fd301 commit 006dbbe

File tree

4 files changed

+118
-10
lines changed

4 files changed

+118
-10
lines changed

rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ public Mono<Void> onClose() {
195195
}
196196

197197
private Mono<Void> handleFireAndForget(Payload payload) {
198+
if (payload.refCnt() <= 0) {
199+
return Mono.error(new IllegalReferenceCountException());
200+
}
201+
198202
Throwable err = checkAvailable();
199203
if (err != null) {
200204
payload.release();
@@ -227,6 +231,10 @@ private Mono<Void> handleFireAndForget(Payload payload) {
227231
}
228232

229233
private Mono<Payload> handleRequestResponse(final Payload payload) {
234+
if (payload.refCnt() <= 0) {
235+
return Mono.error(new IllegalReferenceCountException());
236+
}
237+
230238
Throwable err = checkAvailable();
231239
if (err != null) {
232240
payload.release();
@@ -289,6 +297,10 @@ public void hookOnTerminal(SignalType signalType) {
289297
}
290298

291299
private Flux<Payload> handleRequestStream(final Payload payload) {
300+
if (payload.refCnt() <= 0) {
301+
return Flux.error(new IllegalReferenceCountException());
302+
}
303+
292304
Throwable err = checkAvailable();
293305
if (err != null) {
294306
payload.release();
@@ -371,6 +383,10 @@ private Flux<Payload> handleChannel(Flux<Payload> request) {
371383
(s, flux) -> {
372384
Payload payload = s.get();
373385
if (payload != null) {
386+
if (payload.refCnt() <= 0) {
387+
return Mono.error(new IllegalReferenceCountException());
388+
}
389+
374390
if (!PayloadValidationUtils.isValid(mtu, payload)) {
375391
payload.release();
376392
final IllegalArgumentException t =
@@ -509,6 +525,10 @@ public void cancel() {
509525
}
510526

511527
private Mono<Void> handleMetadataPush(Payload payload) {
528+
if (payload.refCnt() <= 0) {
529+
return Mono.error(new IllegalReferenceCountException());
530+
}
531+
512532
Throwable err = this.terminationError;
513533
if (err != null) {
514534
payload.release();

rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,32 @@ protected void hookOnSubscribe(Subscription s) {
450450

451451
@Override
452452
protected void hookOnNext(Payload payload) {
453-
if (!PayloadValidationUtils.isValid(mtu, payload)) {
454-
payload.release();
453+
try {
454+
if (!PayloadValidationUtils.isValid(mtu, payload)) {
455+
payload.release();
456+
// specifically for requestChannel case so when Payload is invalid we will not be
457+
// sending CancelFrame and ErrorFrame
458+
// Note: CancelFrame is redundant and due to spec
459+
// (https://github.com/rsocket/rsocket/blob/master/Protocol.md#request-channel)
460+
// Upon receiving an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream
461+
// is
462+
// terminated on both Requester and Responder.
463+
// Upon sending an ERROR[APPLICATION_ERROR|REJECTED|CANCELED|INVALID], the stream is
464+
// terminated on both the Requester and Responder.
465+
if (requestChannel != null) {
466+
channelProcessors.remove(streamId, requestChannel);
467+
}
468+
cancel();
469+
final IllegalArgumentException t =
470+
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
471+
handleError(streamId, t);
472+
return;
473+
}
474+
475+
ByteBuf byteBuf =
476+
PayloadFrameFlyweight.encodeNextReleasingPayload(allocator, streamId, payload);
477+
sendProcessor.onNext(byteBuf);
478+
} catch (Throwable e) {
455479
// specifically for requestChannel case so when Payload is invalid we will not be
456480
// sending CancelFrame and ErrorFrame
457481
// Note: CancelFrame is redundant and due to spec
@@ -464,15 +488,8 @@ protected void hookOnNext(Payload payload) {
464488
channelProcessors.remove(streamId, requestChannel);
465489
}
466490
cancel();
467-
final IllegalArgumentException t =
468-
new IllegalArgumentException(INVALID_PAYLOAD_ERROR_MESSAGE);
469-
handleError(streamId, t);
470-
return;
491+
handleError(streamId, e);
471492
}
472-
473-
ByteBuf byteBuf =
474-
PayloadFrameFlyweight.encodeNextReleasingPayload(allocator, streamId, payload);
475-
sendProcessor.onNext(byteBuf);
476493
}
477494

478495
@Override

rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import io.netty.buffer.ByteBufUtil;
3838
import io.netty.buffer.Unpooled;
3939
import io.netty.util.CharsetUtil;
40+
import io.netty.util.IllegalReferenceCountException;
4041
import io.netty.util.ReferenceCountUtil;
4142
import io.netty.util.ReferenceCounted;
4243
import io.rsocket.Payload;
@@ -775,6 +776,30 @@ static Stream<Arguments> encodeDecodePayloadCases() {
775776
Arguments.of(REQUEST_CHANNEL, 5, 5));
776777
}
777778

779+
@ParameterizedTest
780+
@MethodSource("refCntCases")
781+
public void ensureSendsErrorOnIllegalRefCntPayload(
782+
BiFunction<Payload, RSocket, Publisher<?>> sourceProducer) {
783+
Payload invalidPayload = ByteBufPayload.create("test", "test");
784+
invalidPayload.release();
785+
786+
Publisher<?> source = sourceProducer.apply(invalidPayload, rule.socket);
787+
788+
StepVerifier.create(source, 0)
789+
.expectError(IllegalReferenceCountException.class)
790+
.verify(Duration.ofMillis(100));
791+
}
792+
793+
private static Stream<BiFunction<Payload, RSocket, Publisher<?>>> refCntCases() {
794+
return Stream.of(
795+
(p, r) -> r.fireAndForget(p),
796+
(p, r) -> r.requestResponse(p),
797+
(p, r) -> r.requestStream(p),
798+
(p, r) -> r.requestChannel(Mono.just(p)),
799+
(p, r) ->
800+
r.requestChannel(Flux.just(EmptyPayload.INSTANCE, p).doOnSubscribe(s -> s.request(1))));
801+
}
802+
778803
@Test
779804
public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() {
780805
Payload payload1 = ByteBufPayload.create("abc1");

rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE;
2020
import static io.rsocket.frame.FrameHeaderFlyweight.frameType;
21+
import static io.rsocket.frame.FrameType.ERROR;
2122
import static io.rsocket.frame.FrameType.REQUEST_CHANNEL;
2223
import static io.rsocket.frame.FrameType.REQUEST_FNF;
2324
import static io.rsocket.frame.FrameType.REQUEST_N;
@@ -711,6 +712,51 @@ static Stream<Arguments> encodeDecodePayloadCases() {
711712
Arguments.of(REQUEST_CHANNEL, 5, 5));
712713
}
713714

715+
@ParameterizedTest
716+
@MethodSource("refCntCases")
717+
public void ensureSendsErrorOnIllegalRefCntPayload(FrameType frameType) {
718+
rule.setAcceptingSocket(
719+
new RSocket() {
720+
@Override
721+
public Mono<Payload> requestResponse(Payload payload) {
722+
Payload invalidPayload = ByteBufPayload.create("test", "test");
723+
invalidPayload.release();
724+
return Mono.just(invalidPayload);
725+
}
726+
727+
@Override
728+
public Flux<Payload> requestStream(Payload payload) {
729+
Payload invalidPayload = ByteBufPayload.create("test", "test");
730+
invalidPayload.release();
731+
return Flux.just(invalidPayload);
732+
}
733+
734+
@Override
735+
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
736+
Payload invalidPayload = ByteBufPayload.create("test", "test");
737+
invalidPayload.release();
738+
return Flux.just(invalidPayload);
739+
}
740+
});
741+
742+
rule.sendRequest(1, frameType);
743+
744+
Assertions.assertThat(rule.connection.getSent())
745+
.hasSize(1)
746+
.first()
747+
.matches(
748+
bb -> frameType(bb) == ERROR,
749+
"Expect frame type to be {"
750+
+ ERROR
751+
+ "} but was {"
752+
+ frameType(rule.connection.getSent().iterator().next())
753+
+ "}");
754+
}
755+
756+
private static Stream<FrameType> refCntCases() {
757+
return Stream.of(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL);
758+
}
759+
714760
public static class ServerSocketRule extends AbstractSocketRule<RSocketResponder> {
715761

716762
private RSocket acceptingSocket;

0 commit comments

Comments
 (0)