Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions extmod/mbedtls/mbedtls_config_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#define MBEDTLS_KEY_EXCHANGE_RSA_ENABLED
#define MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED
#define MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED
#define MBEDTLS_KEY_EXCHANGE_PSK_ENABLED
#define MBEDTLS_CAN_ECDH
#define MBEDTLS_PK_CAN_ECDSA_SIGN
#define MBEDTLS_PKCS1_V15
Expand Down
186 changes: 182 additions & 4 deletions extmod/modtls_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#endif
#include "mbedtls/debug.h"
#include "mbedtls/error.h"
#include "mbedtls/ssl_ciphersuites.h"
#if MBEDTLS_VERSION_NUMBER >= 0x03000000
#include "mbedtls/build_info.h"
#else
Expand Down Expand Up @@ -127,6 +128,9 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname,
mp_obj_t client_id);

// Helper function to check if a ciphersuite uses PSK
static bool ciphersuite_uses_psk(const mbedtls_ssl_ciphersuite_t *info);

/******************************************************************************/
// Helper functions.

Expand Down Expand Up @@ -293,6 +297,8 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args
self->ecdsa_sign_callback = mp_const_none;
#endif

// Initialize PSK fields

#ifdef MBEDTLS_DEBUG_C
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(3);
Expand Down Expand Up @@ -408,10 +414,93 @@ static mp_obj_t ssl_context_get_ciphers(mp_obj_t self_in) {
}
static MP_DEFINE_CONST_FUN_OBJ_1(ssl_context_get_ciphers_obj, ssl_context_get_ciphers);

// Helper function to set PSK ciphersuites
static void set_psk_ciphersuites(mbedtls_ssl_config *conf) {
// Create a list of PSK ciphersuites
static int *psk_ciphersuites = NULL;

if (psk_ciphersuites == NULL) {
// Define known PSK ciphersuites
// These are common PSK ciphersuites supported by mbedtls
static const int known_psk_ciphersuites[] = {
MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA256,
MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA,
MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA,
MBEDTLS_TLS_PSK_WITH_AES_128_GCM_SHA256,
MBEDTLS_TLS_PSK_WITH_AES_256_GCM_SHA384,
0 // Terminating zero
};

// Count available PSK ciphersuites
int count = 0;
for (int i = 0; known_psk_ciphersuites[i] != 0; i++) {
count++;
}
Comment on lines +423 to +438
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Get the actual authoritative list of supported PSK suites by filtering the output from mbedtls_ssl_list_ciphersuites using mbedtls_ssl_ciphersuite_from_id and mbedtls_ssl_ciphersuite_uses_psk.


// Allocate memory for PSK ciphersuites
psk_ciphersuites = m_new(int, count + 1);
if (psk_ciphersuites == NULL) {
mp_raise_OSError(MP_ENOMEM);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant with the m_malloc_fail path already present in the allocator.

}

// Copy the PSK ciphersuites
for (int i = 0; i <= count; i++) { // Include terminating zero
psk_ciphersuites[i] = known_psk_ciphersuites[i];
}
}

// Set PSK ciphersuites
mbedtls_ssl_conf_ciphersuites(conf, psk_ciphersuites);
}

