Skip to content

Commit 87409f6

Browse files
committed
fix tests
1 parent 7038e00 commit 87409f6

File tree

1 file changed

+76
-40
lines changed

1 file changed

+76
-40
lines changed

tests/test_asyncio/test_cwe_404.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
import sys
34

45
import pytest
@@ -8,13 +9,19 @@
89

910

1011
async def pipe(
11-
reader: asyncio.StreamReader, writer: asyncio.StreamWriter, delay: float, name=""
12+
reader: asyncio.StreamReader,
13+
writer: asyncio.StreamWriter,
14+
proxy: "DelayProxy",
15+
name="",
16+
event: asyncio.Event = None,
1217
):
1318
while True:
1419
data = await reader.read(1000)
1520
if not data:
1621
break
17-
await asyncio.sleep(delay)
22+
if event:
23+
event.set()
24+
await asyncio.sleep(proxy.delay)
1825
writer.write(data)
1926
await writer.drain()
2027

@@ -24,18 +31,32 @@ def __init__(self, addr, redis_addr, delay: float):
2431
self.addr = addr
2532
self.redis_addr = redis_addr
2633
self.delay = delay
34+
self.send_event = asyncio.Event()
2735

2836
async def start(self):
2937
self.server = await asyncio.start_server(self.handle, *self.addr)
3038
self.ROUTINE = asyncio.create_task(self.server.serve_forever())
3139

40+
@contextlib.contextmanager
41+
def override(self, delay: float = 0.0):
42+
"""
43+
Allow to override the delay for parts of tests which aren't time dependent,
44+
to speed up execution.
45+
"""
46+
old = self.delay
47+
self.delay = delay
48+
try:
49+
yield
50+
finally:
51+
self.delay = old
52+
3253
async def handle(self, reader, writer):
3354
# establish connection to redis
3455
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
35-
pipe1 = asyncio.create_task(pipe(reader, redis_writer, self.delay, "to redis:"))
36-
pipe2 = asyncio.create_task(
37-
pipe(redis_reader, writer, self.delay, "from redis:")
56+
pipe1 = asyncio.create_task(
57+
pipe(reader, redis_writer, self, "to redis:", self.send_event)
3858
)
59+
pipe2 = asyncio.create_task(pipe(redis_reader, writer, self, "from redis:"))
3960
await asyncio.gather(pipe1, pipe2)
4061

4162
async def stop(self):
@@ -60,23 +81,26 @@ async def test_standalone(delay):
6081
# note that we connect to proxy, rather than to Redis directly
6182
async with Redis(host="localhost", port=5380, single_connection_client=b) as r:
6283

63-
await r.set("foo", "foo")
64-
await r.set("bar", "bar")
84+
with dp.override():
85+
await r.set("foo", "foo")
86+
await r.set("bar", "bar")
6587

88+
dp.send_event.clear()
6689
t = asyncio.create_task(r.get("foo"))
67-
await asyncio.sleep(delay)
90+
# wait until the task has sent, and then some, to make sure it has settled on
91+
# reading.
92+
await dp.send_event.wait()
93+
await asyncio.sleep(0.05)
6894
t.cancel()
69-
try:
95+
with pytest.raises(asyncio.CancelledError):
7096
await t
71-
sys.stderr.write("try again, we did not cancel the task in time\n")
72-
except asyncio.CancelledError:
73-
sys.stderr.write(
74-
"canceled task, connection is left open with unread response\n"
75-
)
7697

77-
assert await r.get("bar") == b"bar"
78-
assert await r.ping()
79-
assert await r.get("foo") == b"foo"
98+
# make sure that our previous request, cancelled while waiting for a repsponse,
99+
# didn't leave the connection in a bad state
100+
with dp.override():
101+
assert await r.get("bar") == b"bar"
102+
assert await r.ping()
103+
assert await r.get("foo") == b"foo"
80104

81105
await dp.stop()
82106

@@ -90,8 +114,9 @@ async def test_standalone_pipeline(delay):
90114
await dp.start()
91115
for b in [True, False]:
92116
async with Redis(host="localhost", port=5380, single_connection_client=b) as r:
93-
await r.set("foo", "foo")
94-
await r.set("bar", "bar")
117+
with dp.override():
118+
await r.set("foo", "foo")
119+
await r.set("bar", "bar")
95120

96121
pipe = r.pipeline()
97122

@@ -100,23 +125,32 @@ async def test_standalone_pipeline(delay):
100125
pipe2.ping()
101126
pipe2.get("foo")
102127

128+
dp.send_event.clear()
103129
t = asyncio.create_task(pipe.get("foo").execute())
104-
await asyncio.sleep(delay)
130+
# wait until task has settled on the read
131+
await dp.send_event.wait()
132+
await asyncio.sleep(0.05)
105133
t.cancel()
134+
with pytest.raises(asyncio.CancelledError):
135+
await t
106136

107-
pipe.get("bar")
108-
pipe.ping()
109-
pipe.get("foo")
110-
pipe.reset()
137+
# we have now cancelled the pieline in the middle of a request, make sure
138+
# that the connection is still usable
139+
with dp.override():
140+
pipe.get("bar")
141+
pipe.ping()
142+
pipe.get("foo")
143+
await pipe.reset()
111144

112-
assert await pipe.execute() is None
145+
# check that the pipeline is empty after reset
146+
assert await pipe.execute() == []
113147

114-
# validating that the pipeline can be used as it could previously
115-
pipe.get("bar")
116-
pipe.ping()
117-
pipe.get("foo")
118-
assert await pipe.execute() == [b"bar", True, b"foo"]
119-
assert await pipe2.execute() == [b"bar", True, b"foo"]
148+
# validating that the pipeline can be used as it could previously
149+
pipe.get("bar")
150+
pipe.ping()
151+
pipe.get("foo")
152+
assert await pipe.execute() == [b"bar", True, b"foo"]
153+
assert await pipe2.execute() == [b"bar", True, b"foo"]
120154

121155
await dp.stop()
122156

@@ -129,19 +163,21 @@ async def test_cluster(request):
129163

130164
r = RedisCluster.from_url("redis://localhost:5381")
131165
await r.initialize()
132-
await r.set("foo", "foo")
133-
await r.set("bar", "bar")
166+
with dp.override():
167+
await r.set("foo", "foo")
168+
await r.set("bar", "bar")
134169

170+
dp.send_event.clear()
135171
t = asyncio.create_task(r.get("foo"))
136-
await asyncio.sleep(0.050)
172+
await dp.send_event.wait()
173+
await asyncio.sleep(0.05)
137174
t.cancel()
138-
try:
175+
with pytest.raises(asyncio.CancelledError):
139176
await t
140-
except asyncio.CancelledError:
141-
pytest.fail("connection is left open with unread response")
142177

143-
assert await r.get("bar") == b"bar"
144-
assert await r.ping()
145-
assert await r.get("foo") == b"foo"
178+
with dp.override():
179+
assert await r.get("bar") == b"bar"
180+
assert await r.ping()
181+
assert await r.get("foo") == b"foo"
146182

147183
await dp.stop()

0 commit comments

Comments
 (0)