Skip to content

extmod/modussl: Fix ussl read/recv/send/write errors when non-blocking (v2) #6907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions docs/library/ussl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side.
Functions
---------

.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None)

.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True)
Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type),
and returns an instance of ssl.SSLSocket, which wraps the underlying stream in
an SSL context. Returned object has the usual `stream` interface methods like
``read()``, ``write()``, etc. In MicroPython, the returned object does not expose
socket interface and methods like ``recv()``, ``send()``. In particular, a
server-side SSL socket should be created from a normal socket returned from
``read()``, ``write()``, etc.
A server-side SSL socket should be created from a normal socket returned from
:meth:`~usocket.socket.accept()` on a non-SSL listening server socket.

- *do_handshake* determines whether the handshake is done as part of the ``wrap_socket``
or whether it is deferred to be done as part of the initial reads or writes
(there is no ``do_handshake`` method as in CPython).
For blocking sockets doing the handshake immediately is standard. For non-blocking
sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode)
the handshake should generally be deferred because otherwise ``wrap_socket`` blocks
until it completes. Note that in AXTLS the handshake can be deferred until the first
read or write but it then blocks until completion.

Depending on the underlying module implementation in a particular
:term:`MicroPython port`, some or all keyword arguments above may be not supported.

Expand All @@ -31,6 +38,11 @@ Functions
Some implementations of ``ussl`` module do NOT validate server certificates,
which makes an SSL connection established prone to man-in-the-middle attacks.

CPython's ``wrap_socket`` returns an ``SSLSocket`` object which has methods typical
for sockets, such as ``send``, ``recv``, etc. MicroPython's ``wrap_socket``
returns an object more similar to CPython's ``SSLObject`` which does not have
these socket methods.

Exceptions
----------

Expand Down
31 changes: 26 additions & 5 deletions extmod/modussl_axtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,15 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args
o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext);

if (args->do_handshake.u_bool) {
int res = ssl_handshake_status(o->ssl_sock);

if (res != SSL_OK) {
ussl_raise_error(res);
int r = ssl_handshake_status(o->ssl_sock);

if (r != SSL_OK) {
if (r == SSL_CLOSE_NOTIFY) { // EOF
r = MP_ENOTCONN;
} else if (r == SSL_EAGAIN) {
r = MP_EAGAIN;
}
ussl_raise_error(r);
}
}

Expand Down Expand Up @@ -242,8 +247,24 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz
return MP_STREAM_ERROR;
}

mp_int_t r = ssl_write(o->ssl_sock, buf, size);
mp_int_t r;
eagain:
r = ssl_write(o->ssl_sock, buf, size);
if (r == 0) {
// see comment in ussl_socket_read above
if (o->blocking) {
goto eagain;
} else {
r = SSL_EAGAIN;
}
}
if (r < 0) {
if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) {
return 0; // EOF
}
if (r == SSL_EAGAIN) {
r = MP_EAGAIN;
}
*errcode = r;
return MP_STREAM_ERROR;
}
Expand Down
3 changes: 2 additions & 1 deletion extmod/modussl_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) {
}
}

// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket
STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) {
mp_obj_t sock = *(mp_obj_t *)ctx;

Expand Down Expand Up @@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
mbedtls_pk_init(&o->pkey);
mbedtls_ctr_drbg_init(&o->ctr_drbg);
#ifdef MBEDTLS_DEBUG_C
// Debug level (0-4)
// Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose
mbedtls_debug_set_threshold(0);
#endif

Expand Down
8 changes: 5 additions & 3 deletions ports/esp32/modsocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
MP_THREAD_GIL_EXIT();
int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen);
MP_THREAD_GIL_ENTER();
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns EINPROGRESS when trying to send right after a non-blocking connect
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
mp_raise_OSError(errno);
}
if (r > 0) {
Expand All @@ -567,7 +568,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) {
check_for_exceptions();
}
if (sentlen == 0) {
mp_raise_OSError(MP_ETIMEDOUT);
mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT);
}
return sentlen;
}
Expand Down Expand Up @@ -650,7 +651,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_
if (r > 0) {
return r;
}
if (r < 0 && errno != EWOULDBLOCK) {
// lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect
if (r < 0 && errno != EWOULDBLOCK && errno != EINPROGRESS) {
*errcode = errno;
return MP_STREAM_ERROR;
}
Expand Down
6 changes: 3 additions & 3 deletions tests/net_hosted/accept_timeout.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# test that socket.accept() on a socket with timeout raises ETIMEDOUT

try:
import usocket as socket
import uerrno as errno, usocket as socket
except:
import socket
import errno, socket

try:
socket.socket.settimeout
Expand All @@ -18,5 +18,5 @@
try:
s.accept()
except OSError as er:
print(er.args[0] in (110, "timed out")) # 110 is ETIMEDOUT; CPython uses a string
print(er.args[0] in (errno.ETIMEDOUT, "timed out")) # CPython uses a string instead of errno
s.close()
147 changes: 147 additions & 0 deletions tests/net_hosted/connect_nonblock_xfer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing

try:
import sys, time
import uerrno as errno, usocket as socket, ussl as ssl
except:
import socket, errno, ssl
isMP = sys.implementation.name == "micropython"


def dp(e):
# uncomment next line for development and testing, to print the actual exceptions
# print(repr(e))
pass


# do_connect establishes the socket and wraps it if tls is True.
# If handshake is true, the initial connect (and TLS handshake) is
# allowed to be performed before returning.
def do_connect(peer_addr, tls, handshake):
s = socket.socket()
s.setblocking(False)
try:
# print("Connecting to", peer_addr)
s.connect(peer_addr)
except OSError as er:
print("connect:", er.args[0] == errno.EINPROGRESS)
if er.args[0] != errno.EINPROGRESS:
print(" got", er.args[0])
# wrap with ssl/tls if desired
if tls:
try:
if sys.implementation.name == "micropython":
s = ssl.wrap_socket(s, do_handshake=handshake)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=handshake)
print("wrap: True")
except Exception as e:
dp(e)
print("wrap:", e)
elif handshake:
# just sleep a little bit, this allows any connect() errors to happen
time.sleep(0.2)
return s


