diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c index a71adc5b366da..86bc4c8c405b1 100644 --- a/extmod/modussl_mbedtls.c +++ b/extmod/modussl_mbedtls.c @@ -66,6 +66,14 @@ struct ssl_args { STATIC const mp_obj_type_t ussl_socket_type; +// Set of exceptions returned by recv, send or +// do_hanshake based on the status of the +// connection. +MP_DEFINE_EXCEPTION(SSLError, Exception) +MP_DEFINE_EXCEPTION(SSLWantReadError, SSLError) +MP_DEFINE_EXCEPTION(SSLWantWriteError, SSLError) +MP_DEFINE_EXCEPTION(SSLInProgress, SSLError) + #ifdef MBEDTLS_DEBUG_C STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) { (void)ctx; @@ -311,6 +319,93 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode); } +// Since Python 3.8, ssl.read has been deprecated in favor of +// recv. This will make plain tcp sockets and (u)ssl share the +// same call which can let users write abstractions easily. +// Since SSL sockets are often used as non-blocking, this method +// will raise ussl.SSLWantReadError or ussl.SSlWantWriteError +// depending on the error code returned by mbedtls_ssl_read. +STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t size) { + mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in); + + // we set a default value of 1024 by default + // for embedded systems. + mp_int_t buffer_size = 1024; + if (!mp_obj_get_int_maybe(size, &buffer_size)) { + buffer_size = 1024; + } + // make sure we read at most 1 byte or the connection + // will be maked as closed even if the call was sucessful. + if (buffer_size <= 0) { + mp_raise_ValueError("recv argument must be greater than 0"); + } + + byte buff[buffer_size]; + int ret = mbedtls_ssl_read(&o->ssl, buff, (size_t)buffer_size); + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + mp_raise_msg(&mp_type_SSLWantWriteError, NULL); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + mp_raise_msg(&mp_type_SSLWantReadError, NULL); + } else if (ret < 0) { + mp_raise_OSError(MP_EIO); + } + return mp_obj_new_bytes(buff, ret); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_recv, socket_recv); + +// Since Python 3.8, ssl.write has been deprecated in favor of +// send. This will make plain tcp sockets and (u)ssl share the +// same call which can let users write abtractions easily. +// Since SSL sockets are often used as non-blocking, this method +// will raise ussl.SSLWantReadError or ussl.SSlWantWriteError +// depending on the error code returned by mbedtls_ssl_write. +STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t data) { + mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in); + + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(data, &bufinfo, MP_BUFFER_READ); + + int ret = mbedtls_ssl_write(&o->ssl, bufinfo.buf, bufinfo.len); + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + mp_raise_msg(&mp_type_SSLWantWriteError, NULL); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + mp_raise_msg(&mp_type_SSLWantReadError, NULL); + } else if (ret < 0) { + mp_raise_OSError(MP_EIO); + } + return mp_obj_new_int(ret); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_send, socket_send); + +// If ussl.wrap_socket is called with do_handshake = False, +// the user has now the possibility to call do_handshake +// later before calling any of recv or send. +// ussl.SSLInProgress is raised if all the the socket has +// POLLIN | POLLOUT available but the internal buffer is not +// drained. +// Any other error should be considered as critical and the +// user MUST stop using the current socket. +STATIC mp_obj_t socket_do_handshake(mp_obj_t self_in) { + mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in); + + int ret = mbedtls_ssl_handshake(&o->ssl); + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + mp_raise_msg(&mp_type_SSLWantWriteError, NULL); + } else if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + mp_raise_msg(&mp_type_SSLWantReadError, NULL); +#ifdef MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS // not defined in the esp-idf version. + } else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) { +#else + } else if (ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS) { +#endif + mp_raise_msg(&mp_type_SSLInProgress, NULL); + } else if (ret != 0) { + mp_raise_OSError(MP_EIO); + } + return mp_obj_new_int(ret); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_ssl_do_handshake, socket_do_handshake); + STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) }, { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, @@ -322,6 +417,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) }, #endif { MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) }, + { MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&mod_ssl_recv) }, + { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&mod_ssl_send) }, + { MP_ROM_QSTR(MP_QSTR_do_handshake), MP_ROM_PTR(&mod_ssl_do_handshake) }, }; STATIC MP_DEFINE_CONST_DICT(ussl_socket_locals_dict, ussl_socket_locals_dict_table); @@ -367,6 +465,9 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socke STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = { { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ussl) }, { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) }, + { MP_ROM_QSTR(MP_QSTR_SSLWantReadError), MP_OBJ_FROM_PTR(&mp_type_SSLWantReadError) }, + { MP_ROM_QSTR(MP_QSTR_SSLWantWriteError), MP_OBJ_FROM_PTR(&mp_type_SSLWantWriteError) }, + { MP_ROM_QSTR(MP_QSTR_SSLInProgress), MP_OBJ_FROM_PTR(&mp_type_SSLInProgress) }, }; STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table);