|
1 | 1 | import platform
|
| 2 | +import queue |
| 3 | +import socket |
2 | 4 | import threading
|
3 | 5 | import time
|
4 | 6 | from unittest import mock
|
@@ -608,3 +610,121 @@ def test_pubsub_deadlock(self, master_host):
|
608 | 610 | p = r.pubsub()
|
609 | 611 | p.subscribe("my-channel-1", "my-channel-2")
|
610 | 612 | pool.reset()
|
| 613 | + |
| 614 | + |
| 615 | +@pytest.mark.timeout(5, method="thread") |
| 616 | +@pytest.mark.onlynoncluster |
| 617 | +class TestPubSubAutoReconnect: |
| 618 | + def test_reconnect_socket_error(self, r: redis.Redis): |
| 619 | + """ |
| 620 | + Test that a socket error will cause reconnect |
| 621 | + """ |
| 622 | + self.messages = queue.Queue() |
| 623 | + self.pubsub = r.pubsub() |
| 624 | + self.state = 0 |
| 625 | + self.cond = threading.Condition() |
| 626 | + |
| 627 | + thread = threading.Thread(target=self.loop) |
| 628 | + thread.start() |
| 629 | + # get the initial connect message |
| 630 | + message = self.messages.get(timeout=1) |
| 631 | + assert message == { |
| 632 | + "channel": b"foo", |
| 633 | + "data": 1, |
| 634 | + "pattern": None, |
| 635 | + "type": "subscribe", |
| 636 | + } |
| 637 | + # now, disconnect the connection, and wait for it to be re-established |
| 638 | + with self.cond: |
| 639 | + self.state = 1 |
| 640 | + with patch("socket.socket.recv") as mock: |
| 641 | + mock.side_effect = socket.error |
| 642 | + # wait until thread noticies the disconnect until we undo the patch |
| 643 | + self.cond.wait_for(lambda: self.state >= 2) |
| 644 | + assert ( |
| 645 | + self.pubsub.connection._sock is None |
| 646 | + ) # it is in a disconnecte state |
| 647 | + # wait for reconnect |
| 648 | + self.cond.wait_for(lambda: self.pubsub.connection._sock is not None) |
| 649 | + assert self.state == 3 |
| 650 | + |
| 651 | + message = self.messages.get(timeout=1) |
| 652 | + assert message == { |
| 653 | + "channel": b"foo", |
| 654 | + "data": 1, |
| 655 | + "pattern": None, |
| 656 | + "type": "subscribe", |
| 657 | + } |
| 658 | + # kill thread |
| 659 | + with self.cond: |
| 660 | + self.state = 4 # quit |
| 661 | + thread.join() |
| 662 | + |
| 663 | + def test_reconnect_disconnect(self, r: redis.Redis): |
| 664 | + """ |
| 665 | + Test that a socket error will cause reconnect |
| 666 | + """ |
| 667 | + self.messages = queue.Queue() |
| 668 | + self.pubsub = r.pubsub() |
| 669 | + self.state = 0 |
| 670 | + self.cond = threading.Condition() |
| 671 | + |
| 672 | + thread = threading.Thread(target=self.loop) |
| 673 | + thread.start() |
| 674 | + # get the initial connect message |
| 675 | + message = self.messages.get(timeout=1) |
| 676 | + assert message == { |
| 677 | + "channel": b"foo", |
| 678 | + "data": 1, |
| 679 | + "pattern": None, |
| 680 | + "type": "subscribe", |
| 681 | + } |
| 682 | + # now, disconnect the connection, and wait for it to be re-established |
| 683 | + with self.cond: |
| 684 | + self.state = 1 |
| 685 | + self.pubsub.connection.disconnect() |
| 686 | + assert self.pubsub.connection._sock is None # it is in a disconnecte state |
| 687 | + # wait for reconnect |
| 688 | + self.cond.wait_for(lambda: self.pubsub.connection._sock is not None) |
| 689 | + assert self.state == 3 |
| 690 | + |
| 691 | + message = self.messages.get(timeout=1) |
| 692 | + assert message == { |
| 693 | + "channel": b"foo", |
| 694 | + "data": 1, |
| 695 | + "pattern": None, |
| 696 | + "type": "subscribe", |
| 697 | + } |
| 698 | + # kill thread |
| 699 | + with self.cond: |
| 700 | + self.state = 4 # quit |
| 701 | + thread.join() |
| 702 | + |
| 703 | + def loop(self): |
| 704 | + # must make sure the task exits |
| 705 | + self.pubsub.subscribe("foo") |
| 706 | + while True: |
| 707 | + time.sleep(0.01) # give main thread chance to get lock |
| 708 | + with self.cond: |
| 709 | + try: |
| 710 | + if self.state == 1: |
| 711 | + self.state = 2 |
| 712 | + elif self.state == 4: |
| 713 | + break |
| 714 | + # print ('state, %s, sock %s' % (state, pubsub.connection._sock)) |
| 715 | + self.loop_step(0.1) |
| 716 | + if self.state == 2: |
| 717 | + self.state = 3 # successful reconnect |
| 718 | + except redis.ConnectionError: |
| 719 | + pass # we will reconnect |
| 720 | + |
| 721 | + finally: |
| 722 | + self.cond.notify() |
| 723 | + |
| 724 | + def loop_step(self, timeout): |
| 725 | + # get a single message via listen() |
| 726 | + message = self.pubsub.get_message(timeout=timeout) |
| 727 | + if message is not None: |
| 728 | + self.messages.put(message) |
| 729 | + return True |
| 730 | + return False |
0 commit comments