Skip to content

Added support of goodbye message before closing the driver. #529

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,9 @@ public class ConnectionPoolImpl implements ConnectionPool
private final ConcurrentMap<BoltServerAddress,ChannelPool> pools = new ConcurrentHashMap<>();
private final AtomicBoolean closed = new AtomicBoolean();

public ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, PoolSettings settings,
MetricsListener metricsListener, Logging logging, Clock clock )
public ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, PoolSettings settings, MetricsListener metricsListener, Logging logging, Clock clock )
{
this( connector, bootstrap, new NettyChannelTracker( metricsListener, logging ), settings, metricsListener, logging, clock );
this( connector, bootstrap, new NettyChannelTracker( metricsListener, bootstrap.config().group().next(), logging ), settings, metricsListener, logging, clock );
}

ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, NettyChannelTracker nettyChannelTracker,
Expand Down Expand Up @@ -153,11 +152,11 @@ public CompletionStage<Void> close()
{
try
{
nettyChannelTracker.prepareToCloseChannels();
for ( Map.Entry<BoltServerAddress,ChannelPool> entry : pools.entrySet() )
{
BoltServerAddress address = entry.getKey();
ChannelPool pool = entry.getValue();

log.info( "Closing connection pool towards %s", address );
pool.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ public class NettyChannelPool extends FixedChannelPool
private final ChannelConnector connector;
private final NettyChannelTracker handler;

public NettyChannelPool( BoltServerAddress address, ChannelConnector connector, Bootstrap bootstrap,
NettyChannelTracker handler, ChannelHealthChecker healthCheck, long acquireTimeoutMillis,
int maxConnections )
public NettyChannelPool( BoltServerAddress address, ChannelConnector connector, Bootstrap bootstrap, NettyChannelTracker handler,
ChannelHealthChecker healthCheck, long acquireTimeoutMillis, int maxConnections )
{
super( bootstrap, handler, healthCheck, AcquireTimeoutAction.FAIL, acquireTimeoutMillis, maxConnections,
MAX_PENDING_ACQUIRES, RELEASE_HEALTH_CHECK );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.channel.pool.ChannelPoolHandler;
import io.netty.util.concurrent.EventExecutor;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.internal.metrics.ListenerEvent;
import org.neo4j.driver.internal.metrics.MetricsListener;
import org.neo4j.driver.v1.Logger;
Expand All @@ -41,11 +45,18 @@ public class NettyChannelTracker implements ChannelPoolHandler
private final Logger log;
private final MetricsListener metricsListener;
private final ChannelFutureListener closeListener = future -> channelClosed( future.channel() );
private final ChannelGroup allChannels;

public NettyChannelTracker( MetricsListener metricsListener, Logging logging )
public NettyChannelTracker( MetricsListener metricsListener, EventExecutor eventExecutor, Logging logging )
{
this( metricsListener, new DefaultChannelGroup( "all-connections", eventExecutor ), logging );
}

public NettyChannelTracker( MetricsListener metricsListener, ChannelGroup channels, Logging logging )
{
this.metricsListener = metricsListener;
this.log = logging.getLog( getClass().getSimpleName() );
this.allChannels = channels;
}

@Override
Expand Down Expand Up @@ -77,6 +88,8 @@ public void channelCreated( Channel channel, ListenerEvent creatingEvent )
log.debug( "Channel %s created", channel );
incrementInUse( channel );
metricsListener.afterCreated( serverAddress( channel ), creatingEvent );

allChannels.add( channel );
}

public ListenerEvent channelCreating( BoltServerAddress address )
Expand Down Expand Up @@ -109,6 +122,24 @@ public int idleChannelCount( BoltServerAddress address )
return count == null ? 0 : count.get();
}

public void prepareToCloseChannels()
{
for ( Channel channel : allChannels )
{
BoltProtocol protocol = BoltProtocol.forChannel( channel );
try
{
protocol.prepareToCloseChannel( channel );
}
catch ( Throwable e )
{
// only logging it
log.debug( "Failed to prepare to close Channel %s due to error %s. " +
"It is safe to ignore this error as the channel will be closed despite if it is successfully prepared to close or not.", channel, e.getMessage() );
}
}
}

private void incrementInUse( Channel channel )
{
increment( channel, addressToInUseChannelCount );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ public interface BoltProtocol
*/
void initializeChannel( String userAgent, Map<String,Value> authToken, ChannelPromise channelInitializedPromise );

/**
* Prepare to close channel before it is closed.
* @param channel the channel to close.
*/
void prepareToCloseChannel( Channel channel );

/**
* Begin an explicit transaction.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public class GoodbyeMessage implements Message
{
public final static byte SIGNATURE = 0x02;

public static final Message GOODBYE = new GoodbyeMessage();
public static final GoodbyeMessage GOODBYE = new GoodbyeMessage();

private GoodbyeMessage()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ public void initializeChannel( String userAgent, Map<String,Value> authToken, Ch
channel.writeAndFlush( message, channel.voidPromise() );
}

@Override
public void prepareToCloseChannel( Channel channel )
{
// left empty on purpose.
}

@Override
public CompletionStage<Void> beginTransaction( Connection connection, Bookmarks bookmarks, TransactionConfig config )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.neo4j.driver.internal.messaging.Message;
import org.neo4j.driver.internal.messaging.MessageFormat;
import org.neo4j.driver.internal.messaging.request.BeginMessage;
import org.neo4j.driver.internal.messaging.request.GoodbyeMessage;
import org.neo4j.driver.internal.messaging.request.HelloMessage;
import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage;
import org.neo4j.driver.internal.spi.Connection;
Expand Down Expand Up @@ -84,6 +85,14 @@ public void initializeChannel( String userAgent, Map<String,Value> authToken, Ch
channel.writeAndFlush( message, channel.voidPromise() );
}

@Override
public void prepareToCloseChannel( Channel channel )
{
GoodbyeMessage message = GoodbyeMessage.GOODBYE;
messageDispatcher( channel ).enqueue( NoOpResponseHandler.INSTANCE );
channel.writeAndFlush( message, channel.voidPromise() );
}

@Override
public CompletionStage<Void> beginTransaction( Connection connection, Bookmarks bookmarks, TransactionConfig config )
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ protected InternalDriver createRoutingDriver( SecurityPlan securityPlan, BoltSer
}

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, MetricsListener metrics, Config config )
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsListener metrics, Config config )
{
return connectionPool;
}
Expand Down Expand Up @@ -255,7 +256,8 @@ protected SessionFactory createSessionFactory( ConnectionProvider connectionProv
}

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, MetricsListener metrics, Config config )
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsListener metrics, Config config )
{
return connectionPoolMock();
}
Expand All @@ -277,7 +279,8 @@ protected Bootstrap createBootstrap()
}

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, MetricsListener metrics, Config config )
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsListener metrics, Config config )
{
return connectionPoolMock();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.neo4j.driver.internal.security.InternalAuthToken;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.FakeClock;
import org.neo4j.driver.internal.util.ImmediateSchedulingEventExecutor;
import org.neo4j.driver.v1.AuthToken;
import org.neo4j.driver.v1.AuthTokens;
import org.neo4j.driver.v1.Value;
Expand Down Expand Up @@ -170,7 +171,7 @@ void shouldLimitNumberOfConcurrentConnections() throws Exception
@Test
void shouldTrackActiveChannels() throws Exception
{
NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, DEV_NULL_LOGGING );
NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, new ImmediateSchedulingEventExecutor(), DEV_NULL_LOGGING );

poolHandler = tracker;
pool = newPool( neo4j.authToken() );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,33 @@

import io.netty.channel.Channel;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.group.ChannelGroup;
import org.bouncycastle.util.Arrays;
import org.junit.jupiter.api.Test;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher;
import org.neo4j.driver.internal.messaging.request.GoodbyeMessage;
import org.neo4j.driver.internal.messaging.v3.BoltProtocolV3;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.internal.async.ChannelAttributes.setMessageDispatcher;
import static org.neo4j.driver.internal.async.ChannelAttributes.setProtocolVersion;
import static org.neo4j.driver.internal.async.ChannelAttributes.setServerAddress;
import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING;
import static org.neo4j.driver.internal.metrics.InternalAbstractMetrics.DEV_NULL_METRICS;

class NettyChannelTrackerTest
{
private final BoltServerAddress address = BoltServerAddress.LOCAL_DEFAULT;
private final NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, DEV_NULL_LOGGING );
private final NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, mock( ChannelGroup.class ), DEV_NULL_LOGGING );

@Test
void shouldIncrementInUseCountWhenChannelCreated()
Expand Down Expand Up @@ -164,10 +177,53 @@ void shouldReturnZeroActiveCountForUnknownAddress()
assertEquals( 0, tracker.inUseChannelCount( address ) );
}

@Test
void shouldAddChannelToGroupWhenChannelCreated()
{
Channel channel = newChannel();
Channel anotherChannel = newChannel();
ChannelGroup group = mock( ChannelGroup.class );
NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, group, DEV_NULL_LOGGING );

tracker.channelCreated( channel, null );
tracker.channelCreated( anotherChannel, null );

verify( group ).add( channel );
verify( group ).add( anotherChannel );
}

@Test
void shouldDelegateToProtocolPrepareToClose()
{
EmbeddedChannel channel = newChannelWithProtocolV3();
EmbeddedChannel anotherChannel = newChannelWithProtocolV3();
ChannelGroup group = mock( ChannelGroup.class );
when( group.iterator() ).thenReturn( new Arrays.Iterator<>( new Channel[]{channel, anotherChannel} ) );

NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, group, DEV_NULL_LOGGING );

tracker.prepareToCloseChannels();

assertThat( channel.outboundMessages().size(), equalTo( 1 ) );
assertThat( channel.outboundMessages(), hasItem( GoodbyeMessage.GOODBYE ) );

assertThat( anotherChannel.outboundMessages().size(), equalTo( 1 ) );
assertThat( anotherChannel.outboundMessages(), hasItem( GoodbyeMessage.GOODBYE ) );
}

private Channel newChannel()
{
EmbeddedChannel channel = new EmbeddedChannel();
setServerAddress( channel, address );
return channel;
}

private EmbeddedChannel newChannelWithProtocolV3()
{
EmbeddedChannel channel = new EmbeddedChannel();
setServerAddress( channel, address );
setProtocolVersion( channel, BoltProtocolV3.VERSION );
setMessageDispatcher( channel, mock( InboundMessageDispatcher.class ) );
return channel;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.internal.messaging.request.BeginMessage;
import org.neo4j.driver.internal.messaging.request.CommitMessage;
import org.neo4j.driver.internal.messaging.request.GoodbyeMessage;
import org.neo4j.driver.internal.messaging.request.HelloMessage;
import org.neo4j.driver.internal.messaging.request.PullAllMessage;
import org.neo4j.driver.internal.messaging.request.RollbackMessage;
Expand Down Expand Up @@ -132,6 +133,16 @@ void shouldInitializeChannel()
assertTrue( promise.isSuccess() );
}

@Test
void shouldPrepareToCloseChannel()
{
protocol.prepareToCloseChannel( channel );

assertThat( channel.outboundMessages(), hasSize( 1 ) );
assertThat( channel.outboundMessages().poll(), instanceOf( GoodbyeMessage.class ) );
assertEquals( 1, messageDispatcher.queuedHandlersCount() );
}

@Test
void shouldFailToInitializeChannelWhenErrorIsReceived()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ protected final ChannelConnector createConnector( ConnectionSettings settings, S
}

@Override
protected final ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, MetricsListener metrics,
Config config )
protected final ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsListener metrics, Config config )
{
pool = super.createConnectionPool( authToken, securityPlan, bootstrap, metrics, config );
return pool;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public class FailingConnectionDriverFactory extends DriverFactory
private final AtomicReference<Throwable> nextRunFailure = new AtomicReference<>();

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, MetricsListener metrics, Config config )
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsListener metrics, Config config )
{
ConnectionPool pool = super.createConnectionPool( authToken, securityPlan, bootstrap, metrics, config );
return new ConnectionPoolWithFailingConnections( pool, nextRunFailure );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ private static class DriverFactoryWithConnectionPool extends DriverFactory
MemorizingConnectionPool connectionPool;

@Override
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, MetricsListener metrics, Config config )
protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap,
MetricsListener metrics, Config config )
{
ConnectionSettings connectionSettings = new ConnectionSettings( authToken, 1000 );
PoolSettings poolSettings = new PoolSettings( config.maxConnectionPoolSize(),
Expand Down
Loading