Skip to content

Commit b2fa3a8

Browse files
committed
Add tests for asyncio pubsub subsciription auto-reconnect
1 parent 836ac61 commit b2fa3a8

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

tests/test_asyncio/compat.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
import asyncio
2+
import sys
13
from unittest import mock
24

35
try:
46
mock.AsyncMock
57
except AttributeError:
68
import mock
9+
10+
11+
def create_task(coroutine):
12+
if sys.version_info[:2] >= (3, 7):
13+
return asyncio.create_task(coroutine)
14+
else:
15+
return asyncio.ensure_future(coroutine)

tests/test_asyncio/test_pubsub.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import functools
3+
import socket
34
from typing import Optional
45

56
import async_timeout
@@ -11,7 +12,7 @@
1112
from redis.typing import EncodableT
1213
from tests.conftest import skip_if_server_version_lt
1314

14-
from .compat import mock
15+
from .compat import create_task, mock
1516

1617

1718
def with_timeout(t):
@@ -786,3 +787,130 @@ def callback(message):
786787
"pattern": None,
787788
"type": "message",
788789
}
790+
791+
792+
# @pytest.mark.xfail
793+
@pytest.mark.parametrize("method", ["get_message", pytest.param("listen", marks=pytest.mark.xfail)])
794+
@pytest.mark.onlynoncluster
795+
class TestPubSubAutoReconnect:
796+
timeout = 2
797+
798+
async def mysetup(self, r, method):
799+
self.messages = asyncio.Queue()
800+
self.pubsub = r.pubsub()
801+
# State: 0 = initial state , 1 = after disconnect, 2 = ConnectionError is seen,
802+
# 3=successfully reconnected 4 = exit
803+
self.state = 0
804+
self.cond = asyncio.Condition()
805+
if method == "get_message":
806+
self.get_message = self.loop_step_get_message
807+
else:
808+
self.get_message = self.loop_step_listen
809+
810+
self. task = create_task(self.loop())
811+
# get the initial connect message
812+
message = await self.messages.get()
813+
assert message == {
814+
"channel": b"foo",
815+
"data": 1,
816+
"pattern": None,
817+
"type": "subscribe",
818+
}
819+
820+
async def mycleanup(self):
821+
message = await self.messages.get()
822+
assert message == {
823+
"channel": b"foo",
824+
"data": 1,
825+
"pattern": None,
826+
"type": "subscribe",
827+
}
828+
# kill thread
829+
async with self.cond:
830+
self.state = 4 # quit
831+
await self.task
832+
833+
async def test_reconnect_socket_error(self, r: redis.Redis, method):
834+
"""
835+
Test that a socket error will cause reconnect
836+
"""
837+
async with async_timeout.timeout(self.timeout):
838+
await self.mysetup(r, method)
839+
# now, disconnect the connection, and wait for it to be re-established
840+
async with self.cond:
841+
assert self.state == 0
842+
self.state = 1
843+
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
844+
mockobj.read_response.side_effect = socket.error
845+
mockobj.can_read.side_effect = socket.error
846+
# wait until task noticies the disconnect until we undo the patch
847+
await self.cond.wait_for(lambda: self.state >= 2)
848+
assert not self.pubsub.connection.is_connected
849+
# it is in a disconnecte state
850+
# wait for reconnect
851+
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
852+
assert self.state == 3
853+
854+
await self.mycleanup()
855+
856+
async def test_reconnect_disconnect(self, r: redis.Redis, method):
857+
"""
858+
Test that a manual disconnect() will cause reconnect
859+
"""
860+
async with async_timeout.timeout(self.timeout):
861+
await self.mysetup(r, method)
862+
# now, disconnect the connection, and wait for it to be re-established
863+
async with self.cond:
864+
self.state = 1
865+
await self.pubsub.connection.disconnect()
866+
assert not self.pubsub.connection.is_connected # it is in a disconnecte state
867+
# wait for reconnect
868+
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
869+
assert self.state == 3
870+
871+
await self.mycleanup()
872+
873+
async def loop(self):
874+
# reader loop, performing state transitions as it
875+
# discovers disconnects and reconnects
876+
await self.pubsub.subscribe("foo")
877+
while True:
878+
await asyncio.sleep(0.01) # give main thread chance to get lock
879+
async with self.cond:
880+
old_state = self.state
881+
try:
882+
if self.state == 4:
883+
break
884+
# print("state a ", self.state)
885+
got_msg = await self.get_message()
886+
assert got_msg
887+
if self.state in (1, 2):
888+
self.state = 3 # successful reconnect
889+
except redis.ConnectionError:
890+
assert self.state in (1, 2)
891+
self.state = 2 # signal that we noticed the disconnect
892+
finally:
893+
self.cond.notify()
894+
# make sure that we did notice the connection error
895+
# or reconnected without any error
896+
if old_state == 1:
897+
assert self.state in (2, 3)
898+
899+
async def loop_step_get_message(self):
900+
# get a single message via get_message
901+
message = await self.pubsub.get_message(timeout=0.1)
902+
# print(message)
903+
if message is not None:
904+
await self.messages.put(message)
905+
return True
906+
return False
907+
908+
async def loop_step_listen(self):
909+
# get a single message via listen()
910+
try:
911+
async with async_timeout.timeout(0.1):
912+
async for message in self.pubsub.listen():
913+
await self.messages.put(message)
914+
return True
915+
except asyncio.TimeoutError:
916+
return False

0 commit comments

Comments
 (0)