Skip to content

Cleanup TLSSocketChannelFragmentation ITs #369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import org.neo4j.driver.v1.exceptions.Neo4jException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import org.neo4j.driver.v1.exceptions.TransientException;
import org.neo4j.driver.v1.util.DaemonThreadFactory;
import org.neo4j.driver.v1.util.TestNeo4j;

import static java.lang.String.format;
Expand Down Expand Up @@ -86,6 +85,7 @@
import static org.mockito.Mockito.verify;
import static org.neo4j.driver.internal.util.ServerVersion.v3_1_0;
import static org.neo4j.driver.v1.Values.parameters;
import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon;
import static org.neo4j.driver.v1.util.Neo4jRunner.DEFAULT_AUTH_TOKEN;

public class SessionIT
Expand Down Expand Up @@ -1450,7 +1450,7 @@ private static void assertDeadlockDetectedError( ExecutionException e )

private static <T> Future<T> executeInDifferentThread( Callable<T> callable )
{
ExecutorService executor = newSingleThreadExecutor( new DaemonThreadFactory( "test-thread-" ) );
ExecutorService executor = newSingleThreadExecutor( daemon( "test-thread-" ) );
return executor.submit( callable );
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,72 @@
*/
package org.neo4j.driver.v1.integration;

import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.security.GeneralSecurityException;
import java.security.KeyManagementException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.neo4j.driver.v1.util.DaemonThreadFactory.daemon;

/**
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
* for the following variables:
*
* <p>
* - Network frame size
* - Bolt message size
* - Read buffer size
*
* <p>
* It tests every possible combination, and it does this currently only for the read path, expanding
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
* handshake, transferring the data, and verifying the data is correct after decryption.
*/
public abstract class TLSSocketChannelFragmentation
{
protected SSLContext sslCtx;
SSLContext sslCtx;
ServerSocket serverSocket;
volatile byte[] blobOfData;

private ExecutorService serverExecutor;
private Future<?> serverTask;

@Before
public void setup() throws Throwable
public void setUp() throws Throwable
{
sslCtx = createSSLContext();
serverSocket = createServerSocket( sslCtx );
serverExecutor = createServerExecutor();
serverTask = launchServer( serverExecutor, createServerRunnable( sslCtx ) );
}

@After
public void tearDown() throws Exception
{
createSSLContext();
createServer();
serverSocket.close();
serverExecutor.shutdownNow();
assertTrue( "Unable to terminate server socket", serverExecutor.awaitTermination( 30, SECONDS ) );

assertNull( serverTask.get( 30, SECONDS ) );
}

@Test
Expand All @@ -67,51 +92,104 @@ public void shouldHandleFuzziness() throws Throwable
// Given
int networkFrameSize, userBufferSize, blobOfDataSize;

for(int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude+=2 )
for ( int dataBlobMagnitude = 1; dataBlobMagnitude < 16; dataBlobMagnitude += 2 )
{
blobOfDataSize = (int) Math.pow( 2, dataBlobMagnitude );
blobOfData = blobOfData( blobOfDataSize );

for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude+=2 )
for ( int frameSizeMagnitude = 1; frameSizeMagnitude < 16; frameSizeMagnitude += 2 )
{
networkFrameSize = (int) Math.pow( 2, frameSizeMagnitude );
for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude+=2 )
for ( int userBufferMagnitude = 1; userBufferMagnitude < 16; userBufferMagnitude += 2 )
{
userBufferSize = (int) Math.pow( 2, userBufferMagnitude );
testForBufferSizes( blobOfDataSize, networkFrameSize, userBufferSize );
testForBufferSizes( blobOfData, networkFrameSize, userBufferSize );
}
}
}
}

protected void createSSLContext()
throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException,
UnrecoverableKeyException, KeyManagementException
protected abstract void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize )
throws Exception;

protected abstract Runnable createServerRunnable( SSLContext sslContext ) throws IOException;

private static SSLContext createSSLContext() throws Exception
{
KeyStore ks = KeyStore.getInstance("JKS");
KeyStore ks = KeyStore.getInstance( "JKS" );
char[] password = "password".toCharArray();
ks.load( getClass().getResourceAsStream( "/keystore.jks" ), password );
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, password);
ks.load( TLSSocketChannelFragmentation.class.getResourceAsStream( "/keystore.jks" ), password );
KeyManagerFactory kmf = KeyManagerFactory.getInstance( "SunX509" );
kmf.init( ks, password );

