diff --git a/Doc/library/asyncio-protocol.rst b/Doc/library/asyncio-protocol.rst index ef6441605cd72c..9a08a4a49021cc 100644 --- a/Doc/library/asyncio-protocol.rst +++ b/Doc/library/asyncio-protocol.rst @@ -463,16 +463,23 @@ The idea of BufferedProtocol is that it allows to manually allocate and control the receive buffer. Event loops can then use the buffer provided by the protocol to avoid unnecessary data copies. This can result in noticeable performance improvement for protocols that -receive big amounts of data. Sophisticated protocols can allocate -the buffer only once at creation time. +receive big amounts of data. Sophisticated protocols implementations +can allocate the buffer only once at creation time. The following callbacks are called on :class:`BufferedProtocol` instances: -.. method:: BufferedProtocol.get_buffer() +.. method:: BufferedProtocol.get_buffer(sizehint) - Called to allocate a new receive buffer. Must return an object - that implements the :ref:`buffer protocol `. + Called to allocate a new receive buffer. + + *sizehint* is a recommended minimal size for the returned + buffer. It is acceptable to return smaller or bigger buffers + than what *sizehint* suggests. When set to -1, the buffer size + can be arbitrary. It is an error to return a zero-sized buffer. + + Must return an object that implements the + :ref:`buffer protocol `. .. method:: BufferedProtocol.buffer_updated(nbytes) diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 09eb440b0ef7af..a0243f5bac9a2c 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -157,7 +157,6 @@ def _run_until_complete_cb(fut): futures._get_loop(fut).stop() - class _SendfileFallbackProtocol(protocols.Protocol): def __init__(self, transp): if not isinstance(transp, transports._FlowControlMixin): @@ -304,6 +303,9 @@ def close(self): async def start_serving(self): self._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0, loop=self._loop) async def serve_forever(self): if self._serving_forever_fut is not None: @@ -1363,6 +1365,9 @@ async def create_server( ssl, backlog, ssl_handshake_timeout) if start_serving: server._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0, loop=self) if self._debug: logger.info("%r is serving", server) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 877dfb0746708e..337ed0fb204751 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -30,7 +30,7 @@ def __init__(self, loop, sock, protocol, waiter=None, super().__init__(extra, loop) self._set_extra(sock) self._sock = sock - self._protocol = protocol + self.set_protocol(protocol) self._server = server self._buffer = None # None or bytearray. self._read_fut = None @@ -159,16 +159,26 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + self._loop_reading_cb = None + self._paused = True super().__init__(loop, sock, protocol, waiter, extra, server) - self._paused = False + self._reschedule_on_resume = False + self._loop.call_soon(self._loop_reading) + self._paused = False - if protocols._is_buffered_protocol(protocol): - self._loop_reading = self._loop_reading__get_buffer + def set_protocol(self, protocol): + if isinstance(protocol, protocols.BufferedProtocol): + self._loop_reading_cb = self._loop_reading__get_buffer else: - self._loop_reading = self._loop_reading__data_received + self._loop_reading_cb = self._loop_reading__data_received - self._loop.call_soon(self._loop_reading) + super().set_protocol(protocol) + + if self.is_reading(): + # reset reading callback / buffers / self._read_fut + self.pause_reading() + self.resume_reading() def is_reading(self): return not self._paused and not self._closing @@ -179,6 +189,13 @@ def pause_reading(self): self._paused = True if self._read_fut is not None and not self._read_fut.done(): + # TODO: This is an ugly hack to cancel the current read future + # *and* avoid potential race conditions, as read cancellation + # goes through `future.cancel()` and `loop.call_soon()`. + # We then use this special attribute in the reader callback to + # exit *immediately* without doing any cleanup/rescheduling. + self._read_fut.__asyncio_cancelled_on_pause__ = True + self._read_fut.cancel() self._read_fut = None self._reschedule_on_resume = True @@ -210,7 +227,14 @@ def _loop_reading__on_eof(self): if not keep_open: self.close() - def _loop_reading__data_received(self, fut=None): + def _loop_reading(self, fut=None): + self._loop_reading_cb(fut) + + def _loop_reading__data_received(self, fut): + if (fut is not None and + getattr(fut, '__asyncio_cancelled_on_pause__', False)): + return + if self._paused: self._reschedule_on_resume = True return @@ -253,14 +277,18 @@ def _loop_reading__data_received(self, fut=None): if not self._closing: raise else: - self._read_fut.add_done_callback(self._loop_reading) + self._read_fut.add_done_callback(self._loop_reading__data_received) finally: if data: self._protocol.data_received(data) elif data == b'': self._loop_reading__on_eof() - def _loop_reading__get_buffer(self, fut=None): + def _loop_reading__get_buffer(self, fut): + if (fut is not None and + getattr(fut, '__asyncio_cancelled_on_pause__', False)): + return + if self._paused: self._reschedule_on_resume = True return @@ -310,7 +338,9 @@ def _loop_reading__get_buffer(self, fut=None): return try: - buf = self._protocol.get_buffer() + buf = self._protocol.get_buffer(-1) + if not len(buf): + raise RuntimeError('get_buffer() returned an empty buffer') except Exception as exc: self._fatal_error( exc, 'Fatal error: protocol.get_buffer() call failed.') @@ -319,7 +349,7 @@ def _loop_reading__get_buffer(self, fut=None): try: # schedule a new read self._read_fut = self._loop._proactor.recv_into(self._sock, buf) - self._read_fut.add_done_callback(self._loop_reading) + self._read_fut.add_done_callback(self._loop_reading__get_buffer) except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc, 'Fatal read error on pipe transport') diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py index dc298a8d5c9510..b8d2e6be552e1e 100644 --- a/Lib/asyncio/protocols.py +++ b/Lib/asyncio/protocols.py @@ -130,11 +130,15 @@ class BufferedProtocol(BaseProtocol): * CL: connection_lost() """ - def get_buffer(self): + def get_buffer(self, sizehint): """Called to allocate a new receive buffer. + *sizehint* is a recommended minimal size for the returned + buffer. When set to -1, the buffer size can be arbitrary. + Must return an object that implements the :ref:`buffer protocol `. + It is an error to return a zero-sized buffer. """ def buffer_updated(self, nbytes): @@ -185,7 +189,3 @@ def pipe_connection_lost(self, fd, exc): def process_exited(self): """Called when subprocess has exited.""" - - -def _is_buffered_protocol(proto): - return hasattr(proto, 'get_buffer') and not hasattr(proto, 'data_received') diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index f9533a1d77be54..93e2de22b593a5 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -597,8 +597,10 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._extra['peername'] = None self._sock = sock self._sock_fd = sock.fileno() - self._protocol = protocol - self._protocol_connected = True + + self._protocol_connected = False + self.set_protocol(protocol) + self._server = server self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. @@ -640,6 +642,7 @@ def abort(self): def set_protocol(self, protocol): self._protocol = protocol + self._protocol_connected = True def get_protocol(self): return self._protocol @@ -721,11 +724,7 @@ class _SelectorSocketTransport(_SelectorTransport): def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): - if protocols._is_buffered_protocol(protocol): - self._read_ready = self._read_ready__get_buffer - else: - self._read_ready = self._read_ready__data_received - + self._read_ready_cb = None super().__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False @@ -745,6 +744,14 @@ def __init__(self, loop, sock, protocol, waiter=None, self._loop.call_soon(futures._set_result_unless_cancelled, waiter, None) + def set_protocol(self, protocol): + if isinstance(protocol, protocols.BufferedProtocol): + self._read_ready_cb = self._read_ready__get_buffer + else: + self._read_ready_cb = self._read_ready__data_received + + super().set_protocol(protocol) + def is_reading(self): return not self._paused and not self._closing @@ -764,12 +771,17 @@ def resume_reading(self): if self._loop.get_debug(): logger.debug("%r resumes reading", self) + def _read_ready(self): + self._read_ready_cb() + def _read_ready__get_buffer(self): if self._conn_lost: return try: - buf = self._protocol.get_buffer() + buf = self._protocol.get_buffer(-1) + if not len(buf): + raise RuntimeError('get_buffer() returned an empty buffer') except Exception as exc: self._fatal_error( exc, 'Fatal error: protocol.get_buffer() call failed.') diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 2bbf134c0f7e6f..2bfa45dd1585af 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -441,6 +441,8 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._waiter = waiter self._loop = loop self._app_protocol = app_protocol + self._app_protocol_is_buffer = \ + isinstance(app_protocol, protocols.BufferedProtocol) self._app_transport = _SSLProtocolTransport(self._loop, self) # _SSLPipe instance (None until the connection is made) self._sslpipe = None @@ -522,7 +524,16 @@ def data_received(self, data): for chunk in appdata: if chunk: - self._app_protocol.data_received(chunk) + try: + if self._app_protocol_is_buffer: + _feed_data_to_bufferred_proto( + self._app_protocol, chunk) + else: + self._app_protocol.data_received(chunk) + except Exception as ex: + self._fatal_error( + ex, 'application protocol failed to receive SSL data') + return else: self._start_shutdown() break @@ -709,3 +720,22 @@ def _abort(self): self._transport.abort() finally: self._finalize() + + +def _feed_data_to_bufferred_proto(proto, data): + data_len = len(data) + while data_len: + buf = proto.get_buffer(data_len) + buf_len = len(buf) + if not buf_len: + raise RuntimeError('get_buffer() returned an empty buffer') + + if buf_len >= data_len: + buf[:data_len] = data + proto.buffer_updated(data_len) + return + else: + buf[:buf_len] = data[:buf_len] + proto.buffer_updated(buf_len) + data = data[buf_len:] + data_len = len(data) diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index f64037a25c67b8..7cad7e3637a11f 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -20,6 +20,7 @@ from . import events from . import futures from . import selector_events +from . import tasks from . import transports from .log import logger @@ -308,6 +309,9 @@ async def create_unix_server( ssl, backlog, ssl_handshake_timeout) if start_serving: server._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0, loop=self) return server diff --git a/Lib/test/test_asyncio/test_buffered_proto.py b/Lib/test/test_asyncio/test_buffered_proto.py index 22f9269e814f99..89d3df72d98b62 100644 --- a/Lib/test/test_asyncio/test_buffered_proto.py +++ b/Lib/test/test_asyncio/test_buffered_proto.py @@ -9,7 +9,7 @@ def __init__(self, cb, con_lost_fut): self.cb = cb self.con_lost_fut = con_lost_fut - def get_buffer(self): + def get_buffer(self, sizehint): self.buffer = bytearray(100) return self.buffer diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 64d726d16d1cd8..d7b0a665a0abc1 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -2095,7 +2095,7 @@ async def connect(cmd=None, **kwds): class SendfileBase: - DATA = b"12345abcde" * 16 * 1024 # 160 KiB + DATA = b"12345abcde" * 64 * 1024 # 64 KiB (don't use smaller sizes) @classmethod def setUpClass(cls): @@ -2452,7 +2452,7 @@ def test_sendfile_ssl_close_peer_after_receiving(self): self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) - def test_sendfile_close_peer_in_middle_of_receiving(self): + def test_sendfile_close_peer_in_the_middle_of_receiving(self): srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( @@ -2465,7 +2465,7 @@ def test_sendfile_close_peer_in_middle_of_receiving(self): self.file.tell()) self.assertTrue(cli_proto.transport.is_closing()) - def test_sendfile_fallback_close_peer_in_middle_of_receiving(self): + def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index 6313d594477a74..6da6b4a34db81e 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -465,8 +465,8 @@ def setUp(self): self.loop._proactor = self.proactor self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol) - self.buf = mock.Mock() - self.protocol.get_buffer.side_effect = lambda: self.buf + self.buf = bytearray(1) + self.protocol.get_buffer.side_effect = lambda hint: self.buf self.sock = mock.Mock(socket.socket) @@ -505,6 +505,64 @@ def test_get_buffer_error(self): self.assertTrue(self.protocol.get_buffer.called) self.assertFalse(self.protocol.buffer_updated.called) + def test_get_buffer_zerosized(self): + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + self.protocol.get_buffer.side_effect = lambda hint: bytearray(0) + + transport._loop_reading() + + self.assertTrue(transport._fatal_error.called) + self.assertTrue(self.protocol.get_buffer.called) + self.assertFalse(self.protocol.buffer_updated.called) + + def test_proto_type_switch(self): + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + tr = self.socket_transport() + + res = asyncio.Future(loop=self.loop) + res.set_result(b'data') + + tr = self.socket_transport() + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv.assert_called_with(self.sock, 32768) + self.protocol.data_received.assert_called_with(b'data') + + # switch protocol to a BufferedProtocol + + buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol) + buf = bytearray(4) + buf_proto.get_buffer.side_effect = lambda hint: buf + + tr.set_protocol(buf_proto) + test_utils.run_briefly(self.loop) + res = asyncio.Future(loop=self.loop) + res.set_result(4) + + tr._read_fut = res + tr._loop_reading(res) + self.loop._proactor.recv_into.assert_called_with(self.sock, buf) + buf_proto.buffer_updated.assert_called_with(4) + + def test_proto_buf_switch(self): + tr = self.socket_transport() + test_utils.run_briefly(self.loop) + self.protocol.get_buffer.assert_called_with(-1) + + # switch protocol to *another* BufferedProtocol + + buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol) + buf = bytearray(4) + buf_proto.get_buffer.side_effect = lambda hint: buf + tr._read_fut.done.side_effect = lambda: False + tr.set_protocol(buf_proto) + self.assertFalse(buf_proto.get_buffer.called) + test_utils.run_briefly(self.loop) + buf_proto.get_buffer.assert_called_with(-1) + def test_buffer_updated_error(self): transport = self.socket_transport() transport._fatal_error = mock.Mock() diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 684c29dec3e283..a0aac3a594765c 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -772,7 +772,8 @@ def test_accept_connection_multiple(self): accept2_mock.return_value = None with mock_obj(self.loop, 'create_task') as task_mock: task_mock.return_value = None - self.loop._accept_connection(mock.Mock(), sock, backlog=backlog) + self.loop._accept_connection( + mock.Mock(), sock, backlog=backlog) self.assertEqual(sock.accept.call_count, backlog) @@ -1279,8 +1280,8 @@ def setUp(self): self.loop = self.new_test_loop() self.protocol = test_utils.make_test_protocol(asyncio.BufferedProtocol) - self.buf = mock.Mock() - self.protocol.get_buffer.side_effect = lambda: self.buf + self.buf = bytearray(1) + self.protocol.get_buffer.side_effect = lambda hint: self.buf self.sock = mock.Mock(socket.socket) self.sock_fd = self.sock.fileno.return_value = 7 @@ -1313,6 +1314,42 @@ def test_get_buffer_error(self): self.assertTrue(self.protocol.get_buffer.called) self.assertFalse(self.protocol.buffer_updated.called) + def test_get_buffer_zerosized(self): + transport = self.socket_transport() + transport._fatal_error = mock.Mock() + + self.loop.call_exception_handler = mock.Mock() + self.protocol.get_buffer.side_effect = lambda hint: bytearray(0) + + transport._read_ready() + + self.assertTrue(transport._fatal_error.called) + self.assertTrue(self.protocol.get_buffer.called) + self.assertFalse(self.protocol.buffer_updated.called) + + def test_proto_type_switch(self): + self.protocol = test_utils.make_test_protocol(asyncio.Protocol) + transport = self.socket_transport() + + self.sock.recv.return_value = b'data' + transport._read_ready() + + self.protocol.data_received.assert_called_with(b'data') + + # switch protocol to a BufferedProtocol + + buf_proto = test_utils.make_test_protocol(asyncio.BufferedProtocol) + buf = bytearray(4) + buf_proto.get_buffer.side_effect = lambda hint: buf + + transport.set_protocol(buf_proto) + + self.sock.recv_into.return_value = 10 + transport._read_ready() + + buf_proto.get_buffer.assert_called_with(-1) + buf_proto.buffer_updated.assert_called_with(10) + def test_buffer_updated_error(self): transport = self.socket_transport() transport._fatal_error = mock.Mock() @@ -1348,7 +1385,7 @@ def test_read_ready(self): self.sock.recv_into.return_value = 10 transport._read_ready() - self.protocol.get_buffer.assert_called_with() + self.protocol.get_buffer.assert_called_with(-1) self.protocol.buffer_updated.assert_called_with(10) def test_read_ready_eof(self): diff --git a/Lib/test/test_asyncio/test_sslproto.py b/Lib/test/test_asyncio/test_sslproto.py index c534a341352b00..932487a9e3c639 100644 --- a/Lib/test/test_asyncio/test_sslproto.py +++ b/Lib/test/test_asyncio/test_sslproto.py @@ -1,8 +1,7 @@ """Tests for asyncio/sslproto.py.""" -import os import logging -import time +import socket import unittest from unittest import mock try: @@ -185,17 +184,67 @@ def test_write_after_closing(self): class BaseStartTLS(func_tests.FunctionalTestCaseMixin): + PAYLOAD_SIZE = 1024 * 100 + TIMEOUT = 60 + def new_loop(self): raise NotImplementedError - def test_start_tls_client_1(self): - HELLO_MSG = b'1' * 1024 * 1024 + def test_buf_feed_data(self): + + class Proto(asyncio.BufferedProtocol): + + def __init__(self, bufsize, usemv): + self.buf = bytearray(bufsize) + self.mv = memoryview(self.buf) + self.data = b'' + self.usemv = usemv + + def get_buffer(self, sizehint): + if self.usemv: + return self.mv + else: + return self.buf + + def buffer_updated(self, nsize): + if self.usemv: + self.data += self.mv[:nsize] + else: + self.data += self.buf[:nsize] + + for usemv in [False, True]: + proto = Proto(1, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(2, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(2, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'1234') + self.assertEqual(proto.data, b'1234') + + proto = Proto(4, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'1234') + self.assertEqual(proto.data, b'1234') + + proto = Proto(100, usemv) + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + self.assertEqual(proto.data, b'12345') + + proto = Proto(0, usemv) + with self.assertRaisesRegex(RuntimeError, 'empty buffer'): + sslproto._feed_data_to_bufferred_proto(proto, b'12345') + + def test_start_tls_client_reg_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() def serve(sock): - sock.settimeout(5) + sock.settimeout(self.TIMEOUT) data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) @@ -205,6 +254,8 @@ def serve(sock): sock.sendall(b'O') data = sock.recv_all(len(HELLO_MSG)) self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) sock.close() class ClientProto(asyncio.Protocol): @@ -246,17 +297,80 @@ async def client(addr): self.loop.run_until_complete( asyncio.wait_for(client(srv.addr), loop=self.loop, timeout=10)) + def test_start_tls_client_buf_proto_1(self): + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + + server_context = test_utils.simple_server_sslcontext() + client_context = test_utils.simple_client_sslcontext() + + def serve(sock): + sock.settimeout(self.TIMEOUT) + + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.start_tls(server_context, server_side=True) + + sock.sendall(b'O') + data = sock.recv_all(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + sock.shutdown(socket.SHUT_RDWR) + sock.close() + + class ClientProto(asyncio.BufferedProtocol): + def __init__(self, on_data, on_eof): + self.on_data = on_data + self.on_eof = on_eof + self.con_made_cnt = 0 + self.buf = bytearray(1) + + def connection_made(proto, tr): + proto.con_made_cnt += 1 + # Ensure connection_made gets called only once. + self.assertEqual(proto.con_made_cnt, 1) + + def get_buffer(self, sizehint): + return self.buf + + def buffer_updated(self, nsize): + assert nsize == 1 + self.on_data.set_result(bytes(self.buf[:nsize])) + + def eof_received(self): + self.on_eof.set_result(True) + + async def client(addr): + await asyncio.sleep(0.5, loop=self.loop) + + on_data = self.loop.create_future() + on_eof = self.loop.create_future() + + tr, proto = await self.loop.create_connection( + lambda: ClientProto(on_data, on_eof), *addr) + + tr.write(HELLO_MSG) + new_tr = await self.loop.start_tls(tr, proto, client_context) + + self.assertEqual(await on_data, b'O') + new_tr.write(HELLO_MSG) + await on_eof + + new_tr.close() + + with self.tcp_server(serve) as srv: + self.loop.run_until_complete( + asyncio.wait_for(client(srv.addr), + loop=self.loop, timeout=self.TIMEOUT)) + def test_start_tls_server_1(self): - HELLO_MSG = b'1' * 1024 * 1024 + HELLO_MSG = b'1' * self.PAYLOAD_SIZE server_context = test_utils.simple_server_sslcontext() client_context = test_utils.simple_client_sslcontext() - # TODO: fix TLSv1.3 support - client_context.options |= ssl.OP_NO_TLSv1_3 def client(sock, addr): - time.sleep(0.5) - sock.settimeout(5) + sock.settimeout(self.TIMEOUT) sock.connect(addr) data = sock.recv_all(len(HELLO_MSG)) @@ -264,12 +378,15 @@ def client(sock, addr): sock.start_tls(client_context) sock.sendall(HELLO_MSG) + + sock.shutdown(socket.SHUT_RDWR) sock.close() class ServerProto(asyncio.Protocol): - def __init__(self, on_con, on_eof): + def __init__(self, on_con, on_eof, on_con_lost): self.on_con = on_con self.on_eof = on_eof + self.on_con_lost = on_con_lost self.data = b'' def connection_made(self, tr): @@ -281,7 +398,13 @@ def data_received(self, data): def eof_received(self): self.on_eof.set_result(1) - async def main(): + def connection_lost(self, exc): + if exc is None: + self.on_con_lost.set_result(None) + else: + self.on_con_lost.set_exception(exc) + + async def main(proto, on_con, on_eof, on_con_lost): tr = await on_con tr.write(HELLO_MSG) @@ -292,24 +415,29 @@ async def main(): server_side=True) await on_eof + await on_con_lost self.assertEqual(proto.data, HELLO_MSG) new_tr.close() - server.close() - await server.wait_closed() + async def run_main(): + on_con = self.loop.create_future() + on_eof = self.loop.create_future() + on_con_lost = self.loop.create_future() + proto = ServerProto(on_con, on_eof, on_con_lost) - on_con = self.loop.create_future() - on_eof = self.loop.create_future() - proto = ServerProto(on_con, on_eof) + server = await self.loop.create_server( + lambda: proto, '127.0.0.1', 0) + addr = server.sockets[0].getsockname() - server = self.loop.run_until_complete( - self.loop.create_server( - lambda: proto, '127.0.0.1', 0)) - addr = server.sockets[0].getsockname() + with self.tcp_client(lambda sock: client(sock, addr)): + await asyncio.wait_for( + main(proto, on_con, on_eof, on_con_lost), + loop=self.loop, timeout=self.TIMEOUT) - with self.tcp_client(lambda sock: client(sock, addr)): - self.loop.run_until_complete( - asyncio.wait_for(main(), loop=self.loop, timeout=10)) + server.close() + await server.wait_closed() + + self.loop.run_until_complete(run_main()) def test_start_tls_wrong_args(self): async def main(): @@ -332,7 +460,6 @@ def new_loop(self): @unittest.skipIf(ssl is None, 'No ssl module') @unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only') -@unittest.skipIf(os.environ.get('APPVEYOR'), 'XXX: issue 32458') class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase): def new_loop(self): diff --git a/Misc/NEWS.d/next/Library/2018-05-26-13-09-34.bpo-33654.IbYWxA.rst b/Misc/NEWS.d/next/Library/2018-05-26-13-09-34.bpo-33654.IbYWxA.rst new file mode 100644 index 00000000000000..3ae506ddc55f41 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-05-26-13-09-34.bpo-33654.IbYWxA.rst @@ -0,0 +1,3 @@ +Fix transport.set_protocol() to support switching between asyncio.Protocol +and asyncio.BufferedProtocol. Fix loop.start_tls() to work with +asyncio.BufferedProtocols.