Skip to content

Commit 53bb552

Browse files
committed
extmod/modssl_mbedtls: Implement SSLSession support.
Signed-off-by: Daniël van de Giessen <daniel@dvdgiessen.nl>
1 parent 31e131b commit 53bb552

File tree

1 file changed

+106
-4
lines changed

1 file changed

+106
-4
lines changed

extmod/modtls_mbedtls.c

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ typedef struct _mp_obj_ssl_context_t {
7070
mp_obj_t handler;
7171
} mp_obj_ssl_context_t;
7272

73+
// This corresponds to an SSLSession object.
74+
typedef struct _mp_obj_ssl_session_t {
75+
mp_obj_base_t base;
76+
mbedtls_ssl_session session;
77+
} mp_obj_ssl_session_t;
78+
7379
// This corresponds to an SSLSocket object.
7480
typedef struct _mp_obj_ssl_socket_t {
7581
mp_obj_base_t base;
@@ -81,13 +87,14 @@ typedef struct _mp_obj_ssl_socket_t {
8187
int last_error; // The last error code, if any
8288
} mp_obj_ssl_socket_t;
8389

90+
STATIC const mp_obj_type_t ssl_session_type;
8491
STATIC const mp_obj_type_t ssl_context_type;
8592
STATIC const mp_obj_type_t ssl_socket_type;
8693

8794
STATIC const MP_DEFINE_STR_OBJ(mbedtls_version_obj, MBEDTLS_VERSION_STRING_FULL);
8895

8996
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
90-
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname);
97+
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t ssl_session);
9198

9299
/******************************************************************************/
93100
// Helper functions.
@@ -199,6 +206,60 @@ STATIC int ssl_sock_cert_verify(void *ptr, mbedtls_x509_crt *crt, int depth, uin
199206
return mp_obj_get_int(mp_call_function_2(o->handler, MP_OBJ_FROM_PTR(&cert), MP_OBJ_NEW_SMALL_INT(depth)));
200207
}
201208

209+
/******************************************************************************/
210+
// SSLSession type.
211+
212+
STATIC mp_obj_t ssl_session_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) {
213+
mp_arg_check_num(n_args, n_kw, 1, 1, false);
214+
215+
mp_buffer_info_t bufinfo;
216+
mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ);
217+
218+
mp_obj_ssl_session_t *self = m_new_obj(mp_obj_ssl_session_t);
219+
self->base.type = type_in;
220+
221+
mbedtls_ssl_session_init(&self->session);
222+
int ret = mbedtls_ssl_session_load(&self->session, bufinfo.buf, bufinfo.len);
223+
if (ret != 0) {
224+
mbedtls_raise_error(ret);
225+
}
226+
227+
return MP_OBJ_FROM_PTR(self);
228+
}
229+
230+
STATIC mp_obj_t ssl_session_serialize(mp_obj_t self_in) {
231+
mp_obj_ssl_session_t *self = MP_OBJ_TO_PTR(self_in);
232+
size_t len;
233+
vstr_t vstr;
234+
mbedtls_ssl_session_save(&self->session, NULL, 0, &len);
235+
vstr_init_len(&vstr, len);
236+
mbedtls_ssl_session_save(&self->session, (unsigned char *)vstr.buf, len, &len);
237+
return mp_obj_new_bytes_from_vstr(&vstr);
238+
}
239+
STATIC MP_DEFINE_CONST_FUN_OBJ_1(ssl_session_serialize_obj, ssl_session_serialize);
240+
241+
STATIC mp_int_t ssl_session_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) {
242+
if (flags != MP_BUFFER_READ) {
243+
return 1;
244+
}
245+
mp_get_buffer_raise(ssl_session_serialize(self_in), bufinfo, flags);
246+
return 0;
247+
}
248+
249+
STATIC const mp_rom_map_elem_t ssl_session_locals_dict_table[] = {
250+
{ MP_ROM_QSTR(MP_QSTR_serialize), MP_ROM_PTR(&ssl_session_serialize_obj) },
251+
};
252+
STATIC MP_DEFINE_CONST_DICT(ssl_session_locals_dict, ssl_session_locals_dict_table);
253+
254+
STATIC MP_DEFINE_CONST_OBJ_TYPE(
255+
ssl_session_type,
256+
MP_QSTR_SSLSession,
257+
MP_TYPE_FLAG_NONE,
258+
make_new, ssl_session_make_new,
259+
buffer, ssl_session_get_buffer,
260+
locals_dict, &ssl_session_locals_dict
261+
);
262+
202263
/******************************************************************************/
203264
// SSLContext type.
204265