sslCtx = SSLContext.getInstance("TLS");
sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager() {
public void checkClientTrusted( X509Certificate[] chain, String authType) throws CertificateException
SSLContext sslCtx = SSLContext.getInstance( "TLS" );
sslCtx.init( kmf.getKeyManagers(), new TrustManager[]{new X509TrustManager()
{
@Override
public void checkClientTrusted( X509Certificate[] chain, String authType ) throws CertificateException
{
}

public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
@Override
public void checkServerTrusted( X509Certificate[] chain, String authType ) throws CertificateException
{
}

public X509Certificate[] getAcceptedIssuers() {
@Override
public X509Certificate[] getAcceptedIssuers()
{
return null;
}
}}, null );

return sslCtx;
}

protected abstract void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException,
GeneralSecurityException;
private static ServerSocket createServerSocket( SSLContext sslContext ) throws IOException
{
SSLServerSocketFactory ssf = sslContext.getServerSocketFactory();
return ssf.createServerSocket( 0 );
}

private ExecutorService createServerExecutor()
{
return newSingleThreadExecutor( daemon( getClass().getSimpleName() + "-Server-" ) );
}

protected abstract void createServer() throws IOException;
private Future<?> launchServer( ExecutorService executor, Runnable runnable )
{
return executor.submit( runnable );
}

static byte[] blobOfData( int dataBlobSize )
{
byte[] blobOfData = new byte[dataBlobSize];
// If the blob is all zeros, we'd miss data corruption problems in assertions, so
// fill the data blob with different values.
for ( int i = 0; i < blobOfData.length; i++ )
{
blobOfData[i] = (byte) (i % 128);
}

return blobOfData;
}

static Socket accept( ServerSocket serverSocket ) throws IOException
{
try
{
return serverSocket.accept();
}
catch ( SocketException e )
{
String message = e.getMessage();
if ( "Socket closed".equalsIgnoreCase( message ) )
{
return null;
}
throw e;
}
}

/**
* Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate
Expand All @@ -122,7 +200,7 @@ protected static class LittleAtATimeChannel implements ByteChannel
private final ByteChannel delegate;
private final int maxFrameSize;

public LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize )
LittleAtATimeChannel( ByteChannel delegate, int maxFrameSize )
{

this.delegate = delegate;
Expand Down Expand Up @@ -152,7 +230,7 @@ public int write( ByteBuffer src ) throws IOException
}
finally
{
src.limit(originalLimit);
src.limit( originalLimit );
}
}

Expand All @@ -167,7 +245,7 @@ public int read( ByteBuffer dst ) throws IOException
}
finally
{
dst.limit(originalLimit);
dst.limit( originalLimit );
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.SocketChannel;
import java.security.GeneralSecurityException;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLServerSocketFactory;

import org.neo4j.driver.internal.security.TLSSocketChannel;

Expand All @@ -40,40 +39,24 @@
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
* can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
* for the following variables:
*
* <p>
* - Network frame size
* - Bolt message size
* - Read buffer size
*
* <p>
* It tests every possible combination, and it does this currently only for the read path, expanding
* to the write path as well would be useful. For each size, it sets up a TLS server and tests the
* handshake, transferring the data, and verifying the data is correct after decryption.
*/
public class TLSSocketChannelReadFragmentationIT extends TLSSocketChannelFragmentation
{
private byte[] blobOfData;
private ServerSocket server;



private void blobOfDataSize( int dataBlobSize )
{
blobOfData = new byte[dataBlobSize];
// If the blob is all zeros, we'd miss data corruption problems in assertions, so
// fill the data blob with different values.
for ( int i = 0; i < blobOfData.length; i++ )
{
blobOfData[i] = (byte) (i % 128);
}
}

protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int userBufferSize ) throws IOException, GeneralSecurityException
@Override
protected void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int userBufferSize ) throws Exception
{
blobOfDataSize(blobOfDataSize);
SSLEngine engine = sslCtx.createSSLEngine();
engine.setUseClientMode( true );
ByteChannel ch = SocketChannel.open( new InetSocketAddress( server.getInetAddress(), server.getLocalPort() ) );
ch = new LittleAtATimeChannel( ch, networkFrameSize );
SocketAddress address = new InetSocketAddress( serverSocket.getInetAddress(), serverSocket.getLocalPort() );
ByteChannel ch = new LittleAtATimeChannel( SocketChannel.open( address ), networkFrameSize );

try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine ) )
{
Expand All @@ -88,34 +71,37 @@ protected void testForBufferSizes( int blobOfDataSize, int networkFrameSize, int
}
}

protected void createServer() throws IOException
@Override
protected Runnable createServerRunnable( SSLContext sslContext ) throws IOException
{
SSLServerSocketFactory ssf = sslCtx.getServerSocketFactory();
server = ssf.createServerSocket(0);

new Thread(new Runnable()
return new Runnable()
{
@Override
public void run()
{
try
{
//noinspection InfiniteLoopStatement
while(true)
// noinspection InfiniteLoopStatement
while ( true )
{
Socket client = server.accept();
Socket client = accept( serverSocket );
if ( client == null )
{
return;
}

OutputStream outputStream = client.getOutputStream();
outputStream.write( blobOfData );
outputStream.flush();
// client.close(); // TODO: Uncomment this, fix resulting error handling CLOSED event

client.close();
}
}
catch ( IOException e )
{
e.printStackTrace();
throw new RuntimeException( e );
}
}
}).start();
};
}

}
Loading