Skip to content

Commit aa218e3

Browse files
authored
Add lock to Pubsub.execute_command to ensure only one connection is created (#19)
* Add lock to Pubsub.execute_command to ensure only one connection is created * Add tests
1 parent a0ad397 commit aa218e3

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

redis/client.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,7 @@ def __init__(
14021402
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
14031403
else:
14041404
self.health_check_response = [b"pong", self.health_check_response_b]
1405+
self._connection_lock = threading.Lock()
14051406
self.reset()
14061407

14071408
def __enter__(self):
@@ -1466,12 +1467,14 @@ def execute_command(self, *args):
14661467
# subscribed to one or more channels
14671468

14681469
if self.connection is None:
1469-
self.connection = self.connection_pool.get_connection(
1470-
"pubsub", self.shard_hint
1471-
)
1472-
# register a callback that re-subscribes to any channels we
1473-
# were listening to when we were disconnected
1474-
self.connection.register_connect_callback(self.on_connect)
1470+
with self._connection_lock:
1471+
if self.connection is None:
1472+
self.connection = self.connection_pool.get_connection(
1473+
"pubsub", self.shard_hint
1474+
)
1475+
# register a callback that re-subscribes to any channels we
1476+
# were listening to when we were disconnected
1477+
self.connection.register_connect_callback(self.on_connect)
14751478
connection = self.connection
14761479
kwargs = {"check_health": not self.subscribed}
14771480
if not self.subscribed:

tests/test_pubsub.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -777,3 +777,22 @@ def get_msg():
777777

778778
# the timeout on the read should not cause disconnect
779779
assert is_connected()
780+
781+
782+
@pytest.mark.onlynoncluster
783+
class TestConnectionLeak:
784+
def test_connection_leak(self, r: redis.Redis):
785+
pubsub = r.pubsub()
786+
787+
def test():
788+
tid = threading.get_ident()
789+
pubsub.subscribe(f"foo{tid}")
790+
791+
threads = [threading.Thread(target=test) for _ in range(10)]
792+
for thread in threads:
793+
thread.start()
794+
795+
for thread in threads:
796+
thread.join()
797+
798+
assert r.connection_pool._created_connections == 2

0 commit comments

Comments
 (0)