35
35
36
36
37
37
if not sys .implementation .name == "circuitpython" :
38
- from typing import Optional , Tuple
38
+ from typing import List , Optional , Tuple
39
39
40
40
from circuitpython_typing .socket import (
41
41
CircuitPythonSocketType ,
@@ -71,8 +71,7 @@ class _FakeSSLContext:
71
71
def __init__ (self , iface : InterfaceType ) -> None :
72
72
self ._iface = iface
73
73
74
- # pylint: disable=unused-argument
75
- def wrap_socket (
74
+ def wrap_socket ( # pylint: disable=unused-argument
76
75
self , socket : CircuitPythonSocketType , server_hostname : Optional [str ] = None
77
76
) -> _FakeSSLSocket :
78
77
"""Return the same socket"""
@@ -184,54 +183,75 @@ def __init__(
184
183
) -> None :
185
184
self ._socket_pool = socket_pool
186
185
# Hang onto open sockets so that we can reuse them.
187
- self ._available_socket = {}
188
- self ._open_sockets = {}
186
+ self ._available_sockets = set ()
187
+ self ._managed_socket_by_key = {}
188
+ self ._managed_socket_by_socket = {}
189
189
190
190
def _free_sockets (self , force : bool = False ) -> None :
191
- available_sockets = []
192
- for socket , free in self ._available_socket .items ():
193
- if free or force :
194
- available_sockets .append (socket )
195
-
191
+ # cloning lists since items are being removed
192
+ available_sockets = list (self ._available_sockets )
196
193
for socket in available_sockets :
197
194
self .close_socket (socket )
195
+ if force :
196
+ open_sockets = list (self ._managed_socket_by_key .values ())
197
+ for socket in open_sockets :
198
+ self .close_socket (socket )
198
199
199
- def _get_key_for_socket (self , socket ):
200
+ def _get_connected_socket ( # pylint: disable=too-many-arguments
201
+ self ,
202
+ addr_info : List [Tuple [int , int , int , str , Tuple [str , int ]]],
203
+ host : str ,
204
+ port : int ,
205
+ timeout : float ,
206
+ is_ssl : bool ,
207
+ ssl_context : Optional [SSLContextType ] = None ,
208
+ ):
200
209
try :
201
- return next (
202
- key for key , value in self ._open_sockets .items () if value == socket
203
- )
204
- except StopIteration :
205
- return None
210
+ socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
211
+ except (OSError , RuntimeError ) as exc :
212
+ return exc
206
213
207
- @property
208
- def open_sockets (self ) -> int :
209
- """Get the count of open sockets"""
210
- return len (self ._open_sockets )
214
+ if is_ssl :
215
+ socket = ssl_context .wrap_socket (socket , server_hostname = host )
216
+ connect_host = host
217
+ else :
218
+ connect_host = addr_info [- 1 ][0 ]
219
+ socket .settimeout (timeout ) # socket read timeout
220
+
221
+ try :
222
+ socket .connect ((connect_host , port ))
223
+ except (MemoryError , OSError ) as exc :
224
+ socket .close ()
225
+ return exc
226
+
227
+ return socket
211
228
212
229
@property
213
- def freeable_open_sockets (self ) -> int :
230
+ def available_socket_count (self ) -> int :
214
231
"""Get the count of freeable open sockets"""
215
- return len (
216
- [socket for socket , free in self ._available_socket .items () if free is True ]
217
- )
232
+ return len (self ._available_sockets )
233
+
234
+ @property
235
+ def managed_socket_count (self ) -> int :
236
+ """Get the count of open sockets"""
237
+ return len (self ._managed_socket_by_key )
218
238
219
239
def close_socket (self , socket : SocketType ) -> None :
220
240
"""Close a previously opened socket."""
221
- if socket not in self ._open_sockets .values ():
241
+ if socket not in self ._managed_socket_by_key .values ():
222
242
raise RuntimeError ("Socket not managed" )
223
- key = self ._get_key_for_socket (socket )
224
243
socket .close ()
225
- del self ._available_socket [socket ]
226
- del self ._open_sockets [key ]
244
+ key = self ._managed_socket_by_socket .pop (socket )
245
+ del self ._managed_socket_by_key [key ]
246
+ if socket in self ._available_sockets :
247
+ self ._available_sockets .remove (socket )
227
248
228
249
def free_socket (self , socket : SocketType ) -> None :
229
250
"""Mark a previously opened socket as available so it can be reused if needed."""
230
- if socket not in self ._open_sockets .values ():
251
+ if socket not in self ._managed_socket_by_key .values ():
231
252
raise RuntimeError ("Socket not managed" )
232
- self ._available_socket [ socket ] = True
253
+ self ._available_sockets . add ( socket )
233
254
234
- # pylint: disable=too-many-branches,too-many-locals,too-many-statements
235
255
def get_socket (
236
256
self ,
237
257
host : str ,
@@ -247,10 +267,10 @@ def get_socket(
247
267
if session_id :
248
268
session_id = str (session_id )
249
269
key = (host , port , proto , session_id )
250
- if key in self ._open_sockets :
251
- socket = self ._open_sockets [key ]
252
- if self ._available_socket [ socket ] :
253
- self ._available_socket [ socket ] = False
270
+ if key in self ._managed_socket_by_key :
271
+ socket = self ._managed_socket_by_key [key ]
272
+ if socket in self ._available_sockets :
273
+ self ._available_sockets . remove ( socket )
254
274
return socket
255
275
256
276
raise RuntimeError (f"Socket already connected to { proto } //{ host } :{ port } " )
@@ -266,54 +286,22 @@ def get_socket(
266
286
host , port , 0 , self ._socket_pool .SOCK_STREAM
267
287
)[0 ]
268
288
269
- try_count = 0
270
- socket = None
271
- last_exc = None
272
- while try_count < 2 and socket is None :
273
- try_count += 1
274
- if try_count > 1 :
275
- if any (
276
- socket
277
- for socket , free in self ._available_socket .items ()
278
- if free is True
279
- ):
280
- self ._free_sockets ()
281
- else :
282
- break
283
-
284
- try :
285
- socket = self ._socket_pool .socket (addr_info [0 ], addr_info [1 ])
286
- except OSError as exc :
287
- last_exc = exc
288
- continue
289
- except RuntimeError as exc :
290
- last_exc = exc
291
- continue
292
-
293
- if is_ssl :
294
- socket = ssl_context .wrap_socket (socket , server_hostname = host )
295
- connect_host = host
296
- else :
297
- connect_host = addr_info [- 1 ][0 ]
298
- socket .settimeout (timeout ) # socket read timeout
299
-
300
- try :
301
- socket .connect ((connect_host , port ))
302
- except MemoryError as exc :
303
- last_exc = exc
304
- socket .close ()
305
- socket = None
306
- except OSError as exc :
307
- last_exc = exc
308
- socket .close ()
309
- socket = None
310
-
311
- if socket is None :
312
- raise RuntimeError (f"Error connecting socket: { last_exc } " ) from last_exc
313
-
314
- self ._available_socket [socket ] = False
315
- self ._open_sockets [key ] = socket
316
- return socket
289
+ result = self ._get_connected_socket (
290
+ addr_info , host , port , timeout , is_ssl , ssl_context
291
+ )
292
+ if isinstance (result , Exception ):
293
+ # Got an error, if there are any available sockets, free them and try again
294
+ if self .available_socket_count :
295
+ self ._free_sockets ()
296
+ result = self ._get_connected_socket (
297
+ addr_info , host , port , timeout , is_ssl , ssl_context
298
+ )
299
+ if isinstance (result , Exception ):
300
+ raise RuntimeError (f"Error connecting socket: { result } " ) from result
301
+
302
+ self ._managed_socket_by_key [key ] = result
303
+ self ._managed_socket_by_socket [result ] = key
304
+ return result
317
305
318
306
319
307
# global helpers
0 commit comments