Skip to content

Commit c5ec374

Browse files
mostroverkhovrobertroeser
authored andcommitted
fix websockets ping/pong control frames handling (#707)
* fix websockets ping/pong control frames handling Signed-off-by: Maksym Ostroverkhov <[email protected]> * remove odd tests Signed-off-by: Maksym Ostroverkhov <[email protected]>
1 parent dd1e44f commit c5ec374

File tree

5 files changed

+206
-116
lines changed

5 files changed

+206
-116
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package io.rsocket.transport.netty.server;
2+
3+
import static io.netty.channel.ChannelHandler.*;
4+
5+
import io.netty.channel.ChannelHandler;
6+
import io.netty.channel.ChannelHandlerContext;
7+
import io.netty.channel.ChannelInboundHandlerAdapter;
8+
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
9+
import io.netty.util.ReferenceCountUtil;
10+
import io.rsocket.Closeable;
11+
import io.rsocket.transport.ServerTransport;
12+
import java.util.function.Function;
13+
import org.slf4j.Logger;
14+
import org.slf4j.LoggerFactory;
15+
import reactor.netty.http.server.HttpServer;
16+
17+
abstract class BaseWebsocketServerTransport<T extends Closeable> implements ServerTransport<T> {
18+
private static final Logger logger = LoggerFactory.getLogger(BaseWebsocketServerTransport.class);
19+
private static final ChannelHandler pongHandler = new PongHandler();
20+
21+
static Function<HttpServer, HttpServer> serverConfigurer =
22+
server ->
23+
server.tcpConfiguration(
24+
tcpServer ->
25+
tcpServer.doOnConnection(connection -> connection.addHandlerLast(pongHandler)));
26+
27+
@Sharable
28+
private static class PongHandler extends ChannelInboundHandlerAdapter {
29+
@Override
30+
public void channelRead(ChannelHandlerContext ctx, Object msg) {
31+
if (msg instanceof PongWebSocketFrame) {
32+
logger.debug("received WebSocket Pong Frame");
33+
ReferenceCountUtil.safeRelease(msg);
34+
ctx.read();
35+
} else {
36+
ctx.fireChannelRead(msg);
37+
}
38+
}
39+
}
40+
}

rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
* An implementation of {@link ServerTransport} that connects via Websocket and listens on specified
4747
* routes.
4848
*/
49-
public final class WebsocketRouteTransport implements ServerTransport<Closeable> {
49+
public final class WebsocketRouteTransport extends BaseWebsocketServerTransport<Closeable> {
5050

5151
private final UriPathTemplate template;
5252

@@ -63,8 +63,7 @@ public final class WebsocketRouteTransport implements ServerTransport<Closeable>
6363
*/
6464
public WebsocketRouteTransport(
6565
HttpServer server, Consumer<? super HttpServerRoutes> routesBuilder, String path) {
66-
67-
this.server = Objects.requireNonNull(server, "server must not be null");
66+
this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null"));
6867
this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null");
6968
this.template = new UriPathTemplate(Objects.requireNonNull(path, "path must not be null"));
7069
}

rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,6 @@
1919
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
2020

2121
import io.netty.buffer.ByteBufAllocator;
22-
import io.netty.buffer.Unpooled;
23-
import io.netty.channel.ChannelHandlerContext;
24-
import io.netty.channel.ChannelInboundHandlerAdapter;
25-
import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame;
26-
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
27-
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
2822
import io.rsocket.DuplexConnection;
2923
import io.rsocket.fragmentation.FragmentationDuplexConnection;
3024
import io.rsocket.transport.ClientTransport;
@@ -46,16 +40,16 @@
4640
* An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a
4741
* Websocket.
4842
*/
49-
public final class WebsocketServerTransport
50-
implements ServerTransport<CloseableChannel>, TransportHeaderAware {
43+
public final class WebsocketServerTransport extends BaseWebsocketServerTransport<CloseableChannel>
44+
implements TransportHeaderAware {
5145
private static final Logger logger = LoggerFactory.getLogger(WebsocketServerTransport.class);
5246

5347
private final HttpServer server;
5448

5549
private Supplier<Map<String, String>> transportHeaders = Collections::emptyMap;
5650

5751
private WebsocketServerTransport(HttpServer server) {
58-
this.server = server;
52+
this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null"));
5953
}
6054

6155
/**
@@ -107,33 +101,7 @@ public static WebsocketServerTransport create(InetSocketAddress address) {
107101
public static WebsocketServerTransport create(final HttpServer server) {
108102
Objects.requireNonNull(server, "server must not be null");
109103

110-
return new WebsocketServerTransport(
111-
server.tcpConfiguration(
112-
tcpServer ->
113-
tcpServer.doOnConnection(
114-
connection ->
115-
connection.addHandlerLast(
116-
new ChannelInboundHandlerAdapter() {
117-
@Override
118-
public void channelRead(ChannelHandlerContext ctx, Object msg)
119-
throws Exception {
120-
if (msg instanceof PongWebSocketFrame) {
121-
logger.debug("received WebSocket Pong Frame");
122-
} else if (msg instanceof PingWebSocketFrame) {
123-
logger.debug(
124-
"received WebSocket Ping Frame - sending Pong Frame");
125-
PongWebSocketFrame pongWebSocketFrame =
126-
new PongWebSocketFrame(Unpooled.EMPTY_BUFFER);
127-
ctx.writeAndFlush(pongWebSocketFrame);
128-
} else if (msg instanceof CloseWebSocketFrame) {
129-
logger.warn(
130-
"received WebSocket Close Frame - connection is closing");
131-
ctx.close();
132-
} else {
133-
ctx.fireChannelRead(msg);
134-
}
135-
}
136-
}))));
104+
return new WebsocketServerTransport(server);
137105
}
138106

139107
@Override
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package io.rsocket.transport.netty;
2+
3+
import io.netty.buffer.Unpooled;
4+
import io.netty.channel.Channel;
5+
import io.netty.channel.ChannelHandlerContext;
6+
import io.netty.channel.ChannelInboundHandlerAdapter;
7+
import io.netty.handler.codec.http.websocketx.PingWebSocketFrame;
8+
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame;
9+
import io.netty.handler.codec.http.websocketx.WebSocketFrame;
10+
import io.netty.util.ReferenceCountUtil;
11+
import io.rsocket.*;
12+
import io.rsocket.transport.ServerTransport;
13+
import io.rsocket.transport.netty.client.WebsocketClientTransport;
14+
import io.rsocket.transport.netty.server.WebsocketRouteTransport;
15+
import io.rsocket.transport.netty.server.WebsocketServerTransport;
16+
import io.rsocket.util.DefaultPayload;
17+
import java.nio.charset.StandardCharsets;
18+
import java.time.Duration;
19+
import java.util.stream.Stream;
20+
import org.junit.jupiter.api.AfterEach;
21+
import org.junit.jupiter.params.ParameterizedTest;
22+
import org.junit.jupiter.params.provider.Arguments;
23+
import org.junit.jupiter.params.provider.MethodSource;
24+
import reactor.core.publisher.Mono;
25+
import reactor.core.publisher.MonoProcessor;
26+
import reactor.netty.http.client.HttpClient;
27+
import reactor.netty.http.server.HttpServer;
28+
import reactor.test.StepVerifier;
29+
30+
public class WebsocketPingPongIntegrationTest {
31+
private static final String host = "localhost";
32+
private static final int port = 8088;
33+
34+
private Closeable server;
35+
36+
@AfterEach
37+
void tearDown() {
38+
server.dispose();
39+
}
40+
41+
@ParameterizedTest
42+
@MethodSource("provideServerTransport")
43+
void webSocketPingPong(ServerTransport<Closeable> serverTransport) {
44+
server =
45+
RSocketFactory.receive()
46+
.acceptor((setup, sendingSocket) -> Mono.just(new EchoRSocket()))
47+
.transport(serverTransport)
48+
.start()
49+
.block();
50+
51+
String expectedData = "data";
52+
String expectedPing = "ping";
53+
54+
PingSender pingSender = new PingSender();
55+
56+
HttpClient httpClient =
57+
HttpClient.create()
58+
.tcpConfiguration(
59+
tcpClient ->
60+
tcpClient
61+
.doOnConnected(b -> b.addHandlerLast(pingSender))
62+
.host(host)
63+
.port(port));
64+
65+
RSocket rSocket =
66+
RSocketFactory.connect()
67+
.transport(WebsocketClientTransport.create(httpClient, "/"))
68+
.start()
69+
.block();
70+
71+
rSocket
72+
.requestResponse(DefaultPayload.create(expectedData))
73+
.delaySubscription(pingSender.sendPing(expectedPing))
74+
.as(StepVerifier::create)
75+
.expectNextMatches(p -> expectedData.equals(p.getDataUtf8()))
76+
.expectComplete()
77+
.verify(Duration.ofSeconds(5));
78+
79+
pingSender
80+
.receivePong()
81+
.as(StepVerifier::create)
82+
.expectNextMatches(expectedPing::equals)
83+
.expectComplete()
84+
.verify(Duration.ofSeconds(5));
85+
86+
rSocket
87+
.requestResponse(DefaultPayload.create(expectedData))
88+
.delaySubscription(pingSender.sendPong())
89+
.as(StepVerifier::create)
90+
.expectNextMatches(p -> expectedData.equals(p.getDataUtf8()))
91+
.expectComplete()
92+
.verify(Duration.ofSeconds(5));
93+
}
94+
95+
private static Stream<Arguments> provideServerTransport() {
96+
return Stream.of(
97+
Arguments.of(WebsocketServerTransport.create(host, port)),
98+
Arguments.of(
99+
new WebsocketRouteTransport(
100+
HttpServer.create().host(host).port(port), routes -> {}, "/")));
101+
}
102+
103+
private static class EchoRSocket extends AbstractRSocket {
104+
@Override
105+
public Mono<Payload> requestResponse(Payload payload) {
106+
return Mono.just(payload);
107+
}
108+
}
109+
110+
private static class PingSender extends ChannelInboundHandlerAdapter {
111+
private final MonoProcessor<Channel> channel = MonoProcessor.create();
112+
private final MonoProcessor<String> pong = MonoProcessor.create();
113+
114+
@Override
115+
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
116+
if (msg instanceof PongWebSocketFrame) {
117+
pong.onNext(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8));
118+
ReferenceCountUtil.safeRelease(msg);
119+
ctx.read();
120+
} else {
121+
super.channelRead(ctx, msg);
122+
}
123+
}
124+
125+
@Override
126+
public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
127+
Channel ch = ctx.channel();
128+
if (!channel.isTerminated() && ch.isWritable()) {
129+
channel.onNext(ctx.channel());
130+
}
131+
super.channelWritabilityChanged(ctx);
132+
}
133+
134+
@Override
135+
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
136+
Channel ch = ctx.channel();
137+
if (ch.isWritable()) {
138+
channel.onNext(ch);
139+
}
140+
super.handlerAdded(ctx);
141+
}
142+
143+
public Mono<Void> sendPing(String data) {
144+
return send(
145+
new PingWebSocketFrame(Unpooled.wrappedBuffer(data.getBytes(StandardCharsets.UTF_8))));
146+
}
147+
148+
public Mono<Void> sendPong() {
149+
return send(new PongWebSocketFrame());
150+
}
151+
152+
public Mono<String> receivePong() {
153+
return pong;
154+
}
155+
156+
private Mono<Void> send(WebSocketFrame webSocketFrame) {
157+
return channel.doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then();
158+
}
159+
}
160+
}

rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,93 +16,16 @@
1616

1717
package io.rsocket.transport.netty.server;
1818

19-
import static io.rsocket.frame.FrameLengthFlyweight.FRAME_LENGTH_MASK;
2019
import static org.assertj.core.api.Assertions.assertThatNullPointerException;
2120

22-
import java.util.function.BiFunction;
23-
import java.util.function.Consumer;
24-
import java.util.function.Predicate;
2521
import org.junit.jupiter.api.DisplayName;
2622
import org.junit.jupiter.api.Test;
27-
import org.mockito.ArgumentCaptor;
28-
import org.mockito.Mockito;
2923
import reactor.core.publisher.Mono;
3024
import reactor.netty.http.server.HttpServer;
31-
import reactor.netty.http.server.HttpServerRoutes;
3225
import reactor.test.StepVerifier;
3326

3427
final class WebsocketRouteTransportTest {
3528

36-
@Test
37-
public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() {
38-
ArgumentCaptor<Consumer> captor = ArgumentCaptor.forClass(Consumer.class);
39-
HttpServer httpServer = Mockito.spy(HttpServer.create());
40-
HttpServerRoutes routes = Mockito.mock(HttpServerRoutes.class);
41-
Mockito.doAnswer(a -> httpServer).when(httpServer).route(captor.capture());
42-
Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind();
43-
44-
WebsocketRouteTransport serverTransport =
45-
new WebsocketRouteTransport(httpServer, (r) -> {}, "");
46-
47-
serverTransport.start(c -> Mono.empty(), 0).subscribe();
48-
49-
captor.getValue().accept(routes);
50-
51-
Mockito.verify(routes)
52-
.ws(
53-
Mockito.any(Predicate.class),
54-
Mockito.any(BiFunction.class),
55-
Mockito.nullable(String.class),
56-
Mockito.eq(FRAME_LENGTH_MASK));
57-
}
58-
59-
@Test
60-
public void testThatSetupWithSpecifiedFrameSizeButLowerThanWsDefaultShouldSetToWsDefault() {
61-
ArgumentCaptor<Consumer> captor = ArgumentCaptor.forClass(Consumer.class);
62-
HttpServer httpServer = Mockito.spy(HttpServer.create());
63-
HttpServerRoutes routes = Mockito.mock(HttpServerRoutes.class);
64-
Mockito.doAnswer(a -> httpServer).when(httpServer).route(captor.capture());
65-
Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind();
66-
67-
WebsocketRouteTransport serverTransport =
68-
new WebsocketRouteTransport(httpServer, (r) -> {}, "");
69-
70-
serverTransport.start(c -> Mono.empty(), 1000).subscribe();
71-
72-
captor.getValue().accept(routes);
73-
74-
Mockito.verify(routes)
75-
.ws(
76-
Mockito.any(Predicate.class),
77-
Mockito.any(BiFunction.class),
78-
Mockito.nullable(String.class),
79-
Mockito.eq(FRAME_LENGTH_MASK));
80-
}
81-
82-
@Test
83-
public void
84-
testThatSetupWithSpecifiedFrameSizeButHigherThanWsDefaultShouldSetToSpecifiedFrameSize() {
85-
ArgumentCaptor<Consumer> captor = ArgumentCaptor.forClass(Consumer.class);
86-
HttpServer httpServer = Mockito.spy(HttpServer.create());
87-
HttpServerRoutes routes = Mockito.mock(HttpServerRoutes.class);
88-
Mockito.doAnswer(a -> httpServer).when(httpServer).route(captor.capture());
89-
Mockito.doAnswer(a -> Mono.empty()).when(httpServer).bind();
90-
91-
WebsocketRouteTransport serverTransport =
92-
new WebsocketRouteTransport(httpServer, (r) -> {}, "");
93-
94-
serverTransport.start(c -> Mono.empty(), 65536 + 1000).subscribe();
95-
96-
captor.getValue().accept(routes);
97-
98-
Mockito.verify(routes)
99-
.ws(
100-
Mockito.any(Predicate.class),
101-
Mockito.any(BiFunction.class),
102-
Mockito.nullable(String.class),
103-
Mockito.eq(FRAME_LENGTH_MASK));
104-
}
105-
10629
@DisplayName("creates server")
10730
@Test
10831
void constructor() {

0 commit comments

Comments
 (0)