Skip to content

Commit d6dfc80

Browse files
committed
Fix SSLProtocol to correctly propagate errors and abort connections
Add functional tests
1 parent bd27898 commit d6dfc80

File tree

3 files changed

+139
-53
lines changed

3 files changed

+139
-53
lines changed

Lib/asyncio/sslproto.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from . import base_events
99
from . import constants
10+
from . import futures
1011
from . import protocols
1112
from . import transports
1213
from .log import logger
@@ -490,6 +491,12 @@ def connection_lost(self, exc):
490491
if self._session_established:
491492
self._session_established = False
492493
self._loop.call_soon(self._app_protocol.connection_lost, exc)
494+
else:
495+
# Most likely an exception occurred while in SSL handshake.
496+
# Just mark the app transport as closed so that its __del__
497+
# doesn't complain.
498+
if self._app_transport is not None:
499+
self._app_transport._closed = True
493500
self._transport = None
494501
self._app_transport = None
495502
self._wakeup_waiter(exc)
@@ -605,10 +612,12 @@ def _start_handshake(self):
605612

606613
def _check_handshake_timeout(self):
607614
if self._in_handshake is True:
608-
logger.warning(
609-
"SSL handshake for %r is taking longer than %r seconds: "
610-
"aborting the connection", self, self._ssl_handshake_timeout)
611-
self._abort()
615+
msg = (
616+
f"SSL handshake for {self} is taking longer than "
617+
f"{self._ssl_handshake_timeout} seconds: "
618+
f"aborting the connection"
619+
)
620+
self._fatal_error(ConnectionAbortedError(msg))
612621

613622
def _on_handshake_complete(self, handshake_exc):
614623
self._in_handshake = False
@@ -620,21 +629,16 @@ def _on_handshake_complete(self, handshake_exc):
620629
raise handshake_exc
621630

622631
peercert = sslobj.getpeercert()
623-
except BaseException as exc:
624-
if self._loop.get_debug():
625-
if isinstance(exc, ssl.CertificateError):
626-
logger.warning("%r: SSL handshake failed "
627-
"on verifying the certificate",
628-
self, exc_info=True)
629-
else:
630-
logger.warning("%r: SSL handshake failed",
631-
self, exc_info=True)
632-
self._transport.close()
633-
if isinstance(exc, Exception):
634-
self._wakeup_waiter(exc)
635-
return
632+
except Exception as exc:
633+
if isinstance(exc, ssl.CertificateError):
634+
msg = (
635+
f'{self}: SSL handshake failed on verifying '
636+
f'the certificate'
637+
)
636638
else:
637-
raise
639+
msg = f'{self}: SSL handshake failed'
640+
self._fatal_error(exc, msg)
641+
return
638642

639643
if self._loop.get_debug():
640644
dt = self._loop.time() - self._handshake_start_time
@@ -702,19 +706,19 @@ def _process_write_backlog(self):
702706
raise
703707

704708
def _fatal_error(self, exc, message='Fatal error on transport'):
705-
# Should be called from exception handler only.
709+
if self._transport:
710+
self._transport._force_close(exc)
711+
706712
if isinstance(exc, base_events._FATAL_ERROR_IGNORE):
707713
if self._loop.get_debug():
708714
logger.debug("%r: %s", self, message, exc_info=True)
709-
else:
715+
elif not isinstance(exc, futures.CancelledError):
710716
self._loop.call_exception_handler({
711717
'message': message,
712718
'exception': exc,
713719
'transport': self._transport,
714720
'protocol': self,
715721
})
716-
if self._transport:
717-
self._transport._force_close(exc)
718722

719723
def _finalize(self):
720724
self._sslpipe = None

Lib/test/test_asyncio/test_sslproto.py

Lines changed: 110 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -49,35 +49,6 @@ def mock_handshake(callback):
4949
ssl_proto.connection_made(transport)
5050
return transport
5151

52-
def test_cancel_handshake(self):
53-
# Python issue #23197: cancelling a handshake must not raise an
54-
# exception or log an error, even if the handshake failed
55-
waiter = asyncio.Future(loop=self.loop)
56-
ssl_proto = self.ssl_protocol(waiter=waiter)
57-
handshake_fut = asyncio.Future(loop=self.loop)
58-
59-
def do_handshake(callback):
60-
exc = Exception()
61-
callback(exc)
62-
handshake_fut.set_result(None)
63-
return []
64-
65-
waiter.cancel()
66-
self.connection_made(ssl_proto, do_handshake=do_handshake)
67-
68-
with test_utils.disable_logger():
69-
self.loop.run_until_complete(handshake_fut)
70-
71-
def test_handshake_timeout(self):
72-
# bpo-29970: Check that a connection is aborted if handshake is not
73-
# completed in timeout period, instead of remaining open indefinitely
74-
ssl_proto = self.ssl_protocol()
75-
transport = self.connection_made(ssl_proto)
76-
77-
with test_utils.disable_logger():
78-
self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
79-
self.assertTrue(transport.abort.called)
80-
8152
def test_handshake_timeout_zero(self):
8253
sslcontext = test_utils.dummy_ssl_context()
8354
app_proto = mock.Mock()
@@ -477,6 +448,116 @@ async def main():
477448

