Skip to content

Commit 9e6472f

Browse files
committed
ssl: work on anything implementing the socket protocol
In principle this allows core SSL code to be used with e.g., wiznet or airlift sockets. It might actually be useful with wiznet ethernet devices (it's probably not with airlift)
1 parent e04fe9a commit 9e6472f

File tree

7 files changed

+115
-61
lines changed

7 files changed

+115
-61
lines changed

shared-bindings/ssl/SSLContext.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,9 @@ STATIC mp_obj_t ssl_sslcontext_wrap_socket(size_t n_args, const mp_obj_t *pos_ar
200200
mp_raise_ValueError(MP_ERROR_TEXT("Server side context cannot have hostname"));
201201
}
202202

203-
socketpool_socket_obj_t *sock = args[ARG_sock].u_obj;
203+
mp_obj_t sock_obj = args[ARG_sock].u_obj;
204204

205-
return common_hal_ssl_sslcontext_wrap_socket(self, sock, server_side, server_hostname);
205+
return common_hal_ssl_sslcontext_wrap_socket(self, sock_obj, server_side, server_hostname);
206206
}
207207
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_sslcontext_wrap_socket_obj, 1, ssl_sslcontext_wrap_socket);
208208

shared-bindings/ssl/SSLContext.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,14 @@
3333
#include "common-hal/ssl/SSLContext.h"
3434
#endif
3535

36-
#include "shared-bindings/socketpool/Socket.h"
3736
#include "shared-bindings/ssl/SSLSocket.h"
3837

3938
extern const mp_obj_type_t ssl_sslcontext_type;
4039

4140
void common_hal_ssl_sslcontext_construct(ssl_sslcontext_obj_t *self);
4241

4342
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
44-
socketpool_socket_obj_t *sock, bool server_side, const char *server_hostname);
43+
mp_obj_t socket, bool server_side, const char *server_hostname);
4544

4645
void common_hal_ssl_sslcontext_load_verify_locations(ssl_sslcontext_obj_t *self,
4746
const char *cadata);

shared-bindings/ssl/SSLSocket.c

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(ssl_sslsocket___exit___obj, 4, 4, ssl
7373
//| Returns a tuple of (new_socket, remote_address)"""
7474
STATIC mp_obj_t ssl_sslsocket_accept(mp_obj_t self_in) {
7575
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
76-
uint8_t ip[4];
77-
uint32_t port;
78-
79-
ssl_sslsocket_obj_t *sslsock = common_hal_ssl_sslsocket_accept(self, ip, &port);
80-
81-
mp_obj_t tuple_contents[2];
82-
tuple_contents[0] = MP_OBJ_FROM_PTR(sslsock);
83-
tuple_contents[1] = netutils_format_inet_addr(ip, port, NETUTILS_BIG);
84-
return mp_obj_new_tuple(2, tuple_contents);
76+
return common_hal_ssl_sslsocket_accept(self);
8577
}
8678
STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslsocket_accept_obj, ssl_sslsocket_accept);
8779

@@ -96,14 +88,7 @@ STATIC mp_obj_t ssl_sslsocket_bind(mp_obj_t self_in, mp_obj_t addr_in) {
9688
mp_obj_t *addr_items;
9789
mp_obj_get_array_fixed_n(addr_in, 2, &addr_items);
9890

99-
size_t hostlen;
100-
const char *host = mp_obj_str_get_data(addr_items[0], &hostlen);
101-
mp_int_t port = mp_obj_get_int(addr_items[1]);
102-
if (port < 0) {
103-
mp_raise_ValueError(MP_ERROR_TEXT("port must be >= 0"));
104-
}
105-
106-
size_t error = common_hal_ssl_sslsocket_bind(self, host, hostlen, (uint32_t)port);
91+
size_t error = common_hal_ssl_sslsocket_bind(self, addr_in);
10792
if (error != 0) {
10893
mp_raise_OSError(error);
10994
}
@@ -128,18 +113,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_sslsocket_close_obj, ssl_sslsocket_close);
128113
//| ...
129114
STATIC mp_obj_t ssl_sslsocket_connect(mp_obj_t self_in, mp_obj_t addr_in) {
130115
ssl_sslsocket_obj_t *self = MP_OBJ_TO_PTR(self_in);
131-
132-
mp_obj_t *addr_items;
133-
mp_obj_get_array_fixed_n(addr_in, 2, &addr_items);
134-
135-
size_t hostlen;
136-
const char *host = mp_obj_str_get_data(addr_items[0], &hostlen);
137-
mp_int_t port = mp_obj_get_int(addr_items[1]);
138-
if (port < 0) {
139-
mp_raise_ValueError(MP_ERROR_TEXT("port must be >= 0"));
140-
}
141-
142-
common_hal_ssl_sslsocket_connect(self, host, hostlen, (uint32_t)port);
116+
common_hal_ssl_sslsocket_connect(self, addr_in);
143117

144118
return mp_const_none;
145119
}

