@@ -359,6 +359,67 @@ async def client(addr):
359
359
asyncio .wait_for (client (srv .addr ),
360
360
loop = self .loop , timeout = self .TIMEOUT ))
361
361
362
+ def test_start_tls_slow_client_cancel (self ):
363
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
364
+
365
+ client_context = test_utils .simple_client_sslcontext ()
366
+ server_waits_on_handshake = self .loop .create_future ()
367
+
368
+ def serve (sock ):
369
+ sock .settimeout (self .TIMEOUT )
370
+
371
+ data = sock .recv_all (len (HELLO_MSG ))
372
+ self .assertEqual (len (data ), len (HELLO_MSG ))
373
+
374
+ try :
375
+ self .loop .call_soon_threadsafe (
376
+ server_waits_on_handshake .set_result , None )
377
+ data = sock .recv_all (1024 * 1024 )
378
+ except ConnectionAbortedError :
379
+ pass
380
+ finally :
381
+ sock .close ()
382
+
383
+ class ClientProto (asyncio .Protocol ):
384
+ def __init__ (self , on_data , on_eof ):
385
+ self .on_data = on_data
386
+ self .on_eof = on_eof
387
+ self .con_made_cnt = 0
388
+
389
+ def connection_made (proto , tr ):
390
+ proto .con_made_cnt += 1
391
+ # Ensure connection_made gets called only once.
392
+ self .assertEqual (proto .con_made_cnt , 1 )
393
+
394
+ def data_received (self , data ):
395
+ self .on_data .set_result (data )
396
+
397
+ def eof_received (self ):
398
+ self .on_eof .set_result (True )
399
+
400
+ async def client (addr ):
401
+ await asyncio .sleep (0.5 , loop = self .loop )
402
+
403
+ on_data = self .loop .create_future ()
404
+ on_eof = self .loop .create_future ()
405
+
406
+ tr , proto = await self .loop .create_connection (
407
+ lambda : ClientProto (on_data , on_eof ), * addr )
408
+
409
+ tr .write (HELLO_MSG )
410
+
411
+ await server_waits_on_handshake
412
+
413
+ with self .assertRaises (asyncio .TimeoutError ):
414
+ await asyncio .wait_for (
415
+ self .loop .start_tls (tr , proto , client_context ),
416
+ 0.5 ,
417
+ loop = self .loop )
418
+
419
+ with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
420
+ self .loop .run_until_complete (
421
+ asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
422
+
362
423
def test_start_tls_server_1 (self ):
363
424
HELLO_MSG = b'1' * self .PAYLOAD_SIZE
364
425
0 commit comments