Skip to content

Commit 4d604b8

Browse files
committed
Cleanup TLSSocketChannelFragmentation ITs
Make sure tests always perform assertions, terminate threads and close sockets.
1 parent 42e4a67 commit 4d604b8

File tree

5 files changed

+169
-123
lines changed

5 files changed

+169
-123
lines changed

driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
import org.neo4j.driver.v1.exceptions.Neo4jException;
5959
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
6060
import org.neo4j.driver.v1.exceptions.TransientException;
61-
import org.neo4j.driver.v1.util.DaemonThreadFactory;
6261
import org.neo4j.driver.v1.util.TestNeo4j;
6362

6463
import static java.lang.String.format;
@@ -86,6 +85,7 @@
8685
import static org.mockito.Mockito.verify;
8786
import static org.neo4j.driver.internal.util.ServerVersion.v3_1_0;
8887
import static org.neo4j.driver.v1.Values.parameters;
88+
import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon;
8989
import static org.neo4j.driver.v1.util.Neo4jRunner.DEFAULT_AUTH_TOKEN;
9090

9191
public class SessionIT
@@ -1450,7 +1450,7 @@ private static void assertDeadlockDetectedError( ExecutionException e )
14501450

14511451
private static <T> Future<T> executeInDifferentThread( Callable<T> callable )
14521452
{
1453-
ExecutorService executor = newSingleThreadExecutor( new DaemonThreadFactory( "test-thread-" ) );
1453+
ExecutorService executor = newSingleThreadExecutor( daemon( "test-thread-" ) );
14541454
return executor.submit( callable );
14551455
}
14561456

driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelFragmentation.java

Lines changed: 111 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,72 @@
1818
*/
1919
package org.neo4j.driver.v1.integration;
2020

21+
import org.junit.After;
2122
import org.junit.Before;
2223
import org.junit.Test;
2324

2425
import java.io.IOException;
26+
import java.net.ServerSocket;
27+
import java.net.Socket;
28+
import java.net.SocketException;
2529
import java.nio.ByteBuffer;
2630
import java.nio.channels.ByteChannel;
27-
import java.security.GeneralSecurityException;
28-
import java.security.KeyManagementException;
2931
import java.security.KeyStore;
30-
import java.security.KeyStoreException;
31-
import java.security.NoSuchAlgorithmException;
32-
import java.security.UnrecoverableKeyException;
3332
import java.security.cert.CertificateException;
3433
import java.security.cert.X509Certificate;
34+
import java.util.concurrent.ExecutorService;
35+
import java.util.concurrent.Future;
3536
import javax.net.ssl.KeyManagerFactory;
3637
import javax.net.ssl.SSLContext;
38+
import javax.net.ssl.SSLServerSocketFactory;
3739
import javax.net.ssl.TrustManager;
3840
import javax.net.ssl.X509TrustManager;
3941

42+
import static java.util.concurrent.Executors.newSingleThreadExecutor;
43+
import static java.util.concurrent.TimeUnit.SECONDS;
44+
import static org.junit.Assert.assertNull;
45+
import static org.junit.Assert.assertTrue;
46+
import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon;
47+
4048
/**
4149
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
4250
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
4351
* for the following variables:
44-
*
52+
* <p>
4553
* - Network frame size
4654
* - Bolt message size
4755
* - Read buffer size
48-
*
56+
* <p>
4957
* It tests every possible combination, and it does this currently only for the read path, expanding
5058
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
5159
* handshake, transferring the data, and verifying the data is correct after decryption.
5260
*/
5361
public abstract class TLSSocketChannelFragmentation
5462
{
55-
protected SSLContext sslCtx;
63+
SSLContext sslCtx;
64+
ServerSocket serverSocket;
65+
volatile byte[] blobOfData;
66+
67+
private ExecutorService serverExecutor;
68+
private Future<?> serverTask;
5669

5770
@Before
58-
public void setup() throws Throwable
71+
public void setUp() throws Throwable
72+
{
73+
sslCtx = createSSLContext();
74+
serverSocket = createServerSocket( sslCtx );
75+
serverExecutor = createServerExecutor();
76+
serverTask = launchServer( serverExecutor, createServerRunnable( sslCtx ) );
77+
}
78+
79+
@After
80+
public void tearDown() throws Exception
5981
{
60-
createSSLContext();
61-
createServer();
82+
serverSocket.close();
83+
serverExecutor.shutdownNow();
84+
assertTrue( "Unable to terminate server socket", serverExecutor.awaitTermination( 30, SECONDS ) );
85+
86+
assertNull( serverTask.get( 30, SECONDS ) );
6287
}
6388

6489
@Test
@@ -67,51 +92,104 @@ public void shouldHandleFuzziness() throws Throwable
6792
// Given
6893
int networkFrameSize, userBufferSize, blobOfDataSize;
6994

70-
for(int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude+=2 )
95+
for ( int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude += 2 )
7196
{
7297
blobOfDataSize = (int) Math.pow( 2, dataBlobMagnitude );
98+
blobOfData = blobOfData( blobOfDataSize );
7399

74-
for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude+=2 )
100+
for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude += 2 )
75101
{
76102
networkFrameSize = (int) Math.pow( 2, frameSizeMagnitude );
77-
for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude+=2 )
103+
for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude += 2 )
78104
{
79105
userBufferSize = (int) Math.pow( 2, userBufferMagnitude );
80-
testForBufferSizes( blobOfDataSize, networkFrameSize, userBufferSize );
106+
testForBufferSizes( blobOfData, networkFrameSize, userBufferSize );
81107
}
82108
}
83109
}
84110
}
85111