shared-bindings/ssl/SSLSocket.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434

3535
extern const mp_obj_type_t ssl_sslsocket_type;
3636

37-
ssl_sslsocket_obj_t *common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self, uint8_t *ip, uint32_t *port);
38-
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
37+
mp_obj_t common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self);
38+
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, mp_obj_t addr);
3939
void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self);
40-
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port);
40+
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr);
4141
bool common_hal_ssl_sslsocket_get_closed(ssl_sslsocket_obj_t *self);
4242
bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self);
4343
bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog);

shared-module/ssl/SSLContext.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
#include "shared-bindings/ssl/SSLContext.h"
2828
#include "shared-bindings/ssl/SSLSocket.h"
29-
#include "shared-bindings/socketpool/SocketPool.h"
3029

3130
#include "py/runtime.h"
3231
#include "py/stream.h"

shared-module/ssl/SSLSocket.c

Lines changed: 97 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,10 @@
2727
*/
2828

2929
#include "shared-bindings/ssl/SSLSocket.h"
30-
#include "shared-bindings/socketpool/Socket.h"
3130
#include "shared-bindings/ssl/SSLContext.h"
32-
#include "shared-bindings/socketpool/SocketPool.h"
33-
#include "shared-bindings/socketpool/Socket.h"
3431

3532
#include "shared/runtime/interrupt_char.h"
33+
#include "shared/netutils/netutils.h"
3634
#include "py/mperrno.h"
3735
#include "py/mphal.h"
3836
#include "py/objstr.h"
@@ -104,11 +102,73 @@ STATIC NORETURN void mbedtls_raise_error(int err) {
104102
#endif
105103
}
106104

105+
STATIC int call_method_errno(size_t n_args, const mp_obj_t *args) {
106+
nlr_buf_t nlr;
107+
mp_int_t result = -MP_EINVAL;
108+
if (nlr_push(&nlr) == 0) {
109+
mp_obj_t obj_result = mp_call_method_n_kw(n_args, 0, args);
110+
result = (obj_result == mp_const_none) ? 0 : mp_obj_get_int(obj_result);
111+
nlr_pop();
112+
return result;
113+
} else {
114+
mp_obj_t exc = MP_OBJ_FROM_PTR(nlr.ret_val);
115+
mp_obj_print_exception(&mp_plat_print, MP_OBJ_FROM_PTR(nlr.ret_val));
116+
if (nlr_push(&nlr) == 0) {
117+
result = -mp_obj_get_int(mp_load_attr(exc, MP_QSTR_errno));
118+
nlr_pop();
119+
}
120+
}
121+
return result;
122+
}
123+
124+
static int ssl_socket_send(ssl_sslsocket_obj_t *self, const byte *buf, size_t len) {
125+
mp_obj_array_t mv;
126+
mp_obj_memoryview_init(&mv, 'B', 0, len, (void *)buf);
127+
128+
self->send_args[2] = MP_OBJ_FROM_PTR(&mv);
129+
return call_method_errno(1, self->send_args);
130+
}
131+
132+
static int ssl_socket_recv_into(ssl_sslsocket_obj_t *self, byte *buf, size_t len) {
133+
mp_obj_array_t mv;
134+
mp_obj_memoryview_init(&mv, 'B' | MP_OBJ_ARRAY_TYPECODE_FLAG_RW, 0, len, buf);
135+
136+
self->recv_into_args[2] = MP_OBJ_FROM_PTR(&mv);
137+
return call_method_errno(1, self->recv_into_args);
138+
}
139+
140+
static int ssl_socket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
141+
self->connect_args[2] = addr_in;
142+
return call_method_errno(1, self->connect_args);
143+
}
144+
145+
static int ssl_socket_bind(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
146+
self->bind_args[2] = addr_in;
147+
return call_method_errno(1, self->bind_args);
148+
}
149+
150+
static int ssl_socket_close(ssl_sslsocket_obj_t *self) {
151+
return call_method_errno(0, self->close_args);
152+
}
153+
154+
static int ssl_socket_settimeout(ssl_sslsocket_obj_t *self, mp_int_t timeout_ms) {
155+
self->settimeout_args[2] = mp_obj_new_float(timeout_ms * MICROPY_FLOAT_CONST(1e-3));
156+
return call_method_errno(1, self->settimeout_args);
157+
}
158+
159+
static int ssl_socket_listen(ssl_sslsocket_obj_t *self, mp_int_t backlog) {
160+
self->listen_args[2] = MP_OBJ_NEW_SMALL_INT(backlog);
161+
return call_method_errno(1, self->listen_args);
162+
}
163+
164+
static mp_obj_t ssl_socket_accept(ssl_sslsocket_obj_t *self) {
165+
return mp_call_method_n_kw(0, 0, self->accept_args);
166+
}
167+
107168
STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
108-
mp_obj_t sock = *(mp_obj_t *)ctx;
169+
ssl_sslsocket_obj_t *self = (ssl_sslsocket_obj_t *)ctx;
109170

