Skip to content

Commit 8f7bcce

Browse files
committed
provides payload validation
Signed-off-by: Oleh Dokuka <[email protected]>
1 parent 53a09fb commit 8f7bcce

File tree

10 files changed

+245
-10
lines changed

10 files changed

+245
-10
lines changed

rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,11 @@ private Mono<Void> handleMetadataPush(Payload payload) {
464464
return Mono.error(err);
465465
}
466466

467+
if (!FragmentationUtils.isValid(this.mtu, payload)) {
468+
payload.release();
469+
return Mono.error(new IllegalArgumentException("Too big Payload size"));
470+
}
471+
467472
return UnicastMonoEmpty.newInstance(
468473
() -> {
469474
ByteBuf metadataPushFrame =

rsocket-core/src/main/java/io/rsocket/fragmentation/FrameFragmenter.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,11 @@
2828
import io.rsocket.frame.RequestFireAndForgetFrameFlyweight;
2929
import io.rsocket.frame.RequestResponseFrameFlyweight;
3030
import io.rsocket.frame.RequestStreamFrameFlyweight;
31+
import java.util.function.Consumer;
3132
import org.reactivestreams.Publisher;
3233
import reactor.core.publisher.Flux;
3334
import reactor.core.publisher.SynchronousSink;
3435

35-
import java.util.function.Consumer;
36-
3736
/**
3837
* The implementation of the RSocket fragmentation behavior.
3938
*

rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ static RSocketState requester(int tickPeriod, int timeout) {
6161
DefaultPayload::create,
6262
errors,
6363
StreamIdSupplier.clientSupplier(),
64+
0,
6465
tickPeriod,
6566
timeout,
6667
new DefaultKeepAliveHandler(connection),
@@ -86,6 +87,7 @@ static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) {
8687
DefaultPayload::create,
8788
errors,
8889
StreamIdSupplier.clientSupplier(),
90+
0,
8991
tickPeriod,
9092
timeout,
9193
new ResumableKeepAliveHandler(resumableConnection),

rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ void setUp() {
9494
StreamIdSupplier.clientSupplier(),
9595
0,
9696
0,
97+
0,
9798
null,
9899
requesterLeaseHandler);
99100

@@ -111,7 +112,8 @@ void setUp() {
111112
mockRSocketHandler,
112113
payloadDecoder,
113114
err -> {},
114-
responderLeaseHandler);
115+
responderLeaseHandler,
116+
0);
115117
}
116118

117119
@Test

rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,19 @@
1717
package io.rsocket.core;
1818

1919
import static io.rsocket.frame.FrameHeaderFlyweight.frameType;
20-
import static io.rsocket.frame.FrameType.*;
20+
import static io.rsocket.frame.FrameType.CANCEL;
21+
import static io.rsocket.frame.FrameType.KEEPALIVE;
22+
import static io.rsocket.frame.FrameType.REQUEST_CHANNEL;
23+
import static io.rsocket.frame.FrameType.REQUEST_RESPONSE;
24+
import static io.rsocket.frame.FrameType.REQUEST_STREAM;
2125
import static org.hamcrest.MatcherAssert.assertThat;
22-
import static org.hamcrest.Matchers.*;
26+
import static org.hamcrest.Matchers.contains;
27+
import static org.hamcrest.Matchers.equalTo;
28+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
29+
import static org.hamcrest.Matchers.hasSize;
30+
import static org.hamcrest.Matchers.instanceOf;
31+
import static org.hamcrest.Matchers.is;
32+
import static org.hamcrest.Matchers.not;
2333
import static org.mockito.ArgumentMatchers.any;
2434
import static org.mockito.Mockito.verify;
2535

@@ -29,7 +39,15 @@
2939
import io.netty.util.CharsetUtil;
3040
import io.rsocket.exceptions.ApplicationErrorException;
3141
import io.rsocket.exceptions.RejectedSetupException;
32-
import io.rsocket.frame.*;
42+
import io.rsocket.frame.CancelFrameFlyweight;
43+
import io.rsocket.frame.ErrorFrameFlyweight;
44+
import io.rsocket.frame.FrameHeaderFlyweight;
45+
import io.rsocket.frame.FrameLengthFlyweight;
46+
import io.rsocket.frame.FrameType;
47+
import io.rsocket.frame.PayloadFrameFlyweight;
48+
import io.rsocket.frame.RequestChannelFrameFlyweight;
49+
import io.rsocket.frame.RequestNFrameFlyweight;
50+
import io.rsocket.frame.RequestStreamFrameFlyweight;
3351
import io.rsocket.lease.RequesterLeaseHandler;
3452
import io.rsocket.test.util.TestSubscriber;
3553
import io.rsocket.util.DefaultPayload;
@@ -39,7 +57,10 @@
3957
import java.util.ArrayList;
4058
import java.util.Iterator;
4159
import java.util.List;
60+
import java.util.concurrent.ThreadLocalRandom;
61+
import java.util.function.BiFunction;
4262
import java.util.stream.Collectors;
63+
import java.util.stream.Stream;
4364
import org.assertj.core.api.Assertions;
4465
import org.junit.Rule;
4566
import org.junit.Test;
@@ -51,6 +72,7 @@
5172
import reactor.core.publisher.Mono;
5273
import reactor.core.publisher.MonoProcessor;
5374
import reactor.core.publisher.UnicastProcessor;
75+
import reactor.test.StepVerifier;
5476

5577
public class RSocketRequesterTest {
5678

@@ -262,6 +284,62 @@ protected void hookOnSubscribe(Subscription subscription) {}
262284
Assertions.assertThat(iterator.hasNext()).isFalse();
263285
}
264286

287+
@Test
288+
public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() {
289+
prepareCalls()
290+
.forEach(
291+
generator -> {
292+
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
293+
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
294+
ThreadLocalRandom.current().nextBytes(metadata);
295+
ThreadLocalRandom.current().nextBytes(data);
296+
StepVerifier.create(
297+
generator.apply(rule.socket, DefaultPayload.create(data, metadata)))
298+
.expectSubscription()
299+
.expectErrorSatisfies(
300+
t ->
301+
Assertions.assertThat(t)
302+
.isInstanceOf(IllegalArgumentException.class)
303+
.hasMessage("Too big Payload size"))
304+
.verify();
305+
});
306+
}
307+
308+
@Test
309+
public void
310+
shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase() {
311+
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
312+
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
313+
ThreadLocalRandom.current().nextBytes(metadata);
314+
ThreadLocalRandom.current().nextBytes(data);
315+
StepVerifier.create(
316+
rule.socket.requestChannel(
317+
Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata))))
318+
.expectSubscription()
319+
.then(
320+
() ->
321+
rule.connection.addToReceivedBuffer(
322+
RequestNFrameFlyweight.encode(
323+
ByteBufAllocator.DEFAULT,
324+
rule.getStreamIdForRequestType(REQUEST_CHANNEL),
325+
1)))
326+
.expectErrorSatisfies(
327+
t ->
328+
Assertions.assertThat(t)
329+
.isInstanceOf(IllegalArgumentException.class)
330+
.hasMessage("Too big Payload size"))
331+
.verify();
332+
}
333+
334+
static Stream<BiFunction<RSocket, Payload, Publisher<?>>> prepareCalls() {
335+
return Stream.of(
336+
RSocket::fireAndForget,
337+
RSocket::requestResponse,
338+
RSocket::requestStream,
339+
(rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)),
340+
RSocket::metadataPush);
341+
}
342+
265343
public int sendRequestResponse(Publisher<Payload> response) {
266344
Subscriber<Payload> sub = TestSubscriber.create();
267345
response.subscribe(sub);
@@ -285,6 +363,7 @@ protected RSocketRequester newRSocket() {
285363
StreamIdSupplier.clientSupplier(),
286364
0,
287365
0,
366+
0,
288367
null,
289368
RequesterLeaseHandler.None);
290369
}

rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@
3434
import io.rsocket.util.EmptyPayload;
3535
import java.util.Collection;
3636
import java.util.concurrent.ConcurrentLinkedQueue;
37+
import java.util.concurrent.ThreadLocalRandom;
3738
import java.util.concurrent.atomic.AtomicBoolean;
39+
import org.assertj.core.api.Assertions;
3840
import org.junit.Ignore;
3941
import org.junit.Rule;
4042
import org.junit.Test;
43+
import org.reactivestreams.Publisher;
4144
import org.reactivestreams.Subscriber;
45+
import reactor.core.publisher.Flux;
4246
import reactor.core.publisher.Mono;
4347

4448
public class RSocketResponderTest {
@@ -110,6 +114,58 @@ public Mono<Payload> requestResponse(Payload payload) {
110114
assertThat("Subscription not cancelled.", cancelled.get(), is(true));
111115
}
112116

117+
@Test
118+
public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation() {
119+
final int streamId = 4;
120+
final AtomicBoolean cancelled = new AtomicBoolean();
121+
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
122+
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
123+
ThreadLocalRandom.current().nextBytes(metadata);
124+
ThreadLocalRandom.current().nextBytes(data);
125+
final Payload payload = DefaultPayload.create(data, metadata);
126+
final AbstractRSocket acceptingSocket =
127+
new AbstractRSocket() {
128+
@Override
129+
public Mono<Payload> requestResponse(Payload p) {
130+
return Mono.just(payload).doOnCancel(() -> cancelled.set(true));
131+
}
132+
133+
@Override
134+
public Flux<Payload> requestStream(Payload p) {
135+
return Flux.just(payload).doOnCancel(() -> cancelled.set(true));
136+
}
137+
138+
@Override
139+
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
140+
return Flux.just(payload).doOnCancel(() -> cancelled.set(true));
141+
}
142+
};
143+
rule.setAcceptingSocket(acceptingSocket);
144+
145+
final Runnable[] runnables = {
146+
() -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE),
147+
() -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM),
148+
() -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL)
149+
};
150+
151+
for (Runnable runnable : runnables) {
152+
runnable.run();
153+
Assertions.assertThat(rule.errors)
154+
.first()
155+
.isInstanceOf(IllegalArgumentException.class)
156+
.hasToString("java.lang.IllegalArgumentException: Too big Payload size");
157+
Assertions.assertThat(rule.connection.getSent())
158+
.hasSize(1)
159+
.first()
160+
.matches(bb -> FrameHeaderFlyweight.frameType(bb) == FrameType.ERROR)
161+
.matches(bb -> ErrorFrameFlyweight.dataUtf8(bb).contains("Too big Payload size"));
162+
163+
assertThat("Subscription not cancelled.", cancelled.get(), is(true));
164+
rule.init();
165+
rule.setAcceptingSocket(acceptingSocket);
166+
}
167+
}
168+
113169
public static class ServerSocketRule extends AbstractSocketRule<RSocketResponder> {
114170

115171
private RSocket acceptingSocket;
@@ -151,7 +207,8 @@ protected RSocketResponder newRSocket() {
151207
acceptingSocket,
152208
DefaultPayload::create,
153209
throwable -> errors.add(throwable),
154-
ResponderLeaseHandler.None);
210+
ResponderLeaseHandler.None,
211+
0);
155212
}
156213

157214
private void sendRequest(int streamId, FrameType frameType) {

rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
222222
requestAcceptor,
223223
DefaultPayload::create,
224224
throwable -> serverErrors.add(throwable),
225-
ResponderLeaseHandler.None);
225+
ResponderLeaseHandler.None,
226+
0);
226227

227228
crs =
228229
new RSocketRequester(
@@ -233,6 +234,7 @@ public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
233234
StreamIdSupplier.clientSupplier(),
234235
0,
235236
0,
237+
0,
236238
null,
237239
RequesterLeaseHandler.None);
238240
}

rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ void requesterStreamsTerminatedOnZeroErrorFrame() {
5858
StreamIdSupplier.clientSupplier(),
5959
0,
6060
0,
61+
0,
6162
null,
6263
RequesterLeaseHandler.None);
6364

@@ -93,6 +94,7 @@ void requesterNewStreamsTerminatedAfterZeroErrorFrame() {
9394
StreamIdSupplier.clientSupplier(),
9495
0,
9596
0,
97+
0,
9698
null,
9799
RequesterLeaseHandler.None);
98100

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package io.rsocket.fragmentation;
2+
3+
import static org.junit.jupiter.api.Assertions.*;
4+
5+
import io.rsocket.Payload;
6+
import io.rsocket.frame.FrameHeaderFlyweight;
7+
import io.rsocket.frame.FrameLengthFlyweight;
8+
import io.rsocket.util.DefaultPayload;
9+
import java.util.concurrent.ThreadLocalRandom;
10+
import org.assertj.core.api.Assertions;
11+
import org.junit.jupiter.api.Test;
12+
13+
class FragmentationUtilsTest {
14+
15+
@Test
16+
void shouldValidFrameWithNoFragmentation() {
17+
byte[] data =
18+
new byte
19+
[FrameLengthFlyweight.FRAME_LENGTH_MASK
20+
- FrameLengthFlyweight.FRAME_LENGTH_SIZE
21+
- FrameHeaderFlyweight.size()];
22+
ThreadLocalRandom.current().nextBytes(data);
23+
final Payload payload = DefaultPayload.create(data);
24+
25+
Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue();
26+
}
27+
28+
@Test
29+
void shouldValidFrameWithNoFragmentation0() {
30+
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2];
31+
byte[] data =
32+
new byte
33+
[FrameLengthFlyweight.FRAME_LENGTH_MASK / 2
34+
- FrameLengthFlyweight.FRAME_LENGTH_SIZE
35+
- FrameHeaderFlyweight.size()
36+
- FrameHeaderFlyweight.size()];
37+
ThreadLocalRandom.current().nextBytes(data);
38+
ThreadLocalRandom.current().nextBytes(metadata);
39+
final Payload payload = DefaultPayload.create(data, metadata);
40+
41+
Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue();
42+
}
43+
44+
@Test
45+
void shouldValidFrameWithNoFragmentation1() {
46+
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
47+
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
48+
ThreadLocalRandom.current().nextBytes(metadata);
49+
ThreadLocalRandom.current().nextBytes(data);
50+
final Payload payload = DefaultPayload.create(data, metadata);
51+
52+
Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isFalse();
53+
}
54+
55+
@Test
56+
void shouldValidFrameWithNoFragmentation2() {
57+
byte[] metadata = new byte[1];
58+
byte[] data = new byte[1];
59+
ThreadLocalRandom.current().nextBytes(metadata);
60+
ThreadLocalRandom.current().nextBytes(data);
61+
final Payload payload = DefaultPayload.create(data, metadata);
62+
63+
Assertions.assertThat(FragmentationUtils.isValid(0, payload)).isTrue();
64+
}
65+
66+
@Test
67+
void shouldValidFrameWithNoFragmentation3() {
68+
byte[] metadata = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
69+
byte[] data = new byte[FrameLengthFlyweight.FRAME_LENGTH_MASK];
70+
ThreadLocalRandom.current().nextBytes(metadata);
71+
ThreadLocalRandom.current().nextBytes(data);
72+
final Payload payload = DefaultPayload.create(data, metadata);
73+
74+
Assertions.assertThat(FragmentationUtils.isValid(64, payload)).isTrue();
75+
}
76+
77+
@Test
78+
void shouldValidFrameWithNoFragmentation4() {
79+
byte[] metadata = new byte[1];
80+
byte[] data = new byte[1];
81+
ThreadLocalRandom.current().nextBytes(metadata);
82+
ThreadLocalRandom.current().nextBytes(data);
83+
final Payload payload = DefaultPayload.create(data, metadata);
84+
85+
Assertions.assertThat(FragmentationUtils.isValid(64, payload)).isTrue();
86+
}
87+
}

0 commit comments

Comments
 (0)