Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion extmod/modussl_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
#include "mbedtls/debug.h"
#include "mbedtls/error.h"

#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)

typedef struct _mp_obj_ssl_socket_t {
mp_obj_base_t base;
mp_obj_t sock;
Expand All @@ -56,6 +58,9 @@ typedef struct _mp_obj_ssl_socket_t {
mbedtls_x509_crt cacert;
mbedtls_x509_crt cert;
mbedtls_pk_context pkey;

uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
int last_error; // The last error code, if any
} mp_obj_ssl_socket_t;

struct ssl_args {
Expand Down Expand Up @@ -165,6 +170,8 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
#endif
o->base.type = &ussl_socket_type;
o->sock = sock;
o->poll_mask = 0;
o->last_error = 0;

int ret;
mbedtls_ssl_init(&o->ssl);
Expand Down Expand Up @@ -306,6 +313,12 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin

STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
o->poll_mask = 0;

if (o->last_error) {
*errcode = o->last_error;
return MP_STREAM_ERROR;
}

int ret = mbedtls_ssl_read(&o->ssl, buf, size);
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
Expand All @@ -322,13 +335,22 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
// wanting to write next handshake message. The same may happen with
// renegotation.
ret = MP_EWOULDBLOCK;
o->poll_mask = MP_STREAM_POLL_WR;
} else {
o->last_error = ret;
}
*errcode = ret;
return MP_STREAM_ERROR;
}

STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
o->poll_mask = 0;

if (o->last_error) {
*errcode = o->last_error;
return MP_STREAM_ERROR;
}

int ret = mbedtls_ssl_write(&o->ssl, buf, size);
if (ret >= 0) {
Expand All @@ -341,6 +363,9 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
// wanting to read next handshake message. The same may happen with
// renegotation.
ret = MP_EWOULDBLOCK;
o->poll_mask = MP_STREAM_POLL_RD;
} else {
o->last_error = ret;
}
*errcode = ret;
return MP_STREAM_ERROR;
Expand All @@ -358,17 +383,56 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking);

STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) {
mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in);
mp_uint_t ret = 0;
uintptr_t saved_arg = 0;
mp_obj_t sock = self->sock;
if (sock == MP_OBJ_NULL || (request != MP_STREAM_CLOSE && self->last_error != 0)) {
// Closed or error socket:
return MP_STREAM_POLL_NVAL;
}

if (request == MP_STREAM_CLOSE) {
self->sock = MP_OBJ_NULL;
mbedtls_pk_free(&self->pkey);
mbedtls_x509_crt_free(&self->cert);
mbedtls_x509_crt_free(&self->cacert);
mbedtls_ssl_free(&self->ssl);
mbedtls_ssl_config_free(&self->conf);
mbedtls_ctr_drbg_free(&self->ctr_drbg);
mbedtls_entropy_free(&self->entropy);
} else if (request == MP_STREAM_POLL) {
// If the library signaled us that it needs reading or writing, only check that direction,
// but save what the caller asked because we need to restore it later
if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) {
saved_arg = arg & MP_STREAM_POLL_RDWR;
arg = (arg & ~saved_arg) | self->poll_mask;
}

// Take into account that the library might have buffered data already
int has_pending = 0;
if (arg & MP_STREAM_POLL_RD) {
has_pending = mbedtls_ssl_check_pending(&self->ssl);
if (has_pending) {
ret |= MP_STREAM_POLL_RD;
if (arg == MP_STREAM_POLL_RD) {
// Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
return MP_STREAM_POLL_RD;
}
}
}
}

// Pass all requests down to the underlying socket
return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
ret |= mp_get_stream(sock)->ioctl(sock, request, arg, errcode);

if (request == MP_STREAM_POLL) {
// The direction the library needed is available, return a fake result to the caller so that
// it reenters a read or a write to allow the handshake to progress
if (ret & self->poll_mask) {
ret |= saved_arg;
}
}
return ret;
}

STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
Expand All @@ -381,6 +445,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = {
#if MICROPY_PY_USSL_FINALISER
{ MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) },
#endif
#if MICROPY_UNIX_COVERAGE
{ MP_ROM_QSTR(MP_QSTR_ioctl), MP_ROM_PTR(&mp_stream_ioctl_obj) },
#endif
{ MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) },
};

