166
166
167
167
socket_error = OSError # keep that public name in module namespace
168
168
169
- if _ssl .HAS_TLS_UNIQUE :
170
- CHANNEL_BINDING_TYPES = ['tls-unique' ]
171
- else :
172
- CHANNEL_BINDING_TYPES = []
169
+ CHANNEL_BINDING_TYPES = ['tls-unique' ]
173
170
174
171
HAS_NEVER_CHECK_COMMON_NAME = hasattr (_ssl , 'HOSTFLAG_NEVER_CHECK_SUBJECT' )
175
172
@@ -407,11 +404,11 @@ def wrap_bio(self, incoming, outgoing, server_side=False,
407
404
server_hostname = None , session = None ):
408
405
# Need to encode server_hostname here because _wrap_bio() can only
409
406
# handle ASCII str.
410
- sslobj = self ._wrap_bio (
407
+ return self .sslobject_class (
411
408
incoming , outgoing , server_side = server_side ,
412
- server_hostname = self ._encode_hostname (server_hostname )
409
+ server_hostname = self ._encode_hostname (server_hostname ),
410
+ session = session , _context = self ,
413
411
)
414
- return self .sslobject_class (sslobj , session = session )
415
412
416
413
def set_npn_protocols (self , npn_protocols ):
417
414
protos = bytearray ()
@@ -616,12 +613,13 @@ class SSLObject:
616
613
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
617
614
"""
618
615
619
- def __init__ (self , sslobj , owner = None , session = None ):
620
- self ._sslobj = sslobj
621
- # Note: _sslobj takes a weak reference to owner
622
- self ._sslobj .owner = owner or self
623
- if session is not None :
624
- self ._sslobj .session = session
616
+ def __init__ (self , incoming , outgoing , server_side = False ,
617
+ server_hostname = None , session = None , _context = None ):
618
+ self ._sslobj = _context ._wrap_bio (
619
+ incoming , outgoing , server_side = server_side ,
620
+ server_hostname = server_hostname ,
621
+ owner = self , session = session
622
+ )
625
623
626
624
@property
627
625
def context (self ):
@@ -684,7 +682,7 @@ def getpeercert(self, binary_form=False):
684
682
Return None if no certificate was provided, {} if a certificate was
685
683
provided, but not validated.
686
684
"""
687
- return self ._sslobj .peer_certificate (binary_form )
685
+ return self ._sslobj .getpeercert (binary_form )
688
686
689
687
def selected_npn_protocol (self ):
690
688
"""Return the currently selected NPN protocol as a string, or ``None``
@@ -732,13 +730,7 @@ def get_channel_binding(self, cb_type="tls-unique"):
732
730
"""Get channel binding data for current connection. Raise ValueError
733
731
if the requested `cb_type` is not supported. Return bytes of the data
734
732
or None if the data is not available (e.g. before the handshake)."""
735
- if cb_type not in CHANNEL_BINDING_TYPES :
736
- raise ValueError ("Unsupported channel binding type" )
737
- if cb_type != "tls-unique" :
738
- raise NotImplementedError (
739
- "{0} channel binding type not implemented"
740
- .format (cb_type ))
741
- return self ._sslobj .tls_unique_cb ()
733
+ return self ._sslobj .get_channel_binding (cb_type )
742
734
743
735
def version (self ):
744
736
"""Return a string identifying the protocol version used by the
@@ -832,10 +824,10 @@ def __init__(self, sock=None, keyfile=None, certfile=None,
832
824
if connected :
833
825
# create the SSL object
834
826
try :
835
- sslobj = self ._context ._wrap_socket (self , server_side ,
836
- self .server_hostname )
837
- self . _sslobj = SSLObject ( sslobj , owner = self ,
838
- session = self . _session )
827
+ self . _sslobj = self ._context ._wrap_socket (
828
+ self , server_side , self .server_hostname ,
829
+ owner = self , session = self . _session ,
830
+ )
839
831
if do_handshake_on_connect :
840
832
timeout = self .gettimeout ()
841
833
if timeout == 0.0 :
@@ -895,10 +887,13 @@ def read(self, len=1024, buffer=None):
895
887
Return zero-length string on EOF."""
896
888
897
889
self ._checkClosed ()
898
- if not self ._sslobj :
890
+ if self ._sslobj is None :
899
891
raise ValueError ("Read on closed or unwrapped SSL socket." )
900
892
try :
901
- return self ._sslobj .read (len , buffer )
893
+ if buffer is not None :
894
+ return self ._sslobj .read (len , buffer )
895
+ else :
896
+ return self ._sslobj .read (len )
902
897
except SSLError as x :
903
898
if x .args [0 ] == SSL_ERROR_EOF and self .suppress_ragged_eofs :
904
899
if buffer is not None :
@@ -913,7 +908,7 @@ def write(self, data):
913
908
number of bytes of DATA actually transmitted."""
914
909
915
910
self ._checkClosed ()
916
- if not self ._sslobj :
911
+ if self ._sslobj is None :
917
912
raise ValueError ("Write on closed or unwrapped SSL socket." )
918
913
return self ._sslobj .write (data )
919
914
@@ -929,41 +924,42 @@ def getpeercert(self, binary_form=False):
929
924
930
925
def selected_npn_protocol (self ):
931
926
self ._checkClosed ()
932
- if not self ._sslobj or not _ssl .HAS_NPN :
927
+ if self ._sslobj is None or not _ssl .HAS_NPN :
933
928
return None
934
929
else :
935
930
return self ._sslobj .selected_npn_protocol ()
936
931
937
932
def selected_alpn_protocol (self ):
938
933
self ._checkClosed ()
939
- if not self ._sslobj or not _ssl .HAS_ALPN :
934
+ if self ._sslobj is None or not _ssl .HAS_ALPN :
940
935
return None
941
936
else :
942
937
return self ._sslobj .selected_alpn_protocol ()
943
938
944
939
def cipher (self ):
945
940
self ._checkClosed ()
946
- if not self ._sslobj :
941
+ if self ._sslobj is None :
947
942
return None
948
943
else :
949
944
return self ._sslobj .cipher ()
950
945
951
946
def shared_ciphers (self ):
952
947
self ._checkClosed ()
953
- if not self ._sslobj :
948
+ if self ._sslobj is None :
954
949
return None
955
- return self ._sslobj .shared_ciphers ()
950
+ else :
951
+ return self ._sslobj .shared_ciphers ()
956
952
957
953
def compression (self ):
958
954
self ._checkClosed ()
959
- if not self ._sslobj :
955
+ if self ._sslobj is None :
960
956
return None
961
957
else :
962
958
return self ._sslobj .compression ()
963
959
964
960
def send (self , data , flags = 0 ):
965
961
self ._checkClosed ()
966
- if self ._sslobj :
962
+ if self ._sslobj is not None :
967
963
if flags != 0 :
968
964
raise ValueError (
969
965
"non-zero flags not allowed in calls to send() on %s" %
@@ -974,7 +970,7 @@ def send(self, data, flags=0):
974
970
975
971
def sendto (self , data , flags_or_addr , addr = None ):
976
972
self ._checkClosed ()
977
- if self ._sslobj :
973
+ if self ._sslobj is not None :
978
974
raise ValueError ("sendto not allowed on instances of %s" %
979
975
self .__class__ )
980
976
elif addr is None :
@@ -990,7 +986,7 @@ def sendmsg(self, *args, **kwargs):
990
986
991
987
def sendall (self , data , flags = 0 ):
992
988
self ._checkClosed ()
993
- if self ._sslobj :
989
+ if self ._sslobj is not None :
994
990
if flags != 0 :
995
991
raise ValueError (
996
992
"non-zero flags not allowed in calls to sendall() on %s" %
@@ -1008,15 +1004,15 @@ def sendfile(self, file, offset=0, count=None):
1008
1004
"""Send a file, possibly by using os.sendfile() if this is a
1009
1005
clear-text socket. Return the total number of bytes sent.
1010
1006
"""
1011
- if self ._sslobj is None :
1007
+ if self ._sslobj is not None :
1008
+ return self ._sendfile_use_send (file , offset , count )
1009
+ else :
1012
1010
# os.sendfile() works with plain sockets only
1013
1011
return super ().sendfile (file , offset , count )
1014
- else :
1015
- return self ._sendfile_use_send (file , offset , count )
1016
1012
1017
1013
def recv (self , buflen = 1024 , flags = 0 ):
1018
1014
self ._checkClosed ()
1019
- if self ._sslobj :
1015
+ if self ._sslobj is not None :
1020
1016
if flags != 0 :
1021
1017
raise ValueError (
1022
1018
"non-zero flags not allowed in calls to recv() on %s" %
@@ -1031,7 +1027,7 @@ def recv_into(self, buffer, nbytes=None, flags=0):
1031
1027
nbytes = len (buffer )
1032
1028
elif nbytes is None :
1033
1029
nbytes = 1024
1034
- if self ._sslobj :
1030
+ if self ._sslobj is not None :
1035
1031
if flags != 0 :
1036
1032
raise ValueError (
1037
1033
"non-zero flags not allowed in calls to recv_into() on %s" %
@@ -1042,15 +1038,15 @@ def recv_into(self, buffer, nbytes=None, flags=0):
1042
1038
1043
1039
def recvfrom (self , buflen = 1024 , flags = 0 ):
1044
1040
self ._checkClosed ()
1045
- if self ._sslobj :
1041
+ if self ._sslobj is not None :
1046
1042
raise ValueError ("recvfrom not allowed on instances of %s" %
1047
1043
self .__class__ )
1048
1044
else :
1049
1045
return super ().recvfrom (buflen , flags )
1050
1046
1051
1047
def recvfrom_into (self , buffer , nbytes = None , flags = 0 ):
1052
1048
self ._checkClosed ()
1053
- if self ._sslobj :
1049
+ if self ._sslobj is not None :
1054
1050
raise ValueError ("recvfrom_into not allowed on instances of %s" %
1055
1051
self .__class__ )
1056
1052
else :
@@ -1066,7 +1062,7 @@ def recvmsg_into(self, *args, **kwargs):
1066
1062
1067
1063
def pending (self ):
1068
1064
self ._checkClosed ()
1069
- if self ._sslobj :
1065
+ if self ._sslobj is not None :
1070
1066
return self ._sslobj .pending ()
1071
1067
else :
1072
1068
return 0
@@ -1078,7 +1074,7 @@ def shutdown(self, how):
1078
1074
1079
1075
def unwrap (self ):
1080
1076
if self ._sslobj :
1081
- s = self ._sslobj .unwrap ()
1077
+ s = self ._sslobj .shutdown ()
1082
1078
self ._sslobj = None
1083
1079
return s
1084
1080
else :
@@ -1096,6 +1092,11 @@ def do_handshake(self, block=False):
1096
1092
if timeout == 0.0 and block :
1097
1093
self .settimeout (None )
1098
1094
self ._sslobj .do_handshake ()
1095
+ if self .context .check_hostname :
1096
+ if not self .server_hostname :
1097
+ raise ValueError ("check_hostname needs server_hostname "
1098
+ "argument" )
1099
+ match_hostname (self .getpeercert (), self .server_hostname )
1099
1100
finally :
1100
1101
self .settimeout (timeout )
1101
1102
@@ -1104,11 +1105,12 @@ def _real_connect(self, addr, connect_ex):
1104
1105
raise ValueError ("can't connect in server-side mode" )
1105
1106
# Here we assume that the socket is client-side, and not
1106
1107
# connected at the time of the call. We connect it, then wrap it.
1107
- if self ._connected :
1108
+ if self ._connected or self . _sslobj is not None :
1108
1109
raise ValueError ("attempt to connect already-connected SSLSocket!" )
1109
- sslobj = self .context ._wrap_socket (self , False , self .server_hostname )
1110
- self ._sslobj = SSLObject (sslobj , owner = self ,
1111
- session = self ._session )
1110
+ self ._sslobj = self .context ._wrap_socket (
1111
+ self , False , self .server_hostname ,
1112
+ owner = self , session = self ._session
1113
+ )
1112
1114
try :
1113
1115
if connect_ex :
1114
1116
rc = super ().connect_ex (addr )
@@ -1151,18 +1153,24 @@ def get_channel_binding(self, cb_type="tls-unique"):
1151
1153
if the requested `cb_type` is not supported. Return bytes of the data
1152
1154
or None if the data is not available (e.g. before the handshake).
1153
1155
"""
1154
- if self ._sslobj is None :
1156
+ if self ._sslobj is not None :
1157
+ return self ._sslobj .get_channel_binding (cb_type )
1158
+ else :
1159
+ if cb_type not in CHANNEL_BINDING_TYPES :
1160
+ raise ValueError (
1161
+ "{0} channel binding type not implemented" .format (cb_type )
1162
+ )
1155
1163
return None
1156
- return self ._sslobj .get_channel_binding (cb_type )
1157
1164
1158
1165
def version (self ):
1159
1166
"""
1160
1167
Return a string identifying the protocol version used by the
1161
1168
current SSL channel, or None if there is no established channel.
1162
1169
"""
1163
- if self ._sslobj is None :
1170
+ if self ._sslobj is not None :
1171
+ return self ._sslobj .version ()
1172
+ else :
1164
1173
return None
1165
- return self ._sslobj .version ()
1166
1174
1167
1175
1168
1176
# Python does not support forward declaration of types.
0 commit comments