Skip to content

extmod/modussl: fix socket and ussl read/recv/send/write errors for non-blocking sockets #5825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions docs/library/ussl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side.
Functions
---------

.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None)

.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
an SSL context. Returned object has the usual `stream` interface methods like
``read()``, ``write()``, etc. In MicroPython, the returned object does not expose
socket interface and methods like ``recv()``, ``send()``. In particular, a
``read()``, ``write()``, etc. as well as ``recv()``, ``send()``. In particular, a
server-side SSL socket should be created from a normal socket returned from
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.

- *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
or whether it is deferred to be done as part of the initial reads or writes
(there is no ``do_handshake`` method as in CPython).
For blocking sockets doing the handshake immediately is standard. For non-blocking
sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode)
the handshake should generally be deferred because otherwise ``wrap_socket`` blocks
until it completes. Note that in AXTLS the handshake can be deferred until the first
read or write but it then blocks until completion.

Depending on the underlying module implementation in a particular
:term:`MicroPython port`, some or all keyword arguments above may be not supported.

Expand Down
63 changes: 58 additions & 5 deletions extmod/modussl_axtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,16 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);

if (args->do_handshake.u_bool) {
int res = ssl_handshake_status(o->ssl_sock);

if (res != SSL_OK) {
ussl_raise_error(res);
int r = ssl_handshake_status(o->ssl_sock);

if (r != SSL_OK) {
ssl_display_error(r);
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) { // EOF
r = MP_ENOTCONN;
} else if (r == SSL_EAGAIN) {
r = MP_EAGAIN;
}
ussl_raise_error(r);
}
}

Expand Down Expand Up @@ -234,6 +240,22 @@ STATIC mp_uint_t ussl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int
return size;
}

STATIC mp_obj_t ussl_socket_recv(mp_obj_t self_in, mp_obj_t len_in) {
size_t len = mp_obj_int_get_uint_checked(len_in);
vstr_t vstr;
vstr_init_len(&vstr, len);

int errcode;
mp_uint_t ret = ussl_socket_read(self_in, vstr.buf, len, &errcode);
if (ret == MP_STREAM_ERROR) {
mp_raise_OSError(errcode);
}

vstr.len = ret;
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ussl_socket_recv_obj, ussl_socket_recv);

STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);

Expand All @@ -242,14 +264,43 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
return MP_STREAM_ERROR;
}

mp_int_t r = ssl_write(o->ssl_sock, buf, size);
mp_int_t r;
eagain:
r = ssl_write(o->ssl_sock, buf, size);
if (r == 0) {
// see comment in ussl_socket_read above
if (o->blocking) {
goto eagain;
} else {
r = SSL_EAGAIN;
}
}
if (r < 0) {
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) {
return 0; // EOF
}
if (r == SSL_EAGAIN) {
r = MP_EAGAIN;
}
*errcode = r;
return MP_STREAM_ERROR;
}
return r;
}

STATIC mp_obj_t ussl_socket_send(mp_obj_t self_in, mp_obj_t buf_in) {
mp_buffer_info_t bufinfo;
mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ);

int errcode;
mp_uint_t r = ussl_socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode);
if (r == MP_STREAM_ERROR) {
mp_raise_OSError(errcode);
}
return mp_obj_new_int(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(ussl_socket_send_obj, ussl_socket_send);

STATIC mp_uint_t ussl_socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) {
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in);
if (request == MP_STREAM_CLOSE && self->ssl_sock != NULL) {
Expand Down Expand Up @@ -277,7 +328,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
{ MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&ussl_socket_recv_obj) },
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
{ MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&ussl_socket_send_obj) },
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&ussl_socket_setblocking_obj) },
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
#if MICROPY_PY_USSL_FINALISER
Expand Down
34 changes: 33 additions & 1 deletion extmod/modussl_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
}
}

// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

Expand Down Expand Up @@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
mbedtls_pk_init(&o->pkey);
mbedtls_ctr_drbg_init(&o->ctr_drbg);
#ifdef MBEDTLS_DEBUG_C
// Debug level (0-4)
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(0);
#endif

Expand Down Expand Up @@ -308,6 +309,22 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
return MP_STREAM_ERROR;
}

STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t len_in) {
size_t len = mp_obj_get_int(len_in);
vstr_t vstr;
vstr_init_len(&vstr, len);

int errcode;
int ret = socket_read(self_in, vstr.buf, len, &errcode);
if (ret == MP_STREAM_ERROR) {
mp_raise_OSError(errcode);
}

vstr.len = ret;
return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_recv_obj, socket_recv);

STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);

Expand All @@ -327,6 +344,19 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
return MP_STREAM_ERROR;
}

STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t buf_in) {
mp_buffer_info_t bufinfo;
mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ);

int errcode;
int r = socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode);
if (r == MP_STREAM_ERROR) {
mp_raise_OSError(errcode);
}
return mp_obj_new_int(r);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_send_obj, socket_send);

STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in);
mp_obj_t sock = o->sock;
Expand Down Expand Up @@ -356,7 +386,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
{ MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) },
{ MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) },
{ MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) },
{ MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&socket_recv_obj) },
{ MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) },
{ MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&socket_send_obj) },
{ MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) },
{ MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) },
#if MICROPY_PY_USSL_FINALISER
Expand Down
14 changes: 8 additions & 6 deletions ports/esp32/modsocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ STATIC mp_obj_t socket_accept(const mp_obj_t arg0) {
if (new_fd >= 0) {
break;
}
if (errno != EAGAIN) {
if (errno != MP_EAGAIN) {
exception_from_errno(errno);
}
check_for_exceptions();
Expand Down Expand Up @@ -523,7 +523,7 @@ STATIC mp_uint_t _socket_read_data(mp_obj_t self_in, void *buf, size_t size,
if (r >= 0) {
return r;
}
if (errno != EWOULDBLOCK) {
if (errno != MP_EWOULDBLOCK) {
*errcode = errno;
return MP_STREAM_ERROR;
}
Expand Down Expand Up @@ -576,7 +576,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
MP_THREAD_GIL_EXIT();
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
MP_THREAD_GIL_ENTER();
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns MP_EINPROGRESS when trying to send right after a non-blocking connect
if (r < 0 && errno != MP_EWOULDBLOCK && errno != MP_EINPROGRESS) {
exception_from_errno(errno);
}
if (r > 0) {
Expand All @@ -585,7 +586,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
check_for_exceptions();
}
if (sentlen == 0) {
mp_raise_OSError(MP_ETIMEDOUT);
mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
}
return sentlen;
}
Expand Down Expand Up @@ -634,7 +635,7 @@ STATIC mp_obj_t socket_sendto(mp_obj_t self_in, mp_obj_t data_in, mp_obj_t addr_
if (ret > 0) {
return mp_obj_new_int_from_uint(ret);
}
if (ret == -1 && errno != EWOULDBLOCK) {
if (ret == -1 && errno != MP_EWOULDBLOCK) {
exception_from_errno(errno);
}
check_for_exceptions();
Expand Down Expand Up @@ -668,7 +669,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
if (r > 0) {
return r;
}
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect
if (r < 0 && errno != MP_EWOULDBLOCK && errno != MP_EINPROGRESS) {
*errcode = errno;
return MP_STREAM_ERROR;
}
Expand Down
106 changes: 101 additions & 5 deletions tests/net_hosted/connect_nonblock.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,116 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing

try:
import usocket as socket
import usocket as socket, ussl as ssl, sys, time
except:
import socket
import socket, ssl, sys, time


def test(peer_addr):
def dp(e):
# print(e) # uncomment this line for dev&test to print the actual exceptions
pass


# do_connect establishes the socket and wraps it if requested
def do_connect(peer_addr, tls, handshake):
s = socket.socket()
s.setblocking(False)
try:
s.connect(peer_addr)
except OSError as er:
print(er.args[0] == 115) # 115 is EINPROGRESS
print("connect:", er.args[0] == 115) # 115 is EINPROGRESS
# wrap with ssl/tls if desired
if tls:
try:
if sys.implementation.name == "micropython":
s = ssl.wrap_socket(s, do_handshake=handshake)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=handshake)
print("wrap: True")
except Exception as e:
dp(e)
print("wrap:", e)
# if handshake is set, we wait after connect() so it has time to actually happen
if handshake and not tls: # with tls the handshake does it
time.sleep(0.2)
return s


def test(peer_addr, tls=False, handshake=False):
# a fresh socket is opened for each combination because MP on linux is too fast

# hasRW is false in CPython for sockets: they don't have read or write methods
hasRW = sys.implementation.name == "micropython" or tls

# connect + send
s = do_connect(peer_addr, tls, handshake)
# send -> 4 or EAGAIN
try:
ret = s.send(b"1234")
print("send:", handshake and ret == 4)
except OSError as er:
dp(er)
print("send:", er.args[0] == 11) # 11 is EAGAIN
s.close()

# connect + write
if hasRW:
s = do_connect(peer_addr, tls, handshake)
# write -> None
try:
ret = s.write(b"1234")
print("write:", ret is (4 if handshake else None))
except OSError as er:
dp(er)
print("write:", False) # should not raise
except ValueError as er: # CPython
dp(er)
print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.")
s.close()
else: # fake it...
print("connect:", True)
print("write:", True)

# connect + recv
s = do_connect(peer_addr, tls, handshake)
# recv -> EAGAIN
try:
print("recv:", s.recv(10))
except OSError as er:
dp(er)
print("recv:", er.args[0] == 11) # 11 is EAGAIN
s.close()

# connect + read
if hasRW:
s = do_connect(peer_addr, tls, handshake)
# read -> None
try:
ret = s.read(10)
print("read:", ret is None)
except OSError as er:
dp(er)
print("read:", False) # should not raise
except ValueError as er: # CPython
dp(er)
print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.")
s.close()
else: # fake it...
print("connect:", True)
print("read:", True)


if __name__ == "__main__":
test(socket.getaddrinfo("micropython.org", 80)[0][-1])
# these tests use an non-existant test IP address, this way the connect takes forever and
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
print("--- Plain sockets to nowhere ---")
test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
print("--- SSL sockets to nowhere ---")
# this test fails with AXTLS because do_handshake=False blocks on first read/write and
# there it times out until the connect is aborted
test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
print("--- Plain sockets ---")
test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True)
print("--- SSL sockets ---")
test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)
1 change: 0 additions & 1 deletion tests/net_hosted/connect_nonblock.py.exp

This file was deleted.

Loading