Skip to content

Commit fb10188

Browse files
OlegDokukamostroverkhov
authored andcommitted
fixes bug with incorrect WS framesize setup. (#610)
* fixes buf with incorrect WS framesize setup. This PR provides a websocket transport's frame size setup with regards of current RSocket fragment size. Current implementation provides the following setup rules: * if `mtu` is 0 (means no fragmentation) - then WS frame size is equal to maximum default frame size of RSocket frame which is `16_777_215`; * if `mtu` is GT > 0 and LT < 65536 (default for WS frame size) then the WS frame size will be its the default one (which is `65536`); * if `mtu` is GT > 65536 then the WS frame size will be set to the specified by that parameter size Signed-off-by: Oleh Dokuka <[email protected]> * rollbacks commented docs build * fixes format Signed-off-by: Oleh Dokuka <[email protected]> * cleanup pollution uses max frame size from `FrameLengthFlyweight` Signed-off-by: Oleh Dokuka <[email protected]> * simplifies logic of max frame size setup for ws transport Signed-off-by: Oleh Dokuka <[email protected]> * fixes google java format Signed-off-by: Oleh Dokuka <[email protected]> Signed-off-by: Maksym Ostroverkhov <[email protected]>
1 parent 3072d2e commit fb10188

File tree

10 files changed

+346
-7
lines changed

10 files changed

+346
-7
lines changed

build.gradle

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ subprojects {
3535
ext['netty.version'] = '4.1.31.Final'
3636
ext['netty-boringssl.version'] = '2.0.18.Final'
3737
ext['hdrhistogram.version'] = '2.1.10'
38-
ext['mockito.version'] = '2.23.0'
38+
ext['mockito.version'] = '2.25.1'
3939
ext['slf4j.version'] = '1.7.25'
4040
ext['jmh.version'] = '1.21'
4141
ext['junit.version'] = '5.1.0'
@@ -60,8 +60,12 @@ subprojects {
6060
dependency "io.micrometer:micrometer-core:${ext['micrometer.version']}"
6161
dependency "org.assertj:assertj-core:${ext['assertj.version']}"
6262
dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}"
63-
dependency "org.mockito:mockito-core:${ ext['mockito.version']}"
6463
dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}"
64+
65+
dependencySet(group: 'org.mockito', version: ext['mockito.version']) {
66+
entry 'mockito-junit-jupiter'
67+
entry 'mockito-core'
68+
}
6569

6670
dependencySet(group: 'org.junit.jupiter', version: ext['junit.version']) {
6771
entry 'junit-jupiter-api'

rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io.netty.buffer.Unpooled;
66

77
public class FrameUtil {
8+
89
private FrameUtil() {}
910

1011
public static String toString(ByteBuf frame) {

rsocket-transport-netty/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ dependencies {
3636
testImplementation project(':rsocket-test')
3737
testImplementation 'io.projectreactor:reactor-test'
3838
testImplementation 'org.assertj:assertj-core'
39+
testImplementation 'org.mockito:mockito-core'
40+
testImplementation 'org.mockito:mockito-junit-jupiter'
3941
testImplementation 'org.junit.jupiter:junit-jupiter-api'
4042
testImplementation 'org.junit.jupiter:junit-jupiter-params'
4143

rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package io.rsocket.transport.netty.client;
1818

19+
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
1920
import static io.rsocket.transport.netty.UriUtils.getPort;
2021
import static io.rsocket.transport.netty.UriUtils.isSecure;
2122

@@ -42,6 +43,7 @@
4243
*/
4344
public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware {
4445

46+
private static final int DEFAULT_FRAME_SIZE = 65536;
4547
private static final String DEFAULT_PATH = "/";
4648

4749
private final HttpClient client;
@@ -151,7 +153,7 @@ private static TcpClient createClient(URI uri) {
151153
public Mono<DuplexConnection> connect(int mtu) {
152154
return client
153155
.headers(headers -> transportHeaders.get().forEach(headers::set))
154-
.websocket()
156+
.websocket(FRAME_LENGTH_MASK)
155157
.uri(path)
156158
.connect()
157159
.map(

rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,23 @@
1616

1717
package io.rsocket.transport.netty.server;
1818

19+
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
20+
1921
import io.netty.buffer.ByteBufAllocator;
22+
import io.netty.handler.codec.http.HttpMethod;
2023
import io.rsocket.Closeable;
2124
import io.rsocket.DuplexConnection;
2225
import io.rsocket.fragmentation.FragmentationDuplexConnection;
2326
import io.rsocket.transport.ServerTransport;
2427
import io.rsocket.transport.netty.WebsocketDuplexConnection;
28+
import java.util.ArrayList;
29+
import java.util.HashMap;
30+
import java.util.List;
31+
import java.util.Map;
2532
import java.util.Objects;
2633
import java.util.function.Consumer;
34+
import java.util.regex.Matcher;
35+
import java.util.regex.Pattern;
2736
import reactor.core.publisher.Mono;
2837
import reactor.netty.Connection;
2938
import reactor.netty.http.server.HttpServer;
@@ -35,7 +44,7 @@
3544
*/
3645
public final class WebsocketRouteTransport implements ServerTransport<Closeable> {
3746

38-
private final String path;
47+
private final UriPathTemplate template;
3948

4049
private final Consumer<? super HttpServerRoutes> routesBuilder;
4150

@@ -53,7 +62,7 @@ public WebsocketRouteTransport(
5362

5463
this.server = Objects.requireNonNull(server, "server must not be null");
5564
this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null");
56-
this.path = Objects.requireNonNull(path, "path must not be null");
65+
this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null"));
5766
}
5867

5968
@Override
@@ -65,7 +74,7 @@ public Mono<Closeable> start(ConnectionAcceptor acceptor, int mtu) {
6574
routes -> {
6675
routesBuilder.accept(routes);
6776
routes.ws(
68-
path,
77+
hsr -> hsr.method().equals(HttpMethod.GET) && template.matches(hsr.uri()),
6978
(in, out) -> {
7079
DuplexConnection connection = new WebsocketDuplexConnection((Connection) in);
7180
if (mtu > 0) {
@@ -74,9 +83,128 @@ public Mono<Closeable> start(ConnectionAcceptor acceptor, int mtu) {
7483
connection, ByteBufAllocator.DEFAULT, mtu, false);
7584
}
7685
return acceptor.apply(connection).then(out.neverComplete());
77-
});
86+
},
87+
null,
88+
FRAME_LENGTH_MASK);
7889
})
7990
.bind()
8091
.map(CloseableChannel::new);
8192
}
93+
94+
static final class UriPathTemplate {
95+
96+
private static final Pattern FULL_SPLAT_PATTERN = Pattern.compile("[\\*][\\*]");
97+
private static final String FULL_SPLAT_REPLACEMENT = ".*";
98+
99+
private static final Pattern NAME_SPLAT_PATTERN = Pattern.compile("\\{([^/]+?)\\}[\\*][\\*]");
100+
private static final String NAME_SPLAT_REPLACEMENT = "(?<%NAME%>.*)";
101+
102+
private static final Pattern NAME_PATTERN = Pattern.compile("\\{([^/]+?)\\}");
103+
private static final String NAME_REPLACEMENT = "(?<%NAME%>[^\\/]*)";
104+
105+
private final List<String> pathVariables = new ArrayList<>();
106+
private final HashMap<String, Matcher> matchers = new HashMap<>();
107+
private final HashMap<String, Map<String, String>> vars = new HashMap<>();
108+
109+
private final Pattern uriPattern;
110+
111+
static String filterQueryParams(String uri) {
112+
int hasQuery = uri.lastIndexOf("?");
113+
if (hasQuery != -1) {
114+
return uri.substring(0, hasQuery);
115+
} else {
116+
return uri;
117+
}
118+
}
119+
120+
/**
121+
* Creates a new {@code UriPathTemplate} from the given {@code uriPattern}.
122+
*
123+
* @param uriPattern The pattern to be used by the template
124+
*/
125+
UriPathTemplate(String uriPattern) {
126+
String s = "^" + filterQueryParams(uriPattern);
127+
128+
Matcher m = NAME_SPLAT_PATTERN.matcher(s);
129+
while (m.find()) {
130+
for (int i = 1; i <= m.groupCount(); i++) {
131+
String name = m.group(i);
132+
pathVariables.add(name);
133+
s = m.replaceFirst(NAME_SPLAT_REPLACEMENT.replaceAll("%NAME%", name));
134+
m.reset(s);
135+
}
136+
}
137+
138+
m = NAME_PATTERN.matcher(s);
139+
while (m.find()) {
140+
for (int i = 1; i <= m.groupCount(); i++) {
141+
String name = m.group(i);
142+
pathVariables.add(name);
143+
s = m.replaceFirst(NAME_REPLACEMENT.replaceAll("%NAME%", name));
144+
m.reset(s);
145+
}
146+
}
147+
148+
m = FULL_SPLAT_PATTERN.matcher(s);
149+
while (m.find()) {
150+
s = m.replaceAll(FULL_SPLAT_REPLACEMENT);
151+
m.reset(s);
152+
}
153+
154+
this.uriPattern = Pattern.compile(s + "$");
155+
}
156+
157+
/**
158+
* Tests the given {@code uri} against this template, returning {@code true} if the uri matches
159+
* the template, {@code false} otherwise.
160+
*
161+
* @param uri The uri to match
162+
* @return {@code true} if there's a match, {@code false} otherwise
163+
*/
164+
public boolean matches(String uri) {
165+
return matcher(uri).matches();
166+
}
167+
168+
/**
169+
* Matches the template against the given {@code uri} returning a map of path parameters
170+
* extracted from the uri, keyed by the names in the template. If the uri does not match, or
171+
* there are no path parameters, an empty map is returned.
172+
*
173+
* @param uri The uri to match
174+
* @return the path parameters from the uri. Never {@code null}.
175+
*/
176+
final Map<String, String> match(String uri) {
177+
Map<String, String> pathParameters = vars.get(uri);
178+
if (null != pathParameters) {
179+
return pathParameters;
180+
}
181+
182+
pathParameters = new HashMap<>();
183+
Matcher m = matcher(uri);
184+
if (m.matches()) {
185+
int i = 1;
186+
for (String name : pathVariables) {
187+
String val = m.group(i++);
188+
pathParameters.put(name, val);
189+
}
190+
}
191+
synchronized (vars) {
192+
vars.put(uri, pathParameters);
193+
}
194+
195+
return pathParameters;
196+
}
197+
198+
private Matcher matcher(String uri) {
199+
uri = filterQueryParams(uri);
200+
Matcher m = matchers.get(uri);
201+
if (null == m) {
202+
m = uriPattern.matcher(uri);
203+
synchronized (matchers) {
204+
matchers.put(uri, m);
205+
}
206+
}
207+
return m;
208+
}
209+
}
82210
}

rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package io.rsocket.transport.netty.server;
1818

19+
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
20+
1921
import io.netty.buffer.ByteBufAllocator;
2022
import io.rsocket.DuplexConnection;
2123
import io.rsocket.fragmentation.FragmentationDuplexConnection;
@@ -114,6 +116,8 @@ public Mono<CloseableChannel> start(ConnectionAcceptor acceptor, int mtu) {
114116
(request, response) -> {
115117
transportHeaders.get().forEach(response::addHeader);
116118
return response.sendWebsocket(
119+
null,
120+
FRAME_LENGTH_MASK,
117121
(in, out) -> {
118122
DuplexConnection connection = new WebsocketDuplexConnection((Connection) in);
119123
if (mtu > 0) {

rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,71 @@
1616

1717
package io.rsocket.transport.netty.client;
1818

19+
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
1920
import static org.assertj.core.api.Assertions.assertThat;
2021
import static org.assertj.core.api.Assertions.assertThatNullPointerException;
2122

2223
import io.rsocket.transport.netty.server.WebsocketServerTransport;
2324
import java.net.InetSocketAddress;
2425
import java.net.URI;
2526
import java.util.Collections;
27+
import org.assertj.core.api.Assertions;
2628
import org.junit.jupiter.api.DisplayName;
2729
import org.junit.jupiter.api.Test;
30+
import org.junit.jupiter.api.extension.ExtendWith;
31+
import org.mockito.ArgumentCaptor;
32+
import org.mockito.Mockito;
33+
import org.mockito.junit.jupiter.MockitoExtension;
2834
import reactor.core.publisher.Mono;
2935
import reactor.netty.http.client.HttpClient;
3036
import reactor.test.StepVerifier;
3137

38+
@ExtendWith(MockitoExtension.class)
3239
final class WebsocketClientTransportTest {
3340

41+
@Test
42+
public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() {
43+
ArgumentCaptor<Integer> captor = ArgumentCaptor.forClass(Integer.class);
44+
HttpClient httpClient = Mockito.spy(HttpClient.create());
45+
Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any());
46+
Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture());
47+
48+
WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, "");
49+
50+
clientTransport.connect(0).subscribe();
51+
52+
Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK);
53+
}
54+
55+
@Test
56+
public void testThatSetupWithSpecifiedFrameSizeButLowerThanWsDefaultShouldSetToWsDefault() {
57+
ArgumentCaptor<Integer> captor = ArgumentCaptor.forClass(Integer.class);
58+
HttpClient httpClient = Mockito.spy(HttpClient.create());
59+
Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any());
60+
Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture());
61+
62+
WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, "");
63+
64+
clientTransport.connect(65536 - 10000).subscribe();
65+
66+
Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK);
67+
}
68+
69+
@Test
70+
public void
71+
testThatSetupWithSpecifiedFrameSizeButHigherThanWsDefaultShouldSetToSpecifiedFrameSize() {
72+
ArgumentCaptor<Integer> captor = ArgumentCaptor.forClass(Integer.class);
73+
HttpClient httpClient = Mockito.spy(HttpClient.create());
74+
Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any());
75+
Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture());
76+
77+
WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, "");
78+
79+
clientTransport.connect(65536 + 10000).subscribe();
80+
81+
Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK);
82+
}
83+
3484
@DisplayName("connects to server")
3585
@Test
3686
void connect() {

0 commit comments

Comments
 (0)