Skip to content

Commit 141c5e8

Browse files
authored
bpo-24334: Cleanup SSLSocket (#5252)
* The SSLSocket is no longer implemented on top of SSLObject to avoid an extra level of indirection. * Owner and session are now handled in the internal constructor. * _ssl._SSLSocket now uses the same method names as SSLSocket and SSLObject. * Channel binding type check is now handled in C code. Channel binding is always available. The patch also changes the signature of SSLObject.__init__(). In my opinion it's fine. A SSLObject is not a user-constructable object. SSLContext.wrap_bio() is the only valid factory.
1 parent b18f8bc commit 141c5e8

File tree

5 files changed

+183
-117
lines changed

5 files changed

+183
-117
lines changed

Lib/ssl.py

Lines changed: 62 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@
166166

167167
socket_error = OSError # keep that public name in module namespace
168168

169-
if _ssl.HAS_TLS_UNIQUE:
170-
CHANNEL_BINDING_TYPES = ['tls-unique']
171-
else:
172-
CHANNEL_BINDING_TYPES = []
169+
CHANNEL_BINDING_TYPES = ['tls-unique']
173170

174171
HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
175172

@@ -407,11 +404,11 @@ def wrap_bio(self, incoming, outgoing, server_side=False,
407404
server_hostname=None, session=None):
408405
# Need to encode server_hostname here because _wrap_bio() can only
409406
# handle ASCII str.
410-
sslobj = self._wrap_bio(
407+
return self.sslobject_class(
411408
incoming, outgoing, server_side=server_side,
412-
server_hostname=self._encode_hostname(server_hostname)
409+
server_hostname=self._encode_hostname(server_hostname),
410+
session=session, _context=self,
413411
)
414-
return self.sslobject_class(sslobj, session=session)
415412

416413
def set_npn_protocols(self, npn_protocols):
417414
protos = bytearray()
@@ -616,12 +613,13 @@ class SSLObject:
616613
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
617614
"""
618615

619-
def __init__(self, sslobj, owner=None, session=None):
620-
self._sslobj = sslobj
621-
# Note: _sslobj takes a weak reference to owner
622-
self._sslobj.owner = owner or self
623-
if session is not None:
624-
self._sslobj.session = session
616+
def __init__(self, incoming, outgoing, server_side=False,
617+
server_hostname=None, session=None, _context=None):
618+
self._sslobj = _context._wrap_bio(
619+
incoming, outgoing, server_side=server_side,
620+
server_hostname=server_hostname,
621+
owner=self, session=session
622+
)
625623

626624
@property
627625
def context(self):
@@ -684,7 +682,7 @@ def getpeercert(self, binary_form=False):
684682
Return None if no certificate was provided, {} if a certificate was
685683
provided, but not validated.
686684
"""
687-
return self._sslobj.peer_certificate(binary_form)
685+
return self._sslobj.getpeercert(binary_form)
688686

689687
def selected_npn_protocol(self):
690688
"""Return the currently selected NPN protocol as a string, or ``None``
@@ -732,13 +730,7 @@ def get_channel_binding(self, cb_type="tls-unique"):
732730
"""Get channel binding data for current connection. Raise ValueError
733731
if the requested `cb_type` is not supported. Return bytes of the data
734732
or None if the data is not available (e.g. before the handshake)."""
735-
if cb_type not in CHANNEL_BINDING_TYPES:
736-
raise ValueError("Unsupported channel binding type")
737-
if cb_type != "tls-unique":
738-
raise NotImplementedError(
739-
"{0} channel binding type not implemented"
740-
.format(cb_type))
741-
return self._sslobj.tls_unique_cb()
733+
return self._sslobj.get_channel_binding(cb_type)
742734

743735
def version(self):
744736
"""Return a string identifying the protocol version used by the
@@ -832,10 +824,10 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
832824
if connected:
833825
# create the SSL object
834826
try:
835-
sslobj = self._context._wrap_socket(self, server_side,
836-
self.server_hostname)
837-
self._sslobj = SSLObject(sslobj, owner=self,
838-
session=self._session)
827+
self._sslobj = self._context._wrap_socket(
828+
self, server_side, self.server_hostname,
829+
owner=self, session=self._session,
830+
)
839831
if do_handshake_on_connect:
840832
timeout = self.gettimeout()
841833
if timeout == 0.0:
@@ -895,10 +887,13 @@ def read(self, len=1024, buffer=None):
895887
Return zero-length string on EOF."""
896888

897889
self._checkClosed()
898-
if not self._sslobj:
890+
if self._sslobj is None:
899891
raise ValueError("Read on closed or unwrapped SSL socket.")
900892
try:
901-
return self._sslobj.read(len, buffer)
893+
if buffer is not None:
894+
return self._sslobj.read(len, buffer)
895+
else:
896+
return self._sslobj.read(len)
902897
except SSLError as x:
903898
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
904899
if buffer is not None:
@@ -913,7 +908,7 @@ def write(self, data):
913908
number of bytes of DATA actually transmitted."""
914909

915910
self._checkClosed()
916-
if not self._sslobj:
911+
if self._sslobj is None:
917912
raise ValueError("Write on closed or unwrapped SSL socket.")
918913
return self._sslobj.write(data)
919914

@@ -929,41 +924,42 @@ def getpeercert(self, binary_form=False):
929924

930925
def selected_npn_protocol(self):
931926
self._checkClosed()
932-
if not self._sslobj or not _ssl.HAS_NPN:
927+
if self._sslobj is None or not _ssl.HAS_NPN:
933928
return None
934929
else:
935930
return self._sslobj.selected_npn_protocol()
936931

937932
def selected_alpn_protocol(self):
938933
self._checkClosed()
939-
if not self._sslobj or not _ssl.HAS_ALPN:
934+
if self._sslobj is None or not _ssl.HAS_ALPN:
940935
return None
941936
else:
942937
return self._sslobj.selected_alpn_protocol()
943938

944939
def cipher(self):
945940
self._checkClosed()
946-
if not self._sslobj:
941+
if self._sslobj is None:
947942
return None
948943
else:
949944
return self._sslobj.cipher()
950945

951946
def shared_ciphers(self):
952947
self._checkClosed()
953-
if not self._sslobj:
948+
if self._sslobj is None:
954949
return None
955-
return self._sslobj.shared_ciphers()
950+
else:
951+
return self._sslobj.shared_ciphers()
956952

957953
def compression(self):
958954
self._checkClosed()
959-
if not self._sslobj:
955+
if self._sslobj is None:
960956
return None
961957
else:
962958
return self._sslobj.compression()
963959

964960
def send(self, data, flags=0):
965961
self._checkClosed()
966-
if self._sslobj:
962+
if self._sslobj is not None:
967963
if flags != 0:
968964
raise ValueError(
969965
"non-zero flags not allowed in calls to send() on %s" %
@@ -974,7 +970,7 @@ def send(self, data, flags=0):
974970

975971
def sendto(self, data, flags_or_addr, addr=None):
976972
self._checkClosed()
977-
if self._sslobj:
973+
if self._sslobj is not None:
978974
raise ValueError("sendto not allowed on instances of %s" %
979975
self.__class__)
980976
elif addr is None:
@@ -990,7 +986,7 @@ def sendmsg(self, *args, **kwargs):
990986

991987
def sendall(self, data, flags=0):
992988
self._checkClosed()
993-
if self._sslobj:
989+
if self._sslobj is not None:
994990
if flags != 0:
995991
raise ValueError(
996992
"non-zero flags not allowed in calls to sendall() on %s" %
@@ -1008,15 +1004,15 @@ def sendfile(self, file, offset=0, count=None):
10081004
"""Send a file, possibly by using os.sendfile() if this is a
10091005
clear-text socket. Return the total number of bytes sent.
10101006
"""
1011-
if self._sslobj is None:
1007+
if self._sslobj is not None:
1008+
return self._sendfile_use_send(file, offset, count)
1009+
else:
10121010
# os.sendfile() works with plain sockets only
10131011
return super().sendfile(file, offset, count)
1014-
else:
1015-
return self._sendfile_use_send(file, offset, count)
10161012

10171013
def recv(self, buflen=1024, flags=0):
10181014
self._checkClosed()
1019-
if self._sslobj:
1015+
if self._sslobj is not None:
10201016
if flags != 0:
10211017
raise ValueError(
10221018
"non-zero flags not allowed in calls to recv() on %s" %
@@ -1031,7 +1027,7 @@ def recv_into(self, buffer, nbytes=None, flags=0):
10311027
nbytes = len(buffer)
10321028
elif nbytes is None:
10331029
nbytes = 1024
1034-
if self._sslobj:
1030+
if self._sslobj is not None:
10351031
if flags != 0:
10361032
raise ValueError(
10371033
"non-zero flags not allowed in calls to recv_into() on %s" %
@@ -1042,15 +1038,15 @@ def recv_into(self, buffer, nbytes=None, flags=0):
10421038

10431039
def recvfrom(self, buflen=1024, flags=0):
10441040
self._checkClosed()
1045-
if self._sslobj:
1041+
if self._sslobj is not None:
10461042
raise ValueError("recvfrom not allowed on instances of %s" %
10471043
self.__class__)
10481044
else:
10491045
return super().recvfrom(buflen, flags)
10501046

10511047
def recvfrom_into(self, buffer, nbytes=None, flags=0):
10521048
self._checkClosed()
1053-
if self._sslobj:
1049+
if self._sslobj is not None:
10541050
raise ValueError("recvfrom_into not allowed on instances of %s" %
10551051
self.__class__)
10561052
else:
@@ -1066,7 +1062,7 @@ def recvmsg_into(self, *args, **kwargs):
10661062

10671063
def pending(self):
10681064
self._checkClosed()
1069-
if self._sslobj:
1065+
if self._sslobj is not None:
10701066
return self._sslobj.pending()
10711067
else:
10721068
return 0
@@ -1078,7 +1074,7 @@ def shutdown(self, how):
10781074

10791075
def unwrap(self):
10801076
if self._sslobj:
1081-
s = self._sslobj.unwrap()
1077+
s = self._sslobj.shutdown()
10821078
self._sslobj = None
10831079
return s
10841080
else:
@@ -1096,6 +1092,11 @@ def do_handshake(self, block=False):
10961092
if timeout == 0.0 and block:
10971093
self.settimeout(None)
10981094
self._sslobj.do_handshake()
1095+
if self.context.check_hostname:
1096+
if not self.server_hostname:
1097+
raise ValueError("check_hostname needs server_hostname "
1098+
"argument")
1099+
match_hostname(self.getpeercert(), self.server_hostname)
10991100
finally:
11001101
self.settimeout(timeout)
11011102

@@ -1104,11 +1105,12 @@ def _real_connect(self, addr, connect_ex):
11041105
raise ValueError("can't connect in server-side mode")
11051106
# Here we assume that the socket is client-side, and not
11061107
# connected at the time of the call. We connect it, then wrap it.
1107-
if self._connected:
1108+
if self._connected or self._sslobj is not None:
11081109
raise ValueError("attempt to connect already-connected SSLSocket!")
1109-
sslobj = self.context._wrap_socket(self, False, self.server_hostname)
1110-
self._sslobj = SSLObject(sslobj, owner=self,
1111-
session=self._session)
1110+
self._sslobj = self.context._wrap_socket(
1111+
self, False, self.server_hostname,
1112+
owner=self, session=self._session
1113+
)
11121114
try:
11131115
if connect_ex:
11141116
rc = super().connect_ex(addr)
@@ -1151,18 +1153,24 @@ def get_channel_binding(self, cb_type="tls-unique"):
11511153
if the requested `cb_type` is not supported. Return bytes of the data
11521154
or None if the data is not available (e.g. before the handshake).
11531155
"""
1154-
if self._sslobj is None:
1156+
if self._sslobj is not None:
1157+
return self._sslobj.get_channel_binding(cb_type)
1158+
else:
1159+
if cb_type not in CHANNEL_BINDING_TYPES:
1160+
raise ValueError(
1161+
"{0} channel binding type not implemented".format(cb_type)
1162+
)
11551163
return None
1156-
return self._sslobj.get_channel_binding(cb_type)
11571164

11581165
def version(self):
11591166
"""
11601167
Return a string identifying the protocol version used by the
11611168
current SSL channel, or None if there is no established channel.
11621169
"""
1163-
if self._sslobj is None:
1170+
if self._sslobj is not None:
1171+
return self._sslobj.version()
1172+
else:
11641173
return None
1165-
return self._sslobj.version()
11661174

11671175

11681176
# Python does not support forward declaration of types.

Lib/test/test_ssl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,8 @@ def test_wrapped_unconnected(self):
455455
self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
456456
self.assertRaises(OSError, ss.send, b'x')
457457
self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
458+
self.assertRaises(NotImplementedError, ss.sendmsg,
459+
[b'x'], (), 0, ('0.0.0.0', 0))
458460

459461
def test_timeout(self):
460462
# Issue #8524: when creating an SSL socket, the timeout of the
@@ -3381,11 +3383,13 @@ def test_version_basic(self):
33813383
chatty=False) as server:
33823384
with context.wrap_socket(socket.socket()) as s:
33833385
self.assertIs(s.version(), None)
3386+
self.assertIs(s._sslobj, None)
33843387
s.connect((HOST, server.port))
33853388
if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
33863389
self.assertEqual(s.version(), 'TLSv1.2')
33873390
else: # 0.9.8 to 1.0.1
33883391
self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
3392+
self.assertIs(s._sslobj, None)
33893393
self.assertIs(s.version(), None)
33903394

33913395
@unittest.skipUnless(ssl.HAS_TLSv1_3,
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Internal implementation details of ssl module were cleaned up. The SSLSocket
2+
has one less layer of indirection. Owner and session information are now
3+
handled by the SSLSocket and SSLObject constructor. Channel binding
4+
implementation has been simplified.

0 commit comments

Comments
 (0)