diff --git a/docs/library/ssl.rst b/docs/library/ssl.rst index c86101872c364..74d92d8643d46 100644 --- a/docs/library/ssl.rst +++ b/docs/library/ssl.rst @@ -13,7 +13,7 @@ facilities for network sockets, both client-side and server-side. Functions --------- -.. function:: ssl.wrap_socket(sock, server_side=False, key=None, cert=None, cert_reqs=CERT_NONE, cadata=None, server_hostname=None, do_handshake=True) +.. function:: ssl.wrap_socket(sock, server_side=False, key=None, cert=None, cert_reqs=CERT_NONE, cadata=None, server_hostname=None, do_handshake=True, session=None) Wrap the given *sock* and return a new wrapped-socket object. The implementation of this function is to first create an `SSLContext` and then call the `SSLContext.wrap_socket` @@ -28,6 +28,9 @@ Functions - *cadata* is a bytes object containing the CA certificate chain (in DER format) that will validate the peer's certificate. Currently only a single DER-encoded certificate is supported. + - *session* allows a client socket to reuse a session by passing a SSLSession object + previously retrieved from the ``session`` property of a wrapped-socket object. + Depending on the underlying module implementation in a particular :term:`MicroPython port`, some or all keyword arguments above may be not supported. @@ -66,7 +69,7 @@ class SSLContext Set the available ciphers for sockets created with this context. *ciphers* should be a list of strings in the `IANA cipher suite format `_ . -.. method:: SSLContext.wrap_socket(sock, *, server_side=False, do_handshake_on_connect=True, server_hostname=None, client_id=None) +.. method:: SSLContext.wrap_socket(sock, *, server_side=False, do_handshake_on_connect=True, server_hostname=None, client_id=None, session=None) Takes a `stream` *sock* (usually socket.socket instance of ``SOCK_STREAM`` type), and returns an instance of ssl.SSLSocket, wrapping the underlying stream. @@ -92,6 +95,9 @@ class SSLContext - *client_id* is a MicroPython-specific extension argument used only when implementing a DTLS Server. See :ref:`dtls` for details. + - *session* allows a client socket to reuse a session by passing a SSLSession object + previously retrieved from the ``session`` property of a ssl.SSLSocket object. + .. warning:: Some implementations of ``ssl`` module do NOT validate server certificates, @@ -113,6 +119,19 @@ class SSLContext `mpremote rtc --set ` or ``ntptime``, and ``server_hostname`` must be specified when on the client side. +class SSLSession +---------------- + +.. class:: SSLSession(buf) + + This constructor is a MicroPython extension to reconstruct a SSLSession object using + a bytes object previously returned by the ``serialize`` method. + +.. method:: SSLSession.serialize() + + This function is a MicroPython extension to return a bytes object representing the + session, allowing it to be stored and reconstructed at a later time. + Exceptions ---------- diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index 58634257328da..c1d188f519f64 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -66,6 +66,14 @@ #include "mbedtls/ssl_cookie.h" #endif +#if defined(MBEDTLS_CONFIG_FILE) +#include MBEDTLS_CONFIG_FILE +#endif + +#if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_TICKET_C) +#include "mbedtls/ssl_ticket.h" +#endif + #ifndef MICROPY_MBEDTLS_CONFIG_BARE_METAL #define MICROPY_MBEDTLS_CONFIG_BARE_METAL (0) #endif @@ -89,6 +97,9 @@ typedef struct _mp_obj_ssl_context_t { mbedtls_x509_crt cacert; mbedtls_x509_crt cert; mbedtls_pk_context pkey; + #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_TICKET_C) + mbedtls_ssl_ticket_context ticket; + #endif int authmode; int *ciphersuites; mp_obj_t handler; @@ -101,6 +112,12 @@ typedef struct _mp_obj_ssl_context_t { #endif } mp_obj_ssl_context_t; +// This corresponds to an SSLSession object. +typedef struct _mp_obj_ssl_session_t { + mp_obj_base_t base; + mbedtls_ssl_session session; +} mp_obj_ssl_session_t; + // This corresponds to an SSLSocket object. typedef struct _mp_obj_ssl_socket_t { mp_obj_base_t base; @@ -118,6 +135,7 @@ typedef struct _mp_obj_ssl_socket_t { #endif } mp_obj_ssl_socket_t; +static const mp_obj_type_t ssl_session_type; static const mp_obj_type_t ssl_context_type; static const mp_obj_type_t ssl_socket_type; @@ -125,7 +143,7 @@ static const MP_DEFINE_STR_OBJ(mbedtls_version_obj, MBEDTLS_VERSION_STRING_FULL) static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, - mp_obj_t client_id); + mp_obj_t client_id, mp_obj_t ssl_session); /******************************************************************************/ // Helper functions. @@ -261,6 +279,60 @@ static int ssl_sock_cert_verify(void *ptr, mbedtls_x509_crt *crt, int depth, uin return mp_obj_get_int(mp_call_function_2(o->handler, MP_OBJ_FROM_PTR(&cert), MP_OBJ_NEW_SMALL_INT(depth))); } +/******************************************************************************/ +// SSLSession type. + +static mp_obj_t ssl_session_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { + mp_arg_check_num(n_args, n_kw, 1, 1, false); + + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ); + + mp_obj_ssl_session_t *self = m_new_obj(mp_obj_ssl_session_t); + self->base.type = type_in; + + mbedtls_ssl_session_init(&self->session); + int ret = mbedtls_ssl_session_load(&self->session, bufinfo.buf, bufinfo.len); + if (ret != 0) { + mbedtls_raise_error(ret); + } + + return MP_OBJ_FROM_PTR(self); +} + +static mp_obj_t ssl_session_serialize(mp_obj_t self_in) { + mp_obj_ssl_session_t *self = MP_OBJ_TO_PTR(self_in); + size_t len; + vstr_t vstr; + mbedtls_ssl_session_save(&self->session, NULL, 0, &len); + vstr_init_len(&vstr, len); + mbedtls_ssl_session_save(&self->session, (unsigned char *)vstr.buf, len, &len); + return mp_obj_new_bytes_from_vstr(&vstr); +} +static MP_DEFINE_CONST_FUN_OBJ_1(ssl_session_serialize_obj, ssl_session_serialize); + +static mp_int_t ssl_session_get_buffer(mp_obj_t self_in, mp_buffer_info_t *bufinfo, mp_uint_t flags) { + if (flags != MP_BUFFER_READ) { + return 1; + } + mp_get_buffer_raise(ssl_session_serialize(self_in), bufinfo, flags); + return 0; +} + +static const mp_rom_map_elem_t ssl_session_locals_dict_table[] = { + { MP_ROM_QSTR(MP_QSTR_serialize), MP_ROM_PTR(&ssl_session_serialize_obj) }, +}; +static MP_DEFINE_CONST_DICT(ssl_session_locals_dict, ssl_session_locals_dict_table); + +static MP_DEFINE_CONST_OBJ_TYPE( + ssl_session_type, + MP_QSTR_SSLSession, + MP_TYPE_FLAG_NONE, + make_new, ssl_session_make_new, + buffer, ssl_session_get_buffer, + locals_dict, &ssl_session_locals_dict + ); + /******************************************************************************/ // SSLContext type. @@ -287,6 +359,9 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args mbedtls_x509_crt_init(&self->cacert); mbedtls_x509_crt_init(&self->cert); mbedtls_pk_init(&self->pkey); + #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_TICKET_C) + mbedtls_ssl_ticket_init(&self->ticket); + #endif self->ciphersuites = NULL; self->handler = mp_const_none; #if MICROPY_PY_SSL_ECDSA_SIGN_ALT @@ -341,6 +416,14 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args } #endif // MBEDTLS_SSL_DTLS_HELLO_VERIFY + #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_TICKET_C) + ret = mbedtls_ssl_ticket_setup(&self->ticket, mbedtls_ctr_drbg_random, &self->ctr_drbg, MBEDTLS_CIPHER_AES_256_GCM, 86400); + if (ret != 0) { + mbedtls_raise_error(ret); + } + mbedtls_ssl_conf_session_tickets_cb(&self->conf, mbedtls_ssl_ticket_write, mbedtls_ssl_ticket_parse, &self->ticket); + #endif + return MP_OBJ_FROM_PTR(self); } @@ -381,6 +464,9 @@ static void ssl_context_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { #if MICROPY_PY_SSL_FINALISER static mp_obj_t ssl_context___del__(mp_obj_t self_in) { mp_obj_ssl_context_t *self = MP_OBJ_TO_PTR(self_in); + #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_TICKET_C) + mbedtls_ssl_ticket_free(&self->ticket); + #endif mbedtls_pk_free(&self->pkey); mbedtls_x509_crt_free(&self->cert); mbedtls_x509_crt_free(&self->cacert); @@ -494,7 +580,7 @@ static mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations); static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { - enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_client_id }; + enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_client_id, ARG_session }; static const mp_arg_t allowed_args[] = { { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, { MP_QSTR_do_handshake_on_connect, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, @@ -502,6 +588,7 @@ static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, #ifdef MBEDTLS_SSL_DTLS_HELLO_VERIFY { MP_QSTR_client_id, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, #endif + { MP_QSTR_session, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, }; // Parse arguments. @@ -517,7 +604,7 @@ static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, // Create and return the new SSLSocket object. return ssl_socket_make_new(self, sock, args[ARG_server_side].u_bool, - args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj, client_id); + args[ARG_do_handshake_on_connect].u_bool, args[ARG_server_hostname].u_obj, client_id, args[ARG_session].u_obj); } static MP_DEFINE_CONST_FUN_OBJ_KW(ssl_context_wrap_socket_obj, 2, ssl_context_wrap_socket); @@ -614,7 +701,7 @@ static int _mbedtls_timing_get_delay(void *ctx) { #endif static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, - bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t client_id) { + bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, mp_obj_t client_id, mp_obj_t ssl_session) { // Store the current SSL context. store_active_context(ssl_context); @@ -665,6 +752,14 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname")); } + if (ssl_session != mp_const_none) { + mp_obj_ssl_session_t *session = MP_OBJ_TO_PTR(ssl_session); + ret = mbedtls_ssl_set_session(&o->ssl, &session->session); + if (ret != 0) { + goto cleanup; + } + } + #ifdef MBEDTLS_SSL_PROTO_DTLS mbedtls_ssl_set_timer_cb(&o->ssl, o, _mbedtls_timing_set_delay, _mbedtls_timing_get_delay); #endif @@ -889,6 +984,36 @@ static mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i return ret; } +static void ssl_socket_attr(mp_obj_t self_in, qstr attr, mp_obj_t *dest) { + mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); + if (dest[0] == MP_OBJ_NULL) { + // Load attribute. + if (attr == MP_QSTR_session) { + mp_obj_ssl_session_t *o = m_new_obj(mp_obj_ssl_session_t); + o->base.type = &ssl_session_type; + mbedtls_ssl_session_init(&o->session); + int ret = mbedtls_ssl_get_session(&self->ssl, &o->session); + if (ret != 0) { + mbedtls_raise_error(ret); + } + dest[0] = MP_OBJ_FROM_PTR(o); + } else { + // Continue lookup in locals_dict. + dest[1] = MP_OBJ_SENTINEL; + } + } else if (dest[1] != MP_OBJ_NULL) { + // Store attribute. + if (attr == MP_QSTR_session) { + mp_obj_ssl_session_t *ssl_session = MP_OBJ_TO_PTR(dest[1]); + dest[0] = MP_OBJ_NULL; + int ret = mbedtls_ssl_set_session(&self->ssl, &ssl_session->session); + if (ret != 0) { + mbedtls_raise_error(ret); + } + } + } +} + static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) }, { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, @@ -926,6 +1051,7 @@ static MP_DEFINE_CONST_OBJ_TYPE( MP_QSTR_SSLSocket, MP_TYPE_FLAG_NONE, protocol, &ssl_socket_stream_p, + attr, ssl_socket_attr, locals_dict, &ssl_socket_locals_dict ); @@ -988,6 +1114,7 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = { // Classes. { MP_ROM_QSTR(MP_QSTR_SSLContext), MP_ROM_PTR(&ssl_context_type) }, + { MP_ROM_QSTR(MP_QSTR_SSLSession), MP_ROM_PTR(&ssl_session_type) }, // Constants. { MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)}, diff --git a/lib/micropython-lib b/lib/micropython-lib index 34c4ee1647ac4..6534f5dfe3817 160000 --- a/lib/micropython-lib +++ b/lib/micropython-lib @@ -1 +1 @@ -Subproject commit 34c4ee1647ac4b177ae40adf0ec514660e433dc0 +Subproject commit 6534f5dfe38174bf8fe8099d9833fbb9636286f6 diff --git a/ports/unix/mbedtls/mbedtls_config_port.h b/ports/unix/mbedtls/mbedtls_config_port.h index aec65e6581e73..d8ca8f0b783b5 100644 --- a/ports/unix/mbedtls/mbedtls_config_port.h +++ b/ports/unix/mbedtls/mbedtls_config_port.h @@ -28,8 +28,11 @@ // Set mbedtls configuration #define MBEDTLS_CIPHER_MODE_CTR // needed for MICROPY_PY_CRYPTOLIB_CTR +#define MBEDTLS_SSL_SESSION_TICKETS // Enable mbedtls modules +#define MBEDTLS_GCM_C +#define MBEDTLS_SSL_TICKET_C #define MBEDTLS_TIMING_C #if defined(MICROPY_UNIX_COVERAGE) diff --git a/tests/multi_net/sslcontext_server_client_session.py b/tests/multi_net/sslcontext_server_client_session.py new file mode 100644 index 0000000000000..cee809347539f --- /dev/null +++ b/tests/multi_net/sslcontext_server_client_session.py @@ -0,0 +1,131 @@ +# Test creating an SSL connection with certificates as bytes objects. + +try: + from io import IOBase + import os + import socket + import ssl +except ImportError: + print("SKIP") + raise SystemExit + +if not hasattr(ssl, "SSLSession"): + print("SKIP") + raise SystemExit + +PORT = 8000 + +# These are test certificates. See tests/README.md for details. +certfile = "ec_cert.der" +keyfile = "ec_key.der" + +try: + os.stat(certfile) + os.stat(keyfile) +except OSError: + print("SKIP") + raise SystemExit + +with open(certfile, "rb") as cf: + cert = cadata = cf.read() + +with open(keyfile, "rb") as kf: + key = kf.read() + + +# Helper class to count number of bytes going over a TCP socket +class CountingStream(IOBase): + def __init__(self, stream): + self.stream = stream + self.count = 0 + + def readinto(self, buf, nbytes=None): + result = self.stream.readinto(buf) if nbytes is None else self.stream.readinto(buf, nbytes) + self.count += result + return result + + def write(self, buf): + self.count += len(buf) + return self.stream.write(buf) + + def ioctl(self, req, arg): + if hasattr(self.stream, "ioctl"): + return self.stream.ioctl(req, arg) + return 0 + + +# Server +def instance0(): + multitest.globals(IP=multitest.get_network_ip()) + s = socket.socket() + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1]) + s.listen(1) + multitest.next() + server_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + server_ctx.load_cert_chain(cert, key) + for i in range(7): + s2, _ = s.accept() + s2 = server_ctx.wrap_socket(s2, server_side=True) + print(s2.read(18)) + s2.write(b"server to client {}".format(i)) + s2.close() + s.close() + + +# Client +def instance1(): + multitest.next() + + def connect_and_count(i, session, set_method="wrap_socket"): + s = socket.socket() + s.connect(socket.getaddrinfo(IP, PORT)[0][-1]) + s = CountingStream(s) + client_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + client_ctx.verify_mode = ssl.CERT_REQUIRED + client_ctx.load_verify_locations(cadata=cadata) + wrap_socket_kwargs = {} + if set_method == "wrap_socket": + wrap_socket_kwargs = {"session": session} + elif set_method == "socket_attr": + wrap_socket_kwargs = {"do_handshake_on_connect": False} + s2 = client_ctx.wrap_socket(s, server_hostname="micropython.local", **wrap_socket_kwargs) + if set_method == "socket_attr" and session is not None: + s2.session = session + s2.write(b"client to server {}".format(i)) + print(s2.read(18)) + session = s2.session + print(type(session)) + s2.close() + return session, s.count + + # No session reuse + session, count_without_reuse = connect_and_count(0, None) + + # Direct session reuse + session, count = connect_and_count(1, session, "wrap_socket") + print(count < count_without_reuse) + + # Serialized session reuse + session = ssl.SSLSession(session.serialize()) + session, count = connect_and_count(2, session, "wrap_socket") + print(count < count_without_reuse) + + # Serialized session reuse (using buffer protocol) + session = ssl.SSLSession(bytes(session)) + session, count = connect_and_count(3, session, "wrap_socket") + print(count < count_without_reuse) + + # Direct session reuse + session, count = connect_and_count(4, session, "socket_attr") + print(count < count_without_reuse) + + # Serialized session reuse + session = ssl.SSLSession(session.serialize()) + session, count = connect_and_count(5, session, "socket_attr") + print(count < count_without_reuse) + + # Serialized session reuse (using buffer protocol) + session = ssl.SSLSession(bytes(session)) + session, count = connect_and_count(6, session, "socket_attr") + print(count < count_without_reuse) diff --git a/tests/multi_net/sslcontext_server_client_session.py.exp b/tests/multi_net/sslcontext_server_client_session.py.exp new file mode 100644 index 0000000000000..f3ed2c57d68df --- /dev/null +++ b/tests/multi_net/sslcontext_server_client_session.py.exp @@ -0,0 +1,29 @@ +--- instance0 --- +b'client to server 0' +b'client to server 1' +b'client to server 2' +b'client to server 3' +b'client to server 4' +b'client to server 5' +b'client to server 6' +--- instance1 --- +b'server to client 0' + +b'server to client 1' + +True +b'server to client 2' + +True +b'server to client 3' + +True +b'server to client 4' + +True +b'server to client 5' + +True +b'server to client 6' + +True