Skip to content

fixes bug with incorrect WS framesize setup. #610

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ subprojects {
ext['netty.version'] = '4.1.31.Final'
ext['netty-boringssl.version'] = '2.0.18.Final'
ext['hdrhistogram.version'] = '2.1.10'
ext['mockito.version'] = '2.23.0'
ext['mockito.version'] = '2.25.1'
ext['slf4j.version'] = '1.7.25'
ext['jmh.version'] = '1.21'
ext['junit.version'] = '5.1.0'
Expand All @@ -60,8 +60,12 @@ subprojects {
dependency "io.micrometer:micrometer-core:${ext['micrometer.version']}"
dependency "org.assertj:assertj-core:${ext['assertj.version']}"
dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}"
dependency "org.mockito:mockito-core:${ ext['mockito.version']}"
dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}"

dependencySet(group: 'org.mockito', version: ext['mockito.version']) {
entry 'mockito-junit-jupiter'
entry 'mockito-core'
}

dependencySet(group: 'org.junit.jupiter', version: ext['junit.version']) {
entry 'junit-jupiter-api'
Expand Down
1 change: 1 addition & 0 deletions rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import io.netty.buffer.Unpooled;

public class FrameUtil {

private FrameUtil() {}

public static String toString(ByteBuf frame) {
Expand Down
2 changes: 2 additions & 0 deletions rsocket-transport-netty/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ dependencies {
testImplementation project(':rsocket-test')
testImplementation 'io.projectreactor:reactor-test'
testImplementation 'org.assertj:assertj-core'
testImplementation 'org.mockito:mockito-core'
testImplementation 'org.mockito:mockito-junit-jupiter'
testImplementation 'org.junit.jupiter:junit-jupiter-api'
testImplementation 'org.junit.jupiter:junit-jupiter-params'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.rsocket.transport.netty.client;

import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
import static io.rsocket.transport.netty.UriUtils.getPort;
import static io.rsocket.transport.netty.UriUtils.isSecure;

Expand All @@ -42,6 +43,7 @@
*/
public final class WebsocketClientTransport implements ClientTransport, TransportHeaderAware {

private static final int DEFAULT_FRAME_SIZE = 65536;
private static final String DEFAULT_PATH = "/";

private final HttpClient client;
Expand Down Expand Up @@ -151,7 +153,7 @@ private static TcpClient createClient(URI uri) {
public Mono<DuplexConnection> connect(int mtu) {
return client
.headers(headers -> transportHeaders.get().forEach(headers::set))
.websocket()
.websocket(FRAME_LENGTH_MASK)
.uri(path)
.connect()
.map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@

package io.rsocket.transport.netty.server;

import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;

import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.codec.http.HttpMethod;
import io.rsocket.Closeable;
import io.rsocket.DuplexConnection;
import io.rsocket.fragmentation.FragmentationDuplexConnection;
import io.rsocket.transport.ServerTransport;
import io.rsocket.transport.netty.WebsocketDuplexConnection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import reactor.core.publisher.Mono;
import reactor.netty.Connection;
import reactor.netty.http.server.HttpServer;
Expand All @@ -35,7 +44,7 @@
*/
public final class WebsocketRouteTransport implements ServerTransport<Closeable> {

private final String path;
private final UriPathTemplate template;

private final Consumer<? super HttpServerRoutes> routesBuilder;

Expand All @@ -53,7 +62,7 @@ public WebsocketRouteTransport(

this.server = Objects.requireNonNull(server, "server must not be null");
this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null");
this.path = Objects.requireNonNull(path, "path must not be null");
this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null"));
}

@Override
Expand All @@ -65,7 +74,7 @@ public Mono<Closeable> start(ConnectionAcceptor acceptor, int mtu) {
routes -> {
routesBuilder.accept(routes);
routes.ws(
path,
hsr -> hsr.method().equals(HttpMethod.GET) && template.matches(hsr.uri()),
(in, out) -> {
DuplexConnection connection = new WebsocketDuplexConnection((Connection) in);
if (mtu > 0) {
Expand All @@ -74,9 +83,128 @@ public Mono<Closeable> start(ConnectionAcceptor acceptor, int mtu) {
connection, ByteBufAllocator.DEFAULT, mtu, false);
}
return acceptor.apply(connection).then(out.neverComplete());
});
},
null,
FRAME_LENGTH_MASK);
})
.bind()
.map(CloseableChannel::new);
}

static final class UriPathTemplate {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is port from Reactor-Netty internals

cc/
@violetagg


private static final Pattern FULL_SPLAT_PATTERN = Pattern.compile("[\\*][\\*]");
private static final String FULL_SPLAT_REPLACEMENT = ".*";

private static final Pattern NAME_SPLAT_PATTERN = Pattern.compile("\\{([^/]+?)\\}[\\*][\\*]");
private static final String NAME_SPLAT_REPLACEMENT = "(?<%NAME%>.*)";

private static final Pattern NAME_PATTERN = Pattern.compile("\\{([^/]+?)\\}");
private static final String NAME_REPLACEMENT = "(?<%NAME%>[^\\/]*)";

private final List<String> pathVariables = new ArrayList<>();
private final HashMap<String, Matcher> matchers = new HashMap<>();
private final HashMap<String, Map<String, String>> vars = new HashMap<>();

private final Pattern uriPattern;

static String filterQueryParams(String uri) {
int hasQuery = uri.lastIndexOf("?");
if (hasQuery != -1) {
return uri.substring(0, hasQuery);
} else {
return uri;
}
}

/**
* Creates a new {@code UriPathTemplate} from the given {@code uriPattern}.
*
* @param uriPattern The pattern to be used by the template
*/
UriPathTemplate(String uriPattern) {
String s = "^" + filterQueryParams(uriPattern);

Matcher m = NAME_SPLAT_PATTERN.matcher(s);
while (m.find()) {
for (int i = 1; i <= m.groupCount(); i++) {
String name = m.group(i);
pathVariables.add(name);
s = m.replaceFirst(NAME_SPLAT_REPLACEMENT.replaceAll("%NAME%", name));
m.reset(s);
}
}

m = NAME_PATTERN.matcher(s);
while (m.find()) {
for (int i = 1; i <= m.groupCount(); i++) {
String name = m.group(i);
pathVariables.add(name);
s = m.replaceFirst(NAME_REPLACEMENT.replaceAll("%NAME%", name));
m.reset(s);
}
}

m = FULL_SPLAT_PATTERN.matcher(s);
while (m.find()) {
s = m.replaceAll(FULL_SPLAT_REPLACEMENT);
m.reset(s);
}

this.uriPattern = Pattern.compile(s + "$");
}

/**
* Tests the given {@code uri} against this template, returning {@code true} if the uri matches
* the template, {@code false} otherwise.
*
* @param uri The uri to match
* @return {@code true} if there's a match, {@code false} otherwise
*/
public boolean matches(String uri) {
return matcher(uri).matches();
}

/**
* Matches the template against the given {@code uri} returning a map of path parameters
* extracted from the uri, keyed by the names in the template. If the uri does not match, or
* there are no path parameters, an empty map is returned.
*
* @param uri The uri to match
* @return the path parameters from the uri. Never {@code null}.
*/
final Map<String, String> match(String uri) {
Map<String, String> pathParameters = vars.get(uri);
if (null != pathParameters) {
return pathParameters;
}

pathParameters = new HashMap<>();
Matcher m = matcher(uri);
if (m.matches()) {
int i = 1;
for (String name : pathVariables) {
String val = m.group(i++);
pathParameters.put(name, val);
}
}
synchronized (vars) {
vars.put(uri, pathParameters);
}

return pathParameters;
}

private Matcher matcher(String uri) {
uri = filterQueryParams(uri);
Matcher m = matchers.get(uri);
if (null == m) {
m = uriPattern.matcher(uri);
synchronized (matchers) {
matchers.put(uri, m);
}
}
return m;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package io.rsocket.transport.netty.server;

import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;

import io.netty.buffer.ByteBufAllocator;
import io.rsocket.DuplexConnection;
import io.rsocket.fragmentation.FragmentationDuplexConnection;
Expand Down Expand Up @@ -114,6 +116,8 @@ public Mono<CloseableChannel> start(ConnectionAcceptor acceptor, int mtu) {
(request, response) -> {
transportHeaders.get().forEach(response::addHeader);
return response.sendWebsocket(
null,
FRAME_LENGTH_MASK,
(in, out) -> {
DuplexConnection connection = new WebsocketDuplexConnection((Connection) in);
if (mtu > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,71 @@

package io.rsocket.transport.netty.client;

import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNullPointerException;

import io.rsocket.transport.netty.server.WebsocketServerTransport;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.Collections;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
import reactor.core.publisher.Mono;
import reactor.netty.http.client.HttpClient;
import reactor.test.StepVerifier;

@ExtendWith(MockitoExtension.class)
final class WebsocketClientTransportTest {

@Test
public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() {
ArgumentCaptor<Integer> captor = ArgumentCaptor.forClass(Integer.class);
HttpClient httpClient = Mockito.spy(HttpClient.create());
Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any());
Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture());

WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, "");

clientTransport.connect(0).subscribe();

Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK);
}

@Test
public void testThatSetupWithSpecifiedFrameSizeButLowerThanWsDefaultShouldSetToWsDefault() {
ArgumentCaptor<Integer> captor = ArgumentCaptor.forClass(Integer.class);
HttpClient httpClient = Mockito.spy(HttpClient.create());
Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any());
Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture());

WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, "");

clientTransport.connect(65536 - 10000).subscribe();

Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK);
}

@Test
public void
testThatSetupWithSpecifiedFrameSizeButHigherThanWsDefaultShouldSetToSpecifiedFrameSize() {
ArgumentCaptor<Integer> captor = ArgumentCaptor.forClass(Integer.class);
HttpClient httpClient = Mockito.spy(HttpClient.create());
Mockito.doAnswer(a -> httpClient).when(httpClient).headers(Mockito.any());
Mockito.doCallRealMethod().when(httpClient).websocket(captor.capture());

WebsocketClientTransport clientTransport = WebsocketClientTransport.create(httpClient, "");

clientTransport.connect(65536 + 10000).subscribe();

Assertions.assertThat(captor.getValue()).isEqualTo(FRAME_LENGTH_MASK);
}

@DisplayName("connects to server")
@Test
void connect() {
Expand Down
Loading