Skip to content

Commit efe4d2a

Browse files
committed
Add a unittest for asyncio.RedisCluster
1 parent 3856991 commit efe4d2a

File tree

1 file changed

+108
-1
lines changed

1 file changed

+108
-1
lines changed

tests/test_asyncio/test_cluster.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from _pytest.fixtures import FixtureRequest
1212

1313
from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster
14-
from redis.asyncio.connection import Connection, SSLConnection
14+
from redis.asyncio.connection import Connection, SSLConnection, async_timeout
1515
from redis.asyncio.parser import CommandsParser
1616
from redis.asyncio.retry import Retry
1717
from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff
@@ -49,6 +49,71 @@
4949
]
5050

5151

52+
class NodeProxy:
53+
"""A class to proxy a node connection to a different port"""
54+
55+
def __init__(self, addr, redis_addr):
56+
self.addr = addr
57+
self.redis_addr = redis_addr
58+
self.send_event = asyncio.Event()
59+
self.server = None
60+
self.task = None
61+
self.n_connections = 0
62+
63+
async def start(self):
64+
# test that we can connect to redis
65+
async with async_timeout(2):
66+
_, redis_writer = await asyncio.open_connection(*self.redis_addr)
67+
redis_writer.close()
68+
self.server = await asyncio.start_server(
69+
self.handle, *self.addr, reuse_address=True
70+
)
71+
self.task = asyncio.create_task(self.server.serve_forever())
72+
73+
async def handle(self, reader, writer):
74+
# establish connection to redis
75+
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
76+
try:
77+
self.n_connections += 1
78+
pipe1 = asyncio.create_task(self.pipe(reader, redis_writer))
79+
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer))
80+
await asyncio.gather(pipe1, pipe2)
81+
finally:
82+
redis_writer.close()
83+
84+
async def aclose(self):
85+
self.task.cancel()
86+
try:
87+
await self.task
88+
except asyncio.CancelledError:
89+
pass
90+
await self.server.wait_closed()
91+
92+
async def pipe(
93+
self,
94+
reader: asyncio.StreamReader,
95+
writer: asyncio.StreamWriter,
96+
):
97+
while True:
98+
data = await reader.read(1000)
99+
if not data:
100+
break
101+
writer.write(data)
102+
await writer.drain()
103+
104+
105+
@pytest.fixture
106+
def redis_addr(request):
107+
redis_url = request.config.getoption("--redis-url")
108+
scheme, netloc = urlparse(redis_url)[:2]
109+
assert scheme == "redis"
110+
if ":" in netloc:
111+
host, port = netloc.split(":")
112+
return host, int(port)
113+
else:
114+
return netloc, 6379
115+
116+
52117
@pytest_asyncio.fixture()
53118
async def slowlog(r: RedisCluster) -> None:
54119
"""
@@ -809,6 +874,48 @@ async def test_default_node_is_replaced_after_exception(self, r):
809874
# Rollback to the old default node
810875
r.replace_default_node(curr_default_node)
811876

877+
async def test_host_port_remap(self, create_redis, redis_addr):
878+
"""Test that we can create a rediscluster object with
879+
a host-port remapper and map connections through proxy objects
880+
"""
881+
882+
# we remap the first n nodes
883+
offset = 1000
884+
n = 6
885+
ports = [redis_addr[1] + i for i in range(n)]
886+
887+
def host_port_remap(host, port):
888+
# remap first three nodes to our local proxy
889+
old = host, port
890+
if int(port) in ports:
891+
host, port = "127.0.0.1", int(port) + offset
892+
# print(f"{old} {host, port}")
893+
return host, port
894+
895+
# create the proxies
896+
proxies = [
897+
NodeProxy(("127.0.0.1", port + offset), (redis_addr[0], port))
898+
for port in ports
899+
]
900+
await asyncio.gather(*[p.start() for p in proxies])
901+
try:
902+
# create cluster:
903+
r = await create_redis(
904+
cls=RedisCluster, flushdb=False, host_port_remap=host_port_remap
905+
)
906+
try:
907+
assert await r.ping() is True
908+
assert await r.set("byte_string", b"giraffe")
909+
assert await r.get("byte_string") == b"giraffe"
910+
finally:
911+
await r.close()
912+
finally:
913+
await asyncio.gather(*[p.aclose() for p in proxies])
914+
915+
# verify that the proxies were indeed used
916+
n_used = sum((1 if p.n_connections else 0) for p in proxies)
917+
assert n_used > 1
918+
812919

813920
class TestClusterRedisCommands:
814921
"""

0 commit comments

Comments
 (0)