|
11 | 11 | from _pytest.fixtures import FixtureRequest
|
12 | 12 |
|
13 | 13 | from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
|
14 |
| -from redis.asyncio.connection import Connection, SSLConnection |
| 14 | +from redis.asyncio.connection import Connection, SSLConnection, async_timeout |
15 | 15 | from redis.asyncio.parser import CommandsParser
|
16 | 16 | from redis.asyncio.retry import Retry
|
17 | 17 | from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
|
|
49 | 49 | ]
|
50 | 50 |
|
51 | 51 |
|
| 52 | +class NodeProxy: |
| 53 | + """A class to proxy a node connection to a different port""" |
| 54 | + |
| 55 | + def __init__(self, addr, redis_addr): |
| 56 | + self.addr = addr |
| 57 | + self.redis_addr = redis_addr |
| 58 | + self.send_event = asyncio.Event() |
| 59 | + self.server = None |
| 60 | + self.task = None |
| 61 | + self.n_connections = 0 |
| 62 | + |
| 63 | + async def start(self): |
| 64 | + # test that we can connect to redis |
| 65 | + async with async_timeout(2): |
| 66 | + _, redis_writer = await asyncio.open_connection(*self.redis_addr) |
| 67 | + redis_writer.close() |
| 68 | + self.server = await asyncio.start_server( |
| 69 | + self.handle, *self.addr, reuse_address=True |
| 70 | + ) |
| 71 | + self.task = asyncio.create_task(self.server.serve_forever()) |
| 72 | + |
| 73 | + async def handle(self, reader, writer): |
| 74 | + # establish connection to redis |
| 75 | + redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr) |
| 76 | + try: |
| 77 | + self.n_connections += 1 |
| 78 | + pipe1 = asyncio.create_task(self.pipe(reader, redis_writer)) |
| 79 | + pipe2 = asyncio.create_task(self.pipe(redis_reader, writer)) |
| 80 | + await asyncio.gather(pipe1, pipe2) |
| 81 | + finally: |
| 82 | + redis_writer.close() |
| 83 | + |
| 84 | + async def aclose(self): |
| 85 | + self.task.cancel() |
| 86 | + try: |
| 87 | + await self.task |
| 88 | + except asyncio.CancelledError: |
| 89 | + pass |
| 90 | + await self.server.wait_closed() |
| 91 | + |
| 92 | + async def pipe( |
| 93 | + self, |
| 94 | + reader: asyncio.StreamReader, |
| 95 | + writer: asyncio.StreamWriter, |
| 96 | + ): |
| 97 | + while True: |
| 98 | + data = await reader.read(1000) |
| 99 | + if not data: |
| 100 | + break |
| 101 | + writer.write(data) |
| 102 | + await writer.drain() |
| 103 | + |
| 104 | + |
| 105 | +@pytest.fixture |
| 106 | +def redis_addr(request): |
| 107 | + redis_url = request.config.getoption("--redis-url") |
| 108 | + scheme, netloc = urlparse(redis_url)[:2] |
| 109 | + assert scheme == "redis" |
| 110 | + if ":" in netloc: |
| 111 | + host, port = netloc.split(":") |
| 112 | + return host, int(port) |
| 113 | + else: |
| 114 | + return netloc, 6379 |
| 115 | + |
| 116 | + |
52 | 117 | @pytest_asyncio.fixture()
|
53 | 118 | async def slowlog(r: RedisCluster) -> None:
|
54 | 119 | """
|
@@ -809,6 +874,48 @@ async def test_default_node_is_replaced_after_exception(self, r):
|
809 | 874 | # Rollback to the old default node
|
810 | 875 | r.replace_default_node(curr_default_node)
|
811 | 876 |
|
| 877 | + async def test_host_port_remap(self, create_redis, redis_addr): |
| 878 | + """Test that we can create a rediscluster object with |
| 879 | + a host-port remapper and map connections through proxy objects |
| 880 | + """ |
| 881 | + |
| 882 | + # we remap the first n nodes |
| 883 | + offset = 1000 |
| 884 | + n = 6 |
| 885 | + ports = [redis_addr[1] + i for i in range(n)] |
| 886 | + |
| 887 | + def host_port_remap(host, port): |
| 888 | + # remap first three nodes to our local proxy |
| 889 | + old = host, port |
| 890 | + if int(port) in ports: |
| 891 | + host, port = "127.0.0.1", int(port) + offset |
| 892 | + # print(f"{old} {host, port}") |
| 893 | + return host, port |
| 894 | + |
| 895 | + # create the proxies |
| 896 | + proxies = [ |
| 897 | + NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port)) |
| 898 | + for port in ports |
| 899 | + ] |
| 900 | + await asyncio.gather(*[p.start() for p in proxies]) |
| 901 | + try: |
| 902 | + # create cluster: |
| 903 | + r = await create_redis( |
| 904 | + cls=RedisCluster, flushdb=False, host_port_remap=host_port_remap |
| 905 | + ) |
| 906 | + try: |
| 907 | + assert await r.ping() is True |
| 908 | + assert await r.set("byte_string", b"giraffe") |
| 909 | + assert await r.get("byte_string") == b"giraffe" |
| 910 | + finally: |
| 911 | + await r.close() |
| 912 | + finally: |
| 913 | + await asyncio.gather(*[p.aclose() for p in proxies]) |
| 914 | + |
| 915 | + # verify that the proxies were indeed used |
| 916 | + n_used = sum((1 if p.n_connections else 0) for p in proxies) |
| 917 | + assert n_used > 1 |
| 918 | + |
812 | 919 |
|
813 | 920 | class TestClusterRedisCommands:
|
814 | 921 | """
|
|
0 commit comments