-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
extmod/modtls_mbedtls: Add support for TLS PSK #17074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,6 +53,7 @@ | |
#endif | ||
#include "mbedtls/debug.h" | ||
#include "mbedtls/error.h" | ||
#include "mbedtls/ssl_ciphersuites.h" | ||
#if MBEDTLS_VERSION_NUMBER >= 0x03000000 | ||
#include "mbedtls/build_info.h" | ||
#else | ||
|
@@ -127,6 +128,9 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t | |
bool server_side, bool do_handshake_on_connect, mp_obj_t server_hostname, | ||
mp_obj_t client_id); | ||
|
||
// Helper function to check if a ciphersuite uses PSK | ||
static bool ciphersuite_uses_psk(const mbedtls_ssl_ciphersuite_t *info); | ||
|
||
/******************************************************************************/ | ||
// Helper functions. | ||
|
||
|
@@ -293,6 +297,8 @@ static mp_obj_t ssl_context_make_new(const mp_obj_type_t *type_in, size_t n_args | |
self->ecdsa_sign_callback = mp_const_none; | ||
#endif | ||
|
||
// Initialize PSK fields | ||
|
||
#ifdef MBEDTLS_DEBUG_C | ||
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose | ||
mbedtls_debug_set_threshold(3); | ||
|
@@ -408,10 +414,93 @@ static mp_obj_t ssl_context_get_ciphers(mp_obj_t self_in) { | |
} | ||
static MP_DEFINE_CONST_FUN_OBJ_1(ssl_context_get_ciphers_obj, ssl_context_get_ciphers); | ||
|
||
// Helper function to set PSK ciphersuites | ||
static void set_psk_ciphersuites(mbedtls_ssl_config *conf) { | ||
// Create a list of PSK ciphersuites | ||
static int *psk_ciphersuites = NULL; | ||
|
||
if (psk_ciphersuites == NULL) { | ||
// Define known PSK ciphersuites | ||
// These are common PSK ciphersuites supported by mbedtls | ||
static const int known_psk_ciphersuites[] = { | ||
MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA256, | ||
MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA, | ||
MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA, | ||
MBEDTLS_TLS_PSK_WITH_AES_128_GCM_SHA256, | ||
MBEDTLS_TLS_PSK_WITH_AES_256_GCM_SHA384, | ||
0 // Terminating zero | ||
}; | ||
|
||
// Count available PSK ciphersuites | ||
int count = 0; | ||
for (int i = 0; known_psk_ciphersuites[i] != 0; i++) { | ||
count++; | ||
} | ||
|
||
// Allocate memory for PSK ciphersuites | ||
psk_ciphersuites = m_new(int, count + 1); | ||
if (psk_ciphersuites == NULL) { | ||
mp_raise_OSError(MP_ENOMEM); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant with the |
||
} | ||
|
||
// Copy the PSK ciphersuites | ||
for (int i = 0; i <= count; i++) { // Include terminating zero | ||
psk_ciphersuites[i] = known_psk_ciphersuites[i]; | ||
} | ||
} | ||
|
||
// Set PSK ciphersuites | ||
mbedtls_ssl_conf_ciphersuites(conf, psk_ciphersuites); | ||
} | ||
|
||
// SSLContext.set_ciphers(ciphersuite) | ||
static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite) { | ||
mp_obj_ssl_context_t *ssl_context = MP_OBJ_TO_PTR(self_in); | ||
|
||
// Check if ciphersuite is a string | ||
if (mp_obj_is_str(ciphersuite)) { | ||
const char *ciphername = mp_obj_str_get_str(ciphersuite); | ||
|
||
// Check for generic "PSK" mode | ||
if (strcmp(ciphername, "PSK") == 0) { | ||
set_psk_ciphersuites(&ssl_context->conf); | ||
return mp_const_none; | ||
} | ||
|
||
// Try to look up the ciphersuite using mbedtls API | ||
const mbedtls_ssl_ciphersuite_t *info = mbedtls_ssl_ciphersuite_from_string(ciphername); | ||
if (info != NULL) { | ||
// Check if this is a PSK ciphersuite | ||
if (ciphersuite_uses_psk(info)) { | ||
// Create a ciphersuite array with just this one ciphersuite | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why should PSK ciphers be singleton, rather than parsed and accumulated as part of the existing list code? |
||
ssl_context->ciphersuites = m_new(int, 2); | ||
if (ssl_context->ciphersuites == NULL) { | ||
mp_raise_OSError(MP_ENOMEM); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redundant with the |
||
} | ||
ssl_context->ciphersuites[0] = mbedtls_ssl_ciphersuite_get_id(info); | ||
ssl_context->ciphersuites[1] = 0; // Terminating zero | ||
|
||
// Configure the ciphersuite | ||
mbedtls_ssl_conf_ciphersuites(&ssl_context->conf, (const int *)ssl_context->ciphersuites); | ||
return mp_const_none; | ||
} else { | ||
// Not a PSK ciphersuite, but it's a valid ciphersuite name | ||
// Fall through to handle it as a regular single ciphersuite | ||
ssl_context->ciphersuites = m_new(int, 2); | ||
if (ssl_context->ciphersuites == NULL) { | ||
mp_raise_OSError(MP_ENOMEM); | ||
} | ||
ssl_context->ciphersuites[0] = mbedtls_ssl_ciphersuite_get_id(info); | ||
ssl_context->ciphersuites[1] = 0; // Terminating zero | ||
|
||
// Configure the ciphersuite | ||
mbedtls_ssl_conf_ciphersuites(&ssl_context->conf, (const int *)ssl_context->ciphersuites); | ||
return mp_const_none; | ||
} | ||
} | ||
} | ||
|
||
// Original implementation for non-PSK ciphersuites | ||
// Check that ciphersuite is a list or tuple. | ||
size_t len = 0; | ||
mp_obj_t *ciphers; | ||
|
@@ -420,15 +509,15 @@ static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite) | |
mbedtls_raise_error(MBEDTLS_ERR_SSL_BAD_CONFIG); | ||
} | ||
|
||
// Parse list of ciphers. | ||
// Parse list of ciphers using mbedtls API for validation. | ||
ssl_context->ciphersuites = m_new(int, len + 1); | ||
for (size_t i = 0; i < len; ++i) { | ||
const char *ciphername = mp_obj_str_get_str(ciphers[i]); | ||
const int id = mbedtls_ssl_get_ciphersuite_id(ciphername); | ||
if (id == 0) { | ||
const mbedtls_ssl_ciphersuite_t *info = mbedtls_ssl_ciphersuite_from_string(ciphername); | ||
if (info == NULL) { | ||
mbedtls_raise_error(MBEDTLS_ERR_SSL_BAD_CONFIG); | ||
} | ||
ssl_context->ciphersuites[i] = id; | ||
ssl_context->ciphersuites[i] = mbedtls_ssl_ciphersuite_get_id(info); | ||
} | ||
ssl_context->ciphersuites[len] = 0; | ||
|
||
|
@@ -439,6 +528,46 @@ static mp_obj_t ssl_context_set_ciphers(mp_obj_t self_in, mp_obj_t ciphersuite) | |
} | ||
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_ciphers_obj, ssl_context_set_ciphers); | ||
|
||
// Helper function to check if a ciphersuite uses PSK | ||
static bool ciphersuite_uses_psk(const mbedtls_ssl_ciphersuite_t *info) { | ||
if (info == NULL) { | ||
return false; | ||
} | ||
|
||
// Check if ciphersuite ID corresponds to any PSK ciphersuite | ||
int id = mbedtls_ssl_ciphersuite_get_id(info); | ||
|
||
// Check for common PSK ciphersuites based on their IDs | ||
// These correspond to the MBEDTLS_TLS_*_PSK_* constants | ||
return (id == 0x2C || // MBEDTLS_TLS_PSK_WITH_NULL_SHA | ||
id == 0x2D || // MBEDTLS_TLS_DHE_PSK_WITH_NULL_SHA | ||
id == 0x2E || // MBEDTLS_TLS_RSA_PSK_WITH_NULL_SHA | ||
id == 0x8C || // MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA | ||
id == 0x8D || // MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA | ||
id == 0x90 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_128_CBC_SHA | ||
id == 0x91 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_256_CBC_SHA | ||
id == 0x94 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_128_CBC_SHA | ||
id == 0x95 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_256_CBC_SHA | ||
id == 0xA8 || // MBEDTLS_TLS_PSK_WITH_AES_128_GCM_SHA256 | ||
id == 0xA9 || // MBEDTLS_TLS_PSK_WITH_AES_256_GCM_SHA384 | ||
id == 0xAA || // MBEDTLS_TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 | ||
id == 0xAB || // MBEDTLS_TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 | ||
id == 0xAC || // MBEDTLS_TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 | ||
id == 0xAD || // MBEDTLS_TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 | ||
id == 0xAE || // MBEDTLS_TLS_PSK_WITH_AES_128_CBC_SHA256 | ||
id == 0xAF || // MBEDTLS_TLS_PSK_WITH_AES_256_CBC_SHA384 | ||
id == 0xB0 || // MBEDTLS_TLS_PSK_WITH_NULL_SHA256 | ||
id == 0xB1 || // MBEDTLS_TLS_PSK_WITH_NULL_SHA384 | ||
id == 0xB2 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 | ||
id == 0xB3 || // MBEDTLS_TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 | ||
id == 0xB4 || // MBEDTLS_TLS_DHE_PSK_WITH_NULL_SHA256 | ||
id == 0xB5 || // MBEDTLS_TLS_DHE_PSK_WITH_NULL_SHA384 | ||
id == 0xB6 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 | ||
id == 0xB7 || // MBEDTLS_TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 | ||
id == 0xB8 || // MBEDTLS_TLS_RSA_PSK_WITH_NULL_SHA256 | ||
id == 0xB9); // MBEDTLS_TLS_RSA_PSK_WITH_NULL_SHA384 | ||
} | ||
|
||
Comment on lines
+531
to
+570
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a pity the actual int mbedtls_ssl_ciphersuite_uses_psk(const mbedtls_ssl_ciphersuite_t *info); If that's too spicy, add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, don't hardcode the numeric values, use the symbols! You can just |
||
static void ssl_context_load_key(mp_obj_ssl_context_t *self, mp_obj_t key_obj, mp_obj_t cert_obj) { | ||
size_t key_len; | ||
const unsigned char *key = asn1_get_data(key_obj, &key_len); | ||
|
@@ -493,6 +622,23 @@ static mp_obj_t ssl_context_load_verify_locations(mp_obj_t self_in, mp_obj_t cad | |
} | ||
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_load_verify_locations_obj, ssl_context_load_verify_locations); | ||
|
||
// SSLContext.set_psk_identity(identity) and set_psk_key(key) | ||
// These methods now configure PSK directly with mbedtls instead of storing values | ||
static mp_obj_t psk_identity = mp_const_none; | ||
static mp_obj_t psk_key = mp_const_none; | ||
|
||
static mp_obj_t ssl_context_set_psk_identity(mp_obj_t self_in, mp_obj_t identity) { | ||
psk_identity = identity; | ||
return mp_const_none; | ||
} | ||
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_identity_obj, ssl_context_set_psk_identity); | ||
|
||
static mp_obj_t ssl_context_set_psk_key(mp_obj_t self_in, mp_obj_t key) { | ||
psk_key = key; | ||
return mp_const_none; | ||
} | ||
static MP_DEFINE_CONST_FUN_OBJ_2(ssl_context_set_psk_key_obj, ssl_context_set_psk_key); | ||
|
||
Comment on lines
+625
to
+641
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did ChatGPT write this? Because that's a lie:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The extra variables in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The point of my earlier comment was this: The |
||
static mp_obj_t ssl_context_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { | ||
enum { ARG_server_side, ARG_do_handshake_on_connect, ARG_server_hostname, ARG_client_id }; | ||
static const mp_arg_t allowed_args[] = { | ||
|
@@ -530,6 +676,8 @@ static const mp_rom_map_elem_t ssl_context_locals_dict_table[] = { | |
{ MP_ROM_QSTR(MP_QSTR_load_cert_chain), MP_ROM_PTR(&ssl_context_load_cert_chain_obj)}, | ||
{ MP_ROM_QSTR(MP_QSTR_load_verify_locations), MP_ROM_PTR(&ssl_context_load_verify_locations_obj)}, | ||
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&ssl_context_wrap_socket_obj) }, | ||
{ MP_ROM_QSTR(MP_QSTR_set_psk_identity), MP_ROM_PTR(&ssl_context_set_psk_identity_obj) }, | ||
{ MP_ROM_QSTR(MP_QSTR_set_psk_key), MP_ROM_PTR(&ssl_context_set_psk_key_obj) }, | ||
}; | ||
static MP_DEFINE_CONST_DICT(ssl_context_locals_dict, ssl_context_locals_dict_table); | ||
|
||
|
@@ -637,6 +785,36 @@ static mp_obj_t ssl_socket_make_new(mp_obj_ssl_context_t *ssl_context, mp_obj_t | |
|
||
mbedtls_ssl_init(&o->ssl); | ||
|
||
// Configure PSK if PSK ciphersuites are enabled and PSK data is available | ||
if (psk_identity != mp_const_none && psk_key != mp_const_none) { | ||
// Check if any of the configured ciphersuites use PSK | ||
bool has_psk_cipher = false; | ||
if (ssl_context->ciphersuites != NULL) { | ||
for (int i = 0; ssl_context->ciphersuites[i] != 0; i++) { | ||
const mbedtls_ssl_ciphersuite_t *info = mbedtls_ssl_ciphersuite_from_id(ssl_context->ciphersuites[i]); | ||
if (info != NULL && ciphersuite_uses_psk(info)) { | ||
has_psk_cipher = true; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
if (has_psk_cipher) { | ||
// Get PSK identity and key | ||
size_t psk_identity_len; | ||
const byte *psk_identity_data = (const byte *)mp_obj_str_get_data(psk_identity, &psk_identity_len); | ||
|
||
size_t psk_key_len; | ||
const byte *psk_key_data = (const byte *)mp_obj_str_get_data(psk_key, &psk_key_len); | ||
|
||
// Configure PSK | ||
ret = mbedtls_ssl_conf_psk(&ssl_context->conf, psk_key_data, psk_key_len, psk_identity_data, psk_identity_len); | ||
if (ret != 0) { | ||
goto cleanup; | ||
} | ||
} | ||
} | ||
|
||
ret = mbedtls_ssl_setup(&o->ssl, &ssl_context->conf); | ||
#if !MICROPY_MBEDTLS_CONFIG_BARE_METAL | ||
if (ret != 0) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Test TCP server and client with TLS-PSK, using set_psk_identity(), | ||
# set_psk_key(), and set_ciphers("PSK"). | ||
|
||
try: | ||
import socket | ||
import tls | ||
except ImportError: | ||
print("SKIP") | ||
raise SystemExit | ||
|
||
PORT = 8000 | ||
|
||
PSK_ID = "PSK-Identity-1" | ||
PSK_KEY = "c0ffee" | ||
PSK_CIPHER = "PSK" | ||
|
||
|
||
# Server | ||
def instance0(): | ||
multitest.globals(IP=multitest.get_network_ip()) | ||
s = socket.socket() | ||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1]) | ||
s.listen(1) | ||
multitest.next() | ||
s2, _ = s.accept() | ||
server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER) | ||
# Configure PSK | ||
server_ctx.set_psk_identity(PSK_ID) | ||
server_ctx.set_psk_key(bytes.fromhex(PSK_KEY)) | ||
server_ctx.set_ciphers(PSK_CIPHER) | ||
s2 = server_ctx.wrap_socket(s2, server_side=True) | ||
print(s2.read(16)) | ||
s2.write(b"server to client") | ||
s2.close() | ||
s.close() | ||
|
||
|
||
# Client | ||
def instance1(): | ||
multitest.next() | ||
s = socket.socket() | ||
s.connect(socket.getaddrinfo(IP, PORT)[0][-1]) | ||
client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) | ||
# Configure PSK | ||
client_ctx.set_psk_identity(PSK_ID) | ||
client_ctx.set_psk_key(bytes.fromhex(PSK_KEY)) | ||
client_ctx.set_ciphers(PSK_CIPHER) | ||
s = client_ctx.wrap_socket(s, server_hostname="micropython.local") | ||
s.write(b"client to server") | ||
print(s.read(16)) | ||
s.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- instance0 --- | ||
b'client to server' | ||
--- instance1 --- | ||
b'server to client' |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Test TCP server and client with TLS-PSK, using set_psk_identity(), | ||
# set_psk_key(), and set_ciphers("TLS-PSK-WITH-AES-128-CBC-SHA256"). | ||
|
||
try: | ||
import socket | ||
import tls | ||
except ImportError: | ||
print("SKIP") | ||
raise SystemExit | ||
|
||
PORT = 8000 | ||
|
||
PSK_ID = "PSK-Identity-1" | ||
PSK_KEY = "c0ffee" | ||
PSK_CIPHER = "TLS-PSK-WITH-AES-128-CBC-SHA256" | ||
|
||
|
||
# Server | ||
def instance0(): | ||
multitest.globals(IP=multitest.get_network_ip()) | ||
s = socket.socket() | ||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | ||
s.bind(socket.getaddrinfo("0.0.0.0", PORT)[0][-1]) | ||
s.listen(1) | ||
multitest.next() | ||
s2, _ = s.accept() | ||
server_ctx = tls.SSLContext(tls.PROTOCOL_TLS_SERVER) | ||
# Configure PSK with specific ciphersuite | ||
server_ctx.set_psk_identity(PSK_ID) | ||
server_ctx.set_psk_key(bytes.fromhex(PSK_KEY)) | ||
server_ctx.set_ciphers(PSK_CIPHER) | ||
s2 = server_ctx.wrap_socket(s2, server_side=True) | ||
print(s2.read(16)) | ||
s2.write(b"server to client") | ||
s2.close() | ||
s.close() | ||
|
||
|
||
# Client | ||
def instance1(): | ||
multitest.next() | ||
s = socket.socket() | ||
s.connect(socket.getaddrinfo(IP, PORT)[0][-1]) | ||
client_ctx = tls.SSLContext(tls.PROTOCOL_TLS_CLIENT) | ||
# Configure PSK with specific ciphersuite | ||
client_ctx.set_psk_identity(PSK_ID) | ||
client_ctx.set_psk_key(bytes.fromhex(PSK_KEY)) | ||
client_ctx.set_ciphers(PSK_CIPHER) | ||
s = client_ctx.wrap_socket(s, server_hostname="micropython.local") | ||
s.write(b"client to server") | ||
print(s.read(16)) | ||
s.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
--- instance0 --- | ||
b'client to server' | ||
--- instance1 --- | ||
b'server to client' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Get the actual authoritative list of supported PSK suites by filtering the output from
mbedtls_ssl_list_ciphersuites
usingmbedtls_ssl_ciphersuite_from_id
andmbedtls_ssl_ciphersuite_uses_psk
.