// SSLContext.set_ciphers(ciphersuite)
static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite) {
mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(self_in);

// Check if ciphersuite is a string
if (mp_obj_is_str(ciphersuite)) {
const char *ciphername = mp_obj_str_get_str(ciphersuite);

// Check for generic "PSK" mode
if (strcmp(ciphername, "PSK") == 0) {
set_psk_ciphersuites(&ssl_context->conf);
return mp_const_none;
}

// Try to look up the ciphersuite using mbedtls API
const mbedtls_ssl_ciphersuite_t *info = mbedtls_ssl_ciphersuite_from_string(ciphername);
if (info != NULL) {
// Check if this is a PSK ciphersuite
if (ciphersuite_uses_psk(info)) {
// Create a ciphersuite array with just this one ciphersuite
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should PSK ciphers be singleton, rather than parsed and accumulated as part of the existing list code?

ssl_context->ciphersuites = m_new(int, 2);
if (ssl_context->ciphersuites == NULL) {
mp_raise_OSError(MP_ENOMEM);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant with the m_malloc_fail path already present in the allocator.

}
ssl_context->ciphersuites[0] = mbedtls_ssl_ciphersuite_get_id(info);
ssl_context->ciphersuites[1] = 0; // Terminating zero

// Configure the ciphersuite
mbedtls_ssl_conf_ciphersuites(&ssl_context->conf, (const int *)ssl_context->ciphersuites);
return mp_const_none;
} else {
// Not a PSK ciphersuite, but it's a valid ciphersuite name
// Fall through to handle it as a regular single ciphersuite
ssl_context->ciphersuites = m_new(int, 2);
if (ssl_context->ciphersuites == NULL) {
mp_raise_OSError(MP_ENOMEM);
}
ssl_context->ciphersuites[0] = mbedtls_ssl_ciphersuite_get_id(info);
ssl_context->ciphersuites[1] = 0; // Terminating zero

// Configure the ciphersuite
mbedtls_ssl_conf_ciphersuites(&ssl_context->conf, (const int *)ssl_context->ciphersuites);
return mp_const_none;
}
}
}

// Original implementation for non-PSK ciphersuites
// Check that ciphersuite is a list or tuple.
size_t len = 0;
mp_obj_t *ciphers;
Expand All @@ -420,15 +509,15 @@ static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite)
mbedtls_raise_error(MBEDTLS_ERR_SSL_BAD_CONFIG);
}

// Parse list of ciphers.
// Parse list of ciphers using mbedtls API for validation.
ssl_context->ciphersuites = m_new(int, len + 1);
for (size_t i = 0; i < len; ++i) {
const char *ciphername = mp_obj_str_get_str(ciphers[i]);
const int id = mbedtls_ssl_get_ciphersuite_id(ciphername);
if (id == 0) {
const mbedtls_ssl_ciphersuite_t *info = mbedtls_ssl_ciphersuite_from_string(ciphername);
if (info == NULL) {
mbedtls_raise_error(MBEDTLS_ERR_SSL_BAD_CONFIG);
}
ssl_context->ciphersuites[i] = id;
ssl_context->ciphersuites[i] = mbedtls_ssl_ciphersuite_get_id(info);
}
ssl_context->ciphersuites[len] = 0;

Expand All @@ -439,6 +528,46 @@ static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite)
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_ciphers_obj, ssl_context_set_ciphers);

// Helper function to check if a ciphersuite uses PSK
static bool ciphersuite_uses_psk(const mbedtls_ssl_ciphersuite_t *info) {
if (info == NULL) {
return false;
}

// Check if ciphersuite ID corresponds to any PSK ciphersuite
int id = mbedtls_ssl_ciphersuite_get_id(info);

// Check for common PSK ciphersuites based on their IDs
// These correspond to the MBEDTLS_TLS_*_PSK_* constants
return (id == 0x2C || // MBEDTLS_TLS_PSK_WITH_NULL_SHA
id == 0x2D || // MBEDTLS_TLS_DHE_PSK_WITH_NULL_SHA
id == 0x2E || // MBEDTLS_TLS_RSA_PSK_WITH_NULL_SHA
id == 0x8C || // MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA
id == 0x8D || // MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA
id == 0x90 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_128_CBC_SHA
id == 0x91 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_256_CBC_SHA
id == 0x94 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_128_CBC_SHA
id == 0x95 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_256_CBC_SHA
id == 0xA8 || // MBEDTLS_TLS_PSK_WITH_AES_128_GCM_SHA256
id == 0xA9 || // MBEDTLS_TLS_PSK_WITH_AES_256_GCM_SHA384
id == 0xAA || // MBEDTLS_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256
id == 0xAB || // MBEDTLS_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384
id == 0xAC || // MBEDTLS_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256
id == 0xAD || // MBEDTLS_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384
id == 0xAE || // MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA256
id == 0xAF || // MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA384
id == 0xB0 || // MBEDTLS_TLS_PSK_WITH_NULL_SHA256
id == 0xB1 || // MBEDTLS_TLS_PSK_WITH_NULL_SHA384
id == 0xB2 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256
id == 0xB3 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384
id == 0xB4 || // MBEDTLS_TLS_DHE_PSK_WITH_NULL_SHA256
id == 0xB5 || // MBEDTLS_TLS_DHE_PSK_WITH_NULL_SHA384
id == 0xB6 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256
id == 0xB7 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384
id == 0xB8 || // MBEDTLS_TLS_RSA_PSK_WITH_NULL_SHA256
id == 0xB9); // MBEDTLS_TLS_RSA_PSK_WITH_NULL_SHA384
}