86-
protected void createSSLContext()
87-
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException,
88-
UnrecoverableKeyException, KeyManagementException
112+
protected abstract void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize )
113+
throws Exception;
114+
115+
protected abstract Runnable createServerRunnable( SSLContext sslContext ) throws IOException;
116+
117+
private static SSLContext createSSLContext() throws Exception
89118
{
90-
KeyStore ks = KeyStore.getInstance("JKS");
119+
KeyStore ks = KeyStore.getInstance( "JKS" );
91120
char[] password = "password".toCharArray();
92-
ks.load( getClass().getResourceAsStream( "/keystore.jks" ), password );
93-
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
94-
kmf.init(ks, password);
121+
ks.load( TLSSocketChannelFragmentation.class.getResourceAsStream( "/keystore.jks" ), password );
122+
KeyManagerFactory kmf = KeyManagerFactory.getInstance( "SunX509" );
123+
kmf.init( ks, password );
95124

96-
sslCtx = SSLContext.getInstance("TLS");
97-
sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() {
98-
public void checkClientTrusted( X509Certificate[] chain, String authType) throws CertificateException
125+
SSLContext sslCtx = SSLContext.getInstance( "TLS" );
126+
sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager()
127+
{
128+
@Override
129+
public void checkClientTrusted( X509Certificate[] chain, String authType ) throws CertificateException
99130
{
100131
}
101132

102-
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
133+
@Override
134+
public void checkServerTrusted( X509Certificate[] chain, String authType ) throws CertificateException
135+
{
103136
}
104137

105-
public X509Certificate[] getAcceptedIssuers() {
138+
@Override
139+
public X509Certificate[] getAcceptedIssuers()
140+
{
106141
return null;
107142
}
108143
}}, null );
144+
145+
return sslCtx;
109146
}
110147

111-
protected abstract void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException,
112-
GeneralSecurityException;
148+
private static ServerSocket createServerSocket( SSLContext sslContext ) throws IOException
149+
{
150+
SSLServerSocketFactory ssf = sslContext.getServerSocketFactory();
151+
return ssf.createServerSocket( 0 );
152+
}
153+
154+
private ExecutorService createServerExecutor()
155+
{
156+
return newSingleThreadExecutor( daemon( getClass().getSimpleName() + "-Server-" ) );
157+
}
113158

114-
protected abstract void createServer() throws IOException;
159+
private Future<?> launchServer( ExecutorService executor, Runnable runnable )
160+
{
161+
return executor.submit( runnable );
162+
}
163+
164+
static byte[] blobOfData( int dataBlobSize )
165+
{
166+
byte[] blobOfData = new byte[dataBlobSize];
167+
// If the blob is all zeros, we'd miss data corruption problems in assertions, so
168+
// fill the data blob with different values.
169+
for ( int i = 0; i < blobOfData.length; i++ )
170+
{
171+
blobOfData[i] = (byte) (i % 128);
172+
}
173+
174+
return blobOfData;
175+
}
176+
177+
static Socket accept( ServerSocket serverSocket ) throws IOException
178+
{
179+
try
180+
{
181+
return serverSocket.accept();
182+
}
183+
catch ( SocketException e )
184+
{
185+
String message = e.getMessage();
186+
if ( "Socket closed".equalsIgnoreCase( message ) )
187+
{
188+
return null;
189+
}
190+
throw e;
191+
}
192+
}
115193

