27
27
*/
28
28
29
29
#include "shared-bindings/ssl/SSLSocket.h"
30
- #include "shared-bindings/socketpool/Socket.h"
31
30
#include "shared-bindings/ssl/SSLContext.h"
32
- #include "shared-bindings/socketpool/SocketPool.h"
33
- #include "shared-bindings/socketpool/Socket.h"
34
31
35
32
#include "shared/runtime/interrupt_char.h"
33
+ #include "shared/netutils/netutils.h"
36
34
#include "py/mperrno.h"
37
35
#include "py/mphal.h"
38
36
#include "py/objstr.h"
@@ -104,11 +102,73 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
104
102
#endif
105
103
}
106
104
105
+ STATIC int call_method_errno (size_t n_args , const mp_obj_t * args ) {
106
+ nlr_buf_t nlr ;
107
+ mp_int_t result = - MP_EINVAL ;
108
+ if (nlr_push (& nlr ) == 0 ) {
109
+ mp_obj_t obj_result = mp_call_method_n_kw (n_args , 0 , args );
110
+ result = (obj_result == mp_const_none ) ? 0 : mp_obj_get_int (obj_result );
111
+ nlr_pop ();
112
+ return result ;
113
+ } else {
114
+ mp_obj_t exc = MP_OBJ_FROM_PTR (nlr .ret_val );
115
+ mp_obj_print_exception (& mp_plat_print , MP_OBJ_FROM_PTR (nlr .ret_val ));
116
+ if (nlr_push (& nlr ) == 0 ) {
117
+ result = - mp_obj_get_int (mp_load_attr (exc , MP_QSTR_errno ));
118
+ nlr_pop ();
119
+ }
120
+ }
121
+ return result ;
122
+ }
123
+
124
+ static int ssl_socket_send (ssl_sslsocket_obj_t * self , const byte * buf , size_t len ) {
125
+ mp_obj_array_t mv ;
126
+ mp_obj_memoryview_init (& mv , 'B' , 0 , len , (void * )buf );
127
+
128
+ self -> send_args [2 ] = MP_OBJ_FROM_PTR (& mv );
129
+ return call_method_errno (1 , self -> send_args );
130
+ }
131
+
132
+ static int ssl_socket_recv_into (ssl_sslsocket_obj_t * self , byte * buf , size_t len ) {
133
+ mp_obj_array_t mv ;
134
+ mp_obj_memoryview_init (& mv , 'B' | MP_OBJ_ARRAY_TYPECODE_FLAG_RW , 0 , len , buf );
135
+
136
+ self -> recv_into_args [2 ] = MP_OBJ_FROM_PTR (& mv );
137
+ return call_method_errno (1 , self -> recv_into_args );
138
+ }
139
+
140
+ static int ssl_socket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
141
+ self -> connect_args [2 ] = addr_in ;
142
+ return call_method_errno (1 , self -> connect_args );
143
+ }
144
+
145
+ static int ssl_socket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
146
+ self -> bind_args [2 ] = addr_in ;
147
+ return call_method_errno (1 , self -> bind_args );
148
+ }
149
+
150
+ static int ssl_socket_close (ssl_sslsocket_obj_t * self ) {
151
+ return call_method_errno (0 , self -> close_args );
152
+ }
153
+
154
+ static int ssl_socket_settimeout (ssl_sslsocket_obj_t * self , mp_int_t timeout_ms ) {
155
+ self -> settimeout_args [2 ] = mp_obj_new_float (timeout_ms * MICROPY_FLOAT_CONST (1e-3 ));
156
+ return call_method_errno (1 , self -> settimeout_args );
157
+ }
158
+
159
+ static int ssl_socket_listen (ssl_sslsocket_obj_t * self , mp_int_t backlog ) {
160
+ self -> listen_args [2 ] = MP_OBJ_NEW_SMALL_INT (backlog );
161
+ return call_method_errno (1 , self -> listen_args );
162
+ }
163
+
164
+ static mp_obj_t ssl_socket_accept (ssl_sslsocket_obj_t * self ) {
165
+ return mp_call_method_n_kw (0 , 0 , self -> accept_args );
166
+ }
167
+
107
168
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
108
- mp_obj_t sock = * ( mp_obj_t * )ctx ;
169
+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
109
170
110
- // mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
111
- mp_int_t out_sz = socketpool_socket_send (sock , buf , len );
171
+ mp_int_t out_sz = ssl_socket_send (self , buf , len );
112
172
DEBUG_PRINT ("socket_send() -> %d" , out_sz );
113
173
if (out_sz < 0 ) {
114
174
int err = - out_sz ;
@@ -124,9 +184,9 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
124
184
125
185
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
126
186
STATIC int _mbedtls_ssl_recv (void * ctx , byte * buf , size_t len ) {
127
- mp_obj_t sock = * ( mp_obj_t * )ctx ;
187
+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
128
188
129
- mp_int_t out_sz = socketpool_socket_recv_into ( sock , buf , len );
189
+ mp_int_t out_sz = ssl_socket_recv_into ( self , buf , len );
130
190
DEBUG_PRINT ("socket_recv() -> %d" , out_sz );
131
191
if (out_sz < 0 ) {
132
192
int err = - out_sz ;
@@ -151,16 +211,26 @@ static int urandom_adapter(void *unused, unsigned char *buf, size_t n) {
151
211
#endif
152
212
153
213
ssl_sslsocket_obj_t * common_hal_ssl_sslcontext_wrap_socket (ssl_sslcontext_obj_t * self ,
154
- socketpool_socket_obj_t * socket , bool server_side , const char * server_hostname ) {
214
+ mp_obj_t socket , bool server_side , const char * server_hostname ) {
155
215
156
- if (socket -> type != SOCKETPOOL_SOCK_STREAM ) {
216
+ mp_int_t socket_type = mp_obj_get_int (mp_load_attr (socket , MP_QSTR_type ));
217
+ if (socket_type != SOCKETPOOL_SOCK_STREAM ) {
157
218
mp_raise_RuntimeError (MP_ERROR_TEXT ("Invalid socket for TLS" ));
158
219
}
159
220
160
221
ssl_sslsocket_obj_t * o = m_new_obj_with_finaliser (ssl_sslsocket_obj_t );
161
222
o -> base .type = & ssl_sslsocket_type ;
162
223
o -> ssl_context = self ;
163
- o -> sock = socket ;
224
+ o -> sock_obj = socket ;
225
+
226
+ mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
227
+ mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
228
+ mp_load_method (socket , MP_QSTR_close , o -> close_args );
229
+ mp_load_method (socket , MP_QSTR_connect , o -> connect_args );
230
+ mp_load_method (socket , MP_QSTR_listen , o -> listen_args );
231
+ mp_load_method (socket , MP_QSTR_recv_into , o -> recv_into_args );
232
+ mp_load_method (socket , MP_QSTR_send , o -> send_args );
233
+ mp_load_method (socket , MP_QSTR_settimeout , o -> settimeout_args );
164
234
165
235
mbedtls_ssl_init (& o -> ssl );
166
236
mbedtls_ssl_config_init (& o -> conf );
@@ -219,7 +289,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
219
289
}
220
290
}
221
291
222
- mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
292
+ mbedtls_ssl_set_bio (& o -> ssl , o , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
223
293
224
294
if (self -> cert_buf .buf != NULL ) {
225
295
#if MBEDTLS_VERSION_MAJOR >= 3
@@ -304,13 +374,13 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
304
374
mp_raise_OSError (ret );
305
375
}
306
376
307
- size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
308
- return common_hal_socketpool_socket_bind (self -> sock , host , hostlen , port );
377
+ size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
378
+ return ssl_socket_bind (self , addr_in );
309
379
}
310
380
311
381
void common_hal_ssl_sslsocket_close (ssl_sslsocket_obj_t * self ) {
312
382
self -> closed = true;
313
- common_hal_socketpool_socket_close (self -> sock );
383
+ ssl_socket_close (self );
314
384
mbedtls_pk_free (& self -> pkey );
315
385
mbedtls_x509_crt_free (& self -> cert );
316
386
mbedtls_x509_crt_free (& self -> cacert );
@@ -356,8 +426,8 @@ STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
356
426
}
357
427
}
358
428
359
- void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
360
- common_hal_socketpool_socket_connect (self -> sock , host , hostlen , port );
429
+ void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
430
+ ssl_socket_connect (self , addr_in );
361
431
do_handshake (self );
362
432
}
363
433
@@ -370,16 +440,21 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
370
440
}
371
441
372
442
bool common_hal_ssl_sslsocket_listen (ssl_sslsocket_obj_t * self , int backlog ) {
373
- return common_hal_socketpool_socket_listen (self -> sock , backlog );
443
+ return ssl_socket_listen (self , backlog );
374
444
}
375
445
376
- ssl_sslsocket_obj_t * common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self , uint8_t * ip , uint32_t * port ) {
377
- socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept (self -> sock , ip , port );
446
+ mp_obj_t common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self ) {
447
+ mp_obj_t accepted = ssl_socket_accept (self );
448
+ mp_obj_t sock = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
378
449
ssl_sslsocket_obj_t * sslsock = common_hal_ssl_sslcontext_wrap_socket (self -> ssl_context , sock , true, NULL );
379
450
do_handshake (sslsock );
380
- return sslsock ;
451
+ mp_obj_t peer = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
452
+ mp_obj_t tuple_contents [2 ];
453
+ tuple_contents [0 ] = MP_OBJ_FROM_PTR (sslsock );
454
+ tuple_contents [1 ] = peer ;
455
+ return mp_obj_new_tuple (2 , tuple_contents );
381
456
}
382
457
383
458
void common_hal_ssl_sslsocket_settimeout (ssl_sslsocket_obj_t * self , uint32_t timeout_ms ) {
384
- common_hal_socketpool_socket_settimeout (self -> sock , timeout_ms );
459
+ ssl_socket_settimeout (self , timeout_ms );
385
460
}
0 commit comments