Skip to content

Commit 24e5520

Browse files
committed
Add a test for proper waiter handling in start_tls
1 parent 0220c96 commit 24e5520

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

Lib/test/test_asyncio/test_sslproto.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,67 @@ async def client(addr):
359359
asyncio.wait_for(client(srv.addr),
360360
loop=self.loop, timeout=self.TIMEOUT))
361361

362+
def test_start_tls_slow_client_cancel(self):
363+
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
364+
365+
client_context = test_utils.simple_client_sslcontext()
366+
server_waits_on_handshake = self.loop.create_future()
367+
368+
def serve(sock):
369+
sock.settimeout(self.TIMEOUT)
370+
371+
data = sock.recv_all(len(HELLO_MSG))
372+
self.assertEqual(len(data), len(HELLO_MSG))
373+
374+
try:
375+
self.loop.call_soon_threadsafe(
376+
server_waits_on_handshake.set_result, None)
377+
data = sock.recv_all(1024 * 1024)
378+
except ConnectionAbortedError:
379+
pass
380+
finally:
381+
sock.close()
382+
383+
class ClientProto(asyncio.Protocol):
384+
def __init__(self, on_data, on_eof):
385+
self.on_data = on_data
386+
self.on_eof = on_eof
387+
self.con_made_cnt = 0
388+
389+
def connection_made(proto, tr):
390+
proto.con_made_cnt += 1
391+
# Ensure connection_made gets called only once.
392+
self.assertEqual(proto.con_made_cnt, 1)
393+
394+
def data_received(self, data):
395+
self.on_data.set_result(data)
396+
397+
def eof_received(self):
398+
self.on_eof.set_result(True)
399+
400+
async def client(addr):
401+
await asyncio.sleep(0.5, loop=self.loop)
402+
403+
on_data = self.loop.create_future()
404+
on_eof = self.loop.create_future()
405+
406+
tr, proto = await self.loop.create_connection(
407+
lambda: ClientProto(on_data, on_eof), *addr)
408+
409+
tr.write(HELLO_MSG)
410+
411+
await server_waits_on_handshake
412+
413+
with self.assertRaises(asyncio.TimeoutError):
414+
await asyncio.wait_for(
415+
self.loop.start_tls(tr, proto, client_context),
416+
0.5,
417+
loop=self.loop)
418+
419+
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
420+
self.loop.run_until_complete(
421+
asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10))
422+
362423
def test_start_tls_server_1(self):
363424
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
364425

0 commit comments

Comments
 (0)