Expand Down
1 change: 0 additions & 1 deletion tests/extmod/ussl_basic.py.exp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ OSError: client
TestSocket.setblocking(False)
TestSocket.setblocking(True)
TestSocket.ioctl 4 0
TestSocket.ioctl 4 0
OSError: read
OSError: write
196 changes: 196 additions & 0 deletions tests/extmod/ussl_poll.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
try:
import uselect
import ussl
import io
import ubinascii as binascii
except ImportError:
print("SKIP")
raise SystemExit

from micropython import const

_MP_STREAM_POLL_RD = const(0x0001)
_MP_STREAM_POLL_WR = const(0x0004)
_MP_STREAM_POLL_NVAL = const(0x0020)
_MP_STREAM_POLL = const(3)
_MP_STREAM_CLOSE = const(4)


# This self-signed key/cert pair is randomly generated and to be used for
# testing/demonstration only. You should always generate your own key/cert.
key = binascii.unhexlify(
b"3082013b020100024100cc20643fd3d9c21a0acba4f48f61aadd675f52175a9dcf07fbef"
b"610a6a6ba14abb891745cd18a1d4c056580d8ff1a639460f867013c8391cdc9f2e573b0f"
b"872d0203010001024100bb17a54aeb3dd7ae4edec05e775ca9632cf02d29c2a089b563b0"
b"d05cdf95aeca507de674553f28b4eadaca82d5549a86058f9996b07768686a5b02cb240d"
b"d9f1022100f4a63f5549e817547dca97b5c658038e8593cb78c5aba3c4642cc4cd031d86"
b"8f022100d598d870ffe4a34df8de57047a50b97b71f4d23e323f527837c9edae88c79483"
b"02210098560c89a70385c36eb07fd7083235c4c1184e525d838aedf7128958bedfdbb102"
b"2051c0dab7057a8176ca966f3feb81123d4974a733df0f958525f547dfd1c271f9022044"
b"6c2cafad455a671a8cf398e642e1be3b18a3d3aec2e67a9478f83c964c4f1f"
)
cert = binascii.unhexlify(
b"308201d53082017f020203e8300d06092a864886f70d01010505003075310b3009060355"
b"0406130258583114301206035504080c0b54686550726f76696e63653110300e06035504"
b"070c075468654369747931133011060355040a0c0a436f6d70616e7958595a3113301106"
b"0355040b0c0a436f6d70616e7958595a3114301206035504030c0b546865486f73744e61"
b"6d65301e170d3139313231383033333935355a170d3239313231353033333935355a3075"
b"310b30090603550406130258583114301206035504080c0b54686550726f76696e636531"
b"10300e06035504070c075468654369747931133011060355040a0c0a436f6d70616e7958"
b"595a31133011060355040b0c0a436f6d70616e7958595a3114301206035504030c0b5468"
b"65486f73744e616d65305c300d06092a864886f70d0101010500034b003048024100cc20"
b"643fd3d9c21a0acba4f48f61aadd675f52175a9dcf07fbef610a6a6ba14abb891745cd18"
b"a1d4c056580d8ff1a639460f867013c8391cdc9f2e573b0f872d0203010001300d06092a"
b"864886f70d0101050500034100b0513fe2829e9ecbe55b6dd14c0ede7502bde5d46153c8"
b"e960ae3ebc247371b525caeb41bbcf34686015a44c50d226e66aef0a97a63874ca5944ef"
b"979b57f0b3"
)


class _Pipe(io.IOBase):
def __init__(self):
self._other = None
self.block_reads = False
self.block_writes = False

self.write_buffers = []
self.last_poll_arg = None

def readinto(self, buf):
if self.block_reads or len(self._other.write_buffers) == 0:
return None

read_buf = self._other.write_buffers[0]
l = min(len(buf), len(read_buf))
buf[:l] = read_buf[:l]
if l == len(read_buf):
self._other.write_buffers.pop(0)
else:
self._other.write_buffers[0] = read_buf[l:]
return l

def write(self, buf):
if self.block_writes:
return None

self.write_buffers.append(memoryview(bytes(buf)))
return len(buf)

def ioctl(self, request, arg):
if request == _MP_STREAM_POLL:
self.last_poll_arg = arg
ret = 0
if arg & _MP_STREAM_POLL_RD:
if not self.block_reads and self._other.write_buffers:
ret |= _MP_STREAM_POLL_RD
if arg & _MP_STREAM_POLL_WR:
if not self.block_writes:
ret |= _MP_STREAM_POLL_WR
return ret