# test runs the test against a specific peer address.
def test(peer_addr, tls=False, handshake=False):
# MicroPython plain sockets have read/write, but CPython's don't
# MicroPython TLS sockets and CPython's have read/write
# hasRW captures this wonderful state of affairs
hasRW = isMP or tls

# MicroPython plain sockets and CPython's have send/recv
# MicroPython TLS sockets don't have send/recv, but CPython's do
# hasSR captures this wonderful state of affairs
hasSR = not (isMP and tls)

# connect + send
if hasSR:
s = do_connect(peer_addr, tls, handshake)
# send -> 4 or EAGAIN
try:
ret = s.send(b"1234")
print("send:", handshake and ret == 4)
except OSError as er:
#
dp(er)
print("send:", er.args[0] in (errno.EAGAIN, errno.EINPROGRESS))
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("send:", True)

# connect + write
if hasRW:
s = do_connect(peer_addr, tls, handshake)
# write -> None
try:
ret = s.write(b"1234")
print("write:", ret in (4, None)) # SSL may accept 4 into buffer
except OSError as er:
dp(er)
print("write:", False) # should not raise
except ValueError as er: # CPython
dp(er)
print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.")
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("write:", True)

if hasSR:
# connect + recv
s = do_connect(peer_addr, tls, handshake)
# recv -> EAGAIN
try:
print("recv:", s.recv(10))
except OSError as er:
dp(er)
print("recv:", er.args[0] == errno.EAGAIN)
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("recv:", True)

# connect + read
if hasRW:
s = do_connect(peer_addr, tls, handshake)
# read -> None
try:
ret = s.read(10)
print("read:", ret is None)
except OSError as er:
dp(er)
print("read:", False) # should not raise
except ValueError as er: # CPython
dp(er)
print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.")
s.close()
else: # fake it...
print("connect:", True)
if tls:
print("wrap:", True)
print("read:", True)


if __name__ == "__main__":
# these tests use a non-existent test IP address, this way the connect takes forever and
# we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737)
print("--- Plain sockets to nowhere ---")
test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False)
print("--- SSL sockets to nowhere ---")
# this test fails with AXTLS because do_handshake=False blocks on first read/write and
# there it times out until the connect is aborted
test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False)
print("--- Plain sockets ---")
test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True)
print("--- SSL sockets ---")
test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True)
51 changes: 51 additions & 0 deletions tests/net_inet/ssl_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# test that socket.connect() on a non-blocking socket raises EINPROGRESS
# and that an immediate write/send/read/recv does the right thing

import sys

try:
import uerrno as errno, usocket as socket, ussl as ssl
except:
import errno, socket, ssl


def test(addr, hostname, block=True):
print("---", hostname or addr)
s = socket.socket()
s.setblocking(block)
try:
s.connect(addr)
print("connected")
except OSError as e:
if e.args[0] != errno.EINPROGRESS:
raise
print("EINPROGRESS")

try:
if sys.implementation.name == "micropython":
s = ssl.wrap_socket(s, do_handshake=block)
else:
s = ssl.wrap_socket(s, do_handshake_on_connect=block)
print("wrap: True")
except OSError:
print("wrap: error")

if not block:
try:
while s.write(b"0") is None:
pass
except (ValueError, OSError): # CPython raises ValueError, MicroPython raises OSError
print("write: error")
s.close()


if __name__ == "__main__":
# connect to plain HTTP port, oops!
addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
test(addr, None)
# connect to plain HTTP port, oops!
addr = socket.getaddrinfo("micropython.org", 80)[0][-1]
test(addr, None, False)
# connect to server with self-signed cert, oops!
addr = socket.getaddrinfo("test.mosquitto.org", 8883)[0][-1]
test(addr, "test.mosquitto.org")
Loading