Comment on lines +531 to +570
Copy link
Contributor

@AJMansfield AJMansfield Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity the actual mbedtls_ssl_ciphersuite_uses_psk is only part of mbedtls/library/ssl_ciphersuites_internal.h header rather than a public include --- but IMO it's still better to just link against it anyway instead of re-implementing it. Just declare it yourself the same way its header does without defining it:

int mbedtls_ssl_ciphersuite_uses_psk(const mbedtls_ssl_ciphersuite_t *info);

If that's too spicy, add MBEDTLS_ALLOW_PRIVATE_ACCESS to mbedtls_config_common.h and examine the value of info->key_exchange against the four values it can have that indicate PSK suites.

Copy link
Contributor

@AJMansfield AJMansfield Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, don't hardcode the numeric values, use the symbols! You can just id == MBEDTLS_TLS_PSK_WITH_NULL_SHA, you already even have the mbedtls/ssl_ciphersuites.h header that defines those.

static void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) {
size_t key_len;
const unsigned char *key = asn1_get_data(key_obj, &key_len);
Expand Down Expand Up @@ -493,6 +622,23 @@ static mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations);

// SSLContext.set_psk_identity(identity) and set_psk_key(key)
// These methods now configure PSK directly with mbedtls instead of storing values
static mp_obj_t psk_identity = mp_const_none;
static mp_obj_t psk_key = mp_const_none;

static mp_obj_t ssl_context_set_psk_identity(mp_obj_t self_in, mp_obj_t identity) {
psk_identity = identity;
return mp_const_none;
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_identity_obj, ssl_context_set_psk_identity);

static mp_obj_t ssl_context_set_psk_key(mp_obj_t self_in, mp_obj_t key) {
psk_key = key;
return mp_const_none;
}
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_key_obj, ssl_context_set_psk_key);

Comment on lines +625 to +641
Copy link
Contributor

@AJMansfield AJMansfield Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did ChatGPT write this? Because that's a lie:

These methods now configure PSK directly with mbedtls instead of storing values

Copy link
Contributor

@AJMansfield AJMansfield Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra variables in mp_obj_ssl_context_t was a bit sub-optimal, but using shared static variables like this is completely unacceptable. Even if we ignore mutliple contexts clobbering each other, it's also a use-after-free bug. (Can you spot it? It's because BSS variables aren't visible to micropython's garbage collector unless you give them a root-pointer decorator. The key and identity values here could be eligible for collection the instant these functions return.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point of my earlier comment was this: The mp_obj_ssl_context_t object already has a place to store these values, in its contained mbedtls_ssl_config value --- i.e. the place that calling mbedtls_ssl_conf_psk stores them.
The API here should just be the thinnest possible veneer over that function to make it callable from micropython --- just a function that takes two string arguments, validates that they're strings, extracts the underlying string pointers, and calls mbedtls_ssl_conf_psk with them.

static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) {
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_client_id };
static const mp_arg_t allowed_args[] = {
Expand Down Expand Up @@ -530,6 +676,8 @@ static const mp_rom_map_elem_t ssl_context_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_load_cert_chain), MP_ROM_PTR(&ssl_context_load_cert_chain_obj)},
{ MP_ROM_QSTR(MP_QSTR_load_verify_locations), MP_ROM_PTR(&ssl_context_load_verify_locations_obj)},
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) },
{ MP_ROM_QSTR(MP_QSTR_set_psk_identity), MP_ROM_PTR(&ssl_context_set_psk_identity_obj) },
{ MP_ROM_QSTR(MP_QSTR_set_psk_key), MP_ROM_PTR(&ssl_context_set_psk_key_obj) },
};
static MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table);

Expand Down Expand Up @@ -637,6 +785,36 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t

mbedtls_ssl_init(&o->ssl);

