Skip to content

Commit c9717b5

Browse files
authored
Add lock to Pubsub.execute_command to ensure only one connection is created (#19) (#22)
* Add lock to Pubsub.execute_command to ensure only one connection is created * Add tests
1 parent 2708435 commit c9717b5

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

redis/client.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,7 @@ def __init__(
13921392
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
13931393
else:
13941394
self.health_check_response = [b"pong", self.health_check_response_b]
1395+
self._connection_lock = threading.Lock()
13951396
self.reset()
13961397

13971398
def __enter__(self):
@@ -1465,12 +1466,14 @@ def execute_command(self, *args):
14651466
# subscribed to one or more channels
14661467

14671468
if self.connection is None:
1468-
self.connection = self.connection_pool.get_connection(
1469-
"pubsub", self.shard_hint
1470-
)
1471-
# register a callback that re-subscribes to any channels we
1472-
# were listening to when we were disconnected
1473-
self.connection.register_connect_callback(self.on_connect)
1469+
with self._connection_lock:
1470+
if self.connection is None:
1471+
self.connection = self.connection_pool.get_connection(
1472+
"pubsub", self.shard_hint
1473+
)
1474+
# register a callback that re-subscribes to any channels we
1475+
# were listening to when we were disconnected
1476+
self.connection.register_connect_callback(self.on_connect)
14741477
connection = self.connection
14751478
kwargs = {"check_health": not self.subscribed}
14761479
if not self.subscribed:

tests/test_pubsub.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,3 +1118,22 @@ def get_msg():
11181118

11191119
# the timeout on the read should not cause disconnect
11201120
assert is_connected()
1121+
1122+
1123+
@pytest.mark.onlynoncluster
1124+
class TestConnectionLeak:
1125+
def test_connection_leak(self, r: redis.Redis):
1126+
pubsub = r.pubsub()
1127+
1128+
def test():
1129+
tid = threading.get_ident()
1130+
pubsub.subscribe(f"foo{tid}")
1131+
1132+
threads = [threading.Thread(target=test) for _ in range(10)]
1133+
for thread in threads:
1134+
thread.start()
1135+
1136+
for thread in threads:
1137+
thread.join()
1138+
1139+
assert r.connection_pool._created_connections == 2

0 commit comments

Comments
 (0)