110-
// mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err);
111-
mp_int_t out_sz = socketpool_socket_send(sock, buf, len);
171+
mp_int_t out_sz = ssl_socket_send(self, buf, len);
112172
DEBUG_PRINT("socket_send() -> %d", out_sz);
113173
if (out_sz < 0) {
114174
int err = -out_sz;
@@ -124,9 +184,9 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
124184

125185
// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
126186
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
127-
mp_obj_t sock = *(mp_obj_t *)ctx;
187+
ssl_sslsocket_obj_t *self = (ssl_sslsocket_obj_t *)ctx;
128188

129-
mp_int_t out_sz = socketpool_socket_recv_into(sock, buf, len);
189+
mp_int_t out_sz = ssl_socket_recv_into(self, buf, len);
130190
DEBUG_PRINT("socket_recv() -> %d", out_sz);
131191
if (out_sz < 0) {
132192
int err = -out_sz;
@@ -151,16 +211,26 @@ static int urandom_adapter(void *unused, unsigned char *buf, size_t n) {
151211
#endif
152212

153213
ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t *self,
154-
socketpool_socket_obj_t *socket, bool server_side, const char *server_hostname) {
214+
mp_obj_t socket, bool server_side, const char *server_hostname) {
155215

156-
if (socket->type != SOCKETPOOL_SOCK_STREAM) {
216+
mp_int_t socket_type = mp_obj_get_int(mp_load_attr(socket, MP_QSTR_type));
217+
if (socket_type != SOCKETPOOL_SOCK_STREAM) {
157218
mp_raise_RuntimeError(MP_ERROR_TEXT("Invalid socket for TLS"));
158219
}
159220

160221
ssl_sslsocket_obj_t *o = m_new_obj_with_finaliser(ssl_sslsocket_obj_t);
161222
o->base.type = &ssl_sslsocket_type;
162223
o->ssl_context = self;
163-
o->sock = socket;
224+
o->sock_obj = socket;
225+
226+
mp_load_method(socket, MP_QSTR_accept, o->accept_args);
227+
mp_load_method(socket, MP_QSTR_bind, o->bind_args);
228+
mp_load_method(socket, MP_QSTR_close, o->close_args);
229+
mp_load_method(socket, MP_QSTR_connect, o->connect_args);
230+
mp_load_method(socket, MP_QSTR_listen, o->listen_args);
231+
mp_load_method(socket, MP_QSTR_recv_into, o->recv_into_args);
232+
mp_load_method(socket, MP_QSTR_send, o->send_args);
233+
mp_load_method(socket, MP_QSTR_settimeout, o->settimeout_args);
164234

165235
mbedtls_ssl_init(&o->ssl);
166236
mbedtls_ssl_config_init(&o->conf);
@@ -219,7 +289,7 @@ ssl_sslsocket_obj_t *common_hal_ssl_sslcontext_wrap_socket(ssl_sslcontext_obj_t
219289
}
220290
}
221291

222-
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
292+
mbedtls_ssl_set_bio(&o->ssl, o, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
223293

224294
if (self->cert_buf.buf != NULL) {
225295
#if MBEDTLS_VERSION_MAJOR >= 3
@@ -304,13 +374,13 @@ mp_uint_t common_hal_ssl_sslsocket_send(ssl_sslsocket_obj_t *self, const uint8_t
304374
mp_raise_OSError(ret);
305375
}
306376

307-
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port) {
308-
return common_hal_socketpool_socket_bind(self->sock, host, hostlen, port);
377+
size_t common_hal_ssl_sslsocket_bind(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
378+
return ssl_socket_bind(self, addr_in);
309379
}
310380

311381
void common_hal_ssl_sslsocket_close(ssl_sslsocket_obj_t *self) {
312382
self->closed = true;
313-
common_hal_socketpool_socket_close(self->sock);
383+
ssl_socket_close(self);
314384
mbedtls_pk_free(&self->pkey);
315385
mbedtls_x509_crt_free(&self->cert);
316386
mbedtls_x509_crt_free(&self->cacert);
@@ -356,8 +426,8 @@ STATIC void do_handshake(ssl_sslsocket_obj_t *self) {
356426
}
357427
}
358428

359-
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, const char *host, size_t hostlen, uint32_t port) {
360-
common_hal_socketpool_socket_connect(self->sock, host, hostlen, port);
429+
void common_hal_ssl_sslsocket_connect(ssl_sslsocket_obj_t *self, mp_obj_t addr_in) {
430+
ssl_socket_connect(self, addr_in);
361431
do_handshake(self);
362432
}
363433

@@ -370,16 +440,21 @@ bool common_hal_ssl_sslsocket_get_connected(ssl_sslsocket_obj_t *self) {
370440
}
371441

372442
bool common_hal_ssl_sslsocket_listen(ssl_sslsocket_obj_t *self, int backlog) {
373-
return common_hal_socketpool_socket_listen(self->sock, backlog);
443+
return ssl_socket_listen(self, backlog);
374444
}
375445

376-
ssl_sslsocket_obj_t *common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self, uint8_t *ip, uint32_t *port) {
377-
socketpool_socket_obj_t *sock = common_hal_socketpool_socket_accept(self->sock, ip, port);
446+
mp_obj_t common_hal_ssl_sslsocket_accept(ssl_sslsocket_obj_t *self) {
447+
mp_obj_t accepted = ssl_socket_accept(self);
448+
mp_obj_t sock = mp_obj_subscr(accepted, MP_OBJ_NEW_SMALL_INT(0), MP_OBJ_SENTINEL);
378449
ssl_sslsocket_obj_t *sslsock = common_hal_ssl_sslcontext_wrap_socket(self->ssl_context, sock, true, NULL);
379450
do_handshake(sslsock);
380-
return sslsock;
451+
mp_obj_t peer = mp_obj_subscr(accepted, MP_OBJ_NEW_SMALL_INT(0), MP_OBJ_SENTINEL);
452+
mp_obj_t tuple_contents[2];
453+
tuple_contents[0] = MP_OBJ_FROM_PTR(sslsock);
454+
tuple_contents[1] = peer;
455+
return mp_obj_new_tuple(2, tuple_contents);
381456
}
382457

383458
void common_hal_ssl_sslsocket_settimeout(ssl_sslsocket_obj_t *self, uint32_t timeout_ms) {
384-
common_hal_socketpool_socket_settimeout(self->sock, timeout_ms);
459+
ssl_socket_settimeout(self, timeout_ms);
385460
}

shared-module/ssl/SSLSocket.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
#include "py/obj.h"
3131

3232
#include "shared-module/ssl/SSLContext.h"
33-
#include "common-hal/socketpool/Socket.h"
3433

3534
#include "mbedtls/platform.h"
3635
#include "mbedtls/ssl.h"
@@ -41,7 +40,7 @@
4140

4241
typedef struct ssl_sslsocket_obj {
4342
mp_obj_base_t base;
44-
socketpool_socket_obj_t *sock;
43+
mp_obj_t sock_obj;
4544
ssl_sslcontext_obj_t *ssl_context;
4645
mbedtls_entropy_context entropy;
4746
mbedtls_ctr_drbg_context ctr_drbg;
@@ -51,4 +50,12 @@ typedef struct ssl_sslsocket_obj {
5150
mbedtls_x509_crt cert;
5251
mbedtls_pk_context pkey;
5352
bool closed;
53+
mp_obj_t accept_args[2];
54+
mp_obj_t bind_args[3];
55+
mp_obj_t close_args[2];
56+
mp_obj_t connect_args[3];
57+
mp_obj_t listen_args[3];
58+
mp_obj_t recv_into_args[3];
59+
mp_obj_t send_args[3];
60+
mp_obj_t settimeout_args[3];
5461
} ssl_sslsocket_obj_t;

0 commit comments

Comments
 (0)