116194
/**
117195
* Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate
@@ -122,7 +200,7 @@ protected static class LittleAtATimeChannel implements ByteChannel
122200
private final ByteChannel delegate;
123201
private final int maxFrameSize;
124202

125-
public LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize )
203+
LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize )
126204
{
127205

128206
this.delegate = delegate;
@@ -152,7 +230,7 @@ public int write( ByteBuffer src ) throws IOException
152230
}
153231
finally
154232
{
155-
src.limit(originalLimit);
233+
src.limit( originalLimit );
156234
}
157235
}
158236

@@ -167,7 +245,7 @@ public int read( ByteBuffer dst ) throws IOException
167245
}
168246
finally
169247
{
170-
dst.limit(originalLimit);
248+
dst.limit( originalLimit );
171249
}
172250
}
173251
}

driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelReadFragmentationIT.java

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
import java.io.IOException;
2222
import java.io.OutputStream;
2323
import java.net.InetSocketAddress;
24-
import java.net.ServerSocket;
2524
import java.net.Socket;
25+
import java.net.SocketAddress;
2626
import java.nio.ByteBuffer;
2727
import java.nio.channels.ByteChannel;
2828
import java.nio.channels.SocketChannel;
29-
import java.security.GeneralSecurityException;
29+
import javax.net.ssl.SSLContext;
3030
import javax.net.ssl.SSLEngine;
31-
import javax.net.ssl.SSLServerSocketFactory;
3231

3332
import org.neo4j.driver.internal.security.TLSSocketChannel;
3433

@@ -40,40 +39,24 @@
4039
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
4140
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
4241
* for the following variables:
43-
*
42+
* <p>
4443
* - Network frame size
4544
* - Bolt message size
4645
* - Read buffer size
47-
*
46+
* <p>
4847
* It tests every possible combination, and it does this currently only for the read path, expanding
4948
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
5049
* handshake, transferring the data, and verifying the data is correct after decryption.
5150
*/
5251
public class TLSSocketChannelReadFragmentationIT extends TLSSocketChannelFragmentation
5352
{
54-
private byte[] blobOfData;
55-
private ServerSocket server;
56-
57-
58-
59-
private void blobOfDataSize( int dataBlobSize )
60-
{
61-
blobOfData = new byte[dataBlobSize];
62-
// If the blob is all zeros, we'd miss data corruption problems in assertions, so
63-
// fill the data blob with different values.
64-
for ( int i = 0; i < blobOfData.length; i++ )
65-
{
66-
blobOfData[i] = (byte) (i % 128);
67-
}
68-
}
69-
70-
protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException
53+
@Override
54+
protected void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize ) throws Exception
7155
{
72-
blobOfDataSize(blobOfDataSize);
7356
SSLEngine engine = sslCtx.createSSLEngine();
7457
engine.setUseClientMode( true );
75-
ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) );
76-
ch = new LittleAtATimeChannel( ch, networkFrameSize );
58+
SocketAddress address = new InetSocketAddress( serverSocket.getInetAddress(), serverSocket.getLocalPort() );
59+
ByteChannel ch = new LittleAtATimeChannel( SocketChannel.open( address ), networkFrameSize );
7760

7861
try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine ) )
7962
{
@@ -88,34 +71,37 @@ protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int
8871
}
8972
}
9073

91-
protected void createServer() throws IOException
74+
@Override
75+
protected Runnable createServerRunnable( SSLContext sslContext ) throws IOException
9276
{
93-
SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory();
94-
server = ssf.createServerSocket(0);
95-
96-
new Thread(new Runnable()
77+
return new Runnable()
9778
{
9879
@Override
9980
public void run()
10081
{
10182
try
10283
{
103-
//noinspection InfiniteLoopStatement
104-
while(true)
84+
// noinspection InfiniteLoopStatement
85+
while ( true )
10586
{
106-
Socket client = server.accept();
87+
Socket client = accept( serverSocket );
88+
if ( client == null )
89+
{
90+
return;
91+
}
92+
10793
OutputStream outputStream = client.getOutputStream();
10894
outputStream.write( blobOfData );
10995
outputStream.flush();
110-
// client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event
96+
97+
client.close();
11198
}
11299
}
113100
catch ( IOException e )
114101
{
115-
e.printStackTrace();
102+
throw new RuntimeException( e );
116103
}
117104
}
118-
}).start();
105+
};
119106
}
120-
121107
}

0 commit comments

Comments
 (0)