Skip to content

Commit 836ac61

Browse files
committed
Add unittests verifying that (non-async) PubSub will automatically reconnect
1 parent 7989d1e commit 836ac61

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed

tests/test_pubsub.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import platform
2+
import queue
3+
import socket
24
import threading
35
import time
46
from unittest import mock
@@ -608,3 +610,129 @@ def test_pubsub_deadlock(self, master_host):
608610
p = r.pubsub()
609611
p.subscribe("my-channel-1", "my-channel-2")
610612
pool.reset()
613+
614+
615+
@pytest.mark.timeout(5, method="thread")
616+
@pytest.mark.parametrize("method", ["get_message", pytest.param("listen", marks=pytest.mark.xfail)])
617+
@pytest.mark.onlynoncluster
618+
class TestPubSubAutoReconnect:
619+
def mysetup(self, r, method):
620+
self.messages = queue.Queue()
621+
self.pubsub = r.pubsub()
622+
self.state = 0
623+
self.cond = threading.Condition()
624+
if method == "get_message":
625+
self.get_message = self.loop_step_get_message
626+
else:
627+
self.get_message = self.loop_step_listen
628+
629+
self.thread = threading.Thread(target=self.loop)
630+
self.thread.daemon = True
631+
self.thread.start()
632+
# get the initial connect message
633+
message = self.messages.get(timeout=1)
634+
assert message == {
635+
"channel": b"foo",
636+
"data": 1,
637+
"pattern": None,
638+
"type": "subscribe",
639+
}
640+
641+
def wait_for_reconnect(self):
642+
self.cond.wait_for(
643+
lambda: self.pubsub.connection._sock is not None, timeout=2)
644+
assert self.pubsub.connection._sock is not None # we didn't time out
645+
assert self.state == 3
646+
647+
message = self.messages.get(timeout=1)
648+
assert message == {
649+
"channel": b"foo",
650+
"data": 1,
651+
"pattern": None,
652+
"type": "subscribe",
653+
}
654+
655+
def mycleanup(self):
656+
# kill thread
657+
with self.cond:
658+
self.state = 4 # quit
659+
self.cond.notify()
660+
self.thread.join()
661+
662+
def test_reconnect_socket_error(self, r: redis.Redis, method):
663+
"""
664+
Test that a socket error will cause reconnect
665+
"""
666+
self.mysetup(r, method)
667+
try:
668+
# now, disconnect the connection, and wait for it to be re-established
669+
with self.cond:
670+
self.state = 1
671+
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
672+
mockobj.read_response.side_effect = socket.error
673+
mockobj.can_read.side_effect = socket.error
674+
# wait until thread notices the disconnect until we undo the patch
675+
self.cond.wait_for(lambda: self.state >= 2)
676+
assert (
677+
self.pubsub.connection._sock is None
678+
) # it is in a disconnected state
679+
self.wait_for_reconnect()
680+
681+
finally:
682+
self.mycleanup()
683+
684+
def test_reconnect_disconnect(self, r: redis.Redis, method):
685+
"""
686+
Test that a manual disconnect() will cause reconnect
687+
"""
688+
self.mysetup(r, method)
689+
try:
690+
# now, disconnect the connection, and wait for it to be re-established
691+
with self.cond:
692+
self.state = 1
693+
self.pubsub.connection.disconnect()
694+
assert self.pubsub.connection._sock is None # it is in a disconnected state
695+
# wait for reconnect
696+
self.wait_for_reconnect()
697+
finally:
698+
self.mycleanup()
699+
700+
def loop(self):
701+
# reader loop, performing state transitions as it
702+
# discovers disconnects and reconnects
703+
self.pubsub.subscribe("foo")
704+
while True:
705+
time.sleep(0.01) # give main thread chance to get lock
706+
with self.cond:
707+
old_state = self.state
708+
try:
709+
if self.state == 4:
710+
break
711+
# print ('state, %s, sock %s' % (state, pubsub.connection._sock))
712+
got_msg = self.get_message()
713+
assert got_msg
714+
if self.state in (1, 2):
715+
self.state = 3 # successful reconnect
716+
except redis.ConnectionError:
717+
assert self.state in (1, 2)
718+
self.state = 2
719+
finally:
720+
self.cond.notify()
721+
# assert that we noticed a connect error, or automatically
722+
# reconnected without error
723+
if old_state == 1:
724+
assert self.state in (2, 3)
725+
726+
def loop_step_get_message(self):
727+
# get a single message via listen()
728+
message = self.pubsub.get_message(timeout=0.1)
729+
if message is not None:
730+
self.messages.put(message)
731+
return True
732+
return False
733+
734+
def loop_step_listen(self):
735+
# get a single message via listen()
736+
for message in self.pubsub.listen():
737+
self.messages.put(message)
738+
return True

0 commit comments

Comments
 (0)