elif request == _MP_STREAM_CLOSE:
return 0

raise NotImplementedError()

@classmethod
def new_pair(cls):
p1 = cls()
p2 = cls()
p1._other = p2
p2._other = p1
return p1, p2


def assert_poll(s, i, arg, expected_arg, expected_ret):
ret = s.ioctl(_MP_STREAM_POLL, arg)
assert i.last_poll_arg == expected_arg
i.last_poll_arg = None
assert ret == expected_ret


def assert_raises(cb, *args, **kwargs):
try:
cb(*args, **kwargs)
raise AssertionError("should have raised")
except Exception as exc:
pass


client_io, server_io = _Pipe.new_pair()

client_io.block_reads = True
client_io.block_writes = True
client_sock = ussl.wrap_socket(client_io, do_handshake=False)

server_sock = ussl.wrap_socket(server_io, key=key, cert=cert, server_side=True, do_handshake=False)

# Do a test read, at this point the TLS handshake wants to write,
# so it returns None:
assert client_sock.read(128) is None

# Polling for either read or write actually check if the underlying socket can write:
assert_poll(client_sock, client_io, _MP_STREAM_POLL_RD, _MP_STREAM_POLL_WR, 0)
assert_poll(client_sock, client_io, _MP_STREAM_POLL_WR, _MP_STREAM_POLL_WR, 0)

# Mark the socket as writable, and do another test read:
client_io.block_writes = False
assert client_sock.read(128) is None

# The client wrote the CLIENT_HELLO message
assert len(client_io.write_buffers) == 1

# At this point the TLS handshake wants to read, but we don't know that yet:
assert_poll(client_sock, client_io, _MP_STREAM_POLL_RD, _MP_STREAM_POLL_RD, 0)
assert_poll(client_sock, client_io, _MP_STREAM_POLL_WR, _MP_STREAM_POLL_WR, _MP_STREAM_POLL_WR)

# Do a test write
client_sock.write(b"foo")

# Now we know that we want to read:
assert_poll(client_sock, client_io, _MP_STREAM_POLL_RD, _MP_STREAM_POLL_RD, 0)
assert_poll(client_sock, client_io, _MP_STREAM_POLL_WR, _MP_STREAM_POLL_RD, 0)

# Unblock reads and nudge the two sockets:
client_io.block_reads = False
while server_io.write_buffers or client_io.write_buffers:
if server_io.write_buffers:
assert client_sock.read(128) is None
if client_io.write_buffers:
assert server_sock.read(128) is None

# At this point, the handshake is done, try writing data:
client_sock.write(b"foo")
assert server_sock.read(3) == b"foo"

# Test reading partial data:
client_sock.write(b"foobar")
assert server_sock.read(3) == b"foo"
server_io.block_reads = True
assert_poll(
server_sock, server_io, _MP_STREAM_POLL_RD, None, _MP_STREAM_POLL_RD
) # Did not go to the socket, just consumed buffered data
assert server_sock.read(3) == b"bar"


# Polling on a closed socket errors out:
client_io, _ = _Pipe.new_pair()
client_sock = ussl.wrap_socket(client_io, do_handshake=False)
client_sock.close()
assert_poll(
client_sock, client_io, _MP_STREAM_POLL_RD, None, _MP_STREAM_POLL_NVAL
) # Did not go to the socket


# Errors propagates to poll:
client_io, server_io = _Pipe.new_pair()
client_sock = ussl.wrap_socket(client_io, do_handshake=False)

# The server returns garbage:
server_io.write(b"fooba") # Needs to be exactly 5 bytes

assert_poll(client_sock, client_io, _MP_STREAM_POLL_RD, _MP_STREAM_POLL_RD, _MP_STREAM_POLL_RD)
assert_raises(client_sock.read, 128)
assert_poll(
client_sock, client_io, _MP_STREAM_POLL_RD, None, _MP_STREAM_POLL_NVAL
) # Did not go to the socket
Empty file added tests/extmod/ussl_poll.py.exp
Empty file.
1 change: 1 addition & 0 deletions tests/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def run_tests(pyb, tests, args, result_dir, num_threads=1):
if not has_coverage:
skip_tests.add("cmdline/cmd_parsetree.py")
skip_tests.add("cmdline/repl_sys_ps1_ps2.py")
skip_tests.add("extmod/ussl_poll.py")

# Some tests shouldn't be run on a PC
if args.target == "unix":
Expand Down