diff --git a/docs/library/ussl.rst b/docs/library/ussl.rst index ffe146331ca08..598fbe79c2954 100644 --- a/docs/library/ussl.rst +++ b/docs/library/ussl.rst @@ -13,16 +13,23 @@ facilities for network sockets, both client-side and server-side. Functions --------- -.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None) - +.. function:: ussl.wrap_socket(sock, server_side=False, keyfile=None, certfile=None, cert_reqs=CERT_NONE, ca_certs=None, do_handshake=True) Takes a `stream` *sock* (usually usocket.socket instance of ``SOCK_STREAM`` type), and returns an instance of ssl.SSLSocket, which wraps the underlying stream in an SSL context. Returned object has the usual `stream` interface methods like - ``read()``, ``write()``, etc. In MicroPython, the returned object does not expose - socket interface and methods like ``recv()``, ``send()``. In particular, a + ``read()``, ``write()``, etc. as well as ``recv()``, ``send()``. In particular, a server-side SSL socket should be created from a normal socket returned from :meth:`~usocket.socket.accept()` on a non-SSL listening server socket. + - *do_handshake* determines whether the handshake is done as part of the ``wrap_socket`` + or whether it is deferred to be done as part of the initial reads or writes + (there is no ``do_handshake`` method as in CPython). + For blocking sockets doing the handshake immediately is standard. For non-blocking + sockets (i.e. when the *sock* passed into ``wrap_socket`` is in non-blocking mode) + the handshake should generally be deferred because otherwise ``wrap_socket`` blocks + until it completes. Note that in AXTLS the handshake can be deferred until the first + read or write but it then blocks until completion. + Depending on the underlying module implementation in a particular :term:`MicroPython port`, some or all keyword arguments above may be not supported. diff --git a/extmod/modussl_axtls.c b/extmod/modussl_axtls.c index da5941a55b33e..b7ce8f063f670 100644 --- a/extmod/modussl_axtls.c +++ b/extmod/modussl_axtls.c @@ -167,10 +167,16 @@ STATIC mp_obj_ssl_socket_t *ussl_socket_new(mp_obj_t sock, struct ssl_args *args o->ssl_sock = ssl_client_new(o->ssl_ctx, (long)sock, NULL, 0, ext); if (args->do_handshake.u_bool) { - int res = ssl_handshake_status(o->ssl_sock); - - if (res != SSL_OK) { - ussl_raise_error(res); + int r = ssl_handshake_status(o->ssl_sock); + + if (r != SSL_OK) { + ssl_display_error(r); + if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) { // EOF + r = MP_ENOTCONN; + } else if (r == SSL_EAGAIN) { + r = MP_EAGAIN; + } + ussl_raise_error(r); } } @@ -234,6 +240,22 @@ STATIC mp_uint_t ussl_socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int return size; } +STATIC mp_obj_t ussl_socket_recv(mp_obj_t self_in, mp_obj_t len_in) { + size_t len = mp_obj_int_get_uint_checked(len_in); + vstr_t vstr; + vstr_init_len(&vstr, len); + + int errcode; + mp_uint_t ret = ussl_socket_read(self_in, vstr.buf, len, &errcode); + if (ret == MP_STREAM_ERROR) { + mp_raise_OSError(errcode); + } + + vstr.len = ret; + return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ussl_socket_recv_obj, ussl_socket_recv); + STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); @@ -242,14 +264,43 @@ STATIC mp_uint_t ussl_socket_write(mp_obj_t o_in, const void *buf, mp_uint_t siz return MP_STREAM_ERROR; } - mp_int_t r = ssl_write(o->ssl_sock, buf, size); + mp_int_t r; +eagain: + r = ssl_write(o->ssl_sock, buf, size); + if (r == 0) { + // see comment in ussl_socket_read above + if (o->blocking) { + goto eagain; + } else { + r = SSL_EAGAIN; + } + } if (r < 0) { + if (r == SSL_CLOSE_NOTIFY || r == SSL_ERROR_CONN_LOST) { + return 0; // EOF + } + if (r == SSL_EAGAIN) { + r = MP_EAGAIN; + } *errcode = r; return MP_STREAM_ERROR; } return r; } +STATIC mp_obj_t ussl_socket_send(mp_obj_t self_in, mp_obj_t buf_in) { + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); + + int errcode; + mp_uint_t r = ussl_socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode); + if (r == MP_STREAM_ERROR) { + mp_raise_OSError(errcode); + } + return mp_obj_new_int(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(ussl_socket_send_obj, ussl_socket_send); + STATIC mp_uint_t ussl_socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) { mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in); if (request == MP_STREAM_CLOSE && self->ssl_sock != NULL) { @@ -277,7 +328,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) }, { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, { MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) }, + { MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&ussl_socket_recv_obj) }, { MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) }, + { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&ussl_socket_send_obj) }, { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&ussl_socket_setblocking_obj) }, { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) }, #if MICROPY_PY_USSL_FINALISER diff --git a/extmod/modussl_mbedtls.c b/extmod/modussl_mbedtls.c index 1677dc6e1ca70..ef4f8c2815e27 100644 --- a/extmod/modussl_mbedtls.c +++ b/extmod/modussl_mbedtls.c @@ -133,6 +133,7 @@ STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { } } +// _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { mp_obj_t sock = *(mp_obj_t *)ctx; @@ -171,7 +172,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { mbedtls_pk_init(&o->pkey); mbedtls_ctr_drbg_init(&o->ctr_drbg); #ifdef MBEDTLS_DEBUG_C - // Debug level (0-4) + // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose mbedtls_debug_set_threshold(0); #endif @@ -308,6 +309,22 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc return MP_STREAM_ERROR; } +STATIC mp_obj_t socket_recv(mp_obj_t self_in, mp_obj_t len_in) { + size_t len = mp_obj_get_int(len_in); + vstr_t vstr; + vstr_init_len(&vstr, len); + + int errcode; + int ret = socket_read(self_in, vstr.buf, len, &errcode); + if (ret == MP_STREAM_ERROR) { + mp_raise_OSError(errcode); + } + + vstr.len = ret; + return mp_obj_new_str_from_vstr(&mp_type_bytes, &vstr); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_recv_obj, socket_recv); + STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); @@ -327,6 +344,19 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in return MP_STREAM_ERROR; } +STATIC mp_obj_t socket_send(mp_obj_t self_in, mp_obj_t buf_in) { + mp_buffer_info_t bufinfo; + mp_get_buffer_raise(buf_in, &bufinfo, MP_BUFFER_READ); + + int errcode; + int r = socket_write(self_in, bufinfo.buf, bufinfo.len, &errcode); + if (r == MP_STREAM_ERROR) { + mp_raise_OSError(errcode); + } + return mp_obj_new_int(r); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_send_obj, socket_send); + STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) { mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in); mp_obj_t sock = o->sock; @@ -356,7 +386,9 @@ STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { { MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) }, { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, { MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) }, + { MP_ROM_QSTR(MP_QSTR_recv), MP_ROM_PTR(&socket_recv_obj) }, { MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) }, + { MP_ROM_QSTR(MP_QSTR_send), MP_ROM_PTR(&socket_send_obj) }, { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) }, { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) }, #if MICROPY_PY_USSL_FINALISER diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c index 85433e575fc51..005d11d53dea8 100644 --- a/ports/esp32/modsocket.c +++ b/ports/esp32/modsocket.c @@ -330,7 +330,7 @@ STATIC mp_obj_t socket_accept(const mp_obj_t arg0) { if (new_fd >= 0) { break; } - if (errno != EAGAIN) { + if (errno != MP_EAGAIN) { exception_from_errno(errno); } check_for_exceptions(); @@ -523,7 +523,7 @@ STATIC mp_uint_t _socket_read_data(mp_obj_t self_in, void *buf, size_t size, if (r >= 0) { return r; } - if (errno != EWOULDBLOCK) { + if (errno != MP_EWOULDBLOCK) { *errcode = errno; return MP_STREAM_ERROR; } @@ -576,7 +576,8 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) { MP_THREAD_GIL_EXIT(); int r = lwip_write(sock->fd, data + sentlen, datalen - sentlen); MP_THREAD_GIL_ENTER(); - if (r < 0 && errno != EWOULDBLOCK) { + // lwip returns MP_EINPROGRESS when trying to send right after a non-blocking connect + if (r < 0 && errno != MP_EWOULDBLOCK && errno != MP_EINPROGRESS) { exception_from_errno(errno); } if (r > 0) { @@ -585,7 +586,7 @@ int _socket_send(socket_obj_t *sock, const char *data, size_t datalen) { check_for_exceptions(); } if (sentlen == 0) { - mp_raise_OSError(MP_ETIMEDOUT); + mp_raise_OSError(sock->retries == 0 ? MP_EWOULDBLOCK : MP_ETIMEDOUT); } return sentlen; } @@ -634,7 +635,7 @@ STATIC mp_obj_t socket_sendto(mp_obj_t self_in, mp_obj_t data_in, mp_obj_t addr_ if (ret > 0) { return mp_obj_new_int_from_uint(ret); } - if (ret == -1 && errno != EWOULDBLOCK) { + if (ret == -1 && errno != MP_EWOULDBLOCK) { exception_from_errno(errno); } check_for_exceptions(); @@ -668,7 +669,8 @@ STATIC mp_uint_t socket_stream_write(mp_obj_t self_in, const void *buf, mp_uint_ if (r > 0) { return r; } - if (r < 0 && errno != EWOULDBLOCK) { + // lwip returns MP_EINPROGRESS when trying to write right after a non-blocking connect + if (r < 0 && errno != MP_EWOULDBLOCK && errno != MP_EINPROGRESS) { *errcode = errno; return MP_STREAM_ERROR; } diff --git a/tests/net_hosted/connect_nonblock.py b/tests/net_hosted/connect_nonblock.py index 3a3eaa2ba01e4..0a31f770a77ae 100644 --- a/tests/net_hosted/connect_nonblock.py +++ b/tests/net_hosted/connect_nonblock.py @@ -1,20 +1,116 @@ # test that socket.connect() on a non-blocking socket raises EINPROGRESS +# and that an immediate write/send/read/recv does the right thing try: - import usocket as socket + import usocket as socket, ussl as ssl, sys, time except: - import socket + import socket, ssl, sys, time -def test(peer_addr): +def dp(e): + # print(e) # uncomment this line for dev&test to print the actual exceptions + pass + + +# do_connect establishes the socket and wraps it if requested +def do_connect(peer_addr, tls, handshake): s = socket.socket() s.setblocking(False) try: s.connect(peer_addr) except OSError as er: - print(er.args[0] == 115) # 115 is EINPROGRESS + print("connect:", er.args[0] == 115) # 115 is EINPROGRESS + # wrap with ssl/tls if desired + if tls: + try: + if sys.implementation.name == "micropython": + s = ssl.wrap_socket(s, do_handshake=handshake) + else: + s = ssl.wrap_socket(s, do_handshake_on_connect=handshake) + print("wrap: True") + except Exception as e: + dp(e) + print("wrap:", e) + # if handshake is set, we wait after connect() so it has time to actually happen + if handshake and not tls: # with tls the handshake does it + time.sleep(0.2) + return s + + +def test(peer_addr, tls=False, handshake=False): + # a fresh socket is opened for each combination because MP on linux is too fast + + # hasRW is false in CPython for sockets: they don't have read or write methods + hasRW = sys.implementation.name == "micropython" or tls + + # connect + send + s = do_connect(peer_addr, tls, handshake) + # send -> 4 or EAGAIN + try: + ret = s.send(b"1234") + print("send:", handshake and ret == 4) + except OSError as er: + dp(er) + print("send:", er.args[0] == 11) # 11 is EAGAIN s.close() + # connect + write + if hasRW: + s = do_connect(peer_addr, tls, handshake) + # write -> None + try: + ret = s.write(b"1234") + print("write:", ret is (4 if handshake else None)) + except OSError as er: + dp(er) + print("write:", False) # should not raise + except ValueError as er: # CPython + dp(er) + print("write:", er.args[0] == "Write on closed or unwrapped SSL socket.") + s.close() + else: # fake it... + print("connect:", True) + print("write:", True) + + # connect + recv + s = do_connect(peer_addr, tls, handshake) + # recv -> EAGAIN + try: + print("recv:", s.recv(10)) + except OSError as er: + dp(er) + print("recv:", er.args[0] == 11) # 11 is EAGAIN + s.close() + + # connect + read + if hasRW: + s = do_connect(peer_addr, tls, handshake) + # read -> None + try: + ret = s.read(10) + print("read:", ret is None) + except OSError as er: + dp(er) + print("read:", False) # should not raise + except ValueError as er: # CPython + dp(er) + print("read:", er.args[0] == "Read on closed or unwrapped SSL socket.") + s.close() + else: # fake it... + print("connect:", True) + print("read:", True) + if __name__ == "__main__": - test(socket.getaddrinfo("micropython.org", 80)[0][-1]) + # these tests use an non-existant test IP address, this way the connect takes forever and + # we can see EAGAIN/None (https://tools.ietf.org/html/rfc5737) + print("--- Plain sockets to nowhere ---") + test(socket.getaddrinfo("192.0.2.1", 80)[0][-1], False, False) + print("--- SSL sockets to nowhere ---") + # this test fails with AXTLS because do_handshake=False blocks on first read/write and + # there it times out until the connect is aborted + test(socket.getaddrinfo("192.0.2.1", 443)[0][-1], True, False) + print("--- Plain sockets ---") + test(socket.getaddrinfo("micropython.org", 80)[0][-1], False, True) + print("--- SSL sockets ---") + test(socket.getaddrinfo("micropython.org", 443)[0][-1], True, True) diff --git a/tests/net_hosted/connect_nonblock.py.exp b/tests/net_hosted/connect_nonblock.py.exp deleted file mode 100644 index 0ca95142bb715..0000000000000 --- a/tests/net_hosted/connect_nonblock.py.exp +++ /dev/null @@ -1 +0,0 @@ -True diff --git a/tests/net_inet/ssl_errors.py b/tests/net_inet/ssl_errors.py new file mode 100644 index 0000000000000..65976d68a7b9a --- /dev/null +++ b/tests/net_inet/ssl_errors.py @@ -0,0 +1,45 @@ +# test that socket.connect() on a non-blocking socket raises EINPROGRESS +# and that an immediate write/send/read/recv does the right thing + +try: + import usocket as socket, ussl as ssl, sys +except: + import socket, ssl, sys + + +def test(addr, hostname, block=True): + print("---", hostname or addr) + s = socket.socket() + s.setblocking(block) + try: + s.connect(addr) + print("connected") + except OSError as e: + if e.args[0] != 115: # 115 == EINPROGRESS + raise + + try: + s = ssl.wrap_socket(s) + print("wrap: True") + except OSError as e: + print("wrap:", e) + + if not block: + try: + while s.write(b"0") is None: + pass + except OSError as e: + print("write:", e) + s.close() + + +if __name__ == "__main__": + # connect to plain HTTP port, oops! + addr = socket.getaddrinfo("micropython.org", 80)[0][-1] + test(addr, None) + # connect to plain HTTP port, oops! + addr = socket.getaddrinfo("micropython.org", 80)[0][-1] + test(addr, None, False) + # connect to server with self-signed cert, oops! + addr = socket.getaddrinfo("test.mosquitto.org", 8883)[0][-1] + test(addr, "test.mosquitto.org") diff --git a/tests/net_inet/test_tls_nonblock.py b/tests/net_inet/test_tls_nonblock.py new file mode 100644 index 0000000000000..cde0f4f015a90 --- /dev/null +++ b/tests/net_inet/test_tls_nonblock.py @@ -0,0 +1,152 @@ +try: + import usocket as socket, ussl as ssl, sys +except: + import socket, ssl, sys, time, select + + +def test_one(site, opts): + ai = socket.getaddrinfo(site, 443) + addr = ai[0][-1] + print(addr) + + use_send = "send" in opts and opts["send"] + + # Connect the raw socket + s = socket.socket() + s.setblocking(False) + try: + s.connect(addr) + raise OSError(-1, "connect blocks") + except OSError as e: + if e.args[0] != 115: # 115=EINPROGRESS + raise + + if sys.implementation.name != "micropython": + # in CPython we have to wait, otherwise wrap_socket is not happy + select.select([], [s], []) + + try: + # Wrap with SSL + try: + if sys.implementation.name == "micropython": + s = ssl.wrap_socket(s, do_handshake=False) + else: + s = ssl.wrap_socket(s, do_handshake_on_connect=False) + except OSError as e: + if e.args[0] != 115: # 115=EINPROGRESS + raise + print("wrapped") + + # CPython needs to be told to do the handshake + if sys.implementation.name != "micropython": + while True: + try: + s.do_handshake() + break + except ssl.SSLError as err: + if err.args[0] == ssl.SSL_ERROR_WANT_READ: + select.select([s], [], []) + elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: + select.select([], [s], []) + else: + raise + time.sleep(0.1) + # print("shook hands") + + # Write HTTP request + out = b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin") + if use_send: + while len(out) > 0: + try: + n = s.send(out) + except OSError as e: + if e.args[0] != 11: # 11=EAGAIN + raise + continue + if n > 0: + out = out[n:] + else: + raise OSError(-1, "unexpected write result") + else: + while len(out) > 0: + n = s.write(out) + if n is None: + continue + if n > 0: + out = out[n:] + elif n == 0: + raise OSError(-1, "unexpected EOF in write") + print("wrote") + + # Read response + resp = b"" + while True: + if use_send: + try: + b = s.recv(128) + except OSError as e: + if e.args[0] == 11: # 11=EAGAIN + continue + if e.args[0] == 2: # 2=ssl.SSL_ERROR_WANT_READ: + continue + raise + if len(b) > 0: + if len(resp) < 1024: + resp += b + elif len(b) == 0: + break + else: + raise OSError(-1, "unexpected read result") + else: + try: + b = s.read(128) + except OSError as err: + if err.args[0] == 2: # 2=ssl.SSL_ERROR_WANT_READ: + continue + raise + if b is None: + continue + if len(b) > 0: + if len(resp) < 1024: + resp += b + elif len(b) == 0: + break + print("read") + + if resp[:7] != b"HTTP/1.": + raise ValueError("response doesn't start with HTTP/1.") + # print(resp) + + finally: + s.close() + + +SITES = [ + "google.com", + {"host": "www.google.com", "send": True}, + "micropython.org", # used in the built-in upip, it better work... + "pypi.org", # ditto + "api.telegram.org", + {"host": "api.pushbullet.com", "sni": True}, + # this no longer works, not sure which special case it is supposed to test... + # "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", + # {"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True}, +] + + +def main(): + for site in SITES: + opts = {} + if isinstance(site, dict): + opts = site + site = opts["host"] + + try: + test_one(site, opts) + print(site, "ok") + except Exception as e: + print(site, e) + print("DONE") + + +main() diff --git a/tests/net_inet/test_tls_sites.py b/tests/net_inet/test_tls_sites.py index d2cb928c8d5b9..d808e9123397d 100644 --- a/tests/net_inet/test_tls_sites.py +++ b/tests/net_inet/test_tls_sites.py @@ -27,6 +27,8 @@ def test_one(site, opts): s.write(b"GET / HTTP/1.0\r\nHost: %s\r\n\r\n" % bytes(site, "latin")) resp = s.read(4096) + if resp[:7] != b"HTTP/1.": + raise ValueError("response doesn't start with HTTP/1.") # print(resp) finally: @@ -36,10 +38,13 @@ def test_one(site, opts): SITES = [ "google.com", "www.google.com", + "micropython.org", # used in the built-in upip, it better work... + "pypi.org", # ditto "api.telegram.org", {"host": "api.pushbullet.com", "sni": True}, - # "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", - {"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True}, + # this no longer works, not sure which special case it is supposed to test... + # "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", + # {"host": "w9rybpfril.execute-api.ap-southeast-2.amazonaws.com", "sni": True}, ]