Skip to content

PYTHON-2462 Avoid connection storms: implement pool PAUSED state #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Changelog
Changes in Version 4.0
----------------------

Breaking Changes in 4.0
```````````````````````

- Removed :mod:`~pymongo.thread_util`.

Issues Resolved
...............

Expand Down
3 changes: 3 additions & 0 deletions pymongo/event_loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class ConnectionPoolLogger(monitoring.ConnectionPoolListener):
def pool_created(self, event):
logging.info("[pool {0.address}] pool created".format(event))

def pool_ready(self, event):
logging.info("[pool {0.address}] pool ready".format(event))

def pool_cleared(self, event):
logging.info("[pool {0.address}] pool cleared".format(event))

Expand Down
2 changes: 1 addition & 1 deletion pymongo/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def target():

executor = periodic_executor.PeriodicExecutor(
interval=common.KILL_CURSOR_FREQUENCY,
min_interval=0.5,
min_interval=common.MIN_HEARTBEAT_INTERVAL,
target=target,
name="pymongo_kill_cursors_thread")

Expand Down
34 changes: 34 additions & 0 deletions pymongo/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,18 @@ def pool_created(self, event):
"""
raise NotImplementedError

def pool_ready(self, event):
"""Abstract method to handle a :class:`PoolReadyEvent`.

Emitted when a Connection Pool is marked ready.

:Parameters:
- `event`: An instance of :class:`PoolReadyEvent`.

.. versionadded:: 4.0
"""
raise NotImplementedError

def pool_cleared(self, event):
"""Abstract method to handle a `PoolClearedEvent`.

Expand Down Expand Up @@ -692,6 +704,18 @@ def __repr__(self):
self.__class__.__name__, self.address, self.__options)


class PoolReadyEvent(_PoolEvent):
"""Published when a Connection Pool is marked ready.

:Parameters:
- `address`: The address (host, port) pair of the server this Pool is
attempting to connect to.

.. versionadded:: 4.0
"""
__slots__ = ()


class PoolClearedEvent(_PoolEvent):
"""Published when a Connection Pool is cleared.

Expand Down Expand Up @@ -1475,6 +1499,16 @@ def publish_pool_created(self, address, options):
except Exception:
_handle_exception()

def publish_pool_ready(self, address):
"""Publish a :class:`PoolReadyEvent` to all pool listeners.
"""
event = PoolReadyEvent(address)
for subscriber in self.__cmap_listeners:
try:
subscriber.pool_ready(event)
except Exception:
_handle_exception()

def publish_pool_cleared(self, address):
"""Publish a :class:`PoolClearedEvent` to all pool listeners.
"""
Expand Down
135 changes: 107 additions & 28 deletions pymongo/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from bson import DEFAULT_CODEC_OPTIONS
from bson.py3compat import imap, itervalues, _unicode, PY3
from bson.son import SON
from pymongo import auth, helpers, thread_util, __version__
from pymongo import auth, helpers, __version__
from pymongo.client_session import _validate_session_write_concern
from pymongo.common import (MAX_BSON_SIZE,
MAX_CONNECTING,
Expand All @@ -46,6 +46,7 @@
CertificateError,
ConnectionFailure,
ConfigurationError,
ExceededMaxWaiters,
InvalidOperation,
DocumentTooLarge,
NetworkTimeout,
Expand Down Expand Up @@ -309,7 +310,8 @@ class PoolOptions(object):
'__wait_queue_timeout', '__wait_queue_multiple',
'__ssl_context', '__ssl_match_hostname', '__socket_keepalive',
'__event_listeners', '__appname', '__driver', '__metadata',
'__compression_settings', '__max_connecting')
'__compression_settings', '__max_connecting',
'__pause_enabled')

def __init__(self, max_pool_size=MAX_POOL_SIZE,
min_pool_size=MIN_POOL_SIZE,
Expand All @@ -318,7 +320,8 @@ def __init__(self, max_pool_size=MAX_POOL_SIZE,
wait_queue_multiple=None, ssl_context=None,
ssl_match_hostname=True, socket_keepalive=True,
event_listeners=None, appname=None, driver=None,
compression_settings=None, max_connecting=MAX_CONNECTING):
compression_settings=None, max_connecting=MAX_CONNECTING,
pause_enabled=True):

self.__max_pool_size = max_pool_size
self.__min_pool_size = min_pool_size
Expand All @@ -335,6 +338,7 @@ def __init__(self, max_pool_size=MAX_POOL_SIZE,
self.__driver = driver
self.__compression_settings = compression_settings
self.__max_connecting = max_connecting
self.__pause_enabled = pause_enabled
self.__metadata = copy.deepcopy(_METADATA)
if appname:
self.__metadata['application'] = {'name': appname}
Expand Down Expand Up @@ -406,6 +410,10 @@ def max_connecting(self):
"""
return self.__max_connecting

@property
def pause_enabled(self):
return self.__pause_enabled

@property
def max_idle_time_seconds(self):
"""The maximum number of seconds that a connection can remain
Expand Down Expand Up @@ -1058,6 +1066,12 @@ class _PoolClosedError(PyMongoError):
pass


class PoolState(object):
PAUSED = 1
READY = 2
CLOSED = 3


# Do *not* explicitly inherit from object or Jython won't call __del__
# http://bugs.jython.org/issue1057
class Pool:
Expand All @@ -1068,6 +1082,10 @@ def __init__(self, address, options, handshake=True):
- `options`: a PoolOptions instance
- `handshake`: whether to call ismaster for each new SocketInfo
"""
if options.pause_enabled:
self.state = PoolState.PAUSED
else:
self.state = PoolState.READY
# Check a socket's health with socket_closed() every once in a while.
# Can override for testing: 0 to always check, None to never check.
self._check_interval_seconds = 1
Expand All @@ -1079,7 +1097,6 @@ def __init__(self, address, options, handshake=True):
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
self.closed = False
# Track whether the sockets in this pool are writeable or not.
self.is_writable = None

Expand All @@ -1098,13 +1115,23 @@ def __init__(self, address, options, handshake=True):

if (self.opts.wait_queue_multiple is None or
self.opts.max_pool_size is None):
max_waiters = None
max_waiters = float('inf')
else:
max_waiters = (
self.opts.max_pool_size * self.opts.wait_queue_multiple)

self._socket_semaphore = thread_util.create_semaphore(
self.opts.max_pool_size, max_waiters)
# The first portion of the wait queue.
# Enforces: maxPoolSize and waitQueueMultiple
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(self.lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if self.max_pool_size is None:
self.max_pool_size = float('inf')
self.waiters = 0
self.max_waiters = max_waiters
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(self.lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
Expand All @@ -1114,10 +1141,23 @@ def __init__(self, address, options, handshake=True):
# Similar to active_sockets but includes threads in the wait queue.
self.operation_count = 0

def _reset(self, close):
with self.lock:
def ready(self):
old_state, self.state = self.state, PoolState.READY
if old_state != PoolState.READY:
if self.enabled_for_cmap:
self.opts.event_listeners.publish_pool_ready(self.address)

@property
def closed(self):
return self.state == PoolState.CLOSED

def _reset(self, close, pause=True):
old_state = self.state
with self.size_cond:
if self.closed:
return
if self.opts.pause_enabled and pause:
old_state, self.state = self.state, PoolState.PAUSED
self.generation += 1
newpid = os.getpid()
if self.pid != newpid:
Expand All @@ -1126,7 +1166,10 @@ def _reset(self, close):
self.operation_count = 0
sockets, self.sockets = self.sockets, collections.deque()
if close:
self.closed = True
self.state = PoolState.CLOSED
# Clear the wait queue
self._max_connecting_cond.notify_all()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when this happens and there are threads waiting to create a connection, they will get notified (via either size_cond or max_connecting_cond) and end up raising an exception since the pool is not ready?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly that.

self.size_cond.notify_all()

listeners = self.opts.event_listeners
# CMAP spec says that close() MUST close sockets before publishing the
Expand All @@ -1138,7 +1181,7 @@ def _reset(self, close):
if self.enabled_for_cmap:
listeners.publish_pool_closed(self.address)
else:
if self.enabled_for_cmap:
if old_state != PoolState.PAUSED and self.enabled_for_cmap:
listeners.publish_pool_cleared(self.address)
for sock_info in sockets:
sock_info.close_socket(ConnectionClosedReason.STALE)
Expand All @@ -1155,6 +1198,9 @@ def update_is_writable(self, is_writable):
def reset(self):
self._reset(close=False)

def reset_without_pause(self):
self._reset(close=False, pause=False)

def close(self):
self._reset(close=True)

Expand All @@ -1164,6 +1210,9 @@ def remove_stale_sockets(self, reference_generation, all_credentials):
`generation` at the point in time this operation was requested on the
pool.
"""
if self.state != PoolState.READY:
return

if self.opts.max_idle_time_seconds is not None:
with self.lock:
while (self.sockets and
Expand All @@ -1172,15 +1221,14 @@ def remove_stale_sockets(self, reference_generation, all_credentials):
sock_info.close_socket(ConnectionClosedReason.IDLE)

while True:
with self.lock:
with self.size_cond:
# There are enough sockets in the pool.
if (len(self.sockets) + self.active_sockets >=
self.opts.min_pool_size):
# There are enough sockets in the pool.
return

# We must acquire the semaphore to respect max_pool_size.
if not self._socket_semaphore.acquire(False):
return
if self.requests >= self.opts.min_pool_size:
return
self.requests += 1
incremented = False
try:
with self._max_connecting_cond:
Expand All @@ -1204,7 +1252,10 @@ def remove_stale_sockets(self, reference_generation, all_credentials):
with self._max_connecting_cond:
self._pending -= 1
self._max_connecting_cond.notify()
self._socket_semaphore.release()

with self.size_cond:
self.requests -= 1
self.size_cond.notify()

def connect(self, all_credentials=None):
"""Connect to Mongo and return a new SocketInfo.
Expand Down Expand Up @@ -1289,6 +1340,14 @@ def get_socket(self, all_credentials, checkout=False):
if not checkout:
self.return_socket(sock_info)

def _raise_if_not_ready(self, emit_event):
if self.state != PoolState.READY:
if self.enabled_for_cmap and emit_event:
self.opts.event_listeners.publish_connection_check_out_failed(
self.address, ConnectionCheckOutFailedReason.CONN_ERROR)
_raise_connection_failure(
self.address, AutoReconnect('connection pool paused'))

def _get_socket(self, all_credentials):
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
# We use the pid here to avoid issues with fork / multiprocessing.
Expand All @@ -1313,9 +1372,26 @@ def _get_socket(self, all_credentials):
deadline = _time() + self.opts.wait_queue_timeout
else:
deadline = None
if not self._socket_semaphore.acquire(
True, self.opts.wait_queue_timeout):
self._raise_wait_queue_timeout()

with self.size_cond:
self._raise_if_not_ready(emit_event=True)
if self.waiters >= self.max_waiters:
raise ExceededMaxWaiters(
'exceeded max waiters: %s threads already waiting' % (
self.waiters))
self.waiters += 1
try:
while not (self.requests < self.max_pool_size):
if not _cond_wait(self.size_cond, deadline):
# Timed out, notify the next thread to ensure a
# timeout doesn't consume the condition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow the logic here. Under what circumstance would we enter this next if block? The check seems to directly contradict the entry condition of the while above. Also, if the number of requests did drop below max_pool_size while we were waiting to acquire the lock, shouldn't the condition variable in a different variable be notified by whatever caused the reduction in requests (e.g. checking a connection back in)? Why are we manually notifying in this case?
I know this is similar to how we notify the _max_connecting_cond condition, but it would be great if you could explain what is going on here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment attempts to explain it. This extra logic fixes the following race:

  1. Thread 1 and 2 are waiting on this condition variable.
  2. Thread 3 completes and calls self.size_cond.notfiy() which wakes Thread 1.
  3. Thread 1 wakes up and realizes that it's timeout has expired and raises WaitQueueTimeout.
  4. Thread 1 has "consumed" the notification but did not proceed to run an operation.
  5. Thread 2 is stuck waiting for the next notification or timeout.

This change fixes the bug by notifying the next thread in the wait queue before raising a timeout.

if self.requests < self.max_pool_size:
self.size_cond.notify()
self._raise_wait_queue_timeout()
self._raise_if_not_ready(emit_event=True)
finally:
self.waiters -= 1
self.requests += 1

# We've now acquired the semaphore and must release it on error.
sock_info = None
Expand All @@ -1330,6 +1406,7 @@ def _get_socket(self, all_credentials):
# CMAP: we MUST wait for either maxConnecting OR for a socket
# to be checked back into the pool.
with self._max_connecting_cond:
self._raise_if_not_ready(emit_event=False)
while not (self.sockets or
self._pending < self._max_connecting):
if not _cond_wait(self._max_connecting_cond, deadline):
Expand All @@ -1340,6 +1417,7 @@ def _get_socket(self, all_credentials):
self._max_connecting_cond.notify()
emitted_event = True
self._raise_wait_queue_timeout()
self._raise_if_not_ready(emit_event=False)

try:
sock_info = self.sockets.popleft()
Expand All @@ -1361,11 +1439,11 @@ def _get_socket(self, all_credentials):
if sock_info:
# We checked out a socket but authentication failed.
sock_info.close_socket(ConnectionClosedReason.ERROR)
self._socket_semaphore.release()

if incremented:
with self.lock:
with self.size_cond:
self.requests -= 1
if incremented:
self.active_sockets -= 1
self.size_cond.notify()

if self.enabled_for_cmap and not emitted_event:
self.opts.event_listeners.publish_connection_check_out_failed(
Expand Down Expand Up @@ -1401,10 +1479,11 @@ def return_socket(self, sock_info):
# Notify any threads waiting to create a connection.
self._max_connecting_cond.notify()

self._socket_semaphore.release()
with self.lock:
with self.size_cond:
self.requests -= 1
self.active_sockets -= 1
self.operation_count -= 1
self.size_cond.notify()

def _perished(self, sock_info):
"""Return True and close the connection if it is "perished".
Expand Down
Loading