@@ -402,11 +463,12 @@ STATIC mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad
402463
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations);
403464

404465
STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
405-
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname };
466+
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_session };
406467
static const mp_arg_t allowed_args[] = {
407468
{ MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} },
408469
{ MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} },
409470
{ MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
471+
{ MP_QSTR_session, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} },
410472
};
411473

412474
// Parse arguments.
@@ -417,7 +479,7 @@ STATIC mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args,
417479

418480
// Create and return the new SSLSocket object.
419481
return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool,
420-
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj);
482+
args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj, args[ARG_session].u_obj);
421483
}
422484
STATIC MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket);
423485

@@ -481,7 +543,7 @@ STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
481543
}
482544

483545
STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock,
484-
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) {
546+
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t ssl_session) {
485547

486548
// Verify the socket object has the full stream protocol
487549
mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL);
@@ -519,6 +581,14 @@ STATIC mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
519581
mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname"));
520582
}
521583

584+
if (ssl_session != mp_const_none) {
585+
mp_obj_ssl_session_t *session = MP_OBJ_TO_PTR(ssl_session);
586+
ret = mbedtls_ssl_set_session(&o->ssl, &session->session);
587+
if (ret != 0) {
588+
goto cleanup;
589+
}
590+
}
591+
522592
mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL);
523593

524594
if (do_handshake_on_connect) {
@@ -716,6 +786,36 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
716786
return ret;
717787
}
718788

789+
STATIC void ssl_socket_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) {
790+
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in);
791+
if (dest[0] == MP_OBJ_NULL) {
792+
// Load attribute.
793+
if (attr == MP_QSTR_session) {
794+
mp_obj_ssl_session_t *o = m_new_obj(mp_obj_ssl_session_t);
795+
o->base.type = &ssl_session_type;
796+
mbedtls_ssl_session_init(&o->session);
797+
int ret = mbedtls_ssl_get_session(&self->ssl, &o->session);
798+
if (ret != 0) {
799+
mbedtls_raise_error(ret);
800+
}
801+
dest[0] = o;
802+
} else {
803+
// Continue lookup in locals_dict.
804+
dest[1] = MP_OBJ_SENTINEL;
805+
}
806+
} else if (dest[1] != MP_OBJ_NULL) {
807+
// Store attribute.
808+
if (attr == MP_QSTR_session) {
809+
mp_obj_ssl_session_t *ssl_session = MP_OBJ_TO_PTR(dest[1]);
810+
dest[0] = MP_OBJ_NULL;
811+
int ret = mbedtls_ssl_set_session(&self->ssl, &ssl_session->session);
812+
if (ret != 0) {
813+
mbedtls_raise_error(ret);
814+
}
815+
}
816+
}
817+
}
818+
719819
STATIC const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = {
720820
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
721821
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
@@ -747,6 +847,7 @@ STATIC MP_DEFINE_CONST_OBJ_TYPE(
747847
MP_QSTR_SSLSocket,
748848
MP_TYPE_FLAG_NONE,
749849
protocol, &ssl_socket_stream_p,
850+
attr, ssl_socket_attr,
750851
locals_dict, &ssl_socket_locals_dict
751852
);
752853

@@ -758,6 +859,7 @@ STATIC const mp_rom_map_elem_t mp_module_tls_globals_table[] = {
758859

759860
// Classes.
760861
{ MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) },
862+
{ MP_ROM_QSTR(MP_QSTR_SSLSession), MP_ROM_PTR(&ssl_session_type) },
761863

762864
// Constants.
763865
{ MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)},

0 commit comments

Comments
 (0)