|
1 | 1 | import platform
|
| 2 | +import queue |
| 3 | +import socket |
2 | 4 | import threading
|
3 | 5 | import time
|
4 | 6 | from unittest import mock
|
@@ -608,3 +610,129 @@ def test_pubsub_deadlock(self, master_host):
|
608 | 610 | p = r.pubsub()
|
609 | 611 | p.subscribe("my-channel-1", "my-channel-2")
|
610 | 612 | pool.reset()
|
| 613 | + |
| 614 | + |
| 615 | +@pytest.mark.timeout(5, method="thread") |
| 616 | +@pytest.mark.parametrize("method", ["get_message", pytest.param("listen", marks=pytest.mark.xfail)]) |
| 617 | +@pytest.mark.onlynoncluster |
| 618 | +class TestPubSubAutoReconnect: |
| 619 | + def mysetup(self, r, method): |
| 620 | + self.messages = queue.Queue() |
| 621 | + self.pubsub = r.pubsub() |
| 622 | + self.state = 0 |
| 623 | + self.cond = threading.Condition() |
| 624 | + if method == "get_message": |
| 625 | + self.get_message = self.loop_step_get_message |
| 626 | + else: |
| 627 | + self.get_message = self.loop_step_listen |
| 628 | + |
| 629 | + self.thread = threading.Thread(target=self.loop) |
| 630 | + self.thread.daemon = True |
| 631 | + self.thread.start() |
| 632 | + # get the initial connect message |
| 633 | + message = self.messages.get(timeout=1) |
| 634 | + assert message == { |
| 635 | + "channel": b"foo", |
| 636 | + "data": 1, |
| 637 | + "pattern": None, |
| 638 | + "type": "subscribe", |
| 639 | + } |
| 640 | + |
| 641 | + def wait_for_reconnect(self): |
| 642 | + self.cond.wait_for( |
| 643 | + lambda: self.pubsub.connection._sock is not None, timeout=2) |
| 644 | + assert self.pubsub.connection._sock is not None # we didn't time out |
| 645 | + assert self.state == 3 |
| 646 | + |
| 647 | + message = self.messages.get(timeout=1) |
| 648 | + assert message == { |
| 649 | + "channel": b"foo", |
| 650 | + "data": 1, |
| 651 | + "pattern": None, |
| 652 | + "type": "subscribe", |
| 653 | + } |
| 654 | + |
| 655 | + def mycleanup(self): |
| 656 | + # kill thread |
| 657 | + with self.cond: |
| 658 | + self.state = 4 # quit |
| 659 | + self.cond.notify() |
| 660 | + self.thread.join() |
| 661 | + |
| 662 | + def test_reconnect_socket_error(self, r: redis.Redis, method): |
| 663 | + """ |
| 664 | + Test that a socket error will cause reconnect |
| 665 | + """ |
| 666 | + self.mysetup(r, method) |
| 667 | + try: |
| 668 | + # now, disconnect the connection, and wait for it to be re-established |
| 669 | + with self.cond: |
| 670 | + self.state = 1 |
| 671 | + with mock.patch.object(self.pubsub.connection, "_parser") as mockobj: |
| 672 | + mockobj.read_response.side_effect = socket.error |
| 673 | + mockobj.can_read.side_effect = socket.error |
| 674 | + # wait until thread notices the disconnect until we undo the patch |
| 675 | + self.cond.wait_for(lambda: self.state >= 2) |
| 676 | + assert ( |
| 677 | + self.pubsub.connection._sock is None |
| 678 | + ) # it is in a disconnected state |
| 679 | + self.wait_for_reconnect() |
| 680 | + |
| 681 | + finally: |
| 682 | + self.mycleanup() |
| 683 | + |
| 684 | + def test_reconnect_disconnect(self, r: redis.Redis, method): |
| 685 | + """ |
| 686 | + Test that a manual disconnect() will cause reconnect |
| 687 | + """ |
| 688 | + self.mysetup(r, method) |
| 689 | + try: |
| 690 | + # now, disconnect the connection, and wait for it to be re-established |
| 691 | + with self.cond: |
| 692 | + self.state = 1 |
| 693 | + self.pubsub.connection.disconnect() |
| 694 | + assert self.pubsub.connection._sock is None # it is in a disconnected state |
| 695 | + # wait for reconnect |
| 696 | + self.wait_for_reconnect() |
| 697 | + finally: |
| 698 | + self.mycleanup() |
| 699 | + |
| 700 | + def loop(self): |
| 701 | + # reader loop, performing state transitions as it |
| 702 | + # discovers disconnects and reconnects |
| 703 | + self.pubsub.subscribe("foo") |
| 704 | + while True: |
| 705 | + time.sleep(0.01) # give main thread chance to get lock |
| 706 | + with self.cond: |
| 707 | + old_state = self.state |
| 708 | + try: |
| 709 | + if self.state == 4: |
| 710 | + break |
| 711 | + # print ('state, %s, sock %s' % (state, pubsub.connection._sock)) |
| 712 | + got_msg = self.get_message() |
| 713 | + assert got_msg |
| 714 | + if self.state in (1, 2): |
| 715 | + self.state = 3 # successful reconnect |
| 716 | + except redis.ConnectionError: |
| 717 | + assert self.state in (1, 2) |
| 718 | + self.state = 2 |
| 719 | + finally: |
| 720 | + self.cond.notify() |
| 721 | + # assert that we noticed a connect error, or automatically |
| 722 | + # reconnected without error |
| 723 | + if old_state == 1: |
| 724 | + assert self.state in (2, 3) |
| 725 | + |
| 726 | + def loop_step_get_message(self): |
| 727 | + # get a single message via listen() |
| 728 | + message = self.pubsub.get_message(timeout=0.1) |
| 729 | + if message is not None: |
| 730 | + self.messages.put(message) |
| 731 | + return True |
| 732 | + return False |
| 733 | + |
| 734 | + def loop_step_listen(self): |
| 735 | + # get a single message via listen() |
| 736 | + for message in self.pubsub.listen(): |
| 737 | + self.messages.put(message) |
| 738 | + return True |
0 commit comments