Skip to content

Commit af021d9

Browse files
adds message counting to protect against malicious overflow (#1067)
Co-authored-by: Rossen Stoyanchev <[email protected]>
1 parent 52f4583 commit af021d9

File tree

6 files changed

+425
-4
lines changed

6 files changed

+425
-4
lines changed

rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ final class RequestChannelRequesterFlux extends Flux<Payload>
8686
Context cachedContext;
8787
CoreSubscriber<? super Payload> inboundSubscriber;
8888
boolean inboundDone;
89+
long requested;
90+
long produced;
8991

9092
CompositeByteBuf frames;
9193

@@ -138,6 +140,8 @@ public final void request(long n) {
138140
return;
139141
}
140142

143+
this.requested = Operators.addCap(this.requested, n);
144+
141145
long previousState = addRequestN(STATE, this, n, this.requesterLeaseTracker == null);
142146
if (isTerminated(previousState)) {
143147
return;
@@ -706,6 +710,27 @@ public final void handlePayload(Payload value) {
706710
return;
707711
}
708712

713+
final long produced = this.produced;
714+
if (this.requested == produced) {
715+
value.release();
716+
if (!tryCancel()) {
717+
return;
718+
}
719+
720+
final Throwable cause =
721+
Exceptions.failWithOverflow(
722+
"The number of messages received exceeds the number requested");
723+
final RequestInterceptor requestInterceptor = this.requestInterceptor;
724+
if (requestInterceptor != null) {
725+
requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause);
726+
}
727+
728+
this.inboundSubscriber.onError(cause);
729+
return;
730+
}
731+
732+
this.produced = produced + 1;
733+
709734
this.inboundSubscriber.onNext(value);
710735
}
711736
}

rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ final class RequestChannelResponderSubscriber extends Flux<Payload>
8888

8989
boolean inboundDone;
9090
boolean outboundDone;
91+
long requested;
92+
long produced;
9193

9294
public RequestChannelResponderSubscriber(
9395
int streamId,
@@ -179,6 +181,8 @@ public void request(long n) {
179181
return;
180182
}
181183

184+
this.requested = Operators.addCap(this.requested, n);
185+
182186
long previousState = StateUtils.addRequestN(STATE, this, n);
183187
if (isTerminated(previousState)) {
184188
// full termination can be the result of both sides completion / cancelFrame / remote or local
@@ -196,6 +200,9 @@ public void request(long n) {
196200
Payload firstPayload = this.firstPayload;
197201
if (firstPayload != null) {
198202
this.firstPayload = null;
203+
204+
this.produced++;
205+
199206
inboundSubscriber.onNext(firstPayload);
200207
}
201208

@@ -216,6 +223,8 @@ public void request(long n) {
216223
final Payload firstPayload = this.firstPayload;
217224
this.firstPayload = null;
218225

226+
this.produced++;
227+
219228
inboundSubscriber.onNext(firstPayload);
220229
inboundSubscriber.onComplete();
221230

@@ -238,6 +247,9 @@ public void request(long n) {
238247

239248
final Payload firstPayload = this.firstPayload;
240249
this.firstPayload = null;
250+
251+
this.produced++;
252+
241253
inboundSubscriber.onNext(firstPayload);
242254

243255
previousState = markFirstFrameSent(STATE, this);
@@ -416,6 +428,58 @@ final void handlePayload(Payload p) {
416428
return;
417429
}
418430

431+
final long produced = this.produced;
432+
if (this.requested == produced) {
433+
p.release();
434+
435+
this.inboundDone = true;
436+
437+
final Throwable cause =
438+
Exceptions.failWithOverflow(
439+
"The number of messages received exceeds the number requested");
440+
boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause);
441+
442+
long previousState = markTerminated(STATE, this);
443+
if (isTerminated(previousState)) {
444+
if (!wasThrowableAdded) {
445+
Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext());
446+
}
447+
return;
448+
}
449+
450+
this.requesterResponderSupport.remove(this.streamId, this);
451+
452+
this.connection.sendFrame(
453+
streamId,
454+
ErrorFrameCodec.encode(
455+
this.allocator, streamId, new CanceledException(cause.getMessage())));
456+
457+
if (!isSubscribed(previousState)) {
458+
final Payload firstPayload = this.firstPayload;
459+
this.firstPayload = null;
460+
firstPayload.release();
461+
} else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) {
462+
Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this);
463+
if (inboundError != TERMINATED) {
464+
//noinspection ConstantConditions
465+
this.inboundSubscriber.onError(inboundError);
466+
}
467+
}
468+
469+
// this is downstream subscription so need to cancel it just in case error signal has not
470+
// reached it
471+
// needs for disconnected upstream and downstream case
472+
this.outboundSubscription.cancel();
473+
474+
final RequestInterceptor interceptor = requestInterceptor;
475+
if (interceptor != null) {
476+
interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause);
477+
}
478+
return;
479+
}
480+
481+
this.produced = produced + 1;
482+
419483
this.inboundSubscriber.onNext(p);
420484
}
421485
}

rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ final class RequestStreamRequesterFlux extends Flux<Payload>
6565
CoreSubscriber<? super Payload> inboundSubscriber;
6666
CompositeByteBuf frames;
6767
boolean done;
68+
long requested;
69+
long produced;
6870

6971
RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) {
7072
this.allocator = requesterResponderSupport.getAllocator();
@@ -134,6 +136,8 @@ public final void request(long n) {
134136
return;
135137
}
136138

139+
this.requested = Operators.addCap(this.requested, n);
140+
137141
final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker;
138142
final boolean leaseEnabled = requesterLeaseTracker != null;
139143
final long previousState = addRequestN(STATE, this, n, !leaseEnabled);
@@ -295,6 +299,34 @@ public final void handlePayload(Payload p) {
295299
return;
296300
}
297301

302+
final long produced = this.produced;
303+
if (this.requested == produced) {
304+
p.release();
305+
306+
long previousState = markTerminated(STATE, this);
307+
if (isTerminated(previousState)) {
308+
return;
309+
}
310+
311+
final int streamId = this.streamId;
312+
this.requesterResponderSupport.remove(streamId, this);
313+
314+
final IllegalStateException cause =
315+
Exceptions.failWithOverflow(
316+
"The number of messages received exceeds the number requested");
317+
this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId));
318+
319+
final RequestInterceptor requestInterceptor = this.requestInterceptor;
320+
if (requestInterceptor != null) {
321+
requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause);
322+
}
323+
324+
this.inboundSubscriber.onError(cause);
325+
return;
326+
}
327+
328+
this.produced = produced + 1;
329+
298330
this.inboundSubscriber.onNext(p);
299331
}
300332

rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package io.rsocket.core;
1717

1818
import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK;
19+
import static io.rsocket.frame.FrameType.CANCEL;
1920

2021
import io.netty.buffer.ByteBuf;
2122
import io.netty.buffer.Unpooled;
@@ -40,6 +41,7 @@
4041
import java.util.stream.Stream;
4142
import org.assertj.core.api.Assertions;
4243
import org.junit.jupiter.api.BeforeAll;
44+
import org.junit.jupiter.api.Test;
4345
import org.junit.jupiter.params.ParameterizedTest;
4446
import org.junit.jupiter.params.provider.Arguments;
4547
import org.junit.jupiter.params.provider.MethodSource;
@@ -513,6 +515,77 @@ public void errorShouldTerminateExecution(String terminationMode) {
513515
stateAssert.isTerminated();
514516
}
515517

518+
@Test
519+
public void failOnOverflow() {
520+
final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client();
521+
final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator();
522+
final TestDuplexConnection sender = activeStreams.getDuplexConnection();
523+
final TestPublisher<Payload> publisher = TestPublisher.create();
524+
525+
final RequestChannelRequesterFlux requestChannelRequesterFlux =
526+
new RequestChannelRequesterFlux(publisher, activeStreams);
527+
final StateAssert<RequestChannelRequesterFlux> stateAssert =
528+
StateAssert.assertThat(requestChannelRequesterFlux);
529+
530+
// state machine check
531+
532+
stateAssert.isUnsubscribed();
533+
activeStreams.assertNoActiveStreams();
534+
535+
final AssertSubscriber<Payload> assertSubscriber =
536+
requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0));
537+
activeStreams.assertNoActiveStreams();
538+
539+
// state machine check
540+
stateAssert.hasSubscribedFlagOnly();
541+
542+
assertSubscriber.request(1);
543+
stateAssert.hasSubscribedFlag().hasRequestN(1).hasNoFirstFrameSentFlag();
544+
activeStreams.assertNoActiveStreams();
545+
546+
Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator);
547+
548+
publisher.next(payload1.retain());
549+
550+
FrameAssert.assertThat(sender.awaitFrame())
551+
.typeOf(FrameType.REQUEST_CHANNEL)
552+
.hasPayload(payload1)
553+
.hasRequestN(1)
554+
.hasNoLeaks();
555+
payload1.release();
556+
557+
stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag();
558+
activeStreams.assertHasStream(1, requestChannelRequesterFlux);
559+
560+
publisher.assertMaxRequested(1);
561+
562+
Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator);
563+
requestChannelRequesterFlux.handlePayload(nextPayload);
564+
565+
Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator);
566+
requestChannelRequesterFlux.handlePayload(unrequestedPayload);
567+
568+
final ByteBuf cancelFrame = sender.awaitFrame();
569+
FrameAssert.assertThat(cancelFrame)
570+
.isNotNull()
571+
.typeOf(CANCEL)
572+
.hasClientSideStreamId()
573+
.hasStreamId(1)
574+
.hasNoLeaks();
575+
576+
assertSubscriber
577+
.assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks())
578+
.assertError()
579+
.assertErrorMessage("The number of messages received exceeds the number requested");
580+
581+
publisher.assertWasCancelled();
582+
583+
activeStreams.assertNoActiveStreams();
584+
// state machine check
585+
stateAssert.isTerminated();
586+
Assertions.assertThat(sender.isEmpty()).isTrue();
587+
}
588+
516589
/*
517590
* +--------------------------------+
518591
* | Racing Test Cases |

0 commit comments

Comments
 (0)