diff --git a/docs/library/ssl.rst b/docs/library/ssl.rst index dff90b8da58b9..4327c74bad6c8 100644 --- a/docs/library/ssl.rst +++ b/docs/library/ssl.rst @@ -117,11 +117,32 @@ Exceptions This exception does NOT exist. Instead its base class, OSError, is used. +DTLS support +------------ + +.. admonition:: Difference to CPython + :class: attention + + This is a MicroPython extension. + +This module supports DTLS in client and server mode via the `PROTOCOL_DTLS_CLIENT` +and `PROTOCOL_DTLS_SERVER` constants that can be used as the ``protocol`` argument +of `SSLContext`. + +In this case the underlying socket is expected to behave as a datagram socket (i.e. +like the socket opened with ``socket.socket`` with ``socket.AF_INET`` as ``af`` and +``socket.SOCK_DGRAM`` as ``type``). + +DTLS is only supported on ports that use mbed TLS, and it is not enabled by default: +it requires enabling ``MBEDTLS_SSL_PROTO_DTLS`` in the specific port configuration. + Constants --------- .. data:: ssl.PROTOCOL_TLS_CLIENT ssl.PROTOCOL_TLS_SERVER + ssl.PROTOCOL_DTLS_CLIENT (when DTLS support is enabled) + ssl.PROTOCOL_DTLS_SERVER (when DTLS support is enabled) Supported values for the *protocol* parameter. diff --git a/extmod/mbedtls/mbedtls_config_common.h b/extmod/mbedtls/mbedtls_config_common.h index 6ea8540af9925..6cd14befc3196 100644 --- a/extmod/mbedtls/mbedtls_config_common.h +++ b/extmod/mbedtls/mbedtls_config_common.h @@ -89,6 +89,7 @@ #define MBEDTLS_SHA384_C #define MBEDTLS_SHA512_C #define MBEDTLS_SSL_CLI_C +#define MBEDTLS_SSL_PROTO_DTLS #define MBEDTLS_SSL_SRV_C #define MBEDTLS_SSL_TLS_C #define MBEDTLS_X509_CRT_PARSE_C diff --git a/extmod/modtls_mbedtls.c b/extmod/modtls_mbedtls.c index 3fd416d72f5ef..6c34805da42cb 100644 --- a/extmod/modtls_mbedtls.c +++ b/extmod/modtls_mbedtls.c @@ -37,6 +37,7 @@ #include "py/stream.h" #include "py/objstr.h" #include "py/reader.h" +#include "py/mphal.h" #include "py/gc.h" #include "extmod/vfs.h" @@ -47,6 +48,9 @@ #include "mbedtls/pk.h" #include "mbedtls/entropy.h" #include "mbedtls/ctr_drbg.h" +#ifdef MBEDTLS_SSL_PROTO_DTLS +#include "mbedtls/timing.h" +#endif #include "mbedtls/debug.h" #include "mbedtls/error.h" #if MBEDTLS_VERSION_NUMBER >= 0x03000000 @@ -65,6 +69,14 @@ #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR) +#define MP_ENDPOINT_IS_SERVER (1 << 0) +#define MP_TRANSPORT_IS_DTLS (1 << 1) + +#define MP_PROTOCOL_TLS_CLIENT 0 +#define MP_PROTOCOL_TLS_SERVER MP_ENDPOINT_IS_SERVER +#define MP_PROTOCOL_DTLS_CLIENT MP_TRANSPORT_IS_DTLS +#define MP_PROTOCOL_DTLS_SERVER MP_ENDPOINT_IS_SERVER | MP_TRANSPORT_IS_DTLS + // This corresponds to an SSLContext object. typedef struct _mp_obj_ssl_context_t { mp_obj_base_t base; @@ -91,6 +103,12 @@ typedef struct _mp_obj_ssl_socket_t { uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next int last_error; // The last error code, if any + + #ifdef MBEDTLS_SSL_PROTO_DTLS + mp_uint_t timer_start_ms; + mp_uint_t timer_fin_ms; + mp_uint_t timer_int_ms; + #endif } mp_obj_ssl_socket_t; static const mp_obj_type_t ssl_context_type; @@ -242,7 +260,10 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args mp_arg_check_num(n_args, n_kw, 1, 1, false); // This is the "protocol" argument. - mp_int_t endpoint = mp_obj_get_int(args[0]); + mp_int_t protocol = mp_obj_get_int(args[0]); + + int endpoint = (protocol & MP_ENDPOINT_IS_SERVER) ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT; + int transport = (protocol & MP_TRANSPORT_IS_DTLS) ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM; // Create SSLContext object. #if MICROPY_PY_SSL_FINALISER @@ -282,7 +303,7 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args } ret = mbedtls_ssl_config_defaults(&self->conf, endpoint, - MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); + transport, MBEDTLS_SSL_PRESET_DEFAULT); if (ret != 0) { mbedtls_raise_error(ret); } @@ -525,6 +546,39 @@ static int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { } } +#ifdef MBEDTLS_SSL_PROTO_DTLS +static void _mbedtls_timing_set_delay(void *ctx, uint32_t int_ms, uint32_t fin_ms) { + mp_obj_ssl_socket_t *o = (mp_obj_ssl_socket_t *)ctx; + + o->timer_int_ms = int_ms; + o->timer_fin_ms = fin_ms; + + if (fin_ms != 0) { + o->timer_start_ms = mp_hal_ticks_ms(); + } +} + +static int _mbedtls_timing_get_delay(void *ctx) { + mp_obj_ssl_socket_t *o = (mp_obj_ssl_socket_t *)ctx; + + if (o->timer_fin_ms == 0) { + return -1; + } + + mp_uint_t elapsed_ms = mp_hal_ticks_ms() - o->timer_start_ms; + + if (elapsed_ms >= o->timer_fin_ms) { + return 2; + } + + if (elapsed_ms >= o->timer_int_ms) { + return 1; + } + + return 0; +} +#endif + static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t sock, bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname) { @@ -577,6 +631,10 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t mp_raise_ValueError(MP_ERROR_TEXT("CERT_REQUIRED requires server_hostname")); } + #ifdef MBEDTLS_SSL_PROTO_DTLS + mbedtls_ssl_set_timer_cb(&o->ssl, o, _mbedtls_timing_set_delay, _mbedtls_timing_get_delay); + #endif + mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL); if (do_handshake_on_connect) { @@ -788,6 +846,12 @@ static const mp_rom_map_elem_t ssl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, { MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) }, { MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) }, + #ifdef MBEDTLS_SSL_PROTO_DTLS + { MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&mp_stream_read1_obj) }, + { MP_ROM_QSTR(MP_QSTR_recv_into), MP_ROM_PTR(&mp_stream_readinto_obj) }, + { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&mp_stream_write1_obj) }, + { MP_ROM_QSTR(MP_QSTR_sendall), MP_ROM_PTR(&mp_stream_write_obj) }, + #endif { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) }, { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) }, #if MICROPY_PY_SSL_FINALISER @@ -879,8 +943,12 @@ static const mp_rom_map_elem_t mp_module_tls_globals_table[] = { // Constants. { MP_ROM_QSTR(MP_QSTR_MBEDTLS_VERSION), MP_ROM_PTR(&mbedtls_version_obj)}, - { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MBEDTLS_SSL_IS_CLIENT) }, - { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MBEDTLS_SSL_IS_SERVER) }, + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_CLIENT), MP_ROM_INT(MP_PROTOCOL_TLS_CLIENT) }, + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_TLS_SERVER), MP_ROM_INT(MP_PROTOCOL_TLS_SERVER) }, + #ifdef MBEDTLS_SSL_PROTO_DTLS + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_DTLS_CLIENT), MP_ROM_INT(MP_PROTOCOL_DTLS_CLIENT) }, + { MP_ROM_QSTR(MP_QSTR_PROTOCOL_DTLS_SERVER), MP_ROM_INT(MP_PROTOCOL_DTLS_SERVER) }, + #endif { MP_ROM_QSTR(MP_QSTR_CERT_NONE), MP_ROM_INT(MBEDTLS_SSL_VERIFY_NONE) }, { MP_ROM_QSTR(MP_QSTR_CERT_OPTIONAL), MP_ROM_INT(MBEDTLS_SSL_VERIFY_OPTIONAL) }, { MP_ROM_QSTR(MP_QSTR_CERT_REQUIRED), MP_ROM_INT(MBEDTLS_SSL_VERIFY_REQUIRED) }, diff --git a/ports/esp32/boards/sdkconfig.base b/ports/esp32/boards/sdkconfig.base index e20835c70c42d..530db427119ca 100644 --- a/ports/esp32/boards/sdkconfig.base +++ b/ports/esp32/boards/sdkconfig.base @@ -64,6 +64,9 @@ CONFIG_MBEDTLS_HAVE_TIME_DATE=y CONFIG_MBEDTLS_PLATFORM_TIME_ALT=y CONFIG_MBEDTLS_HAVE_TIME=y +# Enable DTLS +CONFIG_MBEDTLS_SSL_PROTO_DTLS=y + # Disable ALPN support as it's not implemented in MicroPython CONFIG_MBEDTLS_SSL_ALPN=n diff --git a/tests/extmod/tls_dtls.py b/tests/extmod/tls_dtls.py new file mode 100644 index 0000000000000..b2d716769d3f7 --- /dev/null +++ b/tests/extmod/tls_dtls.py @@ -0,0 +1,51 @@ +# Test DTLS functionality including timeout handling + +try: + from tls import PROTOCOL_DTLS_CLIENT, PROTOCOL_DTLS_SERVER, SSLContext, CERT_NONE + import io +except ImportError: + print("SKIP") + raise SystemExit + + +class DummySocket(io.IOBase): + def __init__(self): + self.write_buffer = bytearray() + self.read_buffer = bytearray() + + def write(self, data): + return len(data) + + def readinto(self, buf): + # This is a placeholder socket that doesn't actually read anything + # so the read buffer is always empty. + return None + + def ioctl(self, req, arg): + if req == 4: # MP_STREAM_CLOSE + return 0 + return -1 + + +# Create dummy sockets for testing +server_socket = DummySocket() +client_socket = DummySocket() + +# Wrap the DTLS Server +dtls_server_ctx = SSLContext(PROTOCOL_DTLS_SERVER) +dtls_server_ctx.verify_mode = CERT_NONE +dtls_server = dtls_server_ctx.wrap_socket(server_socket, do_handshake_on_connect=False) +print("Wrapped DTLS Server") + +# Wrap the DTLS Client +dtls_client_ctx = SSLContext(PROTOCOL_DTLS_CLIENT) +dtls_client_ctx.verify_mode = CERT_NONE +dtls_client = dtls_client_ctx.wrap_socket(client_socket, do_handshake_on_connect=False) +print("Wrapped DTLS Client") + +# Trigger the timing check multiple times with different elapsed times +for i in range(10): # Try multiple iterations to hit the timing window + dtls_client.write(b"test") + data = dtls_server.read(1024) # This should eventually hit the timing condition + +print("OK") diff --git a/tests/extmod/tls_dtls.py.exp b/tests/extmod/tls_dtls.py.exp new file mode 100644 index 0000000000000..78d72bff18816 --- /dev/null +++ b/tests/extmod/tls_dtls.py.exp @@ -0,0 +1,3 @@ +Wrapped DTLS Server +Wrapped DTLS Client +OK diff --git a/tests/multi_net/tls_dtls_server_client.py b/tests/multi_net/tls_dtls_server_client.py new file mode 100644 index 0000000000000..d50deb354ed4d --- /dev/null +++ b/tests/multi_net/tls_dtls_server_client.py @@ -0,0 +1,89 @@ +# Test DTLS server and client, sending a small amount of data between them. + +try: + import socket + import tls +except ImportError: + print("SKIP") + raise SystemExit + +PORT = 8000 + +# These are test certificates. See tests/README.md for details. +certfile = "ec_cert.der" +keyfile = "ec_key.der" + +try: + with open(certfile, "rb") as cf: + cert = cadata = cf.read() + with open(keyfile, "rb") as kf: + key = kf.read() +except OSError: + print("SKIP") + raise SystemExit + + +# DTLS server. +def instance0(): + multitest.globals(IP=multitest.get_network_ip()) + + # Create a UDP socket and bind it to accept incoming connections. + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1]) + + multitest.next() + + # Wait for the client to connect. + data, client_addr = s.recvfrom(1) + print("incoming connection", data) + + # Connect back to the client, so the UDP socket can be used like a stream. + s.connect(client_addr) + + # Create the DTLS context and load the certificate. + ctx = tls.SSLContext(tls.PROTOCOL_DTLS_SERVER) + ctx.load_cert_chain(cert, key) + + # Wrap the UDP socket in server mode. + print("wrap socket") + s = ctx.wrap_socket(s, server_side=1) + + # Transfer some data. + for _ in range(4): + print(s.recv(16)) + s.send(b"server to client") + + # Close the DTLS and UDP connection. + s.close() + + +# DTLS client. +def instance1(): + multitest.next() + + # Create a UDP socket and connect to the server. + addr = socket.getaddrinfo(IP, PORT)[0][-1] + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + print("connect") + s.connect(addr) + + # Send one byte to indicate a connection, and so the server can obtain our address. + s.write("X") + + # Create a DTLS context and load the certificate. + ctx = tls.SSLContext(tls.PROTOCOL_DTLS_CLIENT) + ctx.verify_mode = tls.CERT_REQUIRED + ctx.load_verify_locations(cadata) + + # Wrap the UDP socket. + print("wrap socket") + s = ctx.wrap_socket(s, server_hostname="micropython.local") + + # Transfer some data. + for _ in range(4): + s.send(b"client to server") + print(s.recv(16)) + + # Close the DTLS and UDP connection. + s.close() diff --git a/tests/multi_net/tls_dtls_server_client.py.exp b/tests/multi_net/tls_dtls_server_client.py.exp new file mode 100644 index 0000000000000..f2ff396e181df --- /dev/null +++ b/tests/multi_net/tls_dtls_server_client.py.exp @@ -0,0 +1,14 @@ +--- instance0 --- +incoming connection b'X' +wrap socket +b'client to server' +b'client to server' +b'client to server' +b'client to server' +--- instance1 --- +connect +wrap socket +b'server to client' +b'server to client' +b'server to client' +b'server to client'