18
18
*/
19
19
package org .neo4j .driver .v1 .integration ;
20
20
21
+ import org .junit .After ;
21
22
import org .junit .Before ;
22
23
import org .junit .Test ;
23
24
24
25
import java .io .IOException ;
26
+ import java .net .ServerSocket ;
27
+ import java .net .Socket ;
28
+ import java .net .SocketException ;
25
29
import java .nio .ByteBuffer ;
26
30
import java .nio .channels .ByteChannel ;
27
- import java .security .GeneralSecurityException ;
28
- import java .security .KeyManagementException ;
29
31
import java .security .KeyStore ;
30
- import java .security .KeyStoreException ;
31
- import java .security .NoSuchAlgorithmException ;
32
- import java .security .UnrecoverableKeyException ;
33
32
import java .security .cert .CertificateException ;
34
33
import java .security .cert .X509Certificate ;
34
+ import java .util .concurrent .ExecutorService ;
35
+ import java .util .concurrent .Future ;
35
36
import javax .net .ssl .KeyManagerFactory ;
36
37
import javax .net .ssl .SSLContext ;
38
+ import javax .net .ssl .SSLServerSocketFactory ;
37
39
import javax .net .ssl .TrustManager ;
38
40
import javax .net .ssl .X509TrustManager ;
39
41
42
+ import static java .util .concurrent .Executors .newSingleThreadExecutor ;
43
+ import static java .util .concurrent .TimeUnit .SECONDS ;
44
+ import static org .junit .Assert .assertNull ;
45
+ import static org .junit .Assert .assertTrue ;
46
+ import static org .neo4j .driver .v1 .util .DaemonThreadFactory .daemon ;
47
+
40
48
/**
41
49
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
42
50
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
43
51
* for the following variables:
44
- *
52
+ * <p>
45
53
* - Network frame size
46
54
* - Bolt message size
47
55
* - Read buffer size
48
- *
56
+ * <p>
49
57
* It tests every possible combination, and it does this currently only for the read path, expanding
50
58
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
51
59
* handshake, transferring the data, and verifying the data is correct after decryption.
52
60
*/
53
61
public abstract class TLSSocketChannelFragmentation
54
62
{
55
- protected SSLContext sslCtx ;
63
+ SSLContext sslCtx ;
64
+ ServerSocket serverSocket ;
65
+ volatile byte [] blobOfData ;
66
+
67
+ private ExecutorService serverExecutor ;
68
+ private Future <?> serverTask ;
56
69
57
70
@ Before
58
- public void setup () throws Throwable
71
+ public void setUp () throws Throwable
72
+ {
73
+ sslCtx = createSSLContext ();
74
+ serverSocket = createServerSocket ( sslCtx );
75
+ serverExecutor = createServerExecutor ();
76
+ serverTask = launchServer ( serverExecutor , createServerRunnable ( sslCtx ) );
77
+ }
78
+
79
+ @ After
80
+ public void tearDown () throws Exception
59
81
{
60
- createSSLContext ();
61
- createServer ();
82
+ serverSocket .close ();
83
+ serverExecutor .shutdownNow ();
84
+ assertTrue ( "Unable to terminate server socket" , serverExecutor .awaitTermination ( 30 , SECONDS ) );
85
+
86
+ assertNull ( serverTask .get ( 30 , SECONDS ) );
62
87
}
63
88
64
89
@ Test
@@ -67,51 +92,104 @@ public void shouldHandleFuzziness() throws Throwable
67
92
// Given
68
93
int networkFrameSize , userBufferSize , blobOfDataSize ;
69
94
70
- for ( int dataBlobMagnitude = 1 ; dataBlobMagnitude < 16 ; dataBlobMagnitude += 2 )
95
+ for ( int dataBlobMagnitude = 1 ; dataBlobMagnitude < 16 ; dataBlobMagnitude += 2 )
71
96
{
72
97
blobOfDataSize = (int ) Math .pow ( 2 , dataBlobMagnitude );
98
+ blobOfData = blobOfData ( blobOfDataSize );
73
99
74
- for ( int frameSizeMagnitude = 1 ; frameSizeMagnitude < 16 ; frameSizeMagnitude += 2 )
100
+ for ( int frameSizeMagnitude = 1 ; frameSizeMagnitude < 16 ; frameSizeMagnitude += 2 )
75
101
{
76
102
networkFrameSize = (int ) Math .pow ( 2 , frameSizeMagnitude );
77
- for ( int userBufferMagnitude = 1 ; userBufferMagnitude < 16 ; userBufferMagnitude += 2 )
103
+ for ( int userBufferMagnitude = 1 ; userBufferMagnitude < 16 ; userBufferMagnitude += 2 )
78
104
{
79
105
userBufferSize = (int ) Math .pow ( 2 , userBufferMagnitude );
80
- testForBufferSizes ( blobOfDataSize , networkFrameSize , userBufferSize );
106
+ testForBufferSizes ( blobOfData , networkFrameSize , userBufferSize );
81
107
}
82
108
}
83
109
}
84
110
}
85
111
86
- protected void createSSLContext ()
87
- throws KeyStoreException , IOException , NoSuchAlgorithmException , CertificateException ,
88
- UnrecoverableKeyException , KeyManagementException
112
+ protected abstract void testForBufferSizes ( byte [] blobOfData , int networkFrameSize , int userBufferSize )
113
+ throws Exception ;
114
+
115
+ protected abstract Runnable createServerRunnable ( SSLContext sslContext ) throws IOException ;
116
+
117
+ private static SSLContext createSSLContext () throws Exception
89
118
{
90
- KeyStore ks = KeyStore .getInstance ("JKS" );
119
+ KeyStore ks = KeyStore .getInstance ( "JKS" );
91
120
char [] password = "password" .toCharArray ();
92
- ks .load ( getClass () .getResourceAsStream ( "/keystore.jks" ), password );
93
- KeyManagerFactory kmf = KeyManagerFactory .getInstance ("SunX509" );
94
- kmf .init (ks , password );
121
+ ks .load ( TLSSocketChannelFragmentation . class .getResourceAsStream ( "/keystore.jks" ), password );
122
+ KeyManagerFactory kmf = KeyManagerFactory .getInstance ( "SunX509" );
123
+ kmf .init ( ks , password );
95
124
96
- sslCtx = SSLContext .getInstance ("TLS" );
97
- sslCtx .init ( kmf .getKeyManagers (), new TrustManager []{new X509TrustManager () {
98
- public void checkClientTrusted ( X509Certificate [] chain , String authType ) throws CertificateException
125
+ SSLContext sslCtx = SSLContext .getInstance ( "TLS" );
126
+ sslCtx .init ( kmf .getKeyManagers (), new TrustManager []{new X509TrustManager ()
127
+ {
128
+ @ Override
129
+ public void checkClientTrusted ( X509Certificate [] chain , String authType ) throws CertificateException
99
130
{
100
131
}
101
132
102
- public void checkServerTrusted (X509Certificate [] chain , String authType ) throws CertificateException {
133
+ @ Override
134
+ public void checkServerTrusted ( X509Certificate [] chain , String authType ) throws CertificateException
135
+ {
103
136
}
104
137
105
- public X509Certificate [] getAcceptedIssuers () {
138
+ @ Override
139
+ public X509Certificate [] getAcceptedIssuers ()
140
+ {
106
141
return null ;
107
142
}
108
143
}}, null );
144
+
145
+ return sslCtx ;
109
146
}
110
147
111
- protected abstract void testForBufferSizes ( int blobOfDataSize , int networkFrameSize , int userBufferSize ) throws IOException ,
112
- GeneralSecurityException ;
148
+ private static ServerSocket createServerSocket ( SSLContext sslContext ) throws IOException
149
+ {
150
+ SSLServerSocketFactory ssf = sslContext .getServerSocketFactory ();
151
+ return ssf .createServerSocket ( 0 );
152
+ }
153
+
154
+ private ExecutorService createServerExecutor ()
155
+ {
156
+ return newSingleThreadExecutor ( daemon ( getClass ().getSimpleName () + "-Server-" ) );
157
+ }
113
158
114
- protected abstract void createServer () throws IOException ;
159
+ private Future <?> launchServer ( ExecutorService executor , Runnable runnable )
160
+ {
161
+ return executor .submit ( runnable );
162
+ }
163
+
164
+ static byte [] blobOfData ( int dataBlobSize )
165
+ {
166
+ byte [] blobOfData = new byte [dataBlobSize ];
167
+ // If the blob is all zeros, we'd miss data corruption problems in assertions, so
168
+ // fill the data blob with different values.
169
+ for ( int i = 0 ; i < blobOfData .length ; i ++ )
170
+ {
171
+ blobOfData [i ] = (byte ) (i % 128 );
172
+ }
173
+
174
+ return blobOfData ;
175
+ }
176
+
177
+ static Socket accept ( ServerSocket serverSocket ) throws IOException
178
+ {
179
+ try
180
+ {
181
+ return serverSocket .accept ();
182
+ }
183
+ catch ( SocketException e )
184
+ {
185
+ String message = e .getMessage ();
186
+ if ( "Socket closed" .equalsIgnoreCase ( message ) )
187
+ {
188
+ return null ;
189
+ }
190
+ throw e ;
191
+ }
192
+ }
115
193
116
194
/**
117
195
* Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate
@@ -122,7 +200,7 @@ protected static class LittleAtATimeChannel implements ByteChannel
122
200
private final ByteChannel delegate ;
123
201
private final int maxFrameSize ;
124
202
125
- public LittleAtATimeChannel ( ByteChannel delegate , int maxFrameSize )
203
+ LittleAtATimeChannel ( ByteChannel delegate , int maxFrameSize )
126
204
{
127
205
128
206
this .delegate = delegate ;
@@ -152,7 +230,7 @@ public int write( ByteBuffer src ) throws IOException
152
230
}
153
231
finally
154
232
{
155
- src .limit (originalLimit );
233
+ src .limit ( originalLimit );
156
234
}
157
235
}
158
236
@@ -167,7 +245,7 @@ public int read( ByteBuffer dst ) throws IOException
167
245
}
168
246
finally
169
247
{
170
- dst .limit (originalLimit );
248
+ dst .limit ( originalLimit );
171
249
}
172
250
}
173
251
}
0 commit comments