Skip to content

Fail early on auth errors in direct driver #373

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 18, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public class DirectConnectionProvider implements ConnectionProvider
{
this.address = address;
this.pool = pool;

verifyConnectivity();
}

@Override
Expand All @@ -55,4 +57,13 @@ public BoltServerAddress getAddress()
{
return address;
}

/**
* Acquires and releases a connection to verify connectivity so this connection provider fails fast. This is
* especially valuable when driver was created with incorrect credentials.
*/
private void verifyConnectivity()
{
acquireConnection( AccessMode.READ ).close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import org.neo4j.driver.internal.spi.PooledConnection;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.RETURNS_MOCKS;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.only;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.v1.AccessMode.READ;
Expand All @@ -42,7 +45,7 @@ public void acquiresConnectionsFromThePool()
ConnectionPool pool = mock( ConnectionPool.class );
PooledConnection connection1 = mock( PooledConnection.class );
PooledConnection connection2 = mock( PooledConnection.class );
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection1 ).thenReturn( connection2 );
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection1, connection1, connection2 );

DirectConnectionProvider provider = newConnectionProvider( pool );

Expand All @@ -53,12 +56,12 @@ public void acquiresConnectionsFromThePool()
@Test
public void closesPool() throws Exception
{
ConnectionPool pool = mock( ConnectionPool.class );
ConnectionPool pool = mock( ConnectionPool.class, RETURNS_MOCKS );
DirectConnectionProvider provider = newConnectionProvider( pool );

provider.close();

verify( pool, only() ).close();
verify( pool ).close();
}

@Test
Expand All @@ -71,9 +74,42 @@ public void returnsCorrectAddress()
assertEquals( address, provider.getAddress() );
}

@Test
public void testsConnectivityOnCreation()
{
ConnectionPool pool = mock( ConnectionPool.class );
PooledConnection connection = mock( PooledConnection.class );
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection );

assertNotNull( newConnectionProvider( pool ) );

verify( pool ).acquire( BoltServerAddress.LOCAL_DEFAULT );
verify( connection ).close();
}

@Test
public void throwsWhenTestConnectionThrows()
{
ConnectionPool pool = mock( ConnectionPool.class );
PooledConnection connection = mock( PooledConnection.class );
RuntimeException error = new RuntimeException();
doThrow( error ).when( connection ).close();
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connection );

try
{
newConnectionProvider( pool );
fail( "Exception expected" );
}
catch ( Exception e )
{
assertSame( error, e );
}
}

private static DirectConnectionProvider newConnectionProvider( BoltServerAddress address )
{
return new DirectConnectionProvider( address, mock( ConnectionPool.class ) );
return new DirectConnectionProvider( address, mock( ConnectionPool.class, RETURNS_MOCKS ) );
}

private static DirectConnectionProvider newConnectionProvider( ConnectionPool pool )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package org.neo4j.driver.internal;

import org.junit.After;
import org.junit.ClassRule;
import org.junit.Test;

import java.net.URI;
Expand All @@ -28,6 +30,7 @@
import org.neo4j.driver.v1.Record;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.util.StubServer;
import org.neo4j.driver.v1.util.TestNeo4j;

import static org.hamcrest.Matchers.is;
import static org.hamcrest.core.IsEqual.equalTo;
Expand All @@ -40,14 +43,28 @@

