|
1 | 1 | import asyncio
|
2 | 2 | import functools
|
| 3 | +import socket |
3 | 4 | from typing import Optional
|
4 | 5 |
|
5 | 6 | import async_timeout
|
|
11 | 12 | from redis.typing import EncodableT
|
12 | 13 | from tests.conftest import skip_if_server_version_lt
|
13 | 14 |
|
14 |
| -from .compat import mock |
| 15 | +from .compat import create_task, mock |
15 | 16 |
|
16 | 17 |
|
17 | 18 | def with_timeout(t):
|
@@ -786,3 +787,130 @@ def callback(message):
|
786 | 787 | "pattern": None,
|
787 | 788 | "type": "message",
|
788 | 789 | }
|
| 790 | + |
| 791 | + |
| 792 | +# @pytest.mark.xfail |
| 793 | +@pytest.mark.parametrize("method", ["get_message", pytest.param("listen", marks=pytest.mark.xfail)]) |
| 794 | +@pytest.mark.onlynoncluster |
| 795 | +class TestPubSubAutoReconnect: |
| 796 | + timeout = 2 |
| 797 | + |
| 798 | + async def mysetup(self, r, method): |
| 799 | + self.messages = asyncio.Queue() |
| 800 | + self.pubsub = r.pubsub() |
| 801 | + # State: 0 = initial state , 1 = after disconnect, 2 = ConnectionError is seen, |
| 802 | + # 3=successfully reconnected 4 = exit |
| 803 | + self.state = 0 |
| 804 | + self.cond = asyncio.Condition() |
| 805 | + if method == "get_message": |
| 806 | + self.get_message = self.loop_step_get_message |
| 807 | + else: |
| 808 | + self.get_message = self.loop_step_listen |
| 809 | + |
| 810 | + self. task = create_task(self.loop()) |
| 811 | + # get the initial connect message |
| 812 | + message = await self.messages.get() |
| 813 | + assert message == { |
| 814 | + "channel": b"foo", |
| 815 | + "data": 1, |
| 816 | + "pattern": None, |
| 817 | + "type": "subscribe", |
| 818 | + } |
| 819 | + |
| 820 | + async def mycleanup(self): |
| 821 | + message = await self.messages.get() |
| 822 | + assert message == { |
| 823 | + "channel": b"foo", |
| 824 | + "data": 1, |
| 825 | + "pattern": None, |
| 826 | + "type": "subscribe", |
| 827 | + } |
| 828 | + # kill thread |
| 829 | + async with self.cond: |
| 830 | + self.state = 4 # quit |
| 831 | + await self.task |
| 832 | + |
| 833 | + async def test_reconnect_socket_error(self, r: redis.Redis, method): |
| 834 | + """ |
| 835 | + Test that a socket error will cause reconnect |
| 836 | + """ |
| 837 | + async with async_timeout.timeout(self.timeout): |
| 838 | + await self.mysetup(r, method) |
| 839 | + # now, disconnect the connection, and wait for it to be re-established |
| 840 | + async with self.cond: |
| 841 | + assert self.state == 0 |
| 842 | + self.state = 1 |
| 843 | + with mock.patch.object(self.pubsub.connection, "_parser") as mockobj: |
| 844 | + mockobj.read_response.side_effect = socket.error |
| 845 | + mockobj.can_read.side_effect = socket.error |
| 846 | + # wait until task noticies the disconnect until we undo the patch |
| 847 | + await self.cond.wait_for(lambda: self.state >= 2) |
| 848 | + assert not self.pubsub.connection.is_connected |
| 849 | + # it is in a disconnecte state |
| 850 | + # wait for reconnect |
| 851 | + await self.cond.wait_for(lambda: self.pubsub.connection.is_connected) |
| 852 | + assert self.state == 3 |
| 853 | + |
| 854 | + await self.mycleanup() |
| 855 | + |
| 856 | + async def test_reconnect_disconnect(self, r: redis.Redis, method): |
| 857 | + """ |
| 858 | + Test that a manual disconnect() will cause reconnect |
| 859 | + """ |
| 860 | + async with async_timeout.timeout(self.timeout): |
| 861 | + await self.mysetup(r, method) |
| 862 | + # now, disconnect the connection, and wait for it to be re-established |
| 863 | + async with self.cond: |
| 864 | + self.state = 1 |
| 865 | + await self.pubsub.connection.disconnect() |
| 866 | + assert not self.pubsub.connection.is_connected # it is in a disconnecte state |
| 867 | + # wait for reconnect |
| 868 | + await self.cond.wait_for(lambda: self.pubsub.connection.is_connected) |
| 869 | + assert self.state == 3 |
| 870 | + |
| 871 | + await self.mycleanup() |
| 872 | + |
| 873 | + async def loop(self): |
| 874 | + # reader loop, performing state transitions as it |
| 875 | + # discovers disconnects and reconnects |
| 876 | + await self.pubsub.subscribe("foo") |
| 877 | + while True: |
| 878 | + await asyncio.sleep(0.01) # give main thread chance to get lock |
| 879 | + async with self.cond: |
| 880 | + old_state = self.state |
| 881 | + try: |
| 882 | + if self.state == 4: |
| 883 | + break |
| 884 | + # print("state a ", self.state) |
| 885 | + got_msg = await self.get_message() |
| 886 | + assert got_msg |
| 887 | + if self.state in (1, 2): |
| 888 | + self.state = 3 # successful reconnect |
| 889 | + except redis.ConnectionError: |
| 890 | + assert self.state in (1, 2) |
| 891 | + self.state = 2 # signal that we noticed the disconnect |
| 892 | + finally: |
| 893 | + self.cond.notify() |
| 894 | + # make sure that we did notice the connection error |
| 895 | + # or reconnected without any error |
| 896 | + if old_state == 1: |
| 897 | + assert self.state in (2, 3) |
| 898 | + |
| 899 | + async def loop_step_get_message(self): |
| 900 | + # get a single message via get_message |
| 901 | + message = await self.pubsub.get_message(timeout=0.1) |
| 902 | + # print(message) |
| 903 | + if message is not None: |
| 904 | + await self.messages.put(message) |
| 905 | + return True |
| 906 | + return False |
| 907 | + |
| 908 | + async def loop_step_listen(self): |
| 909 | + # get a single message via listen() |
| 910 | + try: |
| 911 | + async with async_timeout.timeout(0.1): |
| 912 | + async for message in self.pubsub.listen(): |
| 913 | + await self.messages.put(message) |
| 914 | + return True |
| 915 | + except asyncio.TimeoutError: |
| 916 | + return False |
0 commit comments