478449
self.loop.run_until_complete(main())
479450

451+
def test_handshake_timeout(self):
452+
# bpo-29970: Check that a connection is aborted if handshake is not
453+
# completed in timeout period, instead of remaining open indefinitely
454+
client_sslctx = test_utils.simple_client_sslcontext()
455+
456+
# silence error logger
457+
messages = []
458+
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
459+
460+
server_side_aborted = False
461+
462+
def server(sock):
463+
nonlocal server_side_aborted
464+
try:
465+
sock.recv_all(1024 * 1024)
466+
except ConnectionAbortedError:
467+
server_side_aborted = True
468+
finally:
469+
sock.close()
470+
471+
async def client(addr):
472+
await asyncio.wait_for(
473+
self.loop.create_connection(
474+
asyncio.Protocol,
475+
*addr,
476+
ssl=client_sslctx,
477+
server_hostname='',
478+
ssl_handshake_timeout=10.0),
479+
0.5,
480+
loop=self.loop)
481+
482+
with self.tcp_server(server,
483+
max_clients=1,
484+
backlog=1) as srv:
485+
486+
with self.assertRaises(asyncio.TimeoutError):
487+
self.loop.run_until_complete(client(srv.addr))
488+
489+
self.assertTrue(server_side_aborted)
490+
491+
# Python issue #23197: cancelling a handshake must not raise an
492+
# exception or log an error, even if the handshake failed
493+
self.assertEqual(messages, [])
494+
495+
def test_create_connection_ssl_slow_handshake(self):
496+
client_sslctx = test_utils.simple_client_sslcontext()
497+
498+
# silence error logger
499+
self.loop.set_exception_handler(lambda *args: None)
500+
501+
def server(sock):
502+
try:
503+
sock.recv_all(1024 * 1024)
504+
except ConnectionAbortedError:
505+
pass
506+
finally:
507+
sock.close()
508+
509+
async def client(addr):
510+
reader, writer = await asyncio.open_connection(
511+
*addr,
512+
ssl=client_sslctx,
513+
server_hostname='',
514+
loop=self.loop,
515+
ssl_handshake_timeout=1.0)
516+
517+
with self.tcp_server(server,
518+
max_clients=1,
519+
backlog=1) as srv:
520+
521+
with self.assertRaisesRegex(
522+
ConnectionAbortedError,
523+
r'SSL handshake.*is taking longer'):
524+
525+
self.loop.run_until_complete(client(srv.addr))
526+
527+
def test_create_connection_ssl_failed_certificate(self):
528+
# silence error logger
529+
self.loop.set_exception_handler(lambda *args: None)
530+
531+
sslctx = test_utils.simple_server_sslcontext()
532+
client_sslctx = test_utils.simple_client_sslcontext(
533+
disable_verify=False)
534+
535+
def server(sock):
536+
try:
537+
sock.start_tls(
538+
sslctx,
539+
server_side=True)
540+
sock.connect()
541+
except ssl.SSLError:
542+
pass
543+
finally:
544+
sock.close()
545+
546+
async def client(addr):
547+
reader, writer = await asyncio.open_connection(
548+
*addr,
549+
ssl=client_sslctx,
550+
server_hostname='',
551+
loop=self.loop,
552+
ssl_handshake_timeout=1.0)
553+
554+
with self.tcp_server(server,
555+
max_clients=1,
556+
backlog=1) as srv:
557+
558+
with self.assertRaises(ssl.SSLCertVerificationError):
559+
self.loop.run_until_complete(client(srv.addr))
560+
480561

481562
@unittest.skipIf(ssl is None, 'No ssl module')
482563
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):

Lib/test/test_asyncio/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@ def simple_server_sslcontext():
7777
return server_context
7878

7979

80-
def simple_client_sslcontext():
80+
def simple_client_sslcontext(*, disable_verify=True):
8181
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
8282
client_context.check_hostname = False
83-
client_context.verify_mode = ssl.CERT_NONE
83+
if disable_verify:
84+
client_context.verify_mode = ssl.CERT_NONE
8485
return client_context
8586

8687

0 commit comments

Comments
 (0)