public class DirectDriverTest
{
@ClassRule
public static final TestNeo4j neo4j = new TestNeo4j();

private Driver driver;

@After
public void closeDriver() throws Exception
{
if ( driver != null )
{
driver.close();
}
}

@Test
public void shouldUseDefaultPortIfMissing()
{
// Given
URI uri = URI.create( "bolt://localhost" );

// When
Driver driver = GraphDatabase.driver( uri );
driver = GraphDatabase.driver( uri, neo4j.authToken() );

// Then
assertThat( driver, is( directDriverWithAddress( LOCAL_DEFAULT ) ) );
Expand All @@ -61,7 +78,7 @@ public void shouldAllowIPv6Address()
BoltServerAddress address = BoltServerAddress.from( uri );

// When
Driver driver = GraphDatabase.driver( uri );
driver = GraphDatabase.driver( uri, neo4j.authToken() );

// Then
assertThat( driver, is( directDriverWithAddress( address ) ) );
Expand All @@ -76,7 +93,7 @@ public void shouldRejectInvalidAddress()
// When & Then
try
{
Driver driver = GraphDatabase.driver( uri );
driver = GraphDatabase.driver( uri, neo4j.authToken() );
fail("Expecting error for wrong uri");
}
catch( IllegalArgumentException e )
Expand All @@ -93,7 +110,7 @@ public void shouldRegisterSingleServer()
BoltServerAddress address = BoltServerAddress.from( uri );

// When
Driver driver = GraphDatabase.driver( uri );
driver = GraphDatabase.driver( uri, neo4j.authToken() );

// Then
assertThat( driver, is( directDriverWithAddress( address ) ) );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.internal.spi.ConnectionProvider;
import org.neo4j.driver.internal.spi.PooledConnection;
import org.neo4j.driver.v1.AuthToken;
import org.neo4j.driver.v1.AuthTokens;
import org.neo4j.driver.v1.Config;
Expand All @@ -45,9 +46,11 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.v1.AccessMode.READ;
import static org.neo4j.driver.v1.Config.defaultConfig;

Expand All @@ -69,7 +72,7 @@ public static List<URI> uris()
@Test
public void connectionPoolClosedWhenDriverCreationFails() throws Exception
{
ConnectionPool connectionPool = mock( ConnectionPool.class );
ConnectionPool connectionPool = connectionPoolMock();
DriverFactory factory = new ThrowingDriverFactory( connectionPool );

try
Expand All @@ -87,7 +90,7 @@ public void connectionPoolClosedWhenDriverCreationFails() throws Exception
@Test
public void connectionPoolCloseExceptionIsSupressedWhenDriverCreationFails() throws Exception
{
ConnectionPool connectionPool = mock( ConnectionPool.class );
ConnectionPool connectionPool = connectionPoolMock();
RuntimeException poolCloseError = new RuntimeException( "Pool close error" );
doThrow( poolCloseError ).when( connectionPool ).close();

Expand Down Expand Up @@ -142,6 +145,13 @@ private Driver createDriver( DriverFactory driverFactory, Config config )
return driverFactory.newInstance( uri, auth, routingSettings, RetrySettings.DEFAULT, config );
}

private static ConnectionPool connectionPoolMock()
{
ConnectionPool pool = mock( ConnectionPool.class );
when( pool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( mock( PooledConnection.class ) );
return pool;
}

private static class ThrowingDriverFactory extends DriverFactory
{
final ConnectionPool connectionPool;
Expand Down Expand Up @@ -196,5 +206,11 @@ protected SessionFactory createSessionFactory( ConnectionProvider connectionProv
capturedSessionFactory = sessionFactory;
return sessionFactory;
}

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Config config )
{
return connectionPoolMock();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@
public class GraphDatabaseTest
{
@Test
public void boltSchemeShouldInstantiateDirectDriver()
public void boltSchemeShouldInstantiateDirectDriver() throws Exception
{
// Given
URI uri = URI.create( "bolt://localhost:7687" );
StubServer server = StubServer.start( "dummy_connection.script", 9001 );
URI uri = URI.create( "bolt://localhost:9001" );

// When
Driver driver = GraphDatabase.driver( uri );
Driver driver = GraphDatabase.driver( uri, INSECURE_CONFIG );

// Then
assertThat( driver, is( directDriver() ) );
server.exit();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.neo4j.driver.internal.spi.PooledConnection;
import org.neo4j.driver.internal.util.Clock;
import org.neo4j.driver.v1.AuthToken;
import org.neo4j.driver.v1.AuthTokens;
import org.neo4j.driver.v1.Config;
import org.neo4j.driver.v1.Driver;
import org.neo4j.driver.v1.Logging;
Expand Down Expand Up @@ -85,6 +84,7 @@ public void createDriver()
RetrySettings retrySettings = RetrySettings.DEFAULT;
driver = driverFactory.newInstance( neo4j.uri(), auth, routingSettings, retrySettings, defaultConfig() );
connectionPool = driverFactory.connectionPool;
connectionPool.startMemorizing(); // start memorizing connections after driver creation
}

@After
Expand Down Expand Up @@ -370,23 +370,34 @@ protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan
private static class MemorizingConnectionPool extends SocketConnectionPool
{
PooledConnection lastAcquiredConnectionSpy;
boolean memorize;

MemorizingConnectionPool( PoolSettings poolSettings, Connector connector, Clock clock, Logging logging )
{
super( poolSettings, connector, clock, logging );
}

void startMemorizing()
{
memorize = true;
}

@Override
public PooledConnection acquire( BoltServerAddress address )
{
PooledConnection connection = super.acquire( address );
// this connection pool returns spies so spies will be returned to the pool
// prevent spying on spies...
if ( !Mockito.mockingDetails( connection ).isSpy() )

if ( memorize )
{
connection = spy( connection );
// this connection pool returns spies so spies will be returned to the pool
// prevent spying on spies...
if ( !Mockito.mockingDetails( connection ).isSpy() )
{
connection = spy( connection );
}
lastAcquiredConnectionSpy = connection;
}
lastAcquiredConnectionSpy = connection;

return connection;
}
}
Expand Down
Loading