// Configure PSK if PSK ciphersuites are enabled and PSK data is available
if (psk_identity != mp_const_none && psk_key != mp_const_none) {
// Check if any of the configured ciphersuites use PSK
bool has_psk_cipher = false;
if (ssl_context->ciphersuites != NULL) {
for (int i = 0; ssl_context->ciphersuites[i] != 0; i++) {
const mbedtls_ssl_ciphersuite_t *info = mbedtls_ssl_ciphersuite_from_id(ssl_context->ciphersuites[i]);
if (info != NULL && ciphersuite_uses_psk(info)) {
has_psk_cipher = true;
break;
}
}
}

if (has_psk_cipher) {
// Get PSK identity and key
size_t psk_identity_len;
const byte *psk_identity_data = (const byte *)mp_obj_str_get_data(psk_identity, &psk_identity_len);

size_t psk_key_len;
const byte *psk_key_data = (const byte *)mp_obj_str_get_data(psk_key, &psk_key_len);

// Configure PSK
ret = mbedtls_ssl_conf_psk(&ssl_context->conf, psk_key_data, psk_key_len, psk_identity_data, psk_identity_len);
if (ret != 0) {
goto cleanup;
}
}
}

ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf);
#if !MICROPY_MBEDTLS_CONFIG_BARE_METAL
if (ret != 0) {
Expand Down
52 changes: 52 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Test TCP server and client with TLS-PSK, using set_psk_identity(),
# set_psk_key(), and set_ciphers("PSK").

try:
import socket
import tls
except ImportError:
print("SKIP")
raise SystemExit

PORT = 8000

PSK_ID = "PSK-Identity-1"
PSK_KEY = "c0ffee"
PSK_CIPHER = "PSK"


# Server
def instance0():
multitest.globals(IP=multitest.get_network_ip())
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1])
s.listen(1)
multitest.next()
s2, _ = s.accept()
server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER)
# Configure PSK
server_ctx.set_psk_identity(PSK_ID)
server_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
server_ctx.set_ciphers(PSK_CIPHER)
s2 = server_ctx.wrap_socket(s2, server_side=True)
print(s2.read(16))
s2.write(b"server to client")
s2.close()
s.close()


# Client
def instance1():
multitest.next()
s = socket.socket()
s.connect(socket.getaddrinfo(IP, PORT)[0][-1])
client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
# Configure PSK
client_ctx.set_psk_identity(PSK_ID)
client_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
client_ctx.set_ciphers(PSK_CIPHER)
s = client_ctx.wrap_socket(s, server_hostname="micropython.local")
s.write(b"client to server")
print(s.read(16))
s.close()
4 changes: 4 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--- instance0 ---
b'client to server'
--- instance1 ---
b'server to client'
52 changes: 52 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk_cipher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Test TCP server and client with TLS-PSK, using set_psk_identity(),
# set_psk_key(), and set_ciphers("TLS-PSK-WITH-AES-128-CBC-SHA256").

try:
import socket
import tls
except ImportError:
print("SKIP")
raise SystemExit

PORT = 8000

PSK_ID = "PSK-Identity-1"
PSK_KEY = "c0ffee"
PSK_CIPHER = "TLS-PSK-WITH-AES-128-CBC-SHA256"


# Server
def instance0():
multitest.globals(IP=multitest.get_network_ip())
s = socket.socket()
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1])
s.listen(1)
multitest.next()
s2, _ = s.accept()
server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER)
# Configure PSK with specific ciphersuite
server_ctx.set_psk_identity(PSK_ID)
server_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
server_ctx.set_ciphers(PSK_CIPHER)
s2 = server_ctx.wrap_socket(s2, server_side=True)
print(s2.read(16))
s2.write(b"server to client")
s2.close()
s.close()


# Client
def instance1():
multitest.next()
s = socket.socket()
s.connect(socket.getaddrinfo(IP, PORT)[0][-1])
client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT)
# Configure PSK with specific ciphersuite
client_ctx.set_psk_identity(PSK_ID)
client_ctx.set_psk_key(bytes.fromhex(PSK_KEY))
client_ctx.set_ciphers(PSK_CIPHER)
s = client_ctx.wrap_socket(s, server_hostname="micropython.local")
s.write(b"client to server")
print(s.read(16))
s.close()
4 changes: 4 additions & 0 deletions tests/multi_net/sslcontext_server_client_psk_cipher.py.exp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
--- instance0 ---
b'client to server'
--- instance1 ---
b'server to client'
Loading