27
27
*/
28
28
29
29
#include "shared-bindings/ssl/SSLSocket.h"
30
- #include "shared-bindings/socketpool/Socket.h"
31
30
#include "shared-bindings/ssl/SSLContext.h"
32
- #include "shared-bindings/socketpool/SocketPool.h"
33
- #include "shared-bindings/socketpool/Socket.h"
34
31
35
32
#include "shared/runtime/interrupt_char.h"
33
+ #include "shared/netutils/netutils.h"
36
34
#include "py/mperrno.h"
37
35
#include "py/mphal.h"
38
36
#include "py/objstr.h"
@@ -104,11 +102,72 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
104
102
#endif
105
103
}
106
104
105
+ STATIC int call_method_errno (size_t n_args , const mp_obj_t * args ) {
106
+ nlr_buf_t nlr ;
107
+ mp_int_t result = - MP_EINVAL ;
108
+ if (nlr_push (& nlr ) == 0 ) {
109
+ mp_obj_t obj_result = mp_call_method_n_kw (n_args , 0 , args );
110
+ result = (obj_result == mp_const_none ) ? 0 : mp_obj_get_int (obj_result );
111
+ nlr_pop ();
112
+ return result ;
113
+ } else {
114
+ mp_obj_t exc = MP_OBJ_FROM_PTR (nlr .ret_val );
115
+ if (nlr_push (& nlr ) == 0 ) {
116
+ result = - mp_obj_get_int (mp_load_attr (exc , MP_QSTR_errno ));
117
+ nlr_pop ();
118
+ }
119
+ }
120
+ return result ;
121
+ }
122
+
123
+ static int ssl_socket_send (ssl_sslsocket_obj_t * self , const byte * buf , size_t len ) {
124
+ mp_obj_array_t mv ;
125
+ mp_obj_memoryview_init (& mv , 'B' , 0 , len , (void * )buf );
126
+
127
+ self -> send_args [2 ] = MP_OBJ_FROM_PTR (& mv );
128
+ return call_method_errno (1 , self -> send_args );
129
+ }
130
+
131
+ static int ssl_socket_recv_into (ssl_sslsocket_obj_t * self , byte * buf , size_t len ) {
132
+ mp_obj_array_t mv ;
133
+ mp_obj_memoryview_init (& mv , 'B' | MP_OBJ_ARRAY_TYPECODE_FLAG_RW , 0 , len , buf );
134
+
135
+ self -> recv_into_args [2 ] = MP_OBJ_FROM_PTR (& mv );
136
+ return call_method_errno (1 , self -> recv_into_args );
137
+ }
138
+
139
+ static int ssl_socket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
140
+ self -> connect_args [2 ] = addr_in ;
141
+ return call_method_errno (1 , self -> connect_args );
142
+ }
143
+
144
+ static int ssl_socket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
145
+ self -> bind_args [2 ] = addr_in ;
146
+ return call_method_errno (1 , self -> bind_args );
147
+ }
148
+
149
+ static int ssl_socket_close (ssl_sslsocket_obj_t * self ) {
150
+ return call_method_errno (0 , self -> close_args );
151
+ }
152
+
153
+ static int ssl_socket_settimeout (ssl_sslsocket_obj_t * self , mp_int_t timeout_ms ) {
154
+ self -> settimeout_args [2 ] = mp_obj_new_float (timeout_ms * MICROPY_FLOAT_CONST (1e-3 ));
155
+ return call_method_errno (1 , self -> settimeout_args );
156
+ }
157
+
158
+ static int ssl_socket_listen (ssl_sslsocket_obj_t * self , mp_int_t backlog ) {
159
+ self -> listen_args [2 ] = MP_OBJ_NEW_SMALL_INT (backlog );
160
+ return call_method_errno (1 , self -> listen_args );
161
+ }
162
+
163
+ static mp_obj_t ssl_socket_accept (ssl_sslsocket_obj_t * self ) {
164
+ return mp_call_method_n_kw (0 , 0 , self -> accept_args );
165
+ }
166
+
107
167
STATIC int _mbedtls_ssl_send (void * ctx , const byte * buf , size_t len ) {
108
- mp_obj_t sock = * ( mp_obj_t * )ctx ;
168
+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
109
169
110
- // mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
111
- mp_int_t out_sz = socketpool_socket_send (sock , buf , len );
170
+ mp_int_t out_sz = ssl_socket_send (self , buf , len );
112
171
DEBUG_PRINT ("socket_send() -> %d" , out_sz );
113
172
if (out_sz < 0 ) {
114
173
int err = - out_sz ;
@@ -124,9 +183,9 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
124
183
125
184
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
126
185
STATIC int _mbedtls_ssl_recv (void * ctx , byte * buf , size_t len ) {
127
- mp_obj_t sock = * ( mp_obj_t * )ctx ;
186
+ ssl_sslsocket_obj_t * self = ( ssl_sslsocket_obj_t * )ctx ;
128
187
129
- mp_int_t out_sz = socketpool_socket_recv_into ( sock , buf , len );
188
+ mp_int_t out_sz = ssl_socket_recv_into ( self , buf , len );
130
189
DEBUG_PRINT ("socket_recv() -> %d" , out_sz );
131
190
if (out_sz < 0 ) {
132
191
int err = - out_sz ;
@@ -151,16 +210,26 @@ static int urandom_adapter(void *unused, unsigned char *buf, size_t n) {
151
210
#endif
152
211
153
212
ssl_sslsocket_obj_t * common_hal_ssl_sslcontext_wrap_socket (ssl_sslcontext_obj_t * self ,
154
- socketpool_socket_obj_t * socket , bool server_side , const char * server_hostname ) {
213
+ mp_obj_t socket , bool server_side , const char * server_hostname ) {
155
214
156
- if (socket -> type != SOCKETPOOL_SOCK_STREAM ) {
215
+ mp_int_t socket_type = mp_obj_get_int (mp_load_attr (socket , MP_QSTR_type ));
216
+ if (socket_type != SOCKETPOOL_SOCK_STREAM ) {
157
217
mp_raise_RuntimeError (MP_ERROR_TEXT ("Invalid socket for TLS" ));
158
218
}
159
219
160
220
ssl_sslsocket_obj_t * o = m_new_obj_with_finaliser (ssl_sslsocket_obj_t );
161
221
o -> base .type = & ssl_sslsocket_type ;
162
222
o -> ssl_context = self ;
163
- o -> sock = socket ;
223
+ o -> sock_obj = socket ;
224
+
225
+ mp_load_method (socket , MP_QSTR_accept , o -> accept_args );
226
+ mp_load_method (socket , MP_QSTR_bind , o -> bind_args );
227
+ mp_load_method (socket , MP_QSTR_close , o -> close_args );
228
+ mp_load_method (socket , MP_QSTR_connect , o -> connect_args );
229
+ mp_load_method (socket , MP_QSTR_listen , o -> listen_args );
230
+ mp_load_method (socket , MP_QSTR_recv_into , o -> recv_into_args );
231
+ mp_load_method (socket , MP_QSTR_send , o -> send_args );
232
+ mp_load_method (socket , MP_QSTR_settimeout , o -> settimeout_args );
164
233
165
234
mbedtls_ssl_init (& o -> ssl );
166
235
mbedtls_ssl_config_init (& o -> conf );
@@ -219,7 +288,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
219
288
}
220
289
}
221
290
222
- mbedtls_ssl_set_bio (& o -> ssl , & o -> sock , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
291
+ mbedtls_ssl_set_bio (& o -> ssl , o , _mbedtls_ssl_send , _mbedtls_ssl_recv , NULL );
223
292
224
293
if (self -> cert_buf .buf != NULL ) {
225
294
#if MBEDTLS_VERSION_MAJOR >= 3
@@ -304,13 +373,13 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
304
373
mp_raise_OSError (ret );
305
374
}
306
375
307
- size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
308
- return common_hal_socketpool_socket_bind (self -> sock , host , hostlen , port );
376
+ size_t common_hal_ssl_sslsocket_bind (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
377
+ return ssl_socket_bind (self , addr_in );
309
378
}
310
379
311
380
void common_hal_ssl_sslsocket_close (ssl_sslsocket_obj_t * self ) {
312
381
self -> closed = true;
313
- common_hal_socketpool_socket_close (self -> sock );
382
+ ssl_socket_close (self );
314
383
mbedtls_pk_free (& self -> pkey );
315
384
mbedtls_x509_crt_free (& self -> cert );
316
385
mbedtls_x509_crt_free (& self -> cacert );
@@ -356,8 +425,8 @@ STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
356
425
}
357
426
}
358
427
359
- void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , const char * host , size_t hostlen , uint32_t port ) {
360
- common_hal_socketpool_socket_connect (self -> sock , host , hostlen , port );
428
+ void common_hal_ssl_sslsocket_connect (ssl_sslsocket_obj_t * self , mp_obj_t addr_in ) {
429
+ ssl_socket_connect (self , addr_in );
361
430
do_handshake (self );
362
431
}
363
432
@@ -370,16 +439,21 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
370
439
}
371
440
372
441
bool common_hal_ssl_sslsocket_listen (ssl_sslsocket_obj_t * self , int backlog ) {
373
- return common_hal_socketpool_socket_listen (self -> sock , backlog );
442
+ return ssl_socket_listen (self , backlog );
374
443
}
375
444
376
- ssl_sslsocket_obj_t * common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self , uint8_t * ip , uint32_t * port ) {
377
- socketpool_socket_obj_t * sock = common_hal_socketpool_socket_accept (self -> sock , ip , port );
445
+ mp_obj_t common_hal_ssl_sslsocket_accept (ssl_sslsocket_obj_t * self ) {
446
+ mp_obj_t accepted = ssl_socket_accept (self );
447
+ mp_obj_t sock = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
378
448
ssl_sslsocket_obj_t * sslsock = common_hal_ssl_sslcontext_wrap_socket (self -> ssl_context , sock , true, NULL );
379
449
do_handshake (sslsock );
380
- return sslsock ;
450
+ mp_obj_t peer = mp_obj_subscr (accepted , MP_OBJ_NEW_SMALL_INT (0 ), MP_OBJ_SENTINEL );
451
+ mp_obj_t tuple_contents [2 ];
452
+ tuple_contents [0 ] = MP_OBJ_FROM_PTR (sslsock );
453
+ tuple_contents [1 ] = peer ;
454
+ return mp_obj_new_tuple (2 , tuple_contents );
381
455
}
382
456
383
457
void common_hal_ssl_sslsocket_settimeout (ssl_sslsocket_obj_t * self , uint32_t timeout_ms ) {
384
- common_hal_socketpool_socket_settimeout (self -> sock , timeout_ms );
458
+ ssl_socket_settimeout (self , timeout_ms );
385
459
}
0 commit comments