Skip to content

add modussl_mbedtls.c methods and exceptions. esp32/unix #5436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions extmod/modussl_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ struct ssl_args {

STATIC const mp_obj_type_t ussl_socket_type;

// Set of exceptions returned by recv, send or
// do_hanshake based on the status of the
// connection.
MP_DEFINE_EXCEPTION(SSLError, Exception)
MP_DEFINE_EXCEPTION(SSLWantReadError, SSLError)
MP_DEFINE_EXCEPTION(SSLWantWriteError, SSLError)
MP_DEFINE_EXCEPTION(SSLInProgress, SSLError)

#ifdef MBEDTLS_DEBUG_C
STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) {
(void)ctx;
Expand Down Expand Up @@ -311,6 +319,93 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
}

// Since Python 3.8, ssl.read has been deprecated in favor of
// recv. This will make plain tcp sockets and (u)ssl share the
// same call which can let users write abstractions easily.
// Since SSL sockets are often used as non-blocking, this method
// will raise ussl.SSLWantReadError or ussl.SSlWantWriteError
// depending on the error code returned by mbedtls_ssl_read.
STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t size) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);

// we set a default value of 1024 by default
// for embedded systems.
mp_int_t buffer_size = 1024;
if (!mp_obj_get_int_maybe(size, &buffer_size)) {
buffer_size = 1024;
}
// make sure we read at most 1 byte or the connection
// will be maked as closed even if the call was sucessful.
if (buffer_size <= 0) {
mp_raise_ValueError("recv argument must be greater than 0");
}

byte buff[buffer_size];
int ret = mbedtls_ssl_read(&o->ssl, buff, (size_t)buffer_size);
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
mp_raise_msg(&mp_type_SSLWantWriteError, NULL);
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
mp_raise_msg(&mp_type_SSLWantReadError, NULL);
} else if (ret < 0) {
mp_raise_OSError(MP_EIO);
}
return mp_obj_new_bytes(buff, ret);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_recv, socket_recv);

// Since Python 3.8, ssl.write has been deprecated in favor of
// send. This will make plain tcp sockets and (u)ssl share the
// same call which can let users write abtractions easily.
// Since SSL sockets are often used as non-blocking, this method
// will raise ussl.SSLWantReadError or ussl.SSlWantWriteError
// depending on the error code returned by mbedtls_ssl_write.
STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t data) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);

mp_buffer_info_t bufinfo;
mp_get_buffer_raise(data, &bufinfo, MP_BUFFER_READ);

int ret = mbedtls_ssl_write(&o->ssl, bufinfo.buf, bufinfo.len);
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
mp_raise_msg(&mp_type_SSLWantWriteError, NULL);
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
mp_raise_msg(&mp_type_SSLWantReadError, NULL);
} else if (ret < 0) {
mp_raise_OSError(MP_EIO);
}
return mp_obj_new_int(ret);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_send, socket_send);

// If ussl.wrap_socket is called with do_handshake = False,
// the user has now the possibility to call do_handshake
// later before calling any of recv or send.
// ussl.SSLInProgress is raised if all the the socket has
// POLLIN | POLLOUT available but the internal buffer is not
// drained.
// Any other error should be considered as critical and the
// user MUST stop using the current socket.
STATIC mp_obj_t socket_do_handshake(mp_obj_t self_in) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);

int ret = mbedtls_ssl_handshake(&o->ssl);
if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
mp_raise_msg(&mp_type_SSLWantWriteError, NULL);
} else if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
mp_raise_msg(&mp_type_SSLWantReadError, NULL);
#ifdef MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS // not defined in the esp-idf version.
} else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
#else
} else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS) {
#endif
mp_raise_msg(&mp_type_SSLInProgress, NULL);
} else if (ret != 0) {
mp_raise_OSError(MP_EIO);
}
return mp_obj_new_int(ret);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_ssl_do_handshake, socket_do_handshake);

STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
Expand All @@ -322,6 +417,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
#endif
{ MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) },
{ MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&mod_ssl_recv) },
{ MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&mod_ssl_send) },
{ MP_ROM_QSTR(MP_QSTR_do_handshake), MP_ROM_PTR(&mod_ssl_do_handshake) },
};

STATIC MP_DEFINE_CONST_DICT(ussl_socket_locals_dict, ussl_socket_locals_dict_table);
Expand Down Expand Up @@ -367,6 +465,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke
STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = {
{ MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ussl) },
{ MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) },
{ MP_ROM_QSTR(MP_QSTR_SSLWantReadError), MP_OBJ_FROM_PTR(&mp_type_SSLWantReadError) },
{ MP_ROM_QSTR(MP_QSTR_SSLWantWriteError), MP_OBJ_FROM_PTR(&mp_type_SSLWantWriteError) },
{ MP_ROM_QSTR(MP_QSTR_SSLInProgress), MP_OBJ_FROM_PTR(&mp_type_SSLInProgress) },
};

STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);
Expand Down