Skip to content

Commit b3e7893

Browse files
committed
Remove process-id checks from asyncio. Asyncio and fork() does not mix.
1 parent 19b55c6 commit b3e7893

File tree

2 files changed

+6
-120
lines changed

2 files changed

+6
-120
lines changed

redis/asyncio/connection.py

Lines changed: 6 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
import copy
33
import enum
44
import inspect
5-
import os
65
import socket
76
import ssl
87
import sys
9-
import threading
108
import weakref
119
from abc import abstractmethod
1210
from itertools import chain
@@ -41,7 +39,6 @@
4139
from redis.exceptions import (
4240
AuthenticationError,
4341
AuthenticationWrongNumberOfArgsError,
44-
ChildDeadlockedError,
4542
ConnectionError,
4643
DataError,
4744
RedisError,
@@ -97,7 +94,6 @@ class AbstractConnection:
9794
"""Manages communication to and from a Redis server"""
9895

9996
__slots__ = (
100-
"pid",
10197
"db",
10298
"username",
10399
"client_name",
@@ -158,7 +154,6 @@ def __init__(
158154
"1. 'password' and (optional) 'username'\n"
159155
"2. 'credential_provider'"
160156
)
161-
self.pid = os.getpid()
162157
self.db = db
163158
self.client_name = client_name
164159
self.lib_name = lib_name
@@ -381,12 +376,11 @@ async def disconnect(self, nowait: bool = False) -> None:
381376
if not self.is_connected:
382377
return
383378
try:
384-
if os.getpid() == self.pid:
385-
self._writer.close() # type: ignore[union-attr]
386-
# wait for close to finish, except when handling errors and
387-
# forcefully disconnecting.
388-
if not nowait:
389-
await self._writer.wait_closed() # type: ignore[union-attr]
379+
self._writer.close() # type: ignore[union-attr]
380+
# wait for close to finish, except when handling errors and
381+
# forcefully disconnecting.
382+
if not nowait:
383+
await self._writer.wait_closed() # type: ignore[union-attr]
390384
except OSError:
391385
pass
392386
finally:
@@ -1004,15 +998,6 @@ def __init__(
1004998
self.connection_kwargs = connection_kwargs
1005999
self.max_connections = max_connections
10061000

1007-
# a lock to protect the critical section in _checkpid().
1008-
# this lock is acquired when the process id changes, such as
1009-
# after a fork. during this time, multiple threads in the child
1010-
# process could attempt to acquire this lock. the first thread
1011-
# to acquire the lock will reset the data structures and lock
1012-
# object of this pool. subsequent threads acquiring this lock
1013-
# will notice the first thread already did the work and simply
1014-
# release the lock.
1015-
self._fork_lock = threading.Lock()
10161001
self._lock = asyncio.Lock()
10171002
self._created_connections: int
10181003
self._available_connections: List[AbstractConnection]
@@ -1032,67 +1017,8 @@ def reset(self):
10321017
self._available_connections = []
10331018
self._in_use_connections = set()
10341019

1035-
# this must be the last operation in this method. while reset() is
1036-
# called when holding _fork_lock, other threads in this process
1037-
# can call _checkpid() which compares self.pid and os.getpid() without
1038-
# holding any lock (for performance reasons). keeping this assignment
1039-
# as the last operation ensures that those other threads will also
1040-
# notice a pid difference and block waiting for the first thread to
1041-
# release _fork_lock. when each of these threads eventually acquire
1042-
# _fork_lock, they will notice that another thread already called
1043-
# reset() and they will immediately release _fork_lock and continue on.
1044-
self.pid = os.getpid()
1045-
1046-
def _checkpid(self):
1047-
# _checkpid() attempts to keep ConnectionPool fork-safe on modern
1048-
# systems. this is called by all ConnectionPool methods that
1049-
# manipulate the pool's state such as get_connection() and release().
1050-
#
1051-
# _checkpid() determines whether the process has forked by comparing
1052-
# the current process id to the process id saved on the ConnectionPool
1053-
# instance. if these values are the same, _checkpid() simply returns.
1054-
#
1055-
# when the process ids differ, _checkpid() assumes that the process
1056-
# has forked and that we're now running in the child process. the child
1057-
# process cannot use the parent's file descriptors (e.g., sockets).
1058-
# therefore, when _checkpid() sees the process id change, it calls
1059-
# reset() in order to reinitialize the child's ConnectionPool. this
1060-
# will cause the child to make all new connection objects.
1061-
#
1062-
# _checkpid() is protected by self._fork_lock to ensure that multiple
1063-
# threads in the child process do not call reset() multiple times.
1064-
#
1065-
# there is an extremely small chance this could fail in the following
1066-
# scenario:
1067-
# 1. process A calls _checkpid() for the first time and acquires
1068-
# self._fork_lock.
1069-
# 2. while holding self._fork_lock, process A forks (the fork()
1070-
# could happen in a different thread owned by process A)
1071-
# 3. process B (the forked child process) inherits the
1072-
# ConnectionPool's state from the parent. that state includes
1073-
# a locked _fork_lock. process B will not be notified when
1074-
# process A releases the _fork_lock and will thus never be
1075-
# able to acquire the _fork_lock.
1076-
#
1077-
# to mitigate this possible deadlock, _checkpid() will only wait 5
1078-
# seconds to acquire _fork_lock. if _fork_lock cannot be acquired in
1079-
# that time it is assumed that the child is deadlocked and a
1080-
# redis.ChildDeadlockedError error is raised.
1081-
if self.pid != os.getpid():
1082-
acquired = self._fork_lock.acquire(timeout=5)
1083-
if not acquired:
1084-
raise ChildDeadlockedError
1085-
# reset() the instance for the new process if another thread
1086-
# hasn't already done so
1087-
try:
1088-
if self.pid != os.getpid():
1089-
self.reset()
1090-
finally:
1091-
self._fork_lock.release()
1092-
10931020
async def get_connection(self, command_name, *keys, **options):
10941021
"""Get a connection from the pool"""
1095-
self._checkpid()
10961022
async with self._lock:
10971023
try:
10981024
connection = self._available_connections.pop()
@@ -1141,7 +1067,6 @@ def make_connection(self):
11411067

11421068
async def release(self, connection: AbstractConnection):
11431069
"""Releases the connection back to the pool"""
1144-
self._checkpid()
11451070
async with self._lock:
11461071
try:
11471072
self._in_use_connections.remove(connection)
@@ -1150,18 +1075,7 @@ async def release(self, connection: AbstractConnection):
11501075
# that the pool doesn't actually own
11511076
pass
11521077

1153-
if self.owns_connection(connection):
1154-
self._available_connections.append(connection)
1155-
else:
1156-
# pool doesn't own this connection. do not add it back
1157-
# to the pool and decrement the count so that another
1158-
# connection can take its place if needed
1159-
self._created_connections -= 1
1160-
await connection.disconnect()
1161-
return
1162-
1163-
def owns_connection(self, connection: AbstractConnection):
1164-
return connection.pid == self.pid
1078+
self._available_connections.append(connection)
11651079

11661080
async def disconnect(self, inuse_connections: bool = True):
11671081
"""
@@ -1171,7 +1085,6 @@ async def disconnect(self, inuse_connections: bool = True):
11711085
current in use, potentially by other tasks. Otherwise only disconnect
11721086
connections that are idle in the pool.
11731087
"""
1174-
self._checkpid()
11751088
async with self._lock:
11761089
if inuse_connections:
11771090
connections: Iterable[AbstractConnection] = chain(
@@ -1259,17 +1172,6 @@ def reset(self):
12591172
# disconnect them later.
12601173
self._connections = []
12611174

1262-
# this must be the last operation in this method. while reset() is
1263-
# called when holding _fork_lock, other threads in this process
1264-
# can call _checkpid() which compares self.pid and os.getpid() without
1265-
# holding any lock (for performance reasons). keeping this assignment
1266-
# as the last operation ensures that those other threads will also
1267-
# notice a pid difference and block waiting for the first thread to
1268-
# release _fork_lock. when each of these threads eventually acquire
1269-
# _fork_lock, they will notice that another thread already called
1270-
# reset() and they will immediately release _fork_lock and continue on.
1271-
self.pid = os.getpid()
1272-
12731175
def make_connection(self):
12741176
"""Make a fresh connection."""
12751177
connection = self.connection_class(**self.connection_kwargs)
@@ -1288,8 +1190,6 @@ async def get_connection(self, command_name, *keys, **options):
12881190
create new connections when we need to, i.e.: the actual number of
12891191
connections will only increase in response to demand.
12901192
"""
1291-
# Make sure we haven't changed process.
1292-
self._checkpid()
12931193

12941194
# Try and get a connection from the pool. If one isn't available within
12951195
# self.timeout then raise a ``ConnectionError``.
@@ -1331,17 +1231,6 @@ async def get_connection(self, command_name, *keys, **options):
13311231

13321232
async def release(self, connection: AbstractConnection):
13331233
"""Releases the connection back to the pool."""
1334-
# Make sure we haven't changed process.
1335-
self._checkpid()
1336-
if not self.owns_connection(connection):
1337-
# pool doesn't own this connection. do not add it back
1338-
# to the pool. instead add a None value which is a placeholder
1339-
# that will cause the pool to recreate the connection if
1340-
# its needed.
1341-
await connection.disconnect()
1342-
self.pool.put_nowait(None)
1343-
return
1344-
13451234
# Put the connection back into the pool.
13461235
try:
13471236
self.pool.put_nowait(connection)
@@ -1352,7 +1241,6 @@ async def release(self, connection: AbstractConnection):
13521241

13531242
async def disconnect(self, inuse_connections: bool = True):
13541243
"""Disconnects all connections in the pool."""
1355-
self._checkpid()
13561244
async with self._lock:
13571245
resp = await asyncio.gather(
13581246
*(connection.disconnect() for connection in self._connections),

tests/test_asyncio/test_connection_pool.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32
import re
43

54
import pytest
@@ -94,7 +93,6 @@ class DummyConnection(Connection):
9493

9594
def __init__(self, **kwargs):
9695
self.kwargs = kwargs
97-
self.pid = os.getpid()
9896

9997
async def connect(self):
10098
pass

0 commit comments

Comments
 (0)