diff --git a/Lib/asyncio/__init__.py b/Lib/asyncio/__init__.py index ff69378ba9..eb84bfb189 100644 --- a/Lib/asyncio/__init__.py +++ b/Lib/asyncio/__init__.py @@ -1,22 +1,14 @@ """The asyncio package, tracking PEP 3156.""" # flake8: noqa -import sys - -import selectors -# XXX RustPython TODO: _overlapped -if sys.platform == 'win32' and False: - # Similar thing for _overlapped. - try: - from . import _overlapped - except ImportError: - import _overlapped # Will also be exported. +import sys # This relies on each of the submodules having an __all__ variable. from .base_events import * from .coroutines import * from .events import * +from .exceptions import * from .futures import * from .locks import * from .protocols import * @@ -25,11 +17,17 @@ from .streams import * from .subprocess import * from .tasks import * +from .threads import * from .transports import * +# Exposed for _asynciomodule.c to implement now deprecated +# Task.all_tasks() method. This function will be removed in 3.9. +from .tasks import _all_tasks_compat # NoQA + __all__ = (base_events.__all__ + coroutines.__all__ + events.__all__ + + exceptions.__all__ + futures.__all__ + locks.__all__ + protocols.__all__ + @@ -38,6 +36,7 @@ streams.__all__ + subprocess.__all__ + tasks.__all__ + + threads.__all__ + transports.__all__) if sys.platform == 'win32': # pragma: no cover diff --git a/Lib/asyncio/__main__.py b/Lib/asyncio/__main__.py new file mode 100644 index 0000000000..18bb87a5bc --- /dev/null +++ b/Lib/asyncio/__main__.py @@ -0,0 +1,125 @@ +import ast +import asyncio +import code +import concurrent.futures +import inspect +import sys +import threading +import types +import warnings + +from . import futures + + +class AsyncIOInteractiveConsole(code.InteractiveConsole): + + def __init__(self, locals, loop): + super().__init__(locals) + self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + + self.loop = loop + + def runcode(self, code): + future = concurrent.futures.Future() + + def callback(): + global repl_future + global repl_future_interrupted + + repl_future = None + repl_future_interrupted = False + + func = types.FunctionType(code, self.locals) + try: + coro = func() + except SystemExit: + raise + except KeyboardInterrupt as ex: + repl_future_interrupted = True + future.set_exception(ex) + return + except BaseException as ex: + future.set_exception(ex) + return + + if not inspect.iscoroutine(coro): + future.set_result(coro) + return + + try: + repl_future = self.loop.create_task(coro) + futures._chain_future(repl_future, future) + except BaseException as exc: + future.set_exception(exc) + + loop.call_soon_threadsafe(callback) + + try: + return future.result() + except SystemExit: + raise + except BaseException: + if repl_future_interrupted: + self.write("\nKeyboardInterrupt\n") + else: + self.showtraceback() + + +class REPLThread(threading.Thread): + + def run(self): + try: + banner = ( + f'asyncio REPL {sys.version} on {sys.platform}\n' + f'Use "await" directly instead of "asyncio.run()".\n' + f'Type "help", "copyright", "credits" or "license" ' + f'for more information.\n' + f'{getattr(sys, "ps1", ">>> ")}import asyncio' + ) + + console.interact( + banner=banner, + exitmsg='exiting asyncio REPL...') + finally: + warnings.filterwarnings( + 'ignore', + message=r'^coroutine .* was never awaited$', + category=RuntimeWarning) + + loop.call_soon_threadsafe(loop.stop) + + +if __name__ == '__main__': + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + repl_locals = {'asyncio': asyncio} + for key in {'__name__', '__package__', + '__loader__', '__spec__', + '__builtins__', '__file__'}: + repl_locals[key] = locals()[key] + + console = AsyncIOInteractiveConsole(repl_locals, loop) + + repl_future = None + repl_future_interrupted = False + + try: + import readline # NoQA + except ImportError: + pass + + repl_thread = REPLThread() + repl_thread.daemon = True + repl_thread.start() + + while True: + try: + loop.run_forever() + except KeyboardInterrupt: + if repl_future and not repl_future.done(): + repl_future.cancel() + repl_future_interrupted = True + continue + else: + break diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 2df379933c..4356bfae01 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -14,13 +14,14 @@ """ import collections +import collections.abc import concurrent.futures +import functools import heapq -import inspect import itertools -import logging import os import socket +import stat import subprocess import threading import time @@ -29,16 +30,26 @@ import warnings import weakref -from . import compat +try: + import ssl +except ImportError: # pragma: no cover + ssl = None + +from . import constants from . import coroutines from . import events +from . import exceptions from . import futures +from . import protocols +from . import sslproto +from . import staggered from . import tasks -from .coroutines import coroutine +from . import transports +from . import trsock from .log import logger -__all__ = ['BaseEventLoop'] +__all__ = 'BaseEventLoop','Server', # Minimum number of _scheduled timer handles before cleanup of @@ -49,10 +60,15 @@ # before cleanup of cancelled handles is performed. _MIN_CANCELLED_TIMER_HANDLES_FRACTION = 0.5 -# Exceptions which must not call the exception handler in fatal error -# methods (_fatal_error()) -_FATAL_ERROR_IGNORE = (BrokenPipeError, - ConnectionResetError, ConnectionAbortedError) + +_HAS_IPv6 = hasattr(socket, 'AF_INET6') + +# Maximum timeout passed to select to avoid OS limitations +MAXIMUM_SELECT_TIMEOUT = 24 * 3600 + +# Used for deprecation and removal of `loop.create_datagram_endpoint()`'s +# *reuse_address* parameter +_unset = object() def _format_handle(handle): @@ -84,21 +100,7 @@ def _set_reuseport(sock): 'SO_REUSEPORT defined but not implemented.') -def _is_stream_socket(sock): - # Linux's socket.type is a bitmask that can include extra info - # about socket, therefore we can't do simple - # `sock_type == socket.SOCK_STREAM`. - return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM - - -def _is_dgram_socket(sock): - # Linux's socket.type is a bitmask that can include extra info - # about socket, therefore we can't do simple - # `sock_type == socket.SOCK_DGRAM`. - return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM - - -def _ipaddr_info(host, port, family, type, proto): +def _ipaddr_info(host, port, family, type, proto, flowinfo=0, scopeid=0): # Try to skip getaddrinfo if "host" is already an IP. Users might have # handled name resolution in their own code and pass in resolved IPs. if not hasattr(socket, 'inet_pton'): @@ -109,11 +111,6 @@ def _ipaddr_info(host, port, family, type, proto): return None if type == socket.SOCK_STREAM: - # Linux only: - # getaddrinfo() can raise when socket.type is a bit mask. - # So if socket.type is a bit mask of SOCK_STREAM, and say - # SOCK_NONBLOCK, we simply return None, which will trigger - # a call to getaddrinfo() letting it process this request. proto = socket.IPPROTO_TCP elif type == socket.SOCK_DGRAM: proto = socket.IPPROTO_UDP @@ -135,7 +132,7 @@ def _ipaddr_info(host, port, family, type, proto): if family == socket.AF_UNSPEC: afs = [socket.AF_INET] - if hasattr(socket, 'AF_INET6'): + if _HAS_IPv6: afs.append(socket.AF_INET6) else: afs = [family] @@ -151,7 +148,10 @@ def _ipaddr_info(host, port, family, type, proto): try: socket.inet_pton(af, host) # The host has already been resolved. - return af, type, proto, '', (host, port) + if _HAS_IPv6 and af == socket.AF_INET6: + return af, type, proto, '', (host, port, flowinfo, scopeid) + else: + return af, type, proto, '', (host, port) except OSError: pass @@ -159,75 +159,231 @@ def _ipaddr_info(host, port, family, type, proto): return None -def _ensure_resolved(address, *, family=0, type=socket.SOCK_STREAM, proto=0, - flags=0, loop): - host, port = address[:2] - info = _ipaddr_info(host, port, family, type, proto) - if info is not None: - # "host" is already a resolved IP. - fut = loop.create_future() - fut.set_result([info]) - return fut - else: - return loop.getaddrinfo(host, port, family=family, type=type, - proto=proto, flags=flags) +def _interleave_addrinfos(addrinfos, first_address_family_count=1): + """Interleave list of addrinfo tuples by family.""" + # Group addresses by family + addrinfos_by_family = collections.OrderedDict() + for addr in addrinfos: + family = addr[0] + if family not in addrinfos_by_family: + addrinfos_by_family[family] = [] + addrinfos_by_family[family].append(addr) + addrinfos_lists = list(addrinfos_by_family.values()) + + reordered = [] + if first_address_family_count > 1: + reordered.extend(addrinfos_lists[0][:first_address_family_count - 1]) + del addrinfos_lists[0][:first_address_family_count - 1] + reordered.extend( + a for a in itertools.chain.from_iterable( + itertools.zip_longest(*addrinfos_lists) + ) if a is not None) + return reordered def _run_until_complete_cb(fut): - exc = fut._exception - if (isinstance(exc, BaseException) - and not isinstance(exc, Exception)): - # Issue #22429: run_forever() already finished, no need to - # stop it. - return - fut._loop.stop() + if not fut.cancelled(): + exc = fut.exception() + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + # Issue #22429: run_forever() already finished, no need to + # stop it. + return + futures._get_loop(fut).stop() + + +if hasattr(socket, 'TCP_NODELAY'): + def _set_nodelay(sock): + if (sock.family in {socket.AF_INET, socket.AF_INET6} and + sock.type == socket.SOCK_STREAM and + sock.proto == socket.IPPROTO_TCP): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) +else: + def _set_nodelay(sock): + pass + + +def _check_ssl_socket(sock): + if ssl is not None and isinstance(sock, ssl.SSLSocket): + raise TypeError("Socket cannot be of type SSLSocket") + + +class _SendfileFallbackProtocol(protocols.Protocol): + def __init__(self, transp): + if not isinstance(transp, transports._FlowControlMixin): + raise TypeError("transport should be _FlowControlMixin instance") + self._transport = transp + self._proto = transp.get_protocol() + self._should_resume_reading = transp.is_reading() + self._should_resume_writing = transp._protocol_paused + transp.pause_reading() + transp.set_protocol(self) + if self._should_resume_writing: + self._write_ready_fut = self._transport._loop.create_future() + else: + self._write_ready_fut = None + + async def drain(self): + if self._transport.is_closing(): + raise ConnectionError("Connection closed by peer") + fut = self._write_ready_fut + if fut is None: + return + await fut + + def connection_made(self, transport): + raise RuntimeError("Invalid state: " + "connection should have been established already.") + + def connection_lost(self, exc): + if self._write_ready_fut is not None: + # Never happens if peer disconnects after sending the whole content + # Thus disconnection is always an exception from user perspective + if exc is None: + self._write_ready_fut.set_exception( + ConnectionError("Connection is closed by peer")) + else: + self._write_ready_fut.set_exception(exc) + self._proto.connection_lost(exc) + + def pause_writing(self): + if self._write_ready_fut is not None: + return + self._write_ready_fut = self._transport._loop.create_future() + + def resume_writing(self): + if self._write_ready_fut is None: + return + self._write_ready_fut.set_result(False) + self._write_ready_fut = None + + def data_received(self, data): + raise RuntimeError("Invalid state: reading should be paused") + + def eof_received(self): + raise RuntimeError("Invalid state: reading should be paused") + + async def restore(self): + self._transport.set_protocol(self._proto) + if self._should_resume_reading: + self._transport.resume_reading() + if self._write_ready_fut is not None: + # Cancel the future. + # Basically it has no effect because protocol is switched back, + # no code should wait for it anymore. + self._write_ready_fut.cancel() + if self._should_resume_writing: + self._proto.resume_writing() class Server(events.AbstractServer): - def __init__(self, loop, sockets): + def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, + ssl_handshake_timeout): self._loop = loop - self.sockets = sockets + self._sockets = sockets self._active_count = 0 self._waiters = [] + self._protocol_factory = protocol_factory + self._backlog = backlog + self._ssl_context = ssl_context + self._ssl_handshake_timeout = ssl_handshake_timeout + self._serving = False + self._serving_forever_fut = None def __repr__(self): - return '<%s sockets=%r>' % (self.__class__.__name__, self.sockets) + return f'<{self.__class__.__name__} sockets={self.sockets!r}>' def _attach(self): - assert self.sockets is not None + assert self._sockets is not None self._active_count += 1 def _detach(self): assert self._active_count > 0 self._active_count -= 1 - if self._active_count == 0 and self.sockets is None: + if self._active_count == 0 and self._sockets is None: self._wakeup() + def _wakeup(self): + waiters = self._waiters + self._waiters = None + for waiter in waiters: + if not waiter.done(): + waiter.set_result(waiter) + + def _start_serving(self): + if self._serving: + return + self._serving = True + for sock in self._sockets: + sock.listen(self._backlog) + self._loop._start_serving( + self._protocol_factory, sock, self._ssl_context, + self, self._backlog, self._ssl_handshake_timeout) + + def get_loop(self): + return self._loop + + def is_serving(self): + return self._serving + + @property + def sockets(self): + if self._sockets is None: + return () + return tuple(trsock.TransportSocket(s) for s in self._sockets) + def close(self): - sockets = self.sockets + sockets = self._sockets if sockets is None: return - self.sockets = None + self._sockets = None + for sock in sockets: self._loop._stop_serving(sock) + + self._serving = False + + if (self._serving_forever_fut is not None and + not self._serving_forever_fut.done()): + self._serving_forever_fut.cancel() + self._serving_forever_fut = None + if self._active_count == 0: self._wakeup() - def _wakeup(self): - waiters = self._waiters - self._waiters = None - for waiter in waiters: - if not waiter.done(): - waiter.set_result(waiter) + async def start_serving(self): + self._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0) + + async def serve_forever(self): + if self._serving_forever_fut is not None: + raise RuntimeError( + f'server {self!r} is already being awaited on serve_forever()') + if self._sockets is None: + raise RuntimeError(f'server {self!r} is closed') + + self._start_serving() + self._serving_forever_fut = self._loop.create_future() - @coroutine - def wait_closed(self): - if self.sockets is None or self._waiters is None: + try: + await self._serving_forever_fut + except exceptions.CancelledError: + try: + self.close() + await self.wait_closed() + finally: + raise + finally: + self._serving_forever_fut = None + + async def wait_closed(self): + if self._sockets is None or self._waiters is None: return waiter = self._loop.create_future() self._waiters.append(waiter) - yield from waiter + await waiter class BaseEventLoop(events.AbstractEventLoop): @@ -243,49 +399,49 @@ def __init__(self): # Identifier of the thread running the event loop, or None if the # event loop is not running self._thread_id = None - self._clock_resolution = 1e-06 #time.get_clock_info('monotonic').resolution + self._clock_resolution = time.get_clock_info('monotonic').resolution self._exception_handler = None - self.set_debug((not sys.flags.ignore_environment - and bool(os.environ.get('PYTHONASYNCIODEBUG')))) + self.set_debug(coroutines._is_debug_mode()) # In debug mode, if the execution of a callback or a step of a task # exceed this duration in seconds, the slow callback/task is logged. self.slow_callback_duration = 0.1 self._current_handle = None self._task_factory = None - self._coroutine_wrapper_set = False - - if hasattr(sys, 'get_asyncgen_hooks'): - # Python >= 3.6 - # A weak set of all asynchronous generators that are - # being iterated by the loop. - self._asyncgens = weakref.WeakSet() - else: - self._asyncgens = None + self._coroutine_origin_tracking_enabled = False + self._coroutine_origin_tracking_saved_depth = None + # A weak set of all asynchronous generators that are + # being iterated by the loop. + self._asyncgens = weakref.WeakSet() # Set to True when `loop.shutdown_asyncgens` is called. self._asyncgens_shutdown_called = False + # Set to True when `loop.shutdown_default_executor` is called. + self._executor_shutdown_called = False def __repr__(self): - return ('<%s running=%s closed=%s debug=%s>' - % (self.__class__.__name__, self.is_running(), - self.is_closed(), self.get_debug())) + return ( + f'<{self.__class__.__name__} running={self.is_running()} ' + f'closed={self.is_closed()} debug={self.get_debug()}>' + ) def create_future(self): """Create a Future object attached to the loop.""" return futures.Future(loop=self) - def create_task(self, coro): + def create_task(self, coro, *, name=None): """Schedule a coroutine object. Return a task object. """ self._check_closed() if self._task_factory is None: - task = tasks.Task(coro, loop=self) + task = tasks.Task(coro, loop=self, name=name) if task._source_traceback: del task._source_traceback[-1] else: task = self._task_factory(self, coro) + tasks._set_task_name(task, name) + return task def set_task_factory(self, factory): @@ -311,9 +467,12 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, """Create socket transport.""" raise NotImplementedError - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, - extra=None, server=None): + def _make_ssl_transport( + self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None, + ssl_handshake_timeout=None, + call_connection_made=True): """Create SSL transport.""" raise NotImplementedError @@ -332,10 +491,9 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, """Create write pipe transport.""" raise NotImplementedError - @coroutine - def _make_subprocess_transport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): + async def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): """Create subprocess transport.""" raise NotImplementedError @@ -356,29 +514,29 @@ def _check_closed(self): if self._closed: raise RuntimeError('Event loop is closed') + def _check_default_executor(self): + if self._executor_shutdown_called: + raise RuntimeError('Executor shutdown has been called') + def _asyncgen_finalizer_hook(self, agen): self._asyncgens.discard(agen) if not self.is_closed(): - self.create_task(agen.aclose()) - # Wake up the loop if the finalizer was called from - # a different thread. - self._write_to_self() + self.call_soon_threadsafe(self.create_task, agen.aclose()) def _asyncgen_firstiter_hook(self, agen): if self._asyncgens_shutdown_called: warnings.warn( - "asynchronous generator {!r} was scheduled after " - "loop.shutdown_asyncgens() call".format(agen), + f"asynchronous generator {agen!r} was scheduled after " + f"loop.shutdown_asyncgens() call", ResourceWarning, source=self) self._asyncgens.add(agen) - @coroutine - def shutdown_asyncgens(self): + async def shutdown_asyncgens(self): """Shutdown all active asynchronous generators.""" self._asyncgens_shutdown_called = True - if self._asyncgens is None or not len(self._asyncgens): + if not len(self._asyncgens): # If Python version is <3.6 or we don't have any asynchronous # generators alive. return @@ -386,35 +544,57 @@ def shutdown_asyncgens(self): closing_agens = list(self._asyncgens) self._asyncgens.clear() - shutdown_coro = tasks.gather( + results = await tasks._gather( *[ag.aclose() for ag in closing_agens], return_exceptions=True, loop=self) - results = yield from shutdown_coro for result, agen in zip(results, closing_agens): if isinstance(result, Exception): self.call_exception_handler({ - 'message': 'an error occurred during closing of ' - 'asynchronous generator {!r}'.format(agen), + 'message': f'an error occurred during closing of ' + f'asynchronous generator {agen!r}', 'exception': result, 'asyncgen': agen }) - def run_forever(self): - """Run until stop() is called.""" - self._check_closed() + async def shutdown_default_executor(self): + """Schedule the shutdown of the default executor.""" + self._executor_shutdown_called = True + if self._default_executor is None: + return + future = self.create_future() + thread = threading.Thread(target=self._do_shutdown, args=(future,)) + thread.start() + try: + await future + finally: + thread.join() + + def _do_shutdown(self, future): + try: + self._default_executor.shutdown(wait=True) + self.call_soon_threadsafe(future.set_result, None) + except Exception as ex: + self.call_soon_threadsafe(future.set_exception, ex) + + def _check_running(self): if self.is_running(): raise RuntimeError('This event loop is already running') if events._get_running_loop() is not None: raise RuntimeError( 'Cannot run the event loop while another loop is running') - self._set_coroutine_wrapper(self._debug) + + def run_forever(self): + """Run until stop() is called.""" + self._check_closed() + self._check_running() + self._set_coroutine_origin_tracking(self._debug) self._thread_id = threading.get_ident() - if self._asyncgens is not None: - old_agen_hooks = sys.get_asyncgen_hooks() - sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook, - finalizer=self._asyncgen_finalizer_hook) + + old_agen_hooks = sys.get_asyncgen_hooks() + sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook, + finalizer=self._asyncgen_finalizer_hook) try: events._set_running_loop(self) while True: @@ -425,9 +605,8 @@ def run_forever(self): self._stopping = False self._thread_id = None events._set_running_loop(None) - self._set_coroutine_wrapper(False) - if self._asyncgens is not None: - sys.set_asyncgen_hooks(*old_agen_hooks) + self._set_coroutine_origin_tracking(False) + sys.set_asyncgen_hooks(*old_agen_hooks) def run_until_complete(self, future): """Run until the Future is done. @@ -441,6 +620,7 @@ def run_until_complete(self, future): Return the Future's result, or raise its exception. """ self._check_closed() + self._check_running() new_task = not futures.isfuture(future) future = tasks.ensure_future(future, loop=self) @@ -459,7 +639,8 @@ def run_until_complete(self, future): # local task. future.exception() raise - future.remove_done_callback(_run_until_complete_cb) + finally: + future.remove_done_callback(_run_until_complete_cb) if not future.done(): raise RuntimeError('Event loop stopped before Future completed.') @@ -490,6 +671,7 @@ def close(self): self._closed = True self._ready.clear() self._scheduled.clear() + self._executor_shutdown_called = True executor = self._default_executor if executor is not None: self._default_executor = None @@ -499,16 +681,11 @@ def is_closed(self): """Returns True if the event loop was closed.""" return self._closed - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self.is_closed(): - warnings.warn("unclosed event loop %r" % self, ResourceWarning, - source=self) - if not self.is_running(): - self.close() + def __del__(self, _warn=warnings.warn): + if not self.is_closed(): + _warn(f"unclosed event loop {self!r}", ResourceWarning, source=self) + if not self.is_running(): + self.close() def is_running(self): """Returns True if the event loop is running.""" @@ -523,7 +700,7 @@ def time(self): """ return time.monotonic() - def call_later(self, delay, callback, *args): + def call_later(self, delay, callback, *args, context=None): """Arrange for a callback to be called at a given time. Return a Handle: an opaque object with a cancel() method that @@ -539,12 +716,13 @@ def call_later(self, delay, callback, *args): Any positional arguments after the callback will be passed to the callback when it is called. """ - timer = self.call_at(self.time() + delay, callback, *args) + timer = self.call_at(self.time() + delay, callback, *args, + context=context) if timer._source_traceback: del timer._source_traceback[-1] return timer - def call_at(self, when, callback, *args): + def call_at(self, when, callback, *args, context=None): """Like call_later(), but uses an absolute time. Absolute time corresponds to the event loop's time() method. @@ -553,14 +731,14 @@ def call_at(self, when, callback, *args): if self._debug: self._check_thread() self._check_callback(callback, 'call_at') - timer = events.TimerHandle(when, callback, args, self) + timer = events.TimerHandle(when, callback, args, self, context) if timer._source_traceback: del timer._source_traceback[-1] heapq.heappush(self._scheduled, timer) timer._scheduled = True return timer - def call_soon(self, callback, *args): + def call_soon(self, callback, *args, context=None): """Arrange for a callback to be called as soon as possible. This operates as a FIFO queue: callbacks are called in the @@ -574,7 +752,7 @@ def call_soon(self, callback, *args): if self._debug: self._check_thread() self._check_callback(callback, 'call_soon') - handle = self._call_soon(callback, args) + handle = self._call_soon(callback, args, context) if handle._source_traceback: del handle._source_traceback[-1] return handle @@ -583,15 +761,14 @@ def _check_callback(self, callback, method): if (coroutines.iscoroutine(callback) or coroutines.iscoroutinefunction(callback)): raise TypeError( - "coroutines cannot be used with {}()".format(method)) + f"coroutines cannot be used with {method}()") if not callable(callback): raise TypeError( - 'a callable object was expected by {}(), got {!r}'.format( - method, callback)) - + f'a callable object was expected by {method}(), ' + f'got {callback!r}') - def _call_soon(self, callback, args): - handle = events.Handle(callback, args, self) + def _call_soon(self, callback, args, context): + handle = events.Handle(callback, args, self, context) if handle._source_traceback: del handle._source_traceback[-1] self._ready.append(handle) @@ -614,12 +791,12 @@ def _check_thread(self): "Non-thread-safe operation invoked on an event loop other " "than the current one") - def call_soon_threadsafe(self, callback, *args): + def call_soon_threadsafe(self, callback, *args, context=None): """Like call_soon(), but thread-safe.""" self._check_closed() if self._debug: self._check_callback(callback, 'call_soon_threadsafe') - handle = self._call_soon(callback, args) + handle = self._call_soon(callback, args, context) if handle._source_traceback: del handle._source_traceback[-1] self._write_to_self() @@ -631,24 +808,35 @@ def run_in_executor(self, executor, func, *args): self._check_callback(func, 'run_in_executor') if executor is None: executor = self._default_executor + # Only check when the default executor is being used + self._check_default_executor() if executor is None: - executor = concurrent.futures.ThreadPoolExecutor() + executor = concurrent.futures.ThreadPoolExecutor( + thread_name_prefix='asyncio' + ) self._default_executor = executor - return futures.wrap_future(executor.submit(func, *args), loop=self) + return futures.wrap_future( + executor.submit(func, *args), loop=self) def set_default_executor(self, executor): + if not isinstance(executor, concurrent.futures.ThreadPoolExecutor): + warnings.warn( + 'Using the default executor that is not an instance of ' + 'ThreadPoolExecutor is deprecated and will be prohibited ' + 'in Python 3.9', + DeprecationWarning, 2) self._default_executor = executor def _getaddrinfo_debug(self, host, port, family, type, proto, flags): - msg = ["%s:%r" % (host, port)] + msg = [f"{host}:{port!r}"] if family: - msg.append('family=%r' % family) + msg.append(f'family={family!r}') if type: - msg.append('type=%r' % type) + msg.append(f'type={type!r}') if proto: - msg.append('proto=%r' % proto) + msg.append(f'proto={proto!r}') if flags: - msg.append('flags=%r' % flags) + msg.append(f'flags={flags!r}') msg = ', '.join(msg) logger.debug('Get address info %s', msg) @@ -656,30 +844,139 @@ def _getaddrinfo_debug(self, host, port, family, type, proto, flags): addrinfo = socket.getaddrinfo(host, port, family, type, proto, flags) dt = self.time() - t0 - msg = ('Getting address info %s took %.3f ms: %r' - % (msg, dt * 1e3, addrinfo)) + msg = f'Getting address info {msg} took {dt * 1e3:.3f}ms: {addrinfo!r}' if dt >= self.slow_callback_duration: logger.info(msg) else: logger.debug(msg) return addrinfo - def getaddrinfo(self, host, port, *, - family=0, type=0, proto=0, flags=0): + async def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): if self._debug: - return self.run_in_executor(None, self._getaddrinfo_debug, - host, port, family, type, proto, flags) + getaddr_func = self._getaddrinfo_debug else: - return self.run_in_executor(None, socket.getaddrinfo, - host, port, family, type, proto, flags) + getaddr_func = socket.getaddrinfo - def getnameinfo(self, sockaddr, flags=0): - return self.run_in_executor(None, socket.getnameinfo, sockaddr, flags) + return await self.run_in_executor( + None, getaddr_func, host, port, family, type, proto, flags) - @coroutine - def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None, server_hostname=None): + async def getnameinfo(self, sockaddr, flags=0): + return await self.run_in_executor( + None, socket.getnameinfo, sockaddr, flags) + + async def sock_sendfile(self, sock, file, offset=0, count=None, + *, fallback=True): + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + _check_ssl_socket(sock) + self._check_sendfile_params(sock, file, offset, count) + try: + return await self._sock_sendfile_native(sock, file, + offset, count) + except exceptions.SendfileNotAvailableError as exc: + if not fallback: + raise + return await self._sock_sendfile_fallback(sock, file, + offset, count) + + async def _sock_sendfile_native(self, sock, file, offset, count): + # NB: sendfile syscall is not supported for SSL sockets and + # non-mmap files even if sendfile is supported by OS + raise exceptions.SendfileNotAvailableError( + f"syscall sendfile is not available for socket {sock!r} " + f"and file {file!r} combination") + + async def _sock_sendfile_fallback(self, sock, file, offset, count): + if offset: + file.seek(offset) + blocksize = ( + min(count, constants.SENDFILE_FALLBACK_READBUFFER_SIZE) + if count else constants.SENDFILE_FALLBACK_READBUFFER_SIZE + ) + buf = bytearray(blocksize) + total_sent = 0 + try: + while True: + if count: + blocksize = min(count - total_sent, blocksize) + if blocksize <= 0: + break + view = memoryview(buf)[:blocksize] + read = await self.run_in_executor(None, file.readinto, view) + if not read: + break # EOF + await self.sock_sendall(sock, view[:read]) + total_sent += read + return total_sent + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset + total_sent) + + def _check_sendfile_params(self, sock, file, offset, count): + if 'b' not in getattr(file, 'mode', 'b'): + raise ValueError("file should be opened in binary mode") + if not sock.type == socket.SOCK_STREAM: + raise ValueError("only SOCK_STREAM type sockets are supported") + if count is not None: + if not isinstance(count, int): + raise TypeError( + "count must be a positive integer (got {!r})".format(count)) + if count <= 0: + raise ValueError( + "count must be a positive integer (got {!r})".format(count)) + if not isinstance(offset, int): + raise TypeError( + "offset must be a non-negative integer (got {!r})".format( + offset)) + if offset < 0: + raise ValueError( + "offset must be a non-negative integer (got {!r})".format( + offset)) + + async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): + """Create, bind and connect one socket.""" + my_exceptions = [] + exceptions.append(my_exceptions) + family, type_, proto, _, address = addr_info + sock = None + try: + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setblocking(False) + if local_addr_infos is not None: + for _, _, _, _, laddr in local_addr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + msg = ( + f'error while attempting to bind on ' + f'address {laddr!r}: ' + f'{exc.strerror.lower()}' + ) + exc = OSError(exc.errno, msg) + my_exceptions.append(exc) + else: # all bind attempts failed + raise my_exceptions.pop() + await self.sock_connect(sock, address) + return sock + except OSError as exc: + my_exceptions.append(exc) + if sock is not None: + sock.close() + raise + except: + if sock is not None: + sock.close() + raise + + async def create_connection( + self, protocol_factory, host=None, port=None, + *, ssl=None, family=0, + proto=0, flags=0, sock=None, + local_addr=None, server_hostname=None, + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. Create a streaming transport connection to a given Internet host and @@ -710,68 +1007,60 @@ def create_connection(self, protocol_factory, host=None, port=None, *, 'when using ssl without a host') server_hostname = host + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + + if sock is not None: + _check_ssl_socket(sock) + + if happy_eyeballs_delay is not None and interleave is None: + # If using happy eyeballs, default to interleave addresses by family + interleave = 1 + if host is not None or port is not None: if sock is not None: raise ValueError( 'host/port and sock can not be specified at the same time') - f1 = _ensure_resolved((host, port), family=family, - type=socket.SOCK_STREAM, proto=proto, - flags=flags, loop=self) - fs = [f1] - if local_addr is not None: - f2 = _ensure_resolved(local_addr, family=family, - type=socket.SOCK_STREAM, proto=proto, - flags=flags, loop=self) - fs.append(f2) - else: - f2 = None - - yield from tasks.wait(fs, loop=self) - - infos = f1.result() + infos = await self._ensure_resolved( + (host, port), family=family, + type=socket.SOCK_STREAM, proto=proto, flags=flags, loop=self) if not infos: raise OSError('getaddrinfo() returned empty list') - if f2 is not None: - laddr_infos = f2.result() + + if local_addr is not None: + laddr_infos = await self._ensure_resolved( + local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, + flags=flags, loop=self) if not laddr_infos: raise OSError('getaddrinfo() returned empty list') + else: + laddr_infos = None + + if interleave: + infos = _interleave_addrinfos(infos, interleave) exceptions = [] - for family, type, proto, cname, address in infos: - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - if f2 is not None: - for _, _, _, _, laddr in laddr_infos: - try: - sock.bind(laddr) - break - except OSError as exc: - exc = OSError( - exc.errno, 'error while ' - 'attempting to bind on address ' - '{!r}: {}'.format( - laddr, exc.strerror.lower())) - exceptions.append(exc) - else: - sock.close() - sock = None - continue - if self._debug: - logger.debug("connect %r to %r", sock, address) - yield from self.sock_connect(sock, address) - except OSError as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - except: - if sock is not None: - sock.close() - raise - else: - break - else: + if happy_eyeballs_delay is None: + # not using happy eyeballs + for addrinfo in infos: + try: + sock = await self._connect_sock( + exceptions, addrinfo, laddr_infos) + break + except OSError: + continue + else: # using happy eyeballs + sock, _, _ = await staggered.staggered_race( + (functools.partial(self._connect_sock, + exceptions, addrinfo, laddr_infos) + for addrinfo in infos), + happy_eyeballs_delay, loop=self) + + if sock is None: + exceptions = [exc for sub in exceptions for exc in sub] if len(exceptions) == 1: raise exceptions[0] else: @@ -788,7 +1077,7 @@ def create_connection(self, protocol_factory, host=None, port=None, *, if sock is None: raise ValueError( 'host and port was not specified and no sock specified') - if not _is_stream_socket(sock): + if sock.type != socket.SOCK_STREAM: # We allow AF_INET, AF_INET6, AF_UNIX as long as they # are SOCK_STREAM. # We support passing AF_UNIX sockets even though we have @@ -796,10 +1085,11 @@ def create_connection(self, protocol_factory, host=None, port=None, *, # Disallowing AF_UNIX in this method, breaks backwards # compatibility. raise ValueError( - 'A Stream Socket was expected, got {!r}'.format(sock)) + f'A Stream Socket was expected, got {sock!r}') - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) + transport, protocol = await self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -808,9 +1098,10 @@ def create_connection(self, protocol_factory, host=None, port=None, *, sock, host, port, transport, protocol) return transport, protocol - @coroutine - def _create_connection_transport(self, sock, protocol_factory, ssl, - server_hostname, server_side=False): + async def _create_connection_transport( + self, sock, protocol_factory, ssl, + server_hostname, server_side=False, + ssl_handshake_timeout=None): sock.setblocking(False) @@ -820,42 +1111,163 @@ def _create_connection_transport(self, sock, protocol_factory, ssl, sslcontext = None if isinstance(ssl, bool) else ssl transport = self._make_ssl_transport( sock, protocol, sslcontext, waiter, - server_side=server_side, server_hostname=server_hostname) + server_side=server_side, server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) else: transport = self._make_socket_transport(sock, protocol, waiter) try: - yield from waiter + await waiter except: transport.close() raise return transport, protocol - @coroutine - def create_datagram_endpoint(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0, - reuse_address=None, reuse_port=None, - allow_broadcast=None, sock=None): + async def sendfile(self, transport, file, offset=0, count=None, + *, fallback=True): + """Send a file to transport. + + Return the total number of bytes which were sent. + + The method uses high-performance os.sendfile if available. + + file must be a regular file object opened in binary mode. + + offset tells from where to start reading the file. If specified, + count is the total number of bytes to transmit as opposed to + sending the file until EOF is reached. File position is updated on + return or also in case of error in which case file.tell() + can be used to figure out the number of bytes + which were sent. + + fallback set to True makes asyncio to manually read and send + the file when the platform does not support the sendfile syscall + (e.g. Windows or SSL socket on Unix). + + Raise SendfileNotAvailableError if the system does not support + sendfile syscall and fallback is False. + """ + if transport.is_closing(): + raise RuntimeError("Transport is closing") + mode = getattr(transport, '_sendfile_compatible', + constants._SendfileMode.UNSUPPORTED) + if mode is constants._SendfileMode.UNSUPPORTED: + raise RuntimeError( + f"sendfile is not supported for transport {transport!r}") + if mode is constants._SendfileMode.TRY_NATIVE: + try: + return await self._sendfile_native(transport, file, + offset, count) + except exceptions.SendfileNotAvailableError as exc: + if not fallback: + raise + + if not fallback: + raise RuntimeError( + f"fallback is disabled and native sendfile is not " + f"supported for transport {transport!r}") + + return await self._sendfile_fallback(transport, file, + offset, count) + + async def _sendfile_native(self, transp, file, offset, count): + raise exceptions.SendfileNotAvailableError( + "sendfile syscall is not supported") + + async def _sendfile_fallback(self, transp, file, offset, count): + if offset: + file.seek(offset) + blocksize = min(count, 16384) if count else 16384 + buf = bytearray(blocksize) + total_sent = 0 + proto = _SendfileFallbackProtocol(transp) + try: + while True: + if count: + blocksize = min(count - total_sent, blocksize) + if blocksize <= 0: + return total_sent + view = memoryview(buf)[:blocksize] + read = await self.run_in_executor(None, file.readinto, view) + if not read: + return total_sent # EOF + await proto.drain() + transp.write(view[:read]) + total_sent += read + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset + total_sent) + await proto.restore() + + async def start_tls(self, transport, protocol, sslcontext, *, + server_side=False, + server_hostname=None, + ssl_handshake_timeout=None): + """Upgrade transport to TLS. + + Return a new transport that *protocol* should start using + immediately. + """ + if ssl is None: + raise RuntimeError('Python ssl module is not available') + + if not isinstance(sslcontext, ssl.SSLContext): + raise TypeError( + f'sslcontext is expected to be an instance of ssl.SSLContext, ' + f'got {sslcontext!r}') + + if not getattr(transport, '_start_tls_compatible', False): + raise TypeError( + f'transport {transport!r} is not supported by start_tls()') + + waiter = self.create_future() + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + call_connection_made=False) + + # Pause early so that "ssl_protocol.data_received()" doesn't + # have a chance to get called before "ssl_protocol.connection_made()". + transport.pause_reading() + + transport.set_protocol(ssl_protocol) + conmade_cb = self.call_soon(ssl_protocol.connection_made, transport) + resume_cb = self.call_soon(transport.resume_reading) + + try: + await waiter + except BaseException: + transport.close() + conmade_cb.cancel() + resume_cb.cancel() + raise + + return ssl_protocol._app_transport + + async def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0, + reuse_address=_unset, reuse_port=None, + allow_broadcast=None, sock=None): """Create datagram connection.""" if sock is not None: - if not _is_dgram_socket(sock): + if sock.type != socket.SOCK_DGRAM: raise ValueError( - 'A UDP Socket was expected, got {!r}'.format(sock)) + f'A UDP Socket was expected, got {sock!r}') if (local_addr or remote_addr or family or proto or flags or - reuse_address or reuse_port or allow_broadcast): + reuse_port or allow_broadcast): # show the problematic kwargs in exception msg opts = dict(local_addr=local_addr, remote_addr=remote_addr, family=family, proto=proto, flags=flags, reuse_address=reuse_address, reuse_port=reuse_port, allow_broadcast=allow_broadcast) - problems = ', '.join( - '{}={}'.format(k, v) for k, v in opts.items() if v) + problems = ', '.join(f'{k}={v}' for k, v in opts.items() if v) raise ValueError( - 'socket modifier keyword arguments can not be used ' - 'when sock is specified. ({})'.format(problems)) + f'socket modifier keyword arguments can not be used ' + f'when sock is specified. ({problems})') sock.setblocking(False) r_addr = None else: @@ -863,15 +1275,34 @@ def create_datagram_endpoint(self, protocol_factory, if family == 0: raise ValueError('unexpected address family') addr_pairs_info = (((family, proto), (None, None)),) + elif hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX: + for addr in (local_addr, remote_addr): + if addr is not None and not isinstance(addr, str): + raise TypeError('string is expected') + + if local_addr and local_addr[0] not in (0, '\x00'): + try: + if stat.S_ISSOCK(os.stat(local_addr).st_mode): + os.remove(local_addr) + except FileNotFoundError: + pass + except OSError as err: + # Directory may have permissions only to create socket. + logger.error('Unable to check or remove stale UNIX ' + 'socket %r: %r', + local_addr, err) + + addr_pairs_info = (((family, proto), + (local_addr, remote_addr)), ) else: # join address by (family, protocol) - addr_infos = collections.OrderedDict() + addr_infos = {} # Using order preserving dict for idx, addr in ((0, local_addr), (1, remote_addr)): if addr is not None: - assert isinstance(addr, tuple) and len(addr) == 2, ( - '2-tuple is expected') + if not (isinstance(addr, tuple) and len(addr) == 2): + raise TypeError('2-tuple is expected') - infos = yield from _ensure_resolved( + infos = await self._ensure_resolved( addr, family=family, type=socket.SOCK_DGRAM, proto=proto, flags=flags, loop=self) if not infos: @@ -894,8 +1325,18 @@ def create_datagram_endpoint(self, protocol_factory, exceptions = [] - if reuse_address is None: - reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + # bpo-37228 + if reuse_address is not _unset: + if reuse_address: + raise ValueError("Passing `reuse_address=True` is no " + "longer supported, as the usage of " + "SO_REUSEPORT in UDP poses a significant " + "security concern.") + else: + warnings.warn("The *reuse_address* parameter has been " + "deprecated as of 3.5.10 and is scheduled " + "for removal in 3.11.", DeprecationWarning, + stacklevel=2) for ((family, proto), (local_address, remote_address)) in addr_pairs_info: @@ -904,9 +1345,6 @@ def create_datagram_endpoint(self, protocol_factory, try: sock = socket.socket( family=family, type=socket.SOCK_DGRAM, proto=proto) - if reuse_address: - sock.setsockopt( - socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if reuse_port: _set_reuseport(sock) if allow_broadcast: @@ -917,7 +1355,8 @@ def create_datagram_endpoint(self, protocol_factory, if local_addr: sock.bind(local_address) if remote_addr: - yield from self.sock_connect(sock, remote_address) + if not allow_broadcast: + await self.sock_connect(sock, remote_address) r_addr = remote_address except OSError as exc: if sock is not None: @@ -947,36 +1386,49 @@ def create_datagram_endpoint(self, protocol_factory, remote_addr, transport, protocol) try: - yield from waiter + await waiter except: transport.close() raise return transport, protocol - @coroutine - def _create_server_getaddrinfo(self, host, port, family, flags): - infos = yield from _ensure_resolved((host, port), family=family, + async def _ensure_resolved(self, address, *, + family=0, type=socket.SOCK_STREAM, + proto=0, flags=0, loop): + host, port = address[:2] + info = _ipaddr_info(host, port, family, type, proto, *address[2:]) + if info is not None: + # "host" is already a resolved IP. + return [info] + else: + return await loop.getaddrinfo(host, port, family=family, type=type, + proto=proto, flags=flags) + + async def _create_server_getaddrinfo(self, host, port, family, flags): + infos = await self._ensure_resolved((host, port), family=family, type=socket.SOCK_STREAM, flags=flags, loop=self) if not infos: - raise OSError('getaddrinfo({!r}) returned empty list'.format(host)) + raise OSError(f'getaddrinfo({host!r}) returned empty list') return infos - @coroutine - def create_server(self, protocol_factory, host=None, port=None, - *, - family=socket.AF_UNSPEC, - flags=socket.AI_PASSIVE, - sock=None, - backlog=100, - ssl=None, - reuse_address=None, - reuse_port=None): + async def create_server( + self, protocol_factory, host=None, port=None, + *, + family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, + sock=None, + backlog=100, + ssl=None, + reuse_address=None, + reuse_port=None, + ssl_handshake_timeout=None, + start_serving=True): """Create a TCP server. - The host parameter can be a string, in that case the TCP server is bound - to host and port. + The host parameter can be a string, in that case the TCP server is + bound to host and port. The host parameter can also be a sequence of strings and in that case the TCP server is bound to all hosts of the sequence. If a host @@ -990,19 +1442,26 @@ def create_server(self, protocol_factory, host=None, port=None, """ if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') + + if ssl_handshake_timeout is not None and ssl is None: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + + if sock is not None: + _check_ssl_socket(sock) + if host is not None or port is not None: if sock is not None: raise ValueError( 'host/port and sock can not be specified at the same time') - AF_INET6 = getattr(socket, 'AF_INET6', 0) if reuse_address is None: reuse_address = os.name == 'posix' and sys.platform != 'cygwin' sockets = [] if host == '': hosts = [None] elif (isinstance(host, str) or - not isinstance(host, collections.Iterable)): + not isinstance(host, collections.abc.Iterable)): hosts = [host] else: hosts = host @@ -1010,7 +1469,7 @@ def create_server(self, protocol_factory, host=None, port=None, fs = [self._create_server_getaddrinfo(host, port, family=family, flags=flags) for host in hosts] - infos = yield from tasks.gather(*fs, loop=self) + infos = await tasks._gather(*fs, loop=self) infos = set(itertools.chain.from_iterable(infos)) completed = False @@ -1035,7 +1494,9 @@ def create_server(self, protocol_factory, host=None, port=None, # Disable IPv4/IPv6 dual stack support (enabled by # default on Linux) which makes a single socket # listen on both address families. - if af == AF_INET6 and hasattr(socket, 'IPPROTO_IPV6'): + if (_HAS_IPv6 and + af == socket.AF_INET6 and + hasattr(socket, 'IPPROTO_IPV6')): sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) @@ -1044,7 +1505,7 @@ def create_server(self, protocol_factory, host=None, port=None, except OSError as err: raise OSError(err.errno, 'error while attempting ' 'to bind on address %r: %s' - % (sa, err.strerror.lower())) + % (sa, err.strerror.lower())) from None completed = True finally: if not completed: @@ -1053,22 +1514,29 @@ def create_server(self, protocol_factory, host=None, port=None, else: if sock is None: raise ValueError('Neither host/port nor sock were specified') - if not _is_stream_socket(sock): - raise ValueError( - 'A Stream Socket was expected, got {!r}'.format(sock)) + if sock.type != socket.SOCK_STREAM: + raise ValueError(f'A Stream Socket was expected, got {sock!r}') sockets = [sock] - server = Server(self, sockets) for sock in sockets: - sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl, server, backlog) + + server = Server(self, sockets, protocol_factory, + ssl, backlog, ssl_handshake_timeout) + if start_serving: + server._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0) + if self._debug: logger.info("%r is serving", server) return server - @coroutine - def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): + async def connect_accepted_socket( + self, protocol_factory, sock, + *, ssl=None, + ssl_handshake_timeout=None): """Handle an accepted connection. This is used by servers that accept connections outside of @@ -1077,12 +1545,19 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): This method is a coroutine. When completed, the coroutine returns a (transport, protocol) pair. """ - if not _is_stream_socket(sock): + if sock.type != socket.SOCK_STREAM: + raise ValueError(f'A Stream Socket was expected, got {sock!r}') + + if ssl_handshake_timeout is not None and not ssl: raise ValueError( - 'A Stream Socket was expected, got {!r}'.format(sock)) + 'ssl_handshake_timeout is only meaningful with ssl') + + if sock is not None: + _check_ssl_socket(sock) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, '', server_side=True) + transport, protocol = await self._create_connection_transport( + sock, protocol_factory, ssl, '', server_side=True, + ssl_handshake_timeout=ssl_handshake_timeout) if self._debug: # Get the socket from the transport because SSL transport closes # the old socket and creates a new SSL socket @@ -1090,14 +1565,13 @@ def connect_accepted_socket(self, protocol_factory, sock, *, ssl=None): logger.debug("%r handled: (%r, %r)", sock, transport, protocol) return transport, protocol - @coroutine - def connect_read_pipe(self, protocol_factory, pipe): + async def connect_read_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = self.create_future() transport = self._make_read_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + await waiter except: transport.close() raise @@ -1107,14 +1581,13 @@ def connect_read_pipe(self, protocol_factory, pipe): pipe.fileno(), transport, protocol) return transport, protocol - @coroutine - def connect_write_pipe(self, protocol_factory, pipe): + async def connect_write_pipe(self, protocol_factory, pipe): protocol = protocol_factory() waiter = self.create_future() transport = self._make_write_pipe_transport(pipe, protocol, waiter) try: - yield from waiter + await waiter except: transport.close() raise @@ -1127,21 +1600,24 @@ def connect_write_pipe(self, protocol_factory, pipe): def _log_subprocess(self, msg, stdin, stdout, stderr): info = [msg] if stdin is not None: - info.append('stdin=%s' % _format_pipe(stdin)) + info.append(f'stdin={_format_pipe(stdin)}') if stdout is not None and stderr == subprocess.STDOUT: - info.append('stdout=stderr=%s' % _format_pipe(stdout)) + info.append(f'stdout=stderr={_format_pipe(stdout)}') else: if stdout is not None: - info.append('stdout=%s' % _format_pipe(stdout)) + info.append(f'stdout={_format_pipe(stdout)}') if stderr is not None: - info.append('stderr=%s' % _format_pipe(stderr)) + info.append(f'stderr={_format_pipe(stderr)}') logger.debug(' '.join(info)) - @coroutine - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=False, shell=True, bufsize=0, - **kwargs): + async def subprocess_shell(self, protocol_factory, cmd, *, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=False, + shell=True, bufsize=0, + encoding=None, errors=None, text=None, + **kwargs): if not isinstance(cmd, (bytes, str)): raise ValueError("cmd must be a string") if universal_newlines: @@ -1150,45 +1626,57 @@ def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, raise ValueError("shell must be True") if bufsize != 0: raise ValueError("bufsize must be 0") + if text: + raise ValueError("text must be False") + if encoding is not None: + raise ValueError("encoding must be None") + if errors is not None: + raise ValueError("errors must be None") + protocol = protocol_factory() + debug_log = None if self._debug: # don't log parameters: they may contain sensitive information # (password) and may be too long debug_log = 'run shell command %r' % cmd self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( + transport = await self._make_subprocess_transport( protocol, cmd, True, stdin, stdout, stderr, bufsize, **kwargs) - if self._debug: + if self._debug and debug_log is not None: logger.info('%s: %r', debug_log, transport) return transport, protocol - @coroutine - def subprocess_exec(self, protocol_factory, program, *args, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, universal_newlines=False, - shell=False, bufsize=0, **kwargs): + async def subprocess_exec(self, protocol_factory, program, *args, + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, universal_newlines=False, + shell=False, bufsize=0, + encoding=None, errors=None, text=None, + **kwargs): if universal_newlines: raise ValueError("universal_newlines must be False") if shell: raise ValueError("shell must be False") if bufsize != 0: raise ValueError("bufsize must be 0") + if text: + raise ValueError("text must be False") + if encoding is not None: + raise ValueError("encoding must be None") + if errors is not None: + raise ValueError("errors must be None") + popen_args = (program,) + args - for arg in popen_args: - if not isinstance(arg, (str, bytes)): - raise TypeError("program arguments must be " - "a bytes or text string, not %s" - % type(arg).__name__) protocol = protocol_factory() + debug_log = None if self._debug: # don't log parameters: they may contain sensitive information # (password) and may be too long - debug_log = 'execute program %r' % program + debug_log = f'execute program {program!r}' self._log_subprocess(debug_log, stdin, stdout, stderr) - transport = yield from self._make_subprocess_transport( + transport = await self._make_subprocess_transport( protocol, popen_args, False, stdin, stdout, stderr, bufsize, **kwargs) - if self._debug: + if self._debug and debug_log is not None: logger.info('%s: %r', debug_log, transport) return transport, protocol @@ -1210,8 +1698,8 @@ def set_exception_handler(self, handler): documentation for details about context). """ if handler is not None and not callable(handler): - raise TypeError('A callable object or None is expected, ' - 'got {!r}'.format(handler)) + raise TypeError(f'A callable object or None is expected, ' + f'got {handler!r}') self._exception_handler = handler def default_exception_handler(self, context): @@ -1221,6 +1709,11 @@ def default_exception_handler(self, context): handler is set, and can be called by a custom exception handler that wants to defer to the default behavior. + This default handler logs the error message and other + context-dependent information. In debug mode, a truncated + stack trace is also appended showing where the given object + (e.g. a handle or future or task) was created, if any. + The context parameter has the same meaning as in `call_exception_handler()`. """ @@ -1234,10 +1727,11 @@ def default_exception_handler(self, context): else: exc_info = False - if ('source_traceback' not in context - and self._current_handle is not None - and self._current_handle._source_traceback): - context['handle_traceback'] = self._current_handle._source_traceback + if ('source_traceback' not in context and + self._current_handle is not None and + self._current_handle._source_traceback): + context['handle_traceback'] = \ + self._current_handle._source_traceback log_lines = [message] for key in sorted(context): @@ -1254,7 +1748,7 @@ def default_exception_handler(self, context): value += tb.rstrip() else: value = repr(value) - log_lines.append('{}: {}'.format(key, value)) + log_lines.append(f'{key}: {value}') logger.error('\n'.join(log_lines), exc_info=exc_info) @@ -1266,6 +1760,7 @@ def call_exception_handler(self, context): - 'message': Error message; - 'exception' (optional): Exception object; - 'future' (optional): Future instance; + - 'task' (optional): Task instance; - 'handle' (optional): Handle instance; - 'protocol' (optional): Protocol instance; - 'transport' (optional): Transport instance; @@ -1282,7 +1777,9 @@ def call_exception_handler(self, context): if self._exception_handler is None: try: self.default_exception_handler(context) - except Exception: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: # Second protection layer for unexpected errors # in the default implementation, as well as for subclassed # event loops with overloaded "default_exception_handler". @@ -1291,7 +1788,9 @@ def call_exception_handler(self, context): else: try: self._exception_handler(self, context) - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: # Exception in the user set custom exception handler. try: # Let's try default handler. @@ -1300,7 +1799,9 @@ def call_exception_handler(self, context): 'exception': exc, 'context': context, }) - except Exception: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: # Guard 'default_exception_handler' in case it is # overloaded. logger.error('Exception in default exception handler ' @@ -1363,30 +1864,9 @@ def _run_once(self): elif self._scheduled: # Compute the desired timeout. when = self._scheduled[0]._when - timeout = max(0, when - self.time()) - - if self._debug and timeout != 0: - t0 = self.time() - event_list = self._selector.select(timeout) - dt = self.time() - t0 - if dt >= 1.0: - level = logging.INFO - else: - level = logging.DEBUG - nevent = len(event_list) - if timeout is None: - logger.log(level, 'poll took %.3f ms: %s events', - dt * 1e3, nevent) - elif nevent: - logger.log(level, - 'poll %.3f ms took %.3f ms: %s events', - timeout * 1e3, dt * 1e3, nevent) - elif dt >= 1.0: - logger.log(level, - 'poll %.3f ms took %.3f ms: timeout', - timeout * 1e3, dt * 1e3) - else: - event_list = self._selector.select(timeout) + timeout = min(max(0, when - self.time()), MAXIMUM_SELECT_TIMEOUT) + + event_list = self._selector.select(timeout) self._process_events(event_list) # Handle 'later' callbacks that are ready. @@ -1425,38 +1905,20 @@ def _run_once(self): handle._run() handle = None # Needed to break cycles when an exception occurs. - def _set_coroutine_wrapper(self, enabled): - try: - set_wrapper = sys.set_coroutine_wrapper - get_wrapper = sys.get_coroutine_wrapper - except AttributeError: + def _set_coroutine_origin_tracking(self, enabled): + if bool(enabled) == bool(self._coroutine_origin_tracking_enabled): return - enabled = bool(enabled) - if self._coroutine_wrapper_set == enabled: - return - - wrapper = coroutines.debug_wrapper - current_wrapper = get_wrapper() - if enabled: - if current_wrapper not in (None, wrapper): - warnings.warn( - "loop.set_debug(True): cannot set debug coroutine " - "wrapper; another wrapper is already set %r" % - current_wrapper, RuntimeWarning) - else: - set_wrapper(wrapper) - self._coroutine_wrapper_set = True + self._coroutine_origin_tracking_saved_depth = ( + sys.get_coroutine_origin_tracking_depth()) + sys.set_coroutine_origin_tracking_depth( + constants.DEBUG_STACK_DEPTH) else: - if current_wrapper not in (None, wrapper): - warnings.warn( - "loop.set_debug(False): cannot unset debug coroutine " - "wrapper; another wrapper was set %r" % - current_wrapper, RuntimeWarning) - else: - set_wrapper(None) - self._coroutine_wrapper_set = False + sys.set_coroutine_origin_tracking_depth( + self._coroutine_origin_tracking_saved_depth) + + self._coroutine_origin_tracking_enabled = enabled def get_debug(self): return self._debug @@ -1465,4 +1927,4 @@ def set_debug(self, enabled): self._debug = enabled if self.is_running(): - self._set_coroutine_wrapper(enabled) + self.call_soon_threadsafe(self._set_coroutine_origin_tracking, enabled) diff --git a/Lib/asyncio/base_futures.py b/Lib/asyncio/base_futures.py index 01259a062e..2c01ac98e1 100644 --- a/Lib/asyncio/base_futures.py +++ b/Lib/asyncio/base_futures.py @@ -1,18 +1,9 @@ -__all__ = [] +__all__ = () -import concurrent.futures._base import reprlib +from _thread import get_ident -from . import events - -Error = concurrent.futures._base.Error -CancelledError = concurrent.futures.CancelledError -TimeoutError = concurrent.futures.TimeoutError - - -class InvalidStateError(Error): - """The operation is not allowed in this state.""" - +from . import format_helpers # States for Future. _PENDING = 'PENDING' @@ -38,17 +29,27 @@ def _format_callbacks(cb): cb = '' def format_cb(callback): - return events._format_callback_source(callback, ()) + return format_helpers._format_callback_source(callback, ()) if size == 1: - cb = format_cb(cb[0]) + cb = format_cb(cb[0][0]) elif size == 2: - cb = '{}, {}'.format(format_cb(cb[0]), format_cb(cb[1])) + cb = '{}, {}'.format(format_cb(cb[0][0]), format_cb(cb[1][0])) elif size > 2: - cb = '{}, <{} more>, {}'.format(format_cb(cb[0]), + cb = '{}, <{} more>, {}'.format(format_cb(cb[0][0]), size - 2, - format_cb(cb[-1])) - return 'cb=[%s]' % cb + format_cb(cb[-1][0])) + return f'cb=[{cb}]' + + +# bpo-42183: _repr_running is needed for repr protection +# when a Future or Task result contains itself directly or indirectly. +# The logic is borrowed from @reprlib.recursive_repr decorator. +# Unfortunately, the direct decorator usage is impossible because of +# AttributeError: '_asyncio.Task' object has no attribute '__module__' error. +# +# After fixing this thing we can return to the decorator based approach. +_repr_running = set() def _future_repr_info(future): @@ -57,15 +58,23 @@ def _future_repr_info(future): info = [future._state.lower()] if future._state == _FINISHED: if future._exception is not None: - info.append('exception={!r}'.format(future._exception)) + info.append(f'exception={future._exception!r}') else: - # use reprlib to limit the length of the output, especially - # for very long strings - result = reprlib.repr(future._result) - info.append('result={}'.format(result)) + key = id(future), get_ident() + if key in _repr_running: + result = '...' + else: + _repr_running.add(key) + try: + # use reprlib to limit the length of the output, especially + # for very long strings + result = reprlib.repr(future._result) + finally: + _repr_running.discard(key) + info.append(f'result={result}') if future._callbacks: info.append(_format_callbacks(future._callbacks)) if future._source_traceback: frame = future._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) + info.append(f'created at {frame[0]}:{frame[1]}') return info diff --git a/Lib/asyncio/base_subprocess.py b/Lib/asyncio/base_subprocess.py index a00d9d5732..14d5051922 100644 --- a/Lib/asyncio/base_subprocess.py +++ b/Lib/asyncio/base_subprocess.py @@ -2,10 +2,8 @@ import subprocess import warnings -from . import compat from . import protocols from . import transports -from .coroutines import coroutine from .log import logger @@ -59,9 +57,9 @@ def __repr__(self): if self._closed: info.append('closed') if self._pid is not None: - info.append('pid=%s' % self._pid) + info.append(f'pid={self._pid}') if self._returncode is not None: - info.append('returncode=%s' % self._returncode) + info.append(f'returncode={self._returncode}') elif self._pid is not None: info.append('running') else: @@ -69,19 +67,19 @@ def __repr__(self): stdin = self._pipes.get(0) if stdin is not None: - info.append('stdin=%s' % stdin.pipe) + info.append(f'stdin={stdin.pipe}') stdout = self._pipes.get(1) stderr = self._pipes.get(2) if stdout is not None and stderr is stdout: - info.append('stdout=stderr=%s' % stdout.pipe) + info.append(f'stdout=stderr={stdout.pipe}') else: if stdout is not None: - info.append('stdout=%s' % stdout.pipe) + info.append(f'stdout={stdout.pipe}') if stderr is not None: - info.append('stderr=%s' % stderr.pipe) + info.append(f'stderr={stderr.pipe}') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): raise NotImplementedError @@ -105,12 +103,13 @@ def close(self): continue proto.pipe.close() - if (self._proc is not None - # the child process finished? - and self._returncode is None - # the child process finished but the transport was not notified yet? - and self._proc.poll() is None - ): + if (self._proc is not None and + # has the child process finished? + self._returncode is None and + # the child process has finished, but the + # transport hasn't been notified yet? + self._proc.poll() is None): + if self._loop.get_debug(): logger.warning('Close running child process: kill %r', self) @@ -121,15 +120,10 @@ def close(self): # Don't clear the _proc reference yet: _post_init() may still run - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self._closed: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self.close() + def __del__(self, _warn=warnings.warn): + if not self._closed: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self.close() def get_pid(self): return self._pid @@ -159,26 +153,25 @@ def kill(self): self._check_proc() self._proc.kill() - @coroutine - def _connect_pipes(self, waiter): + async def _connect_pipes(self, waiter): try: proc = self._proc loop = self._loop if proc.stdin is not None: - _, pipe = yield from loop.connect_write_pipe( + _, pipe = await loop.connect_write_pipe( lambda: WriteSubprocessPipeProto(self, 0), proc.stdin) self._pipes[0] = pipe if proc.stdout is not None: - _, pipe = yield from loop.connect_read_pipe( + _, pipe = await loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, 1), proc.stdout) self._pipes[1] = pipe if proc.stderr is not None: - _, pipe = yield from loop.connect_read_pipe( + _, pipe = await loop.connect_read_pipe( lambda: ReadSubprocessPipeProto(self, 2), proc.stderr) self._pipes[2] = pipe @@ -189,7 +182,9 @@ def _connect_pipes(self, waiter): for callback, data in self._pending_calls: loop.call_soon(callback, *data) self._pending_calls = None - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if waiter is not None and not waiter.cancelled(): waiter.set_exception(exc) else: @@ -213,8 +208,7 @@ def _process_exited(self, returncode): assert returncode is not None, returncode assert self._returncode is None, self._returncode if self._loop.get_debug(): - logger.info('%r exited with return code %r', - self, returncode) + logger.info('%r exited with return code %r', self, returncode) self._returncode = returncode if self._proc.returncode is None: # asyncio uses a child watcher: copy the status into the Popen @@ -229,8 +223,7 @@ def _process_exited(self, returncode): waiter.set_result(returncode) self._exit_waiters = None - @coroutine - def _wait(self): + async def _wait(self): """Wait until the process exit and return the process return code. This method is a coroutine.""" @@ -239,7 +232,7 @@ def _wait(self): waiter = self._loop.create_future() self._exit_waiters.append(waiter) - return (yield from waiter) + return await waiter def _try_finish(self): assert not self._finished @@ -271,8 +264,7 @@ def connection_made(self, transport): self.pipe = transport def __repr__(self): - return ('<%s fd=%s pipe=%r>' - % (self.__class__.__name__, self.fd, self.pipe)) + return f'<{self.__class__.__name__} fd={self.fd} pipe={self.pipe!r}>' def connection_lost(self, exc): self.disconnected = True diff --git a/Lib/asyncio/base_tasks.py b/Lib/asyncio/base_tasks.py index 5f34434c57..09bb171a2c 100644 --- a/Lib/asyncio/base_tasks.py +++ b/Lib/asyncio/base_tasks.py @@ -12,21 +12,30 @@ def _task_repr_info(task): # replace status info[0] = 'cancelling' + info.insert(1, 'name=%r' % task.get_name()) + coro = coroutines._format_coroutine(task._coro) - info.insert(1, 'coro=<%s>' % coro) + info.insert(2, f'coro=<{coro}>') if task._fut_waiter is not None: - info.insert(2, 'wait_for=%r' % task._fut_waiter) + info.insert(3, f'wait_for={task._fut_waiter!r}') return info def _task_get_stack(task, limit): frames = [] - try: - # 'async def' coroutines + if hasattr(task._coro, 'cr_frame'): + # case 1: 'async def' coroutines f = task._coro.cr_frame - except AttributeError: + elif hasattr(task._coro, 'gi_frame'): + # case 2: legacy coroutines f = task._coro.gi_frame + elif hasattr(task._coro, 'ag_frame'): + # case 3: async generators + f = task._coro.ag_frame + else: + # case 4: unknown objects + f = None if f is not None: while f is not None: if limit is not None: @@ -61,15 +70,15 @@ def _task_print_stack(task, limit, file): linecache.checkcache(filename) line = linecache.getline(filename, lineno, f.f_globals) extracted_list.append((filename, lineno, name, line)) + exc = task._exception if not extracted_list: - print('No stack for %r' % task, file=file) + print(f'No stack for {task!r}', file=file) elif exc is not None: - print('Traceback for %r (most recent call last):' % task, - file=file) + print(f'Traceback for {task!r} (most recent call last):', file=file) else: - print('Stack for %r (most recent call last):' % task, - file=file) + print(f'Stack for {task!r} (most recent call last):', file=file) + traceback.print_list(extracted_list, file=file) if exc is not None: for line in traceback.format_exception_only(exc.__class__, exc): diff --git a/Lib/asyncio/compat.py b/Lib/asyncio/compat.py deleted file mode 100644 index 4790bb4a35..0000000000 --- a/Lib/asyncio/compat.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Compatibility helpers for the different Python versions.""" - -import sys - -PY34 = sys.version_info >= (3, 4) -PY35 = sys.version_info >= (3, 5) -PY352 = sys.version_info >= (3, 5, 2) - - -def flatten_list_bytes(list_of_data): - """Concatenate a sequence of bytes-like objects.""" - if not PY34: - # On Python 3.3 and older, bytes.join() doesn't handle - # memoryview. - list_of_data = ( - bytes(data) if isinstance(data, memoryview) else data - for data in list_of_data) - return b''.join(list_of_data) diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py index f9e123281e..33feed60e5 100644 --- a/Lib/asyncio/constants.py +++ b/Lib/asyncio/constants.py @@ -1,7 +1,27 @@ -"""Constants.""" +import enum # After the connection is lost, log warnings after this many write()s. LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 # Seconds to wait before retrying accept(). ACCEPT_RETRY_DELAY = 1 + +# Number of stack entries to capture in debug mode. +# The larger the number, the slower the operation in debug mode +# (see extract_stack() in format_helpers.py). +DEBUG_STACK_DEPTH = 10 + +# Number of seconds to wait for SSL handshake to complete +# The default timeout matches that of Nginx. +SSL_HANDSHAKE_TIMEOUT = 60.0 + +# Used in sendfile fallback code. We use fallback for platforms +# that don't support sendfile, or for TLS connections. +SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 256 + +# The enum should be here to break circular dependencies between +# base_events and sslproto +class _SendfileMode(enum.Enum): + UNSUPPORTED = enum.auto() + TRY_NATIVE = enum.auto() + FALLBACK = enum.auto() diff --git a/Lib/asyncio/coroutines.py b/Lib/asyncio/coroutines.py index 08e94412b3..9664ea74d7 100644 --- a/Lib/asyncio/coroutines.py +++ b/Lib/asyncio/coroutines.py @@ -1,87 +1,36 @@ -__all__ = ['coroutine', - 'iscoroutinefunction', 'iscoroutine'] +__all__ = 'coroutine', 'iscoroutinefunction', 'iscoroutine' +import collections.abc import functools import inspect -import opcode import os import sys import traceback import types +import warnings -from . import compat -from . import events from . import base_futures +from . import constants +from . import format_helpers from .log import logger -# Opcode of "yield from" instruction -_YIELD_FROM = opcode.opmap['YIELD_FROM'] - -# If you set _DEBUG to true, @coroutine will wrap the resulting -# generator objects in a CoroWrapper instance (defined below). That -# instance will log a message when the generator is never iterated -# over, which may happen when you forget to use "yield from" with a -# coroutine call. Note that the value of the _DEBUG flag is taken -# when the decorator is used, so to be of any use it must be set -# before you define your coroutines. A downside of using this feature -# is that tracebacks show entries for the CoroWrapper.__next__ method -# when _DEBUG is true. -_DEBUG = (not sys.flags.ignore_environment and - bool(os.environ.get('PYTHONASYNCIODEBUG'))) - - -try: - _types_coroutine = types.coroutine - _types_CoroutineType = types.CoroutineType -except AttributeError: - # Python 3.4 - _types_coroutine = None - _types_CoroutineType = None - -try: - _inspect_iscoroutinefunction = inspect.iscoroutinefunction -except AttributeError: - # Python 3.4 - _inspect_iscoroutinefunction = lambda func: False - -try: - from collections.abc import Coroutine as _CoroutineABC, \ - Awaitable as _AwaitableABC -except ImportError: - _CoroutineABC = _AwaitableABC = None - - -# Check for CPython issue #21209 -def has_yield_from_bug(): - class MyGen: - def __init__(self): - self.send_args = None - def __iter__(self): - return self - def __next__(self): - return 42 - def send(self, *what): - self.send_args = what - return None - def yield_from_gen(gen): - yield from gen - value = (1, 2, 3) - gen = MyGen() - coro = yield_from_gen(gen) - next(coro) - coro.send(value) - return gen.send_args != (value,) -_YIELD_FROM_BUG = has_yield_from_bug() -del has_yield_from_bug - - -def debug_wrapper(gen): - # This function is called from 'sys.set_coroutine_wrapper'. - # We only wrap here coroutines defined via 'async def' syntax. - # Generator-based coroutines are wrapped in @coroutine - # decorator. - return CoroWrapper(gen, None) +def _is_debug_mode(): + # If you set _DEBUG to true, @coroutine will wrap the resulting + # generator objects in a CoroWrapper instance (defined below). That + # instance will log a message when the generator is never iterated + # over, which may happen when you forget to use "await" or "yield from" + # with a coroutine call. + # Note that the value of the _DEBUG flag is taken + # when the decorator is used, so to be of any use it must be set + # before you define your coroutines. A downside of using this feature + # is that tracebacks show entries for the CoroWrapper.__next__ method + # when _DEBUG is true. + return sys.flags.dev_mode or (not sys.flags.ignore_environment and + bool(os.environ.get('PYTHONASYNCIODEBUG'))) + + +_DEBUG = _is_debug_mode() class CoroWrapper: @@ -91,7 +40,7 @@ def __init__(self, gen, func=None): assert inspect.isgenerator(gen) or inspect.iscoroutine(gen), gen self.gen = gen self.func = func # Used to unwrap @coroutine decorator - self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self._source_traceback = format_helpers.extract_stack(sys._getframe(1)) self.__name__ = getattr(gen, '__name__', None) self.__qualname__ = getattr(gen, '__qualname__', None) @@ -99,8 +48,9 @@ def __repr__(self): coro_repr = _format_coroutine(self) if self._source_traceback: frame = self._source_traceback[-1] - coro_repr += ', created at %s:%s' % (frame[0], frame[1]) - return '<%s %s>' % (self.__class__.__name__, coro_repr) + coro_repr += f', created at {frame[0]}:{frame[1]}' + + return f'<{self.__class__.__name__} {coro_repr}>' def __iter__(self): return self @@ -108,21 +58,8 @@ def __iter__(self): def __next__(self): return self.gen.send(None) - if _YIELD_FROM_BUG: - # For for CPython issue #21209: using "yield from" and a custom - # generator, generator.send(tuple) unpacks the tuple instead of passing - # the tuple unchanged. Check if the caller is a generator using "yield - # from" to decide if the parameter should be unpacked or not. - def send(self, *value): - frame = sys._getframe() - caller = frame.f_back - assert caller.f_lasti >= 0 - if caller.f_code.co_code[caller.f_lasti] != _YIELD_FROM: - value = value[0] - return self.gen.send(value) - else: - def send(self, value): - return self.gen.send(value) + def send(self, value): + return self.gen.send(value) def throw(self, type, value=None, traceback=None): return self.gen.throw(type, value, traceback) @@ -142,49 +79,25 @@ def gi_running(self): def gi_code(self): return self.gen.gi_code - if compat.PY35: - - def __await__(self): - cr_await = getattr(self.gen, 'cr_await', None) - if cr_await is not None: - raise RuntimeError( - "Cannot await on coroutine {!r} while it's " - "awaiting for {!r}".format(self.gen, cr_await)) - return self - - @property - def gi_yieldfrom(self): - return self.gen.gi_yieldfrom - - @property - def cr_await(self): - return self.gen.cr_await - - @property - def cr_running(self): - return self.gen.cr_running - - @property - def cr_code(self): - return self.gen.cr_code + def __await__(self): + return self - @property - def cr_frame(self): - return self.gen.cr_frame + @property + def gi_yieldfrom(self): + return self.gen.gi_yieldfrom def __del__(self): # Be careful accessing self.gen.frame -- self.gen might not exist. gen = getattr(self, 'gen', None) frame = getattr(gen, 'gi_frame', None) - if frame is None: - frame = getattr(gen, 'cr_frame', None) if frame is not None and frame.f_lasti == -1: - msg = '%r was never yielded from' % self + msg = f'{self!r} was never yielded from' tb = getattr(self, '_source_traceback', ()) if tb: tb = ''.join(traceback.format_list(tb)) - msg += ('\nCoroutine object created at ' - '(most recent call last):\n') + msg += (f'\nCoroutine object created at ' + f'(most recent call last, truncated to ' + f'{constants.DEBUG_STACK_DEPTH} last lines):\n') msg += tb.rstrip() logger.error(msg) @@ -195,11 +108,12 @@ def coroutine(func): If the coroutine is not yielded from before it is destroyed, an error message is logged. """ - if _inspect_iscoroutinefunction(func): + warnings.warn('"@coroutine" decorator is deprecated since Python 3.8, use "async def" instead', + DeprecationWarning, + stacklevel=2) + if inspect.iscoroutinefunction(func): # In Python 3.5 that's all we need to do for coroutines - # defiend with "async def". - # Wrapping in CoroWrapper will happen via - # 'sys.set_coroutine_wrapper' function. + # defined with "async def". return func if inspect.isgeneratorfunction(func): @@ -209,25 +123,22 @@ def coroutine(func): def coro(*args, **kw): res = func(*args, **kw) if (base_futures.isfuture(res) or inspect.isgenerator(res) or - isinstance(res, CoroWrapper)): + isinstance(res, CoroWrapper)): res = yield from res - elif _AwaitableABC is not None: - # If 'func' returns an Awaitable (new in 3.5) we - # want to run it. + else: + # If 'res' is an awaitable, run it. try: await_meth = res.__await__ except AttributeError: pass else: - if isinstance(res, _AwaitableABC): + if isinstance(res, collections.abc.Awaitable): res = yield from await_meth() return res + coro = types.coroutine(coro) if not _DEBUG: - if _types_coroutine is None: - wrapper = coro - else: - wrapper = _types_coroutine(coro) + wrapper = coro else: @functools.wraps(func) def wrapper(*args, **kwds): @@ -252,93 +163,107 @@ def wrapper(*args, **kwds): def iscoroutinefunction(func): """Return True if func is a decorated coroutine function.""" - return (getattr(func, '_is_coroutine', None) is _is_coroutine or - _inspect_iscoroutinefunction(func)) + return (inspect.iscoroutinefunction(func) or + getattr(func, '_is_coroutine', None) is _is_coroutine) -_COROUTINE_TYPES = (types.GeneratorType, CoroWrapper) -if _CoroutineABC is not None: - _COROUTINE_TYPES += (_CoroutineABC,) -if _types_CoroutineType is not None: - # Prioritize native coroutine check to speed-up - # asyncio.iscoroutine. - _COROUTINE_TYPES = (_types_CoroutineType,) + _COROUTINE_TYPES +# Prioritize native coroutine check to speed-up +# asyncio.iscoroutine. +_COROUTINE_TYPES = (types.CoroutineType, types.GeneratorType, + collections.abc.Coroutine, CoroWrapper) +_iscoroutine_typecache = set() def iscoroutine(obj): """Return True if obj is a coroutine object.""" - return isinstance(obj, _COROUTINE_TYPES) + if type(obj) in _iscoroutine_typecache: + return True + + if isinstance(obj, _COROUTINE_TYPES): + # Just in case we don't want to cache more than 100 + # positive types. That shouldn't ever happen, unless + # someone stressing the system on purpose. + if len(_iscoroutine_typecache) < 100: + _iscoroutine_typecache.add(type(obj)) + return True + else: + return False def _format_coroutine(coro): assert iscoroutine(coro) - if not hasattr(coro, 'cr_code') and not hasattr(coro, 'gi_code'): - # Most likely a built-in type or a Cython coroutine. + is_corowrapper = isinstance(coro, CoroWrapper) - # Built-in types might not have __qualname__ or __name__. - coro_name = getattr( - coro, '__qualname__', - getattr(coro, '__name__', type(coro).__name__)) - coro_name = '{}()'.format(coro_name) + def get_name(coro): + # Coroutines compiled with Cython sometimes don't have + # proper __qualname__ or __name__. While that is a bug + # in Cython, asyncio shouldn't crash with an AttributeError + # in its __repr__ functions. + if is_corowrapper: + return format_helpers._format_callback(coro.func, (), {}) + + if hasattr(coro, '__qualname__') and coro.__qualname__: + coro_name = coro.__qualname__ + elif hasattr(coro, '__name__') and coro.__name__: + coro_name = coro.__name__ + else: + # Stop masking Cython bugs, expose them in a friendly way. + coro_name = f'<{type(coro).__name__} without __name__>' + return f'{coro_name}()' - running = False + def is_running(coro): try: - running = coro.cr_running + return coro.cr_running except AttributeError: try: - running = coro.gi_running + return coro.gi_running except AttributeError: - pass - - if running: - return '{} running'.format(coro_name) - else: - return coro_name + return False - coro_name = None - if isinstance(coro, CoroWrapper): - func = coro.func - coro_name = coro.__qualname__ - if coro_name is not None: - coro_name = '{}()'.format(coro_name) - else: - func = coro + coro_code = None + if hasattr(coro, 'cr_code') and coro.cr_code: + coro_code = coro.cr_code + elif hasattr(coro, 'gi_code') and coro.gi_code: + coro_code = coro.gi_code - if coro_name is None: - coro_name = events._format_callback(func, (), {}) + coro_name = get_name(coro) - try: - coro_code = coro.gi_code - except AttributeError: - coro_code = coro.cr_code + if not coro_code: + # Built-in types might not have __qualname__ or __name__. + if is_running(coro): + return f'{coro_name} running' + else: + return coro_name - try: + coro_frame = None + if hasattr(coro, 'gi_frame') and coro.gi_frame: coro_frame = coro.gi_frame - except AttributeError: + elif hasattr(coro, 'cr_frame') and coro.cr_frame: coro_frame = coro.cr_frame - filename = coro_code.co_filename + # If Cython's coroutine has a fake code object without proper + # co_filename -- expose that. + filename = coro_code.co_filename or '' + lineno = 0 - if (isinstance(coro, CoroWrapper) and - not inspect.isgeneratorfunction(coro.func) and - coro.func is not None): - source = events._get_function_source(coro.func) + if (is_corowrapper and + coro.func is not None and + not inspect.isgeneratorfunction(coro.func)): + source = format_helpers._get_function_source(coro.func) if source is not None: filename, lineno = source if coro_frame is None: - coro_repr = ('%s done, defined at %s:%s' - % (coro_name, filename, lineno)) + coro_repr = f'{coro_name} done, defined at {filename}:{lineno}' else: - coro_repr = ('%s running, defined at %s:%s' - % (coro_name, filename, lineno)) + coro_repr = f'{coro_name} running, defined at {filename}:{lineno}' + elif coro_frame is not None: lineno = coro_frame.f_lineno - coro_repr = ('%s running at %s:%s' - % (coro_name, filename, lineno)) + coro_repr = f'{coro_name} running at {filename}:{lineno}' + else: lineno = coro_code.co_firstlineno - coro_repr = ('%s done, defined at %s:%s' - % (coro_name, filename, lineno)) + coro_repr = f'{coro_name} done, defined at {filename}:{lineno}' return coro_repr diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 466db6d9a3..413ff2aaa6 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -1,96 +1,45 @@ """Event loop and event loop policy.""" -__all__ = ['AbstractEventLoopPolicy', - 'AbstractEventLoop', 'AbstractServer', - 'Handle', 'TimerHandle', - 'get_event_loop_policy', 'set_event_loop_policy', - 'get_event_loop', 'set_event_loop', 'new_event_loop', - 'get_child_watcher', 'set_child_watcher', - '_set_running_loop', 'get_running_loop', - '_get_running_loop', - ] - -import functools -import inspect -import reprlib +__all__ = ( + 'AbstractEventLoopPolicy', + 'AbstractEventLoop', 'AbstractServer', + 'Handle', 'TimerHandle', + 'get_event_loop_policy', 'set_event_loop_policy', + 'get_event_loop', 'set_event_loop', 'new_event_loop', + 'get_child_watcher', 'set_child_watcher', + '_set_running_loop', 'get_running_loop', + '_get_running_loop', +) + +import contextvars +import os import socket import subprocess import sys import threading -import traceback -from asyncio import compat - - -def _get_function_source(func): - if compat.PY34: - func = inspect.unwrap(func) - elif hasattr(func, '__wrapped__'): - func = func.__wrapped__ - if inspect.isfunction(func): - code = func.__code__ - return (code.co_filename, code.co_firstlineno) - if isinstance(func, functools.partial): - return _get_function_source(func.func) - if compat.PY34 and isinstance(func, functools.partialmethod): - return _get_function_source(func.func) - return None - - -def _format_args_and_kwargs(args, kwargs): - """Format function arguments and keyword arguments. - - Special case for a single parameter: ('hello',) is formatted as ('hello'). - """ - # use reprlib to limit the length of the output - items = [] - if args: - items.extend(reprlib.repr(arg) for arg in args) - if kwargs: - items.extend('{}={}'.format(k, reprlib.repr(v)) - for k, v in kwargs.items()) - return '(' + ', '.join(items) + ')' - - -def _format_callback(func, args, kwargs, suffix=''): - if isinstance(func, functools.partial): - suffix = _format_args_and_kwargs(args, kwargs) + suffix - return _format_callback(func.func, func.args, func.keywords, suffix) - - if hasattr(func, '__qualname__'): - func_repr = getattr(func, '__qualname__') - elif hasattr(func, '__name__'): - func_repr = getattr(func, '__name__') - else: - func_repr = repr(func) - - func_repr += _format_args_and_kwargs(args, kwargs) - if suffix: - func_repr += suffix - return func_repr - -def _format_callback_source(func, args): - func_repr = _format_callback(func, args, None) - source = _get_function_source(func) - if source: - func_repr += ' at %s:%s' % source - return func_repr +from . import format_helpers class Handle: """Object returned by callback registration methods.""" __slots__ = ('_callback', '_args', '_cancelled', '_loop', - '_source_traceback', '_repr', '__weakref__') + '_source_traceback', '_repr', '__weakref__', + '_context') - def __init__(self, callback, args, loop): + def __init__(self, callback, args, loop, context=None): + if context is None: + context = contextvars.copy_context() + self._context = context self._loop = loop self._callback = callback self._args = args self._cancelled = False self._repr = None if self._loop.get_debug(): - self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) else: self._source_traceback = None @@ -99,17 +48,18 @@ def _repr_info(self): if self._cancelled: info.append('cancelled') if self._callback is not None: - info.append(_format_callback_source(self._callback, self._args)) + info.append(format_helpers._format_callback_source( + self._callback, self._args)) if self._source_traceback: frame = self._source_traceback[-1] - info.append('created at %s:%s' % (frame[0], frame[1])) + info.append(f'created at {frame[0]}:{frame[1]}') return info def __repr__(self): if self._repr is not None: return self._repr info = self._repr_info() - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def cancel(self): if not self._cancelled: @@ -122,12 +72,18 @@ def cancel(self): self._callback = None self._args = None + def cancelled(self): + return self._cancelled + def _run(self): try: - self._callback(*self._args) - except Exception as exc: - cb = _format_callback_source(self._callback, self._args) - msg = 'Exception in callback {}'.format(cb) + self._context.run(self._callback, *self._args) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + cb = format_helpers._format_callback_source( + self._callback, self._args) + msg = f'Exception in callback {cb}' context = { 'message': msg, 'exception': exc, @@ -144,9 +100,9 @@ class TimerHandle(Handle): __slots__ = ['_scheduled', '_when'] - def __init__(self, when, callback, args, loop): + def __init__(self, when, callback, args, loop, context=None): assert when is not None - super().__init__(callback, args, loop) + super().__init__(callback, args, loop, context) if self._source_traceback: del self._source_traceback[-1] self._when = when @@ -155,27 +111,31 @@ def __init__(self, when, callback, args, loop): def _repr_info(self): info = super()._repr_info() pos = 2 if self._cancelled else 1 - info.insert(pos, 'when=%s' % self._when) + info.insert(pos, f'when={self._when}') return info def __hash__(self): return hash(self._when) def __lt__(self, other): - return self._when < other._when + if isinstance(other, TimerHandle): + return self._when < other._when + return NotImplemented def __le__(self, other): - if self._when < other._when: - return True - return self.__eq__(other) + if isinstance(other, TimerHandle): + return self._when < other._when or self.__eq__(other) + return NotImplemented def __gt__(self, other): - return self._when > other._when + if isinstance(other, TimerHandle): + return self._when > other._when + return NotImplemented def __ge__(self, other): - if self._when > other._when: - return True - return self.__eq__(other) + if isinstance(other, TimerHandle): + return self._when > other._when or self.__eq__(other) + return NotImplemented def __eq__(self, other): if isinstance(other, TimerHandle): @@ -185,26 +145,60 @@ def __eq__(self, other): self._cancelled == other._cancelled) return NotImplemented - def __ne__(self, other): - equal = self.__eq__(other) - return NotImplemented if equal is NotImplemented else not equal - def cancel(self): if not self._cancelled: self._loop._timer_handle_cancelled(self) super().cancel() + def when(self): + """Return a scheduled callback time. + + The time is an absolute timestamp, using the same time + reference as loop.time(). + """ + return self._when + class AbstractServer: """Abstract server returned by create_server().""" def close(self): """Stop serving. This leaves existing connections open.""" - return NotImplemented + raise NotImplementedError + + def get_loop(self): + """Get the event loop the Server object is attached to.""" + raise NotImplementedError + + def is_serving(self): + """Return True if the server is accepting connections.""" + raise NotImplementedError + + async def start_serving(self): + """Start accepting connections. + + This method is idempotent, so it can be called when + the server is already being serving. + """ + raise NotImplementedError + + async def serve_forever(self): + """Start accepting connections until the coroutine is cancelled. + + The server is closed when the coroutine is cancelled. + """ + raise NotImplementedError - def wait_closed(self): + async def wait_closed(self): """Coroutine to wait until service is closed.""" - return NotImplemented + raise NotImplementedError + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + self.close() + await self.wait_closed() class AbstractEventLoop: @@ -250,23 +244,27 @@ def close(self): """ raise NotImplementedError - def shutdown_asyncgens(self): + async def shutdown_asyncgens(self): """Shutdown all active asynchronous generators.""" raise NotImplementedError + async def shutdown_default_executor(self): + """Schedule the shutdown of the default executor.""" + raise NotImplementedError + # Methods scheduling callbacks. All these return Handles. def _timer_handle_cancelled(self, handle): """Notification that a TimerHandle has been cancelled.""" raise NotImplementedError - def call_soon(self, callback, *args): - return self.call_later(0, callback, *args) + def call_soon(self, callback, *args, context=None): + return self.call_later(0, callback, *args, context=context) - def call_later(self, delay, callback, *args): + def call_later(self, delay, callback, *args, context=None): raise NotImplementedError - def call_at(self, when, callback, *args): + def call_at(self, when, callback, *args, context=None): raise NotImplementedError def time(self): @@ -277,12 +275,12 @@ def create_future(self): # Method scheduling a coroutine object: create a task. - def create_task(self, coro): + def create_task(self, coro, *, name=None): raise NotImplementedError # Methods for interacting with threads. - def call_soon_threadsafe(self, callback, *args): + def call_soon_threadsafe(self, callback, *args, context=None): raise NotImplementedError def run_in_executor(self, executor, func, *args): @@ -293,21 +291,29 @@ def set_default_executor(self, executor): # Network I/O methods returning Futures. - def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): + async def getaddrinfo(self, host, port, *, + family=0, type=0, proto=0, flags=0): raise NotImplementedError - def getnameinfo(self, sockaddr, flags=0): + async def getnameinfo(self, sockaddr, flags=0): raise NotImplementedError - def create_connection(self, protocol_factory, host=None, port=None, *, - ssl=None, family=0, proto=0, flags=0, sock=None, - local_addr=None, server_hostname=None): + async def create_connection( + self, protocol_factory, host=None, port=None, + *, ssl=None, family=0, proto=0, + flags=0, sock=None, local_addr=None, + server_hostname=None, + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): raise NotImplementedError - def create_server(self, protocol_factory, host=None, port=None, *, - family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=None, reuse_address=None, - reuse_port=None): + async def create_server( + self, protocol_factory, host=None, port=None, + *, family=socket.AF_UNSPEC, + flags=socket.AI_PASSIVE, sock=None, backlog=100, + ssl=None, reuse_address=None, reuse_port=None, + ssl_handshake_timeout=None, + start_serving=True): """A coroutine which creates a TCP server bound to host and port. The return value is a Server object which can be used to stop @@ -315,8 +321,8 @@ def create_server(self, protocol_factory, host=None, port=None, *, If host is an empty string or None all interfaces are assumed and a list of multiple sockets will be returned (most likely - one for IPv4 and another one for IPv6). The host parameter can also be a - sequence (e.g. list) of hosts to bind to. + one for IPv4 and another one for IPv6). The host parameter can also be + a sequence (e.g. list) of hosts to bind to. family can be set to either AF_INET or AF_INET6 to force the socket to use IPv4 or IPv6. If not set it will be determined @@ -342,22 +348,55 @@ def create_server(self, protocol_factory, host=None, port=None, *, the same port as other existing endpoints are bound to, so long as they all set this flag when being created. This option is not supported on Windows. + + ssl_handshake_timeout is the time in seconds that an SSL server + will wait for completion of the SSL handshake before aborting the + connection. Default is 60s. + + start_serving set to True (default) causes the created server + to start accepting connections immediately. When set to False, + the user should await Server.start_serving() or Server.serve_forever() + to make the server to start accepting connections. + """ + raise NotImplementedError + + async def sendfile(self, transport, file, offset=0, count=None, + *, fallback=True): + """Send a file through a transport. + + Return an amount of sent bytes. + """ + raise NotImplementedError + + async def start_tls(self, transport, protocol, sslcontext, *, + server_side=False, + server_hostname=None, + ssl_handshake_timeout=None): + """Upgrade a transport to TLS. + + Return a new transport that *protocol* should start using + immediately. """ raise NotImplementedError - def create_unix_connection(self, protocol_factory, path, *, - ssl=None, sock=None, - server_hostname=None): + async def create_unix_connection( + self, protocol_factory, path=None, *, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None): raise NotImplementedError - def create_unix_server(self, protocol_factory, path, *, - sock=None, backlog=100, ssl=None): + async def create_unix_server( + self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None, + ssl_handshake_timeout=None, + start_serving=True): """A coroutine which creates a UNIX Domain Socket server. The return value is a Server object, which can be used to stop the service. - path is a str, representing a file systsem path to bind the + path is a str, representing a file system path to bind the server socket to. sock can optionally be specified in order to use a preexisting @@ -368,14 +407,22 @@ def create_unix_server(self, protocol_factory, path, *, ssl can be set to an SSLContext to enable SSL over the accepted connections. + + ssl_handshake_timeout is the time in seconds that an SSL server + will wait for the SSL handshake to complete (defaults to 60s). + + start_serving set to True (default) causes the created server + to start accepting connections immediately. When set to False, + the user should await Server.start_serving() or Server.serve_forever() + to make the server to start accepting connections. """ raise NotImplementedError - def create_datagram_endpoint(self, protocol_factory, - local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0, - reuse_address=None, reuse_port=None, - allow_broadcast=None, sock=None): + async def create_datagram_endpoint(self, protocol_factory, + local_addr=None, remote_addr=None, *, + family=0, proto=0, flags=0, + reuse_address=None, reuse_port=None, + allow_broadcast=None, sock=None): """A coroutine which creates a datagram endpoint. This method will try to establish the endpoint in the background. @@ -383,8 +430,8 @@ def create_datagram_endpoint(self, protocol_factory, protocol_factory must be a callable returning a protocol instance. - socket family AF_INET or socket.AF_INET6 depending on host (or - family if specified), socket type SOCK_DGRAM. + socket family AF_INET, socket.AF_INET6 or socket.AF_UNIX depending on + host (or family if specified), socket type SOCK_DGRAM. reuse_address tells the kernel to reuse a local socket in TIME_WAIT state, without waiting for its natural timeout to @@ -408,7 +455,7 @@ def create_datagram_endpoint(self, protocol_factory, # Pipes and subprocesses. - def connect_read_pipe(self, protocol_factory, pipe): + async def connect_read_pipe(self, protocol_factory, pipe): """Register read pipe in event loop. Set the pipe to non-blocking mode. protocol_factory should instantiate object with Protocol interface. @@ -418,10 +465,10 @@ def connect_read_pipe(self, protocol_factory, pipe): # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. + # close fd in pipe transport then close f and vice versa. raise NotImplementedError - def connect_write_pipe(self, protocol_factory, pipe): + async def connect_write_pipe(self, protocol_factory, pipe): """Register write pipe in event loop. protocol_factory should instantiate object with BaseProtocol interface. @@ -431,17 +478,21 @@ def connect_write_pipe(self, protocol_factory, pipe): # The reason to accept file-like object instead of just file descriptor # is: we need to own pipe and close it at transport finishing # Can got complicated errors if pass f.fileno(), - # close fd in pipe transport then close f and vise versa. + # close fd in pipe transport then close f and vice versa. raise NotImplementedError - def subprocess_shell(self, protocol_factory, cmd, *, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): + async def subprocess_shell(self, protocol_factory, cmd, *, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **kwargs): raise NotImplementedError - def subprocess_exec(self, protocol_factory, *args, stdin=subprocess.PIPE, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - **kwargs): + async def subprocess_exec(self, protocol_factory, *args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + **kwargs): raise NotImplementedError # Ready-based callback registration methods. @@ -463,16 +514,23 @@ def remove_writer(self, fd): # Completion based I/O methods returning Futures. - def sock_recv(self, sock, nbytes): + async def sock_recv(self, sock, nbytes): + raise NotImplementedError + + async def sock_recv_into(self, sock, buf): raise NotImplementedError - def sock_sendall(self, sock, data): + async def sock_sendall(self, sock, data): raise NotImplementedError - def sock_connect(self, sock, address): + async def sock_connect(self, sock, address): raise NotImplementedError - def sock_accept(self, sock): + async def sock_accept(self, sock): + raise NotImplementedError + + async def sock_sendfile(self, sock, file, offset=0, count=None, + *, fallback=None): raise NotImplementedError # Signal handling. @@ -571,17 +629,19 @@ def __init__(self): self._local = self._Local() def get_event_loop(self): - """Get the event loop. + """Get the event loop for the current context. - This may be None or an instance of EventLoop. + Returns an instance of EventLoop or raises an exception. """ if (self._local._loop is None and - not self._local._set_called and - isinstance(threading.current_thread(), threading._MainThread)): + not self._local._set_called and + threading.current_thread() is threading.main_thread()): self.set_event_loop(self.new_event_loop()) + if self._local._loop is None: raise RuntimeError('There is no current event loop in thread %r.' % threading.current_thread().name) + return self._local._loop def set_event_loop(self, loop): @@ -611,7 +671,9 @@ def new_event_loop(self): # A TLS for the running event loop, used by _get_running_loop. class _RunningLoop(threading.local): - _loop = None + loop_pid = (None, None) + + _running_loop = _RunningLoop() @@ -633,7 +695,10 @@ def _get_running_loop(): This is a low-level function intended to be used by event loops. This function is thread-specific. """ - return _running_loop._loop + # NOTE: this function is implemented in C (see _asynciomodule.c) + running_loop, pid = _running_loop.loop_pid + if running_loop is not None and pid == os.getpid(): + return running_loop def _set_running_loop(loop): @@ -642,7 +707,8 @@ def _set_running_loop(loop): This is a low-level function intended to be used by event loops. This function is thread-specific. """ - _running_loop._loop = loop + # NOTE: this function is implemented in C (see _asynciomodule.c) + _running_loop.loop_pid = (loop, os.getpid()) def _init_event_loop_policy(): @@ -678,6 +744,7 @@ def get_event_loop(): If there is no running event loop set, the function will return the result of `get_event_loop_policy().get_event_loop()` call. """ + # NOTE: this function is implemented in C (see _asynciomodule.c) current_loop = _get_running_loop() if current_loop is not None: return current_loop @@ -703,3 +770,26 @@ def set_child_watcher(watcher): """Equivalent to calling get_event_loop_policy().set_child_watcher(watcher).""" return get_event_loop_policy().set_child_watcher(watcher) + + +# Alias pure-Python implementations for testing purposes. +_py__get_running_loop = _get_running_loop +_py__set_running_loop = _set_running_loop +_py_get_running_loop = get_running_loop +_py_get_event_loop = get_event_loop + + +try: + # get_event_loop() is one of the most frequently called + # functions in asyncio. Pure Python implementation is + # about 4 times slower than C-accelerated. + from _asyncio import (_get_running_loop, _set_running_loop, + get_running_loop, get_event_loop) +except ImportError: + pass +else: + # Alias C implementations for testing purposes. + _c__get_running_loop = _get_running_loop + _c__set_running_loop = _set_running_loop + _c_get_running_loop = get_running_loop + _c_get_event_loop = get_event_loop diff --git a/Lib/asyncio/exceptions.py b/Lib/asyncio/exceptions.py new file mode 100644 index 0000000000..f07e448657 --- /dev/null +++ b/Lib/asyncio/exceptions.py @@ -0,0 +1,58 @@ +"""asyncio exceptions.""" + + +__all__ = ('CancelledError', 'InvalidStateError', 'TimeoutError', + 'IncompleteReadError', 'LimitOverrunError', + 'SendfileNotAvailableError') + + +class CancelledError(BaseException): + """The Future or Task was cancelled.""" + + +class TimeoutError(Exception): + """The operation exceeded the given deadline.""" + + +class InvalidStateError(Exception): + """The operation is not allowed in this state.""" + + +class SendfileNotAvailableError(RuntimeError): + """Sendfile syscall is not available. + + Raised if OS does not support sendfile syscall for given socket or + file type. + """ + + +class IncompleteReadError(EOFError): + """ + Incomplete read error. Attributes: + + - partial: read bytes string before the end of stream was reached + - expected: total number of expected bytes (or None if unknown) + """ + def __init__(self, partial, expected): + r_expected = 'undefined' if expected is None else repr(expected) + super().__init__(f'{len(partial)} bytes read on a total of ' + f'{r_expected} expected bytes') + self.partial = partial + self.expected = expected + + def __reduce__(self): + return type(self), (self.partial, self.expected) + + +class LimitOverrunError(Exception): + """Reached the buffer limit while looking for a separator. + + Attributes: + - consumed: total number of to be consumed bytes. + """ + def __init__(self, message, consumed): + super().__init__(message) + self.consumed = consumed + + def __reduce__(self): + return type(self), (self.args[0], self.consumed) diff --git a/Lib/asyncio/format_helpers.py b/Lib/asyncio/format_helpers.py new file mode 100644 index 0000000000..27d11fd4fa --- /dev/null +++ b/Lib/asyncio/format_helpers.py @@ -0,0 +1,76 @@ +import functools +import inspect +import reprlib +import sys +import traceback + +from . import constants + + +def _get_function_source(func): + func = inspect.unwrap(func) + if inspect.isfunction(func): + code = func.__code__ + return (code.co_filename, code.co_firstlineno) + if isinstance(func, functools.partial): + return _get_function_source(func.func) + if isinstance(func, functools.partialmethod): + return _get_function_source(func.func) + return None + + +def _format_callback_source(func, args): + func_repr = _format_callback(func, args, None) + source = _get_function_source(func) + if source: + func_repr += f' at {source[0]}:{source[1]}' + return func_repr + + +def _format_args_and_kwargs(args, kwargs): + """Format function arguments and keyword arguments. + + Special case for a single parameter: ('hello',) is formatted as ('hello'). + """ + # use reprlib to limit the length of the output + items = [] + if args: + items.extend(reprlib.repr(arg) for arg in args) + if kwargs: + items.extend(f'{k}={reprlib.repr(v)}' for k, v in kwargs.items()) + return '({})'.format(', '.join(items)) + + +def _format_callback(func, args, kwargs, suffix=''): + if isinstance(func, functools.partial): + suffix = _format_args_and_kwargs(args, kwargs) + suffix + return _format_callback(func.func, func.args, func.keywords, suffix) + + if hasattr(func, '__qualname__') and func.__qualname__: + func_repr = func.__qualname__ + elif hasattr(func, '__name__') and func.__name__: + func_repr = func.__name__ + else: + func_repr = repr(func) + + func_repr += _format_args_and_kwargs(args, kwargs) + if suffix: + func_repr += suffix + return func_repr + + +def extract_stack(f=None, limit=None): + """Replacement for traceback.extract_stack() that only does the + necessary work for asyncio debug mode. + """ + if f is None: + f = sys._getframe().f_back + if limit is None: + # Limit the amount of work to a reasonable amount, as extract_stack() + # can be called for each coroutine and future in debug mode. + limit = constants.DEBUG_STACK_DEPTH + stack = traceback.StackSummary.extract(traceback.walk_stack(f), + limit=limit, + lookup_lines=False) + stack.reverse() + return stack diff --git a/Lib/asyncio/futures.py b/Lib/asyncio/futures.py index 82c03330ad..aaab09c28e 100644 --- a/Lib/asyncio/futures.py +++ b/Lib/asyncio/futures.py @@ -1,21 +1,21 @@ """A Future class similar to the one in PEP 3148.""" -__all__ = ['CancelledError', 'TimeoutError', 'InvalidStateError', - 'Future', 'wrap_future', 'isfuture'] +__all__ = ( + 'Future', 'wrap_future', 'isfuture', +) import concurrent.futures +import contextvars import logging import sys -import traceback +from types import GenericAlias from . import base_futures -from . import compat from . import events +from . import exceptions +from . import format_helpers -CancelledError = base_futures.CancelledError -InvalidStateError = base_futures.InvalidStateError -TimeoutError = base_futures.TimeoutError isfuture = base_futures.isfuture @@ -27,96 +27,18 @@ STACK_DEBUG = logging.DEBUG - 1 # heavy-duty debugging -class _TracebackLogger: - """Helper to log a traceback upon destruction if not cleared. - - This solves a nasty problem with Futures and Tasks that have an - exception set: if nobody asks for the exception, the exception is - never logged. This violates the Zen of Python: 'Errors should - never pass silently. Unless explicitly silenced.' - - However, we don't want to log the exception as soon as - set_exception() is called: if the calling code is written - properly, it will get the exception and handle it properly. But - we *do* want to log it if result() or exception() was never called - -- otherwise developers waste a lot of time wondering why their - buggy code fails silently. - - An earlier attempt added a __del__() method to the Future class - itself, but this backfired because the presence of __del__() - prevents garbage collection from breaking cycles. A way out of - this catch-22 is to avoid having a __del__() method on the Future - class itself, but instead to have a reference to a helper object - with a __del__() method that logs the traceback, where we ensure - that the helper object doesn't participate in cycles, and only the - Future has a reference to it. - - The helper object is added when set_exception() is called. When - the Future is collected, and the helper is present, the helper - object is also collected, and its __del__() method will log the - traceback. When the Future's result() or exception() method is - called (and a helper object is present), it removes the helper - object, after calling its clear() method to prevent it from - logging. - - One downside is that we do a fair amount of work to extract the - traceback from the exception, even when it is never logged. It - would seem cheaper to just store the exception object, but that - references the traceback, which references stack frames, which may - reference the Future, which references the _TracebackLogger, and - then the _TracebackLogger would be included in a cycle, which is - what we're trying to avoid! As an optimization, we don't - immediately format the exception; we only do the work when - activate() is called, which call is delayed until after all the - Future's callbacks have run. Since usually a Future has at least - one callback (typically set by 'yield from') and usually that - callback extracts the callback, thereby removing the need to - format the exception. - - PS. I don't claim credit for this solution. I first heard of it - in a discussion about closing files when they are collected. - """ - - __slots__ = ('loop', 'source_traceback', 'exc', 'tb') - - def __init__(self, future, exc): - self.loop = future._loop - self.source_traceback = future._source_traceback - self.exc = exc - self.tb = None - - def activate(self): - exc = self.exc - if exc is not None: - self.exc = None - self.tb = traceback.format_exception(exc.__class__, exc, - exc.__traceback__) - - def clear(self): - self.exc = None - self.tb = None - - def __del__(self): - if self.tb: - msg = 'Future/Task exception was never retrieved\n' - if self.source_traceback: - src = ''.join(traceback.format_list(self.source_traceback)) - msg += 'Future/Task created at (most recent call last):\n' - msg += '%s\n' % src.rstrip() - msg += ''.join(self.tb).rstrip() - self.loop.call_exception_handler({'message': msg}) - - class Future: """This class is *almost* compatible with concurrent.futures.Future. Differences: + - This class is not thread-safe. + - result() and exception() do not take a timeout argument and raise an exception when the future isn't done yet. - Callbacks registered with add_done_callback() are always called - via the event loop's call_soon_threadsafe(). + via the event loop's call_soon(). - This class is not compatible with the wait() and as_completed() methods in the concurrent.futures package. @@ -130,6 +52,9 @@ class Future: _exception = None _loop = None _source_traceback = None + _cancel_message = None + # A saved CancelledError for later chaining as an exception context. + _cancelled_exc = None # This field is used for a dual purpose: # - Its presence is a marker to declare that a class implements @@ -137,12 +62,12 @@ class Future: # The value must also be not-None, to enable a subclass to declare # that it is not compatible by setting this to None. # - It is set by __iter__() below so that Task._step() can tell - # the difference between `yield from Future()` (correct) vs. + # the difference between + # `await Future()` or`yield from Future()` (correct) vs. # `yield Future()` (incorrect). _asyncio_future_blocking = False - _log_traceback = False # Used for Python 3.4 and later - _tb_logger = None # Used for Python 3.3 only + __log_traceback = False def __init__(self, *, loop=None): """Initialize the future. @@ -157,50 +82,81 @@ def __init__(self, *, loop=None): self._loop = loop self._callbacks = [] if self._loop.get_debug(): - self._source_traceback = traceback.extract_stack(sys._getframe(1)) + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) _repr_info = base_futures._future_repr_info def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, ' '.join(self._repr_info())) - - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self._log_traceback: - # set_exception() was not called, or result() or exception() - # has consumed the exception - return - exc = self._exception - context = { - 'message': ('%s exception was never retrieved' - % self.__class__.__name__), - 'exception': exc, - 'future': self, - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - - def __class_getitem__(cls, type): - return cls - - def cancel(self): + return '<{} {}>'.format(self.__class__.__name__, + ' '.join(self._repr_info())) + + def __del__(self): + if not self.__log_traceback: + # set_exception() was not called, or result() or exception() + # has consumed the exception + return + exc = self._exception + context = { + 'message': + f'{self.__class__.__name__} exception was never retrieved', + 'exception': exc, + 'future': self, + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + + __class_getitem__ = classmethod(GenericAlias) + + @property + def _log_traceback(self): + return self.__log_traceback + + @_log_traceback.setter + def _log_traceback(self, val): + if bool(val): + raise ValueError('_log_traceback can only be set to False') + self.__log_traceback = False + + def get_loop(self): + """Return the event loop the Future is bound to.""" + loop = self._loop + if loop is None: + raise RuntimeError("Future object is not initialized.") + return loop + + def _make_cancelled_error(self): + """Create the CancelledError to raise if the Future is cancelled. + + This should only be called once when handling a cancellation since + it erases the saved context exception value. + """ + if self._cancel_message is None: + exc = exceptions.CancelledError() + else: + exc = exceptions.CancelledError(self._cancel_message) + exc.__context__ = self._cancelled_exc + # Remove the reference since we don't need this anymore. + self._cancelled_exc = None + return exc + + def cancel(self, msg=None): """Cancel the future and schedule callbacks. If the future is already done or cancelled, return False. Otherwise, change the future's state to cancelled, schedule the callbacks and return True. """ + self.__log_traceback = False if self._state != _PENDING: return False self._state = _CANCELLED - self._schedule_callbacks() + self._cancel_message = msg + self.__schedule_callbacks() return True - def _schedule_callbacks(self): + def __schedule_callbacks(self): """Internal: Ask the event loop to call all callbacks. The callbacks are scheduled to be called as soon as possible. Also @@ -211,8 +167,8 @@ def _schedule_callbacks(self): return self._callbacks[:] = [] - for callback in callbacks: - self._loop.call_soon(callback, self) + for callback, ctx in callbacks: + self._loop.call_soon(callback, self, context=ctx) def cancelled(self): """Return True if the future was cancelled.""" @@ -236,13 +192,11 @@ def result(self): the future is done and has an exception set, this exception is raised. """ if self._state == _CANCELLED: - raise CancelledError + exc = self._make_cancelled_error() + raise exc if self._state != _FINISHED: - raise InvalidStateError('Result is not ready.') - self._log_traceback = False - if self._tb_logger is not None: - self._tb_logger.clear() - self._tb_logger = None + raise exceptions.InvalidStateError('Result is not ready.') + self.__log_traceback = False if self._exception is not None: raise self._exception return self._result @@ -256,16 +210,14 @@ def exception(self): InvalidStateError. """ if self._state == _CANCELLED: - raise CancelledError + exc = self._make_cancelled_error() + raise exc if self._state != _FINISHED: - raise InvalidStateError('Exception is not set.') - self._log_traceback = False - if self._tb_logger is not None: - self._tb_logger.clear() - self._tb_logger = None + raise exceptions.InvalidStateError('Exception is not set.') + self.__log_traceback = False return self._exception - def add_done_callback(self, fn): + def add_done_callback(self, fn, *, context=None): """Add a callback to be run when the future becomes done. The callback is called with a single argument - the future object. If @@ -273,9 +225,11 @@ def add_done_callback(self, fn): scheduled with call_soon. """ if self._state != _PENDING: - self._loop.call_soon(fn, self) + self._loop.call_soon(fn, self, context=context) else: - self._callbacks.append(fn) + if context is None: + context = contextvars.copy_context() + self._callbacks.append((fn, context)) # New method not in PEP 3148. @@ -284,7 +238,9 @@ def remove_done_callback(self, fn): Returns the number of callbacks removed. """ - filtered_callbacks = [f for f in self._callbacks if f != fn] + filtered_callbacks = [(f, ctx) + for (f, ctx) in self._callbacks + if f != fn] removed_count = len(self._callbacks) - len(filtered_callbacks) if removed_count: self._callbacks[:] = filtered_callbacks @@ -299,10 +255,10 @@ def set_result(self, result): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise exceptions.InvalidStateError(f'{self._state}: {self!r}') self._result = result self._state = _FINISHED - self._schedule_callbacks() + self.__schedule_callbacks() def set_exception(self, exception): """Mark the future done and set an exception. @@ -311,7 +267,7 @@ def set_exception(self, exception): InvalidStateError. """ if self._state != _PENDING: - raise InvalidStateError('{}: {!r}'.format(self._state, self)) + raise exceptions.InvalidStateError(f'{self._state}: {self!r}') if isinstance(exception, type): exception = exception() if type(exception) is StopIteration: @@ -319,30 +275,36 @@ def set_exception(self, exception): "and cannot be raised into a Future") self._exception = exception self._state = _FINISHED - self._schedule_callbacks() - if compat.PY34: - self._log_traceback = True - else: - self._tb_logger = _TracebackLogger(self, exception) - # Arrange for the logger to be activated after all callbacks - # have had a chance to call result() or exception(). - self._loop.call_soon(self._tb_logger.activate) + self.__schedule_callbacks() + self.__log_traceback = True - def __iter__(self): + def __await__(self): if not self.done(): self._asyncio_future_blocking = True yield self # This tells Task to wait for completion. - assert self.done(), "yield from wasn't used with future" + if not self.done(): + raise RuntimeError("await wasn't used with future") return self.result() # May raise too. - if compat.PY35: - __await__ = __iter__ # make compatible with 'await' expression + __iter__ = __await__ # make compatible with 'yield from'. # Needed for testing purposes. _PyFuture = Future +def _get_loop(fut): + # Tries to call Future.get_loop() if it's available. + # Otherwise fallbacks to using the old '_loop' property. + try: + get_loop = fut.get_loop + except AttributeError: + pass + else: + return get_loop() + return fut._loop + + def _set_result_unless_cancelled(fut, result): """Helper setting the result only if the future was not cancelled.""" if fut.cancelled(): @@ -350,6 +312,18 @@ def _set_result_unless_cancelled(fut, result): fut.set_result(result) +def _convert_future_exc(exc): + exc_class = type(exc) + if exc_class is concurrent.futures.CancelledError: + return exceptions.CancelledError(*exc.args) + elif exc_class is concurrent.futures.TimeoutError: + return exceptions.TimeoutError(*exc.args) + elif exc_class is concurrent.futures.InvalidStateError: + return exceptions.InvalidStateError(*exc.args) + else: + return exc + + def _set_concurrent_future_state(concurrent, source): """Copy state from a future to a concurrent.futures.Future.""" assert source.done() @@ -359,7 +333,7 @@ def _set_concurrent_future_state(concurrent, source): return exception = source.exception() if exception is not None: - concurrent.set_exception(exception) + concurrent.set_exception(_convert_future_exc(exception)) else: result = source.result() concurrent.set_result(result) @@ -379,7 +353,7 @@ def _copy_future_state(source, dest): else: exception = source.exception() if exception is not None: - dest.set_exception(exception) + dest.set_exception(_convert_future_exc(exception)) else: result = source.result() dest.set_result(result) @@ -398,8 +372,8 @@ def _chain_future(source, destination): if not isfuture(destination) and not isinstance(destination, concurrent.futures.Future): raise TypeError('A future is required for destination argument') - source_loop = source._loop if isfuture(source) else None - dest_loop = destination._loop if isfuture(destination) else None + source_loop = _get_loop(source) if isfuture(source) else None + dest_loop = _get_loop(destination) if isfuture(destination) else None def _set_state(future, other): if isfuture(future): @@ -415,6 +389,9 @@ def _call_check_cancel(destination): source_loop.call_soon_threadsafe(source.cancel) def _call_set_state(source): + if (destination.cancelled() and + dest_loop is not None and dest_loop.is_closed()): + return if dest_loop is None or dest_loop is source_loop: _set_state(destination, source) else: @@ -429,7 +406,7 @@ def wrap_future(future, *, loop=None): if isfuture(future): return future assert isinstance(future, concurrent.futures.Future), \ - 'concurrent.futures.Future is expected, got {!r}'.format(future) + f'concurrent.futures.Future is expected, got {future!r}' if loop is None: loop = events.get_event_loop() new_future = loop.create_future() diff --git a/Lib/asyncio/locks.py b/Lib/asyncio/locks.py index deefc938ec..d17d7ccd81 100644 --- a/Lib/asyncio/locks.py +++ b/Lib/asyncio/locks.py @@ -1,89 +1,23 @@ """Synchronization primitives.""" -__all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore'] +__all__ = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore') import collections +import warnings -from . import compat from . import events -from . import futures -from .coroutines import coroutine +from . import exceptions -class _ContextManager: - """Context manager. - - This enables the following idiom for acquiring and releasing a - lock around a block: - - with (yield from lock): - - - while failing loudly when accidentally using: - - with lock: - - """ - - def __init__(self, lock): - self._lock = lock - - def __enter__(self): +class _ContextManagerMixin: + async def __aenter__(self): + await self.acquire() # We have no use for the "as ..." clause in the with # statement for locks. return None - def __exit__(self, *args): - try: - self._lock.release() - finally: - self._lock = None # Crudely prevent reuse. - - -class _ContextManagerMixin: - def __enter__(self): - raise RuntimeError( - '"yield from" should be used as context manager expression') - - def __exit__(self, *args): - # This must exist because __enter__ exists, even though that - # always raises; that's how the with-statement works. - pass - - @coroutine - def __iter__(self): - # This is not a coroutine. It is meant to enable the idiom: - # - # with (yield from lock): - # - # - # as an alternative to: - # - # yield from lock.acquire() - # try: - # - # finally: - # lock.release() - yield from self.acquire() - return _ContextManager(self) - - if compat.PY35: - - def __await__(self): - # To make "with await lock" work. - yield from self.acquire() - return _ContextManager(self) - - @coroutine - def __aenter__(self): - yield from self.acquire() - # We have no use for the "as ..." clause in the with - # statement for locks. - return None - - @coroutine - def __aexit__(self, exc_type, exc, tb): - self.release() + async def __aexit__(self, exc_type, exc, tb): + self.release() class Lock(_ContextManagerMixin): @@ -108,16 +42,16 @@ class Lock(_ContextManagerMixin): release() call resets the state to unlocked; first coroutine which is blocked in acquire() is being processed. - acquire() is a coroutine and should be called with 'yield from'. + acquire() is a coroutine and should be called with 'await'. - Locks also support the context management protocol. '(yield from lock)' - should be used as the context manager expression. + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. Usage: lock = Lock() ... - yield from lock + await lock.acquire() try: ... finally: @@ -127,13 +61,13 @@ class Lock(_ContextManagerMixin): lock = Lock() ... - with (yield from lock): + async with lock: ... Lock objects can be tested for locking state: if not lock.locked(): - yield from lock + await lock.acquire() else: # lock is acquired ... @@ -141,43 +75,58 @@ class Lock(_ContextManagerMixin): """ def __init__(self, *, loop=None): - self._waiters = collections.deque() + self._waiters = None self._locked = False - if loop is not None: - self._loop = loop - else: + if loop is None: self._loop = events.get_event_loop() + else: + self._loop = loop + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) def __repr__(self): res = super().__repr__() extra = 'locked' if self._locked else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' def locked(self): """Return True if lock is acquired.""" return self._locked - @coroutine - def acquire(self): + async def acquire(self): """Acquire a lock. This method blocks until the lock is unlocked, then sets it to locked and returns True. """ - if not self._locked and all(w.cancelled() for w in self._waiters): + if (not self._locked and (self._waiters is None or + all(w.cancelled() for w in self._waiters))): self._locked = True return True + if self._waiters is None: + self._waiters = collections.deque() fut = self._loop.create_future() self._waiters.append(fut) + + # Finally block should be called before the CancelledError + # handling as we don't want CancelledError to call + # _wake_up_first() and attempt to wake up itself. try: - yield from fut - self._locked = True - return True - finally: - self._waiters.remove(fut) + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + if not self._locked: + self._wake_up_first() + raise + + self._locked = True + return True def release(self): """Release a lock. @@ -192,14 +141,25 @@ def release(self): """ if self._locked: self._locked = False - # Wake up the first waiter who isn't cancelled. - for fut in self._waiters: - if not fut.done(): - fut.set_result(True) - break + self._wake_up_first() else: raise RuntimeError('Lock is not acquired.') + def _wake_up_first(self): + """Wake up the first waiter if it isn't done.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() necessarily means that a waiter will wake up later on and + # either take the lock, or, if it was cancelled and lock wasn't + # taken already, will hit this again and wake up a new waiter. + if not fut.done(): + fut.set_result(True) + class Event: """Asynchronous equivalent to threading.Event. @@ -213,17 +173,20 @@ class Event: def __init__(self, *, loop=None): self._waiters = collections.deque() self._value = False - if loop is not None: - self._loop = loop - else: + if loop is None: self._loop = events.get_event_loop() + else: + self._loop = loop + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) def __repr__(self): res = super().__repr__() extra = 'set' if self._value else 'unset' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' def is_set(self): """Return True if and only if the internal flag is true.""" @@ -247,8 +210,7 @@ def clear(self): to true again.""" self._value = False - @coroutine - def wait(self): + async def wait(self): """Block until the internal flag is true. If the internal flag is true on entry, return True @@ -261,7 +223,7 @@ def wait(self): fut = self._loop.create_future() self._waiters.append(fut) try: - yield from fut + await fut return True finally: self._waiters.remove(fut) @@ -278,13 +240,16 @@ class Condition(_ContextManagerMixin): """ def __init__(self, lock=None, *, loop=None): - if loop is not None: - self._loop = loop - else: + if loop is None: self._loop = events.get_event_loop() + else: + self._loop = loop + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) if lock is None: - lock = Lock(loop=self._loop) + lock = Lock(loop=loop) elif lock._loop is not self._loop: raise ValueError("loop argument must agree with lock") @@ -300,11 +265,10 @@ def __repr__(self): res = super().__repr__() extra = 'locked' if self.locked() else 'unlocked' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' - @coroutine - def wait(self): + async def wait(self): """Wait until notified. If the calling coroutine has not acquired the lock when this @@ -323,22 +287,25 @@ def wait(self): fut = self._loop.create_future() self._waiters.append(fut) try: - yield from fut + await fut return True finally: self._waiters.remove(fut) finally: # Must reacquire lock even if wait is cancelled + cancelled = False while True: try: - yield from self.acquire() + await self.acquire() break - except futures.CancelledError: - pass + except exceptions.CancelledError: + cancelled = True - @coroutine - def wait_for(self, predicate): + if cancelled: + raise exceptions.CancelledError + + async def wait_for(self, predicate): """Wait until a predicate becomes true. The predicate should be a callable which result will be @@ -347,7 +314,7 @@ def wait_for(self, predicate): """ result = predicate() while not result: - yield from self.wait() + await self.wait() result = predicate() return result @@ -404,32 +371,35 @@ def __init__(self, value=1, *, loop=None): raise ValueError("Semaphore initial value must be >= 0") self._value = value self._waiters = collections.deque() - if loop is not None: - self._loop = loop - else: + if loop is None: self._loop = events.get_event_loop() + else: + self._loop = loop + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) + self._wakeup_scheduled = False def __repr__(self): res = super().__repr__() - extra = 'locked' if self.locked() else 'unlocked,value:{}'.format( - self._value) + extra = 'locked' if self.locked() else f'unlocked, value:{self._value}' if self._waiters: - extra = '{},waiters:{}'.format(extra, len(self._waiters)) - return '<{} [{}]>'.format(res[1:-1], extra) + extra = f'{extra}, waiters:{len(self._waiters)}' + return f'<{res[1:-1]} [{extra}]>' def _wake_up_next(self): while self._waiters: waiter = self._waiters.popleft() if not waiter.done(): waiter.set_result(None) + self._wakeup_scheduled = True return def locked(self): """Returns True if semaphore can not be acquired immediately.""" return self._value == 0 - @coroutine - def acquire(self): + async def acquire(self): """Acquire a semaphore. If the internal counter is larger than zero on entry, @@ -438,16 +408,17 @@ def acquire(self): called release() to make it larger than 0, and then return True. """ - while self._value <= 0: + # _wakeup_scheduled is set if *another* task is scheduled to wakeup + # but its acquire() is not resumed yet + while self._wakeup_scheduled or self._value <= 0: fut = self._loop.create_future() self._waiters.append(fut) try: - yield from fut - except: - # See the similar code in Queue.get. - fut.cancel() - if self._value > 0 and not fut.cancelled(): - self._wake_up_next() + await fut + # reset _wakeup_scheduled *after* waiting for a future + self._wakeup_scheduled = False + except exceptions.CancelledError: + self._wake_up_next() raise self._value -= 1 return True @@ -469,6 +440,11 @@ class BoundedSemaphore(Semaphore): """ def __init__(self, value=1, *, loop=None): + if loop: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) + self._bound_value = value super().__init__(value, loop=loop) diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ff12877fae..e3f95cf21d 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -4,20 +4,45 @@ proactor is only implemented on Windows with IOCP. """ -__all__ = ['BaseProactorEventLoop'] +__all__ = 'BaseProactorEventLoop', +import io +import os import socket import warnings +import signal +import threading +import collections from . import base_events -from . import compat from . import constants from . import futures +from . import exceptions +from . import protocols from . import sslproto from . import transports +from . import trsock from .log import logger +def _set_socket_extra(transport, sock): + transport._extra['socket'] = trsock.TransportSocket(sock) + + try: + transport._extra['sockname'] = sock.getsockname() + except socket.error: + if transport._loop.get_debug(): + logger.warning( + "getsockname() failed on %r", sock, exc_info=True) + + if 'peername' not in transport._extra: + try: + transport._extra['peername'] = sock.getpeername() + except socket.error: + # UDP sockets may not have a peer name + transport._extra['peername'] = None + + class _ProactorBasePipeTransport(transports._FlowControlMixin, transports.BaseTransport): """Base class for pipe and socket transports.""" @@ -27,7 +52,7 @@ def __init__(self, loop, sock, protocol, waiter=None, super().__init__(extra, loop) self._set_extra(sock) self._sock = sock - self._protocol = protocol + self.set_protocol(protocol) self._server = server self._buffer = None # None or bytearray. self._read_fut = None @@ -51,17 +76,16 @@ def __repr__(self): elif self._closing: info.append('closing') if self._sock is not None: - info.append('fd=%s' % self._sock.fileno()) + info.append(f'fd={self._sock.fileno()}') if self._read_fut is not None: - info.append('read=%s' % self._read_fut) + info.append(f'read={self._read_fut!r}') if self._write_fut is not None: - info.append("write=%r" % self._write_fut) + info.append(f'write={self._write_fut!r}') if self._buffer: - bufsize = len(self._buffer) - info.append('write_bufsize=%s' % bufsize) + info.append(f'write_bufsize={len(self._buffer)}') if self._eof_written: info.append('EOF written') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def _set_extra(self, sock): self._extra['pipe'] = sock @@ -86,30 +110,32 @@ def close(self): self._read_fut.cancel() self._read_fut = None - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._sock is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self.close() + def __del__(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self.close() def _fatal_error(self, exc, message='Fatal error on pipe transport'): - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): - if self._loop.get_debug(): - logger.debug("%r: %s", self, message, exc_info=True) - else: - self._loop.call_exception_handler({ - 'message': message, - 'exception': exc, - 'transport': self, - 'protocol': self._protocol, - }) - self._force_close(exc) + try: + if isinstance(exc, OSError): + if self._loop.get_debug(): + logger.debug("%r: %s", self, message, exc_info=True) + else: + self._loop.call_exception_handler({ + 'message': message, + 'exception': exc, + 'transport': self, + 'protocol': self._protocol, + }) + finally: + self._force_close(exc) def _force_close(self, exc): + if self._empty_waiter is not None and not self._empty_waiter.done(): + if exc is None: + self._empty_waiter.set_result(None) + else: + self._empty_waiter.set_exception(exc) if self._closing: return self._closing = True @@ -132,7 +158,7 @@ def _call_connection_lost(self, exc): # end then it may fail with ERROR_NETNAME_DELETED if we # just close our end. First calling shutdown() seems to # cure it, but maybe using DisconnectEx() would be better. - if hasattr(self._sock, 'shutdown'): + if hasattr(self._sock, 'shutdown') and self._sock.fileno() != -1: self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() self._sock = None @@ -154,40 +180,107 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + self._pending_data = None + self._paused = True super().__init__(loop, sock, protocol, waiter, extra, server) - self._paused = False + self._loop.call_soon(self._loop_reading) + self._paused = False + + def is_reading(self): + return not self._paused and not self._closing def pause_reading(self): - if self._closing: - raise RuntimeError('Cannot pause_reading() when closing') - if self._paused: - raise RuntimeError('Already paused') + if self._closing or self._paused: + return self._paused = True + + # bpo-33694: Don't cancel self._read_fut because cancelling an + # overlapped WSASend() loss silently data with the current proactor + # implementation. + # + # If CancelIoEx() fails with ERROR_NOT_FOUND, it means that WSASend() + # completed (even if HasOverlappedIoCompleted() returns 0), but + # Overlapped.cancel() currently silently ignores the ERROR_NOT_FOUND + # error. Once the overlapped is ignored, the IOCP loop will ignores the + # completion I/O event and so not read the result of the overlapped + # WSARecv(). + if self._loop.get_debug(): logger.debug("%r pauses reading", self) def resume_reading(self): - if not self._paused: - raise RuntimeError('Not paused') - self._paused = False - if self._closing: + if self._closing or not self._paused: return - self._loop.call_soon(self._loop_reading, self._read_fut) + + self._paused = False + if self._read_fut is None: + self._loop.call_soon(self._loop_reading, None) + + data = self._pending_data + self._pending_data = None + if data is not None: + # Call the protocol methode after calling _loop_reading(), + # since the protocol can decide to pause reading again. + self._loop.call_soon(self._data_received, data) + if self._loop.get_debug(): logger.debug("%r resumes reading", self) - def _loop_reading(self, fut=None): + def _eof_received(self): + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + + try: + keep_open = self._protocol.eof_received() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.eof_received() call failed.') + return + + if not keep_open: + self.close() + + def _data_received(self, data): if self._paused: + # Don't call any protocol method while reading is paused. + # The protocol will be called on resume_reading(). + assert self._pending_data is None + self._pending_data = data return - data = None + if not data: + self._eof_received() + return + + if isinstance(self._protocol, protocols.BufferedProtocol): + try: + protocols._feed_data_to_buffered_proto(self._protocol, data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error(exc, + 'Fatal error: protocol.buffer_updated() ' + 'call failed.') + return + else: + self._protocol.data_received(data) + + def _loop_reading(self, fut=None): + data = None try: if fut is not None: assert self._read_fut is fut or (self._read_fut is None and self._closing) self._read_fut = None - data = fut.result() # deliver data later in "finally" clause + if fut.done(): + # deliver data later in "finally" clause + data = fut.result() + else: + # the future will be replaced by next proactor.recv call + fut.cancel() if self._closing: # since close() has been called we ignore any read data @@ -198,8 +291,12 @@ def _loop_reading(self, fut=None): # we got end-of-file so no need to reschedule a new read return - # reschedule a new read - self._read_fut = self._loop._proactor.recv(self._sock, 4096) + # bpo-33694: buffer_updated() has currently no fast path because of + # a data loss issue caused by overlapped WSASend() cancellation. + + if not self._paused: + # reschedule a new read + self._read_fut = self._loop._proactor.recv(self._sock, 32768) except ConnectionAbortedError as exc: if not self._closing: self._fatal_error(exc, 'Fatal read error on pipe transport') @@ -210,32 +307,36 @@ def _loop_reading(self, fut=None): self._force_close(exc) except OSError as exc: self._fatal_error(exc, 'Fatal read error on pipe transport') - except futures.CancelledError: + except exceptions.CancelledError: if not self._closing: raise else: - self._read_fut.add_done_callback(self._loop_reading) + if not self._paused: + self._read_fut.add_done_callback(self._loop_reading) finally: - if data: - self._protocol.data_received(data) - elif data is not None: - if self._loop.get_debug(): - logger.debug("%r received EOF", self) - keep_open = self._protocol.eof_received() - if not keep_open: - self.close() + if data is not None: + self._data_received(data) class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, transports.WriteTransport): """Transport for write pipes.""" + _start_tls_compatible = True + + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._empty_waiter = None + def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be byte-ish (%r)', - type(data)) + raise TypeError( + f"data argument must be a bytes-like object, " + f"not {type(data).__name__}") if self._eof_written: raise RuntimeError('write_eof() already called') + if self._empty_waiter is not None: + raise RuntimeError('unable to write; sendfile is in progress') if not data: return @@ -267,6 +368,10 @@ def write(self, data): def _loop_writing(self, f=None, data=None): try: + if f is not None and self._write_fut is None and self._closing: + # XXX most likely self._force_close() has been called, and + # it has set self._write_fut to None. + return assert f is self._write_fut self._write_fut = None self._pending_write = 0 @@ -295,6 +400,8 @@ def _loop_writing(self, f=None, data=None): self._maybe_pause_protocol() else: self._write_fut.add_done_callback(self._loop_writing) + if self._empty_waiter is not None and self._write_fut is None: + self._empty_waiter.set_result(None) except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: @@ -309,6 +416,17 @@ def write_eof(self): def abort(self): self._force_close(None) + def _make_empty_waiter(self): + if self._empty_waiter is not None: + raise RuntimeError("Empty waiter is already set") + self._empty_waiter = self._loop.create_future() + if self._write_fut is None: + self._empty_waiter.set_result(None) + return self._empty_waiter + + def _reset_empty_waiter(self): + self._empty_waiter = None + class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): def __init__(self, *args, **kw): @@ -332,6 +450,135 @@ def _pipe_closed(self, fut): self.close() +class _ProactorDatagramTransport(_ProactorBasePipeTransport, + transports.DatagramTransport): + max_size = 256 * 1024 + def __init__(self, loop, sock, protocol, address=None, + waiter=None, extra=None): + self._address = address + self._empty_waiter = None + # We don't need to call _protocol.connection_made() since our base + # constructor does it for us. + super().__init__(loop, sock, protocol, waiter=waiter, extra=extra) + + # The base constructor sets _buffer = None, so we set it here + self._buffer = collections.deque() + self._loop.call_soon(self._loop_reading) + + def _set_extra(self, sock): + _set_socket_extra(self, sock) + + def get_write_buffer_size(self): + return sum(len(data) for data, _ in self._buffer) + + def abort(self): + self._force_close(None) + + def sendto(self, data, addr=None): + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError('data argument must be bytes-like object (%r)', + type(data)) + + if not data: + return + + if self._address is not None and addr not in (None, self._address): + raise ValueError( + f'Invalid address: must be None or {self._address}') + + if self._conn_lost and self._address: + if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: + logger.warning('socket.sendto() raised exception.') + self._conn_lost += 1 + return + + # Ensure that what we buffer is immutable. + self._buffer.append((bytes(data), addr)) + + if self._write_fut is None: + # No current write operations are active, kick one off + self._loop_writing() + # else: A write operation is already kicked off + + self._maybe_pause_protocol() + + def _loop_writing(self, fut=None): + try: + if self._conn_lost: + return + + assert fut is self._write_fut + self._write_fut = None + if fut: + # We are in a _loop_writing() done callback, get the result + fut.result() + + if not self._buffer or (self._conn_lost and self._address): + # The connection has been closed + if self._closing: + self._loop.call_soon(self._call_connection_lost, None) + return + + data, addr = self._buffer.popleft() + if self._address is not None: + self._write_fut = self._loop._proactor.send(self._sock, + data) + else: + self._write_fut = self._loop._proactor.sendto(self._sock, + data, + addr=addr) + except OSError as exc: + self._protocol.error_received(exc) + except Exception as exc: + self._fatal_error(exc, 'Fatal write error on datagram transport') + else: + self._write_fut.add_done_callback(self._loop_writing) + self._maybe_resume_protocol() + + def _loop_reading(self, fut=None): + data = None + try: + if self._conn_lost: + return + + assert self._read_fut is fut or (self._read_fut is None and + self._closing) + + self._read_fut = None + if fut is not None: + res = fut.result() + + if self._closing: + # since close() has been called we ignore any read data + data = None + return + + if self._address is not None: + data, addr = res, self._address + else: + data, addr = res + + if self._conn_lost: + return + if self._address is not None: + self._read_fut = self._loop._proactor.recv(self._sock, + self.max_size) + else: + self._read_fut = self._loop._proactor.recvfrom(self._sock, + self.max_size) + except OSError as exc: + self._protocol.error_received(exc) + except exceptions.CancelledError: + if not self._closing: + raise + else: + if self._read_fut is not None: + self._read_fut.add_done_callback(self._loop_reading) + finally: + if data: + self._protocol.datagram_received(data, addr) + + class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport, _ProactorBaseWritePipeTransport, transports.Transport): @@ -349,21 +596,15 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, transports.Transport): """Transport for connected sockets.""" + _sendfile_compatible = constants._SendfileMode.TRY_NATIVE + + def __init__(self, loop, sock, protocol, waiter=None, + extra=None, server=None): + super().__init__(loop, sock, protocol, waiter, extra, server) + base_events._set_nodelay(sock) + def _set_extra(self, sock): - self._extra['socket'] = sock - try: - self._extra['sockname'] = sock.getsockname() - except (socket.error, AttributeError): - if self._loop.get_debug(): - logger.warning("getsockname() failed on %r", - sock, exc_info=True) - if 'peername' not in self._extra: - try: - self._extra['peername'] = sock.getpeername() - except (socket.error, AttributeError): - if self._loop.get_debug(): - logger.warning("getpeername() failed on %r", - sock, exc_info=True) + _set_socket_extra(self, sock) def can_write_eof(self): return True @@ -387,26 +628,33 @@ def __init__(self, proactor): self._accept_futures = {} # socket file descriptor => Future proactor.set_loop(self) self._make_self_pipe() + if threading.current_thread() is threading.main_thread(): + # wakeup fd can only be installed to a file descriptor from the main thread + signal.set_wakeup_fd(self._csock.fileno()) def _make_socket_transport(self, sock, protocol, waiter=None, extra=None, server=None): return _ProactorSocketTransport(self, sock, protocol, waiter, extra, server) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, - extra=None, server=None): - if not sslproto._is_sslproto_available(): - raise NotImplementedError("Proactor event loop requires Python 3.5" - " or newer (ssl.MemoryBIO) to support " - "SSL") - - ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, - server_side, server_hostname) + def _make_ssl_transport( + self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None, + ssl_handshake_timeout=None): + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) _ProactorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport + def _make_datagram_transport(self, sock, protocol, + address=None, waiter=None, extra=None): + return _ProactorDatagramTransport(self, sock, protocol, address, + waiter, extra) + def _make_duplex_pipe_transport(self, sock, protocol, waiter=None, extra=None): return _ProactorDuplexPipeTransport(self, @@ -428,6 +676,8 @@ def close(self): if self.is_closed(): return + if threading.current_thread() is threading.main_thread(): + signal.set_wakeup_fd(-1) # Call these methods before closing the event loop (before calling # BaseEventLoop.close), because they can schedule callbacks with # call_soon(), which is forbidden when the event loop is closed. @@ -440,20 +690,61 @@ def close(self): # Close the event loop super().close() - def sock_recv(self, sock, n): - return self._proactor.recv(sock, n) + async def sock_recv(self, sock, n): + return await self._proactor.recv(sock, n) - def sock_sendall(self, sock, data): - return self._proactor.send(sock, data) + async def sock_recv_into(self, sock, buf): + return await self._proactor.recv_into(sock, buf) - def sock_connect(self, sock, address): - return self._proactor.connect(sock, address) + async def sock_sendall(self, sock, data): + return await self._proactor.send(sock, data) - def sock_accept(self, sock): - return self._proactor.accept(sock) + async def sock_connect(self, sock, address): + return await self._proactor.connect(sock, address) - def _socketpair(self): - raise NotImplementedError + async def sock_accept(self, sock): + return await self._proactor.accept(sock) + + async def _sock_sendfile_native(self, sock, file, offset, count): + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise exceptions.SendfileNotAvailableError("not a regular file") + try: + fsize = os.fstat(fileno).st_size + except OSError: + raise exceptions.SendfileNotAvailableError("not a regular file") + blocksize = count if count else fsize + if not blocksize: + return 0 # empty file + + blocksize = min(blocksize, 0xffff_ffff) + end_pos = min(offset + count, fsize) if count else fsize + offset = min(offset, fsize) + total_sent = 0 + try: + while True: + blocksize = min(end_pos - offset, blocksize) + if blocksize <= 0: + return total_sent + await self._proactor.sendfile(sock, file, offset, blocksize) + offset += blocksize + total_sent += blocksize + finally: + if total_sent > 0: + file.seek(offset) + + async def _sendfile_native(self, transp, file, offset, count): + resume_reading = transp.is_reading() + transp.pause_reading() + await transp._make_empty_waiter() + try: + return await self.sock_sendfile(transp._sock, file, offset, count, + fallback=False) + finally: + transp._reset_empty_waiter() + if resume_reading: + transp.resume_reading() def _close_self_pipe(self): if self._self_reading_future is not None: @@ -467,21 +758,30 @@ def _close_self_pipe(self): def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = self._socketpair() + self._ssock, self._csock = socket.socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 - self.call_soon(self._loop_self_reading) def _loop_self_reading(self, f=None): try: if f is not None: f.result() # may raise + if self._self_reading_future is not f: + # When we scheduled this Future, we assigned it to + # _self_reading_future. If it's not there now, something has + # tried to cancel the loop while this callback was still in the + # queue (see windows_events.ProactorEventLoop.run_forever). In + # that case stop here instead of continuing to schedule a new + # iteration. + return f = self._proactor.recv(self._ssock, 4096) - except futures.CancelledError: + except exceptions.CancelledError: # _close_self_pipe() has been called, stop waiting for data return - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self.call_exception_handler({ 'message': 'Error on reading from the event loop self pipe', 'exception': exc, @@ -492,10 +792,26 @@ def _loop_self_reading(self, f=None): f.add_done_callback(self._loop_self_reading) def _write_to_self(self): - self._csock.send(b'\0') + # This may be called from a different thread, possibly after + # _close_self_pipe() has been called or even while it is + # running. Guard for self._csock being None or closed. When + # a socket is closed, send() raises OSError (with errno set to + # EBADF, but let's not rely on the exact error code). + csock = self._csock + if csock is None: + return + + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) def _start_serving(self, protocol_factory, sock, - sslcontext=None, server=None, backlog=100): + sslcontext=None, server=None, backlog=100, + ssl_handshake_timeout=None): def loop(f=None): try: @@ -508,7 +824,8 @@ def loop(f=None): if sslcontext is not None: self._make_ssl_transport( conn, protocol, sslcontext, server_side=True, - extra={'peername': addr}, server=server) + extra={'peername': addr}, server=server, + ssl_handshake_timeout=ssl_handshake_timeout) else: self._make_socket_transport( conn, protocol, @@ -521,13 +838,13 @@ def loop(f=None): self.call_exception_handler({ 'message': 'Accept failed on a socket', 'exception': exc, - 'socket': sock, + 'socket': trsock.TransportSocket(sock), }) sock.close() elif self._debug: logger.debug("Accept failed on socket %r", sock, exc_info=True) - except futures.CancelledError: + except exceptions.CancelledError: sock.close() else: self._accept_futures[sock.fileno()] = f @@ -545,6 +862,8 @@ def _stop_accept_futures(self): self._accept_futures.clear() def _stop_serving(self, sock): - self._stop_accept_futures() + future = self._accept_futures.pop(sock.fileno(), None) + if future: + future.cancel() self._proactor._stop_serving(sock) sock.close() diff --git a/Lib/asyncio/protocols.py b/Lib/asyncio/protocols.py index 80fcac9a82..09987b164c 100644 --- a/Lib/asyncio/protocols.py +++ b/Lib/asyncio/protocols.py @@ -1,7 +1,9 @@ -"""Abstract Protocol class.""" +"""Abstract Protocol base classes.""" -__all__ = ['BaseProtocol', 'Protocol', 'DatagramProtocol', - 'SubprocessProtocol'] +__all__ = ( + 'BaseProtocol', 'Protocol', 'DatagramProtocol', + 'SubprocessProtocol', 'BufferedProtocol', +) class BaseProtocol: @@ -14,6 +16,8 @@ class BaseProtocol: write-only transport like write pipe """ + __slots__ = () + def connection_made(self, transport): """Called when a connection is made. @@ -85,6 +89,8 @@ class Protocol(BaseProtocol): * CL: connection_lost() """ + __slots__ = () + def data_received(self, data): """Called when some data is received. @@ -100,9 +106,64 @@ def eof_received(self): """ +class BufferedProtocol(BaseProtocol): + """Interface for stream protocol with manual buffer control. + + Event methods, such as `create_server` and `create_connection`, + accept factories that return protocols that implement this interface. + + The idea of BufferedProtocol is that it allows to manually allocate + and control the receive buffer. Event loops can then use the buffer + provided by the protocol to avoid unnecessary data copies. This + can result in noticeable performance improvement for protocols that + receive big amounts of data. Sophisticated protocols can allocate + the buffer only once at creation time. + + State machine of calls: + + start -> CM [-> GB [-> BU?]]* [-> ER?] -> CL -> end + + * CM: connection_made() + * GB: get_buffer() + * BU: buffer_updated() + * ER: eof_received() + * CL: connection_lost() + """ + + __slots__ = () + + def get_buffer(self, sizehint): + """Called to allocate a new receive buffer. + + *sizehint* is a recommended minimal size for the returned + buffer. When set to -1, the buffer size can be arbitrary. + + Must return an object that implements the + :ref:`buffer protocol `. + It is an error to return a zero-sized buffer. + """ + + def buffer_updated(self, nbytes): + """Called when the buffer was updated with the received data. + + *nbytes* is the total number of bytes that were written to + the buffer. + """ + + def eof_received(self): + """Called when the other end calls write_eof() or equivalent. + + If this returns a false value (including None), the transport + will close itself. If it returns a true value, closing the + transport is up to the protocol. + """ + + class DatagramProtocol(BaseProtocol): """Interface for datagram protocol.""" + __slots__ = () + def datagram_received(self, data, addr): """Called when some datagram is received.""" @@ -116,6 +177,8 @@ def error_received(self, exc): class SubprocessProtocol(BaseProtocol): """Interface for protocol for subprocess calls.""" + __slots__ = () + def pipe_data_received(self, fd, data): """Called when the subprocess writes data into stdout/stderr pipe. @@ -132,3 +195,22 @@ def pipe_connection_lost(self, fd, exc): def process_exited(self): """Called when subprocess has exited.""" + + +def _feed_data_to_buffered_proto(proto, data): + data_len = len(data) + while data_len: + buf = proto.get_buffer(data_len) + buf_len = len(buf) + if not buf_len: + raise RuntimeError('get_buffer() returned an empty buffer') + + if buf_len >= data_len: + buf[:data_len] = data + proto.buffer_updated(data_len) + return + else: + buf[:buf_len] = data[:buf_len] + proto.buffer_updated(buf_len) + data = data[buf_len:] + data_len = len(data) diff --git a/Lib/asyncio/queues.py b/Lib/asyncio/queues.py index e16c46ae73..14ae87e0a2 100644 --- a/Lib/asyncio/queues.py +++ b/Lib/asyncio/queues.py @@ -1,27 +1,21 @@ -"""Queues""" - -__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty'] +__all__ = ('Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty') import collections import heapq +import warnings +from types import GenericAlias -from . import compat from . import events from . import locks -from .coroutines import coroutine class QueueEmpty(Exception): - """Exception raised when Queue.get_nowait() is called on a Queue object - which is empty. - """ + """Raised when Queue.get_nowait() is called on an empty Queue.""" pass class QueueFull(Exception): - """Exception raised when the Queue.put_nowait() method is called on a Queue - object which is full. - """ + """Raised when the Queue.put_nowait() method is called on a full Queue.""" pass @@ -29,7 +23,7 @@ class Queue: """A queue, useful for coordinating producer and consumer coroutines. If maxsize is less than or equal to zero, the queue size is infinite. If it - is an integer greater than 0, then "yield from put()" will block when the + is an integer greater than 0, then "await put()" will block when the queue reaches maxsize, until an item is removed by get(). Unlike the standard library Queue, you can reliably know this Queue's size @@ -42,6 +36,9 @@ def __init__(self, maxsize=0, *, loop=None): self._loop = events.get_event_loop() else: self._loop = loop + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) self._maxsize = maxsize # Futures. @@ -49,7 +46,7 @@ def __init__(self, maxsize=0, *, loop=None): # Futures. self._putters = collections.deque() self._unfinished_tasks = 0 - self._finished = locks.Event(loop=self._loop) + self._finished = locks.Event(loop=loop) self._finished.set() self._init(maxsize) @@ -75,25 +72,23 @@ def _wakeup_next(self, waiters): break def __repr__(self): - return '<{} at {:#x} {}>'.format( - type(self).__name__, id(self), self._format()) + return f'<{type(self).__name__} at {id(self):#x} {self._format()}>' def __str__(self): - return '<{} {}>'.format(type(self).__name__, self._format()) + return f'<{type(self).__name__} {self._format()}>' - def __class_getitem__(cls, type): - return cls + __class_getitem__ = classmethod(GenericAlias) def _format(self): - result = 'maxsize={!r}'.format(self._maxsize) + result = f'maxsize={self._maxsize!r}' if getattr(self, '_queue', None): - result += ' _queue={!r}'.format(list(self._queue)) + result += f' _queue={list(self._queue)!r}' if self._getters: - result += ' _getters[{}]'.format(len(self._getters)) + result += f' _getters[{len(self._getters)}]' if self._putters: - result += ' _putters[{}]'.format(len(self._putters)) + result += f' _putters[{len(self._putters)}]' if self._unfinished_tasks: - result += ' tasks={}'.format(self._unfinished_tasks) + result += f' tasks={self._unfinished_tasks}' return result def qsize(self): @@ -120,22 +115,26 @@ def full(self): else: return self.qsize() >= self._maxsize - @coroutine - def put(self, item): + async def put(self, item): """Put an item into the queue. Put an item into the queue. If the queue is full, wait until a free slot is available before adding item. - - This method is a coroutine. """ while self.full(): putter = self._loop.create_future() self._putters.append(putter) try: - yield from putter + await putter except: putter.cancel() # Just in case putter is not done yet. + try: + # Clean self._putters from canceled putters. + self._putters.remove(putter) + except ValueError: + # The putter could be removed from self._putters by a + # previous get_nowait call. + pass if not self.full() and not putter.cancelled(): # We were woken up by get_nowait(), but can't take # the call. Wake up the next in line. @@ -155,21 +154,25 @@ def put_nowait(self, item): self._finished.clear() self._wakeup_next(self._getters) - @coroutine - def get(self): + async def get(self): """Remove and return an item from the queue. If queue is empty, wait until an item is available. - - This method is a coroutine. """ while self.empty(): getter = self._loop.create_future() self._getters.append(getter) try: - yield from getter + await getter except: getter.cancel() # Just in case getter is not done yet. + try: + # Clean self._getters from canceled getters. + self._getters.remove(getter) + except ValueError: + # The getter could be removed from self._getters by a + # previous put_nowait call. + pass if not self.empty() and not getter.cancelled(): # We were woken up by put_nowait(), but can't take # the call. Wake up the next in line. @@ -208,8 +211,7 @@ def task_done(self): if self._unfinished_tasks == 0: self._finished.set() - @coroutine - def join(self): + async def join(self): """Block until all items in the queue have been gotten and processed. The count of unfinished tasks goes up whenever an item is added to the @@ -218,7 +220,7 @@ def join(self): When the count of unfinished tasks drops to zero, join() unblocks. """ if self._unfinished_tasks > 0: - yield from self._finished.wait() + await self._finished.wait() class PriorityQueue(Queue): @@ -248,9 +250,3 @@ def _put(self, item): def _get(self): return self._queue.pop() - - -if not compat.PY35: - JoinableQueue = Queue - """Deprecated alias for Queue.""" - __all__.append('JoinableQueue') diff --git a/Lib/asyncio/runners.py b/Lib/asyncio/runners.py index c3a696ef57..6920acba38 100644 --- a/Lib/asyncio/runners.py +++ b/Lib/asyncio/runners.py @@ -1,12 +1,12 @@ -__all__ = ['run'] +__all__ = 'run', from . import coroutines from . import events from . import tasks -def run(main, *, debug=False): - """Run a coroutine. +def run(main, *, debug=None): + """Execute the coroutine and return the result. This function runs the passed coroutine, taking care of managing the asyncio event loop and finalizing asynchronous @@ -39,12 +39,14 @@ async def main(): loop = events.new_event_loop() try: events.set_event_loop(loop) - loop.set_debug(debug) + if debug is not None: + loop.set_debug(debug) return loop.run_until_complete(main) finally: try: _cancel_all_tasks(loop) loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) finally: events.set_event_loop(None) loop.close() @@ -59,7 +61,7 @@ def _cancel_all_tasks(loop): task.cancel() loop.run_until_complete( - tasks.gather(*to_cancel, loop=loop, return_exceptions=True)) + tasks._gather(*to_cancel, loop=loop, return_exceptions=True)) for task in to_cancel: if task.cancelled(): diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 9dbe550b01..572d4a8ce1 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -4,11 +4,12 @@ also includes support for signal handling, see the unix_events sub-module. """ -__all__ = ['BaseSelectorEventLoop'] +__all__ = 'BaseSelectorEventLoop', import collections import errno import functools +import selectors import socket import warnings import weakref @@ -18,14 +19,13 @@ ssl = None from . import base_events -from . import compat from . import constants from . import events from . import futures -from . import selectors -from . import transports +from . import protocols from . import sslproto -from .coroutines import coroutine +from . import transports +from . import trsock from .log import logger @@ -40,17 +40,6 @@ def _test_selector_event(selector, fd, event): return bool(key.events & event) -if hasattr(socket, 'TCP_NODELAY'): - def _set_nodelay(sock): - if (sock.family in {socket.AF_INET, socket.AF_INET6} and - sock.type == socket.SOCK_STREAM and - sock.proto == socket.IPPROTO_TCP): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) -else: - def _set_nodelay(sock): - pass - - class BaseSelectorEventLoop(base_events.BaseEventLoop): """Selector event loop. @@ -72,31 +61,19 @@ def _make_socket_transport(self, sock, protocol, waiter=None, *, return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) - def _make_ssl_transport(self, rawsock, protocol, sslcontext, waiter=None, - *, server_side=False, server_hostname=None, - extra=None, server=None): - if not sslproto._is_sslproto_available(): - return self._make_legacy_ssl_transport( - rawsock, protocol, sslcontext, waiter, - server_side=server_side, server_hostname=server_hostname, - extra=extra, server=server) - - ssl_protocol = sslproto.SSLProtocol(self, protocol, sslcontext, waiter, - server_side, server_hostname) + def _make_ssl_transport( + self, rawsock, protocol, sslcontext, waiter=None, + *, server_side=False, server_hostname=None, + extra=None, server=None, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): + ssl_protocol = sslproto.SSLProtocol( + self, protocol, sslcontext, waiter, + server_side, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) _SelectorSocketTransport(self, rawsock, ssl_protocol, extra=extra, server=server) return ssl_protocol._app_transport - def _make_legacy_ssl_transport(self, rawsock, protocol, sslcontext, - waiter, *, - server_side=False, server_hostname=None, - extra=None, server=None): - # Use the legacy API: SSL_write, SSL_read, etc. The legacy API is used - # on Python 3.4 and older, when ssl.MemoryBIO is not available. - return _SelectorSslTransport( - self, rawsock, protocol, sslcontext, waiter, - server_side, server_hostname, extra, server) - def _make_datagram_transport(self, sock, protocol, address=None, waiter=None, extra=None): return _SelectorDatagramTransport(self, sock, protocol, @@ -113,9 +90,6 @@ def close(self): self._selector.close() self._selector = None - def _socketpair(self): - raise NotImplementedError - def _close_self_pipe(self): self._remove_reader(self._ssock.fileno()) self._ssock.close() @@ -126,7 +100,7 @@ def _close_self_pipe(self): def _make_self_pipe(self): # A self-socket, really. :-) - self._ssock, self._csock = self._socketpair() + self._ssock, self._csock = socket.socketpair() self._ssock.setblocking(False) self._csock.setblocking(False) self._internal_fds += 1 @@ -154,22 +128,28 @@ def _write_to_self(self): # a socket is closed, send() raises OSError (with errno set to # EBADF, but let's not rely on the exact error code). csock = self._csock - if csock is not None: - try: - csock.send(b'\0') - except OSError: - if self._debug: - logger.debug("Fail to write a null byte into the " - "self-pipe socket", - exc_info=True) + if csock is None: + return + + try: + csock.send(b'\0') + except OSError: + if self._debug: + logger.debug("Fail to write a null byte into the " + "self-pipe socket", + exc_info=True) def _start_serving(self, protocol_factory, sock, - sslcontext=None, server=None, backlog=100): + sslcontext=None, server=None, backlog=100, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): self._add_reader(sock.fileno(), self._accept_connection, - protocol_factory, sock, sslcontext, server, backlog) + protocol_factory, sock, sslcontext, server, backlog, + ssl_handshake_timeout) - def _accept_connection(self, protocol_factory, sock, - sslcontext=None, server=None, backlog=100): + def _accept_connection( + self, protocol_factory, sock, + sslcontext=None, server=None, backlog=100, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): # This method is only called once for each event loop tick where the # listening socket has triggered an EVENT_READ. There may be multiple # connections waiting for an .accept() so it is called in a loop. @@ -194,24 +174,26 @@ def _accept_connection(self, protocol_factory, sock, self.call_exception_handler({ 'message': 'socket.accept() out of system resource', 'exception': exc, - 'socket': sock, + 'socket': trsock.TransportSocket(sock), }) self._remove_reader(sock.fileno()) self.call_later(constants.ACCEPT_RETRY_DELAY, self._start_serving, protocol_factory, sock, sslcontext, server, - backlog) + backlog, ssl_handshake_timeout) else: raise # The event loop will catch, log and ignore it. else: extra = {'peername': addr} - accept = self._accept_connection2(protocol_factory, conn, extra, - sslcontext, server) + accept = self._accept_connection2( + protocol_factory, conn, extra, sslcontext, server, + ssl_handshake_timeout) self.create_task(accept) - @coroutine - def _accept_connection2(self, protocol_factory, conn, extra, - sslcontext=None, server=None): + async def _accept_connection2( + self, protocol_factory, conn, extra, + sslcontext=None, server=None, + ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT): protocol = None transport = None try: @@ -220,24 +202,27 @@ def _accept_connection2(self, protocol_factory, conn, extra, if sslcontext: transport = self._make_ssl_transport( conn, protocol, sslcontext, waiter=waiter, - server_side=True, extra=extra, server=server) + server_side=True, extra=extra, server=server, + ssl_handshake_timeout=ssl_handshake_timeout) else: transport = self._make_socket_transport( conn, protocol, waiter=waiter, extra=extra, server=server) try: - yield from waiter - except: + await waiter + except BaseException: transport.close() raise + # It's now up to the protocol to handle the connection. - # It's now up to the protocol to handle the connection. - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if self._debug: context = { - 'message': ('Error on transport creation ' - 'for incoming connection'), + 'message': + 'Error on transport creation for incoming connection', 'exception': exc, } if protocol is not None: @@ -247,19 +232,26 @@ def _accept_connection2(self, protocol_factory, conn, extra, self.call_exception_handler(context) def _ensure_fd_no_transport(self, fd): + fileno = fd + if not isinstance(fileno, int): + try: + fileno = int(fileno.fileno()) + except (AttributeError, TypeError, ValueError): + # This code matches selectors._fileobj_to_fd function. + raise ValueError(f"Invalid file object: {fd!r}") from None try: - transport = self._transports[fd] + transport = self._transports[fileno] except KeyError: pass else: if not transport.is_closing(): raise RuntimeError( - 'File descriptor {!r} is used by transport {!r}'.format( - fd, transport)) + f'File descriptor {fd!r} is used by transport ' + f'{transport!r}') def _add_reader(self, fd, callback, *args): self._check_closed() - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) try: key = self._selector.get_key(fd) except KeyError: @@ -271,6 +263,7 @@ def _add_reader(self, fd, callback, *args): (handle, writer)) if reader is not None: reader.cancel() + return handle def _remove_reader(self, fd): if self.is_closed(): @@ -295,7 +288,7 @@ def _remove_reader(self, fd): def _add_writer(self, fd, callback, *args): self._check_closed() - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) try: key = self._selector.get_key(fd) except KeyError: @@ -307,6 +300,7 @@ def _add_writer(self, fd, callback, *args): (reader, handle)) if writer is not None: writer.cancel() + return handle def _remove_writer(self, fd): """Remove a writer callback.""" @@ -334,7 +328,7 @@ def _remove_writer(self, fd): def add_reader(self, fd, callback, *args): """Add a reader callback.""" self._ensure_fd_no_transport(fd) - return self._add_reader(fd, callback, *args) + self._add_reader(fd, callback, *args) def remove_reader(self, fd): """Remove a reader callback.""" @@ -344,50 +338,94 @@ def remove_reader(self, fd): def add_writer(self, fd, callback, *args): """Add a writer callback..""" self._ensure_fd_no_transport(fd) - return self._add_writer(fd, callback, *args) + self._add_writer(fd, callback, *args) def remove_writer(self, fd): """Remove a writer callback.""" self._ensure_fd_no_transport(fd) return self._remove_writer(fd) - def sock_recv(self, sock, n): + async def sock_recv(self, sock, n): """Receive data from the socket. The return value is a bytes object representing the data received. The maximum amount of data to be received at once is specified by nbytes. - - This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") + try: + return sock.recv(n) + except (BlockingIOError, InterruptedError): + pass fut = self.create_future() - self._sock_recv(fut, False, sock, n) - return fut + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recv, fut, sock, n) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_read_done(self, fd, fut, handle=None): + if handle is None or not handle.cancelled(): + self.remove_reader(fd) - def _sock_recv(self, fut, registered, sock, n): + def _sock_recv(self, fut, sock, n): # _sock_recv() can add itself as an I/O callback if the operation can't # be done immediately. Don't use it directly, call sock_recv(). - fd = sock.fileno() - if registered: - # Remove the callback early. It should be rare that the - # selector says the fd is ready but the call still returns - # EAGAIN, and I am willing to take a hit in that case in - # order to simplify the common case. - self.remove_reader(fd) - if fut.cancelled(): + if fut.done(): return try: data = sock.recv(n) except (BlockingIOError, InterruptedError): - self.add_reader(fd, self._sock_recv, fut, True, sock, n) - except Exception as exc: + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result(data) - def sock_sendall(self, sock, data): + async def sock_recv_into(self, sock, buf): + """Receive data from the socket. + + The received data is written into *buf* (a writable buffer). + The return value is the number of bytes written. + """ + base_events._check_ssl_socket(sock) + if self._debug and sock.gettimeout() != 0: + raise ValueError("the socket must be non-blocking") + try: + return sock.recv_into(buf) + except (BlockingIOError, InterruptedError): + pass + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + return await fut + + def _sock_recv_into(self, fut, sock, buf): + # _sock_recv_into() can add itself as an I/O callback if the operation + # can't be done immediately. Don't use it directly, call + # sock_recv_into(). + if fut.done(): + return + try: + nbytes = sock.recv_into(buf) + except (BlockingIOError, InterruptedError): + return # try again next time + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + fut.set_exception(exc) + else: + fut.set_result(nbytes) + + async def sock_sendall(self, sock, data): """Send data to the socket. The socket must be connected to a remote socket. This method continues @@ -395,60 +433,71 @@ def sock_sendall(self, sock, data): error occurs. None is returned on success. On error, an exception is raised, and there is no way to determine how much data, if any, was successfully processed by the receiving end of the connection. - - This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - fut = self.create_future() - if data: - self._sock_sendall(fut, False, sock, data) - else: - fut.set_result(None) - return fut - - def _sock_sendall(self, fut, registered, sock, data): - fd = sock.fileno() + try: + n = sock.send(data) + except (BlockingIOError, InterruptedError): + n = 0 - if registered: - self.remove_writer(fd) - if fut.cancelled(): + if n == len(data): + # all data sent return + fut = self.create_future() + fd = sock.fileno() + self._ensure_fd_no_transport(fd) + # use a trick with a list in closure to store a mutable state + handle = self._add_writer(fd, self._sock_sendall, fut, sock, + memoryview(data), [n]) + fut.add_done_callback( + functools.partial(self._sock_write_done, fd, handle=handle)) + return await fut + + def _sock_sendall(self, fut, sock, view, pos): + if fut.done(): + # Future cancellation can be scheduled on previous loop iteration + return + start = pos[0] try: - n = sock.send(data) + n = sock.send(view[start:]) except (BlockingIOError, InterruptedError): - n = 0 - except Exception as exc: + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) return - if n == len(data): + start += n + + if start == len(view): fut.set_result(None) else: - if n: - data = data[n:] - self.add_writer(fd, self._sock_sendall, fut, True, sock, data) + pos[0] = start - @coroutine - def sock_connect(self, sock, address): + async def sock_connect(self, sock, address): """Connect to a remote socket at address. This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX: - resolved = base_events._ensure_resolved( - address, family=sock.family, proto=sock.proto, loop=self) - if not resolved.done(): - yield from resolved - _, _, _, _, address = resolved.result()[0] + if sock.family == socket.AF_INET or ( + base_events._HAS_IPv6 and sock.family == socket.AF_INET6): + resolved = await self._ensure_resolved( + address, family=sock.family, type=sock.type, proto=sock.proto, + loop=self, + ) + _, _, _, _, address = resolved[0] fut = self.create_future() self._sock_connect(fut, sock, address) - return (yield from fut) + return await fut def _sock_connect(self, fut, sock, address): fd = sock.fileno() @@ -459,66 +508,87 @@ def _sock_connect(self, fut, sock, address): # connection runs in background. We have to wait until the socket # becomes writable to be notified when the connection succeed or # fails. + self._ensure_fd_no_transport(fd) + handle = self._add_writer( + fd, self._sock_connect_cb, fut, sock, address) fut.add_done_callback( - functools.partial(self._sock_connect_done, fd)) - self.add_writer(fd, self._sock_connect_cb, fut, sock, address) - except Exception as exc: + functools.partial(self._sock_write_done, fd, handle=handle)) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result(None) - def _sock_connect_done(self, fd, fut): - self.remove_writer(fd) + def _sock_write_done(self, fd, fut, handle=None): + if handle is None or not handle.cancelled(): + self.remove_writer(fd) def _sock_connect_cb(self, fut, sock, address): - if fut.cancelled(): + if fut.done(): return try: err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) if err != 0: # Jump to any except clause below. - raise OSError(err, 'Connect call failed %s' % (address,)) + raise OSError(err, f'Connect call failed {address}') except (BlockingIOError, InterruptedError): # socket is still registered, the callback will be retried later pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result(None) - def sock_accept(self, sock): + async def sock_accept(self, sock): """Accept a connection. The socket must be bound to an address and listening for connections. The return value is a pair (conn, address) where conn is a new socket object usable to send and receive data on the connection, and address is the address bound to the socket on the other end of the connection. - - This method is a coroutine. """ + base_events._check_ssl_socket(sock) if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") fut = self.create_future() - self._sock_accept(fut, False, sock) - return fut + self._sock_accept(fut, sock) + return await fut - def _sock_accept(self, fut, registered, sock): + def _sock_accept(self, fut, sock): fd = sock.fileno() - if registered: - self.remove_reader(fd) - if fut.cancelled(): - return try: conn, address = sock.accept() conn.setblocking(False) except (BlockingIOError, InterruptedError): - self.add_reader(fd, self._sock_accept, fut, True, sock) - except Exception as exc: + self._ensure_fd_no_transport(fd) + handle = self._add_reader(fd, self._sock_accept, fut, sock) + fut.add_done_callback( + functools.partial(self._sock_read_done, fd, handle=handle)) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: fut.set_exception(exc) else: fut.set_result((conn, address)) + async def _sendfile_native(self, transp, file, offset, count): + del self._transports[transp._sock_fd] + resume_reading = transp.is_reading() + transp.pause_reading() + await transp._make_empty_waiter() + try: + return await self.sock_sendfile(transp._sock, file, offset, count, + fallback=False) + finally: + transp._reset_empty_waiter() + if resume_reading: + transp.resume_reading() + self._transports[transp._sock_fd] = transp + def _process_events(self, event_list): for key, mask in event_list: fileobj, (reader, writer) = key.fileobj, key.data @@ -552,8 +622,11 @@ class _SelectorTransport(transports._FlowControlMixin, def __init__(self, loop, sock, protocol, extra=None, server=None): super().__init__(extra, loop) - self._extra['socket'] = sock - self._extra['sockname'] = sock.getsockname() + self._extra['socket'] = trsock.TransportSocket(sock) + try: + self._extra['sockname'] = sock.getsockname() + except OSError: + self._extra['sockname'] = None if 'peername' not in self._extra: try: self._extra['peername'] = sock.getpeername() @@ -561,8 +634,10 @@ def __init__(self, loop, sock, protocol, extra=None, server=None): self._extra['peername'] = None self._sock = sock self._sock_fd = sock.fileno() - self._protocol = protocol - self._protocol_connected = True + + self._protocol_connected = False + self.set_protocol(protocol) + self._server = server self._buffer = self._buffer_factory() self._conn_lost = 0 # Set when call to connection_lost scheduled. @@ -577,7 +652,7 @@ def __repr__(self): info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % self._sock_fd) + info.append(f'fd={self._sock_fd}') # test if the transport was closed if self._loop is not None and not self._loop.is_closed(): polling = _test_selector_event(self._loop._selector, @@ -596,14 +671,15 @@ def __repr__(self): state = 'idle' bufsize = self.get_write_buffer_size() - info.append('write=<%s, bufsize=%s>' % (state, bufsize)) - return '<%s>' % ' '.join(info) + info.append(f'write=<{state}, bufsize={bufsize}>') + return '<{}>'.format(' '.join(info)) def abort(self): self._force_close(None) def set_protocol(self, protocol): self._protocol = protocol + self._protocol_connected = True def get_protocol(self): return self._protocol @@ -621,19 +697,14 @@ def close(self): self._loop._remove_writer(self._sock_fd) self._loop.call_soon(self._call_connection_lost, None) - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._sock is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self._sock.close() + def __del__(self, _warn=warnings.warn): + if self._sock is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._sock.close() def _fatal_error(self, exc, message='Fatal error on transport'): # Should be called from exception handler only. - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) else: @@ -674,79 +745,162 @@ def _call_connection_lost(self, exc): def get_write_buffer_size(self): return len(self._buffer) + def _add_reader(self, fd, callback, *args): + if self._closing: + return + + self._loop._add_reader(fd, callback, *args) + class _SelectorSocketTransport(_SelectorTransport): + _start_tls_compatible = True + _sendfile_compatible = constants._SendfileMode.TRY_NATIVE + def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): + + self._read_ready_cb = None super().__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False + self._empty_waiter = None # Disable the Nagle algorithm -- small writes will be # sent without waiting for the TCP ACK. This generally # decreases the latency (in some cases significantly.) - _set_nodelay(self._sock) + base_events._set_nodelay(self._sock) self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._loop._add_reader, + self._loop.call_soon(self._add_reader, self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called self._loop.call_soon(futures._set_result_unless_cancelled, waiter, None) + def set_protocol(self, protocol): + if isinstance(protocol, protocols.BufferedProtocol): + self._read_ready_cb = self._read_ready__get_buffer + else: + self._read_ready_cb = self._read_ready__data_received + + super().set_protocol(protocol) + + def is_reading(self): + return not self._paused and not self._closing + def pause_reading(self): - if self._closing: - raise RuntimeError('Cannot pause_reading() when closing') - if self._paused: - raise RuntimeError('Already paused') + if self._closing or self._paused: + return self._paused = True self._loop._remove_reader(self._sock_fd) if self._loop.get_debug(): logger.debug("%r pauses reading", self) def resume_reading(self): - if not self._paused: - raise RuntimeError('Not paused') - self._paused = False - if self._closing: + if self._closing or not self._paused: return - self._loop._add_reader(self._sock_fd, self._read_ready) + self._paused = False + self._add_reader(self._sock_fd, self._read_ready) if self._loop.get_debug(): logger.debug("%r resumes reading", self) def _read_ready(self): + self._read_ready_cb() + + def _read_ready__get_buffer(self): + if self._conn_lost: + return + + try: + buf = self._protocol.get_buffer(-1) + if not len(buf): + raise RuntimeError('get_buffer() returned an empty buffer') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.get_buffer() call failed.') + return + + try: + nbytes = self._sock.recv_into(buf) + except (BlockingIOError, InterruptedError): + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error(exc, 'Fatal read error on socket transport') + return + + if not nbytes: + self._read_ready__on_eof() + return + + try: + self._protocol.buffer_updated(nbytes) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.buffer_updated() call failed.') + + def _read_ready__data_received(self): if self._conn_lost: return try: data = self._sock.recv(self.max_size) except (BlockingIOError, InterruptedError): - pass - except Exception as exc: + return + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._fatal_error(exc, 'Fatal read error on socket transport') + return + + if not data: + self._read_ready__on_eof() + return + + try: + self._protocol.data_received(data) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.data_received() call failed.') + + def _read_ready__on_eof(self): + if self._loop.get_debug(): + logger.debug("%r received EOF", self) + + try: + keep_open = self._protocol.eof_received() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal error: protocol.eof_received() call failed.') + return + + if keep_open: + # We're keeping the connection open so the + # protocol can write more, but we still can't + # receive more, so remove the reader callback. + self._loop._remove_reader(self._sock_fd) else: - if data: - self._protocol.data_received(data) - else: - if self._loop.get_debug(): - logger.debug("%r received EOF", self) - keep_open = self._protocol.eof_received() - if keep_open: - # We're keeping the connection open so the - # protocol can write more, but we still can't - # receive more, so remove the reader callback. - self._loop._remove_reader(self._sock_fd) - else: - self.close() + self.close() def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be a bytes-like object, ' - 'not %r' % type(data).__name__) + raise TypeError(f'data argument must be a bytes-like object, ' + f'not {type(data).__name__!r}') if self._eof: raise RuntimeError('Cannot call write() after write_eof()') + if self._empty_waiter is not None: + raise RuntimeError('unable to write; sendfile is in progress') if not data: return @@ -762,7 +916,9 @@ def write(self, data): n = self._sock.send(data) except (BlockingIOError, InterruptedError): pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._fatal_error(exc, 'Fatal write error on socket transport') return else: @@ -785,23 +941,29 @@ def _write_ready(self): n = self._sock.send(self._buffer) except (BlockingIOError, InterruptedError): pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._loop._remove_writer(self._sock_fd) self._buffer.clear() self._fatal_error(exc, 'Fatal write error on socket transport') + if self._empty_waiter is not None: + self._empty_waiter.set_exception(exc) else: if n: del self._buffer[:n] self._maybe_resume_protocol() # May append to buffer. if not self._buffer: self._loop._remove_writer(self._sock_fd) + if self._empty_waiter is not None: + self._empty_waiter.set_result(None) if self._closing: self._call_connection_lost(None) elif self._eof: self._sock.shutdown(socket.SHUT_WR) def write_eof(self): - if self._eof: + if self._closing or self._eof: return self._eof = True if not self._buffer: @@ -810,237 +972,22 @@ def write_eof(self): def can_write_eof(self): return True - -class _SelectorSslTransport(_SelectorTransport): - - _buffer_factory = bytearray - - def __init__(self, loop, rawsock, protocol, sslcontext, waiter=None, - server_side=False, server_hostname=None, - extra=None, server=None): - if ssl is None: - raise RuntimeError('stdlib ssl module not available') - - if not sslcontext: - sslcontext = sslproto._create_transport_context(server_side, server_hostname) - - wrap_kwargs = { - 'server_side': server_side, - 'do_handshake_on_connect': False, - } - if server_hostname and not server_side: - wrap_kwargs['server_hostname'] = server_hostname - sslsock = sslcontext.wrap_socket(rawsock, **wrap_kwargs) - - super().__init__(loop, sslsock, protocol, extra, server) - # the protocol connection is only made after the SSL handshake - self._protocol_connected = False - - self._server_hostname = server_hostname - self._waiter = waiter - self._sslcontext = sslcontext - self._paused = False - - # SSL-specific extra info. (peercert is set later) - self._extra.update(sslcontext=sslcontext) - - if self._loop.get_debug(): - logger.debug("%r starts SSL handshake", self) - start_time = self._loop.time() - else: - start_time = None - self._on_handshake(start_time) - - def _wakeup_waiter(self, exc=None): - if self._waiter is None: - return - if not self._waiter.cancelled(): - if exc is not None: - self._waiter.set_exception(exc) - else: - self._waiter.set_result(None) - self._waiter = None - - def _on_handshake(self, start_time): - try: - self._sock.do_handshake() - except ssl.SSLWantReadError: - self._loop._add_reader(self._sock_fd, - self._on_handshake, start_time) - return - except ssl.SSLWantWriteError: - self._loop._add_writer(self._sock_fd, - self._on_handshake, start_time) - return - except BaseException as exc: - if self._loop.get_debug(): - logger.warning("%r: SSL handshake failed", - self, exc_info=True) - self._loop._remove_reader(self._sock_fd) - self._loop._remove_writer(self._sock_fd) - self._sock.close() - self._wakeup_waiter(exc) - if isinstance(exc, Exception): - return - else: - raise - - self._loop._remove_reader(self._sock_fd) - self._loop._remove_writer(self._sock_fd) - - peercert = self._sock.getpeercert() - if not hasattr(self._sslcontext, 'check_hostname'): - # Verify hostname if requested, Python 3.4+ uses check_hostname - # and checks the hostname in do_handshake() - if (self._server_hostname and - self._sslcontext.verify_mode != ssl.CERT_NONE): - try: - ssl.match_hostname(peercert, self._server_hostname) - except Exception as exc: - if self._loop.get_debug(): - logger.warning("%r: SSL handshake failed " - "on matching the hostname", - self, exc_info=True) - self._sock.close() - self._wakeup_waiter(exc) - return - - # Add extra info that becomes available after handshake. - self._extra.update(peercert=peercert, - cipher=self._sock.cipher(), - compression=self._sock.compression(), - ssl_object=self._sock, - ) - - self._read_wants_write = False - self._write_wants_read = False - self._loop._add_reader(self._sock_fd, self._read_ready) - self._protocol_connected = True - self._loop.call_soon(self._protocol.connection_made, self) - # only wake up the waiter when connection_made() has been called - self._loop.call_soon(self._wakeup_waiter) - - if self._loop.get_debug(): - dt = self._loop.time() - start_time - logger.debug("%r: SSL handshake took %.1f ms", self, dt * 1e3) - - def pause_reading(self): - # XXX This is a bit icky, given the comment at the top of - # _read_ready(). Is it possible to evoke a deadlock? I don't - # know, although it doesn't look like it; write() will still - # accept more data for the buffer and eventually the app will - # call resume_reading() again, and things will flow again. - - if self._closing: - raise RuntimeError('Cannot pause_reading() when closing') - if self._paused: - raise RuntimeError('Already paused') - self._paused = True - self._loop._remove_reader(self._sock_fd) - if self._loop.get_debug(): - logger.debug("%r pauses reading", self) - - def resume_reading(self): - if not self._paused: - raise RuntimeError('Not paused') - self._paused = False - if self._closing: - return - self._loop._add_reader(self._sock_fd, self._read_ready) - if self._loop.get_debug(): - logger.debug("%r resumes reading", self) - - def _read_ready(self): - if self._conn_lost: - return - if self._write_wants_read: - self._write_wants_read = False - self._write_ready() - - if self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready) - - try: - data = self._sock.recv(self.max_size) - except (BlockingIOError, InterruptedError, ssl.SSLWantReadError): - pass - except ssl.SSLWantWriteError: - self._read_wants_write = True - self._loop._remove_reader(self._sock_fd) - self._loop._add_writer(self._sock_fd, self._write_ready) - except Exception as exc: - self._fatal_error(exc, 'Fatal read error on SSL transport') - else: - if data: - self._protocol.data_received(data) - else: - try: - if self._loop.get_debug(): - logger.debug("%r received EOF", self) - keep_open = self._protocol.eof_received() - if keep_open: - logger.warning('returning true from eof_received() ' - 'has no effect when using ssl') - finally: - self.close() - - def _write_ready(self): - if self._conn_lost: - return - if self._read_wants_write: - self._read_wants_write = False - self._read_ready() - - if not (self._paused or self._closing): - self._loop._add_reader(self._sock_fd, self._read_ready) - - if self._buffer: - try: - n = self._sock.send(self._buffer) - except (BlockingIOError, InterruptedError, ssl.SSLWantWriteError): - n = 0 - except ssl.SSLWantReadError: - n = 0 - self._loop._remove_writer(self._sock_fd) - self._write_wants_read = True - except Exception as exc: - self._loop._remove_writer(self._sock_fd) - self._buffer.clear() - self._fatal_error(exc, 'Fatal write error on SSL transport') - return - - if n: - del self._buffer[:n] - - self._maybe_resume_protocol() # May append to buffer. - - if not self._buffer: - self._loop._remove_writer(self._sock_fd) - if self._closing: - self._call_connection_lost(None) - - def write(self, data): - if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be a bytes-like object, ' - 'not %r' % type(data).__name__) - if not data: - return - - if self._conn_lost: - if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: - logger.warning('socket.send() raised exception.') - self._conn_lost += 1 - return - + def _call_connection_lost(self, exc): + super()._call_connection_lost(exc) + if self._empty_waiter is not None: + self._empty_waiter.set_exception( + ConnectionError("Connection is closed by peer")) + + def _make_empty_waiter(self): + if self._empty_waiter is not None: + raise RuntimeError("Empty waiter is already set") + self._empty_waiter = self._loop.create_future() if not self._buffer: - self._loop._add_writer(self._sock_fd, self._write_ready) - - # Add it to the buffer. - self._buffer.extend(data) - self._maybe_pause_protocol() + self._empty_waiter.set_result(None) + return self._empty_waiter - def can_write_eof(self): - return False + def _reset_empty_waiter(self): + self._empty_waiter = None class _SelectorDatagramTransport(_SelectorTransport): @@ -1053,7 +1000,7 @@ def __init__(self, loop, sock, protocol, address=None, self._address = address self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called - self._loop.call_soon(self._loop._add_reader, + self._loop.call_soon(self._add_reader, self._sock_fd, self._read_ready) if waiter is not None: # only wake up the waiter when connection_made() has been called @@ -1072,21 +1019,25 @@ def _read_ready(self): pass except OSError as exc: self._protocol.error_received(exc) - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._fatal_error(exc, 'Fatal read error on datagram transport') else: self._protocol.datagram_received(data, addr) def sendto(self, data, addr=None): if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError('data argument must be a bytes-like object, ' - 'not %r' % type(data).__name__) + raise TypeError(f'data argument must be a bytes-like object, ' + f'not {type(data).__name__!r}') if not data: return - if self._address and addr not in (None, self._address): - raise ValueError('Invalid address: must be None or %s' % - (self._address,)) + if self._address: + if addr not in (None, self._address): + raise ValueError( + f'Invalid address: must be None or {self._address}') + addr = self._address if self._conn_lost and self._address: if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES: @@ -1097,7 +1048,7 @@ def sendto(self, data, addr=None): if not self._buffer: # Attempt to send it right away first. try: - if self._address: + if self._extra['peername']: self._sock.send(data) else: self._sock.sendto(data, addr) @@ -1107,9 +1058,11 @@ def sendto(self, data, addr=None): except OSError as exc: self._protocol.error_received(exc) return - except Exception as exc: - self._fatal_error(exc, - 'Fatal write error on datagram transport') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal write error on datagram transport') return # Ensure that what we buffer is immutable. @@ -1120,7 +1073,7 @@ def _sendto_ready(self): while self._buffer: data, addr = self._buffer.popleft() try: - if self._address: + if self._extra['peername']: self._sock.send(data) else: self._sock.sendto(data, addr) @@ -1130,9 +1083,11 @@ def _sendto_ready(self): except OSError as exc: self._protocol.error_received(exc) return - except Exception as exc: - self._fatal_error(exc, - 'Fatal write error on datagram transport') + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._fatal_error( + exc, 'Fatal write error on datagram transport') return self._maybe_resume_protocol() # May append to buffer. diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 7ad28d6aa0..00fc16c014 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -5,8 +5,7 @@ except ImportError: # pragma: no cover ssl = None -from . import base_events -from . import compat +from . import constants from . import protocols from . import transports from .log import logger @@ -19,25 +18,13 @@ def _create_transport_context(server_side, server_hostname): # Client side may pass ssl=True to use a default # context; in that case the sslcontext passed is None. # The default is secure for client connections. - if hasattr(ssl, 'create_default_context'): - # Python 3.4+: use up-to-date strong settings. - sslcontext = ssl.create_default_context() - if not server_hostname: - sslcontext.check_hostname = False - else: - # Fallback for Python 3.3. - sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23) - sslcontext.options |= ssl.OP_NO_SSLv2 - sslcontext.options |= ssl.OP_NO_SSLv3 - sslcontext.set_default_verify_paths() - sslcontext.verify_mode = ssl.CERT_REQUIRED + # Python 3.4+: use up-to-date strong settings. + sslcontext = ssl.create_default_context() + if not server_hostname: + sslcontext.check_hostname = False return sslcontext -def _is_sslproto_available(): - return hasattr(ssl, "MemoryBIO") - - # States of an _SSLPipe. _UNWRAPPED = "UNWRAPPED" _DO_HANDSHAKE = "DO_HANDSHAKE" @@ -226,13 +213,14 @@ def feed_ssldata(self, data, only_handshake=False): # Drain possible plaintext data after close_notify. appdata.append(self._incoming.read()) except (ssl.SSLError, ssl.CertificateError) as exc: - if getattr(exc, 'errno', None) not in ( + exc_errno = getattr(exc, 'errno', None) + if exc_errno not in ( ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_SYSCALL): if self._state == _DO_HANDSHAKE and self._handshake_cb: self._handshake_cb(exc) raise - self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) # Check for record level data that needs to be sent back. # Happens for the initial handshake and renegotiations. @@ -275,13 +263,14 @@ def feed_appdata(self, data, offset=0): # It is not allowed to call write() after unwrap() until the # close_notify is acknowledged. We return the condition to the # caller as a short write. + exc_errno = getattr(exc, 'errno', None) if exc.reason == 'PROTOCOL_IS_SHUTDOWN': - exc.errno = ssl.SSL_ERROR_WANT_READ - if exc.errno not in (ssl.SSL_ERROR_WANT_READ, + exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ + if exc_errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE, ssl.SSL_ERROR_SYSCALL): raise - self._need_ssldata = (exc.errno == ssl.SSL_ERROR_WANT_READ) + self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ) # See if there's any record level data back for us. if self._outgoing.pending: @@ -294,11 +283,12 @@ def feed_appdata(self, data, offset=0): class _SSLProtocolTransport(transports._FlowControlMixin, transports.Transport): - def __init__(self, loop, ssl_protocol, app_protocol): + _sendfile_compatible = constants._SendfileMode.FALLBACK + + def __init__(self, loop, ssl_protocol): self._loop = loop # SSLProtocol instance self._ssl_protocol = ssl_protocol - self._app_protocol = app_protocol self._closed = False def get_extra_info(self, name, default=None): @@ -306,10 +296,10 @@ def get_extra_info(self, name, default=None): return self._ssl_protocol._get_extra_info(name, default) def set_protocol(self, protocol): - self._app_protocol = protocol + self._ssl_protocol._set_app_protocol(protocol) def get_protocol(self): - return self._app_protocol + return self._ssl_protocol._app_protocol def is_closing(self): return self._closed @@ -325,15 +315,16 @@ def close(self): self._closed = True self._ssl_protocol._start_shutdown() - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if not self._closed: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self.close() + def __del__(self, _warn=warnings.warn): + if not self._closed: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self.close() + + def is_reading(self): + tr = self._ssl_protocol._transport + if tr is None: + raise RuntimeError('SSL transport has not been initialized yet') + return tr.is_reading() def pause_reading(self): """Pause the receiving end. @@ -376,6 +367,17 @@ def get_write_buffer_size(self): """Return the current size of the write buffer.""" return self._ssl_protocol._transport.get_write_buffer_size() + def get_write_buffer_limits(self): + """Get the high and low watermarks for write flow control. + Return a tuple (low, high) where low and high are + positive number of bytes.""" + return self._ssl_protocol._transport.get_write_buffer_limits() + + @property + def _protocol_paused(self): + # Required for sendfile fallback pause_writing/resume_writing logic + return self._ssl_protocol._transport._protocol_paused + def write(self, data): """Write some data bytes to the transport. @@ -383,8 +385,8 @@ def write(self, data): to be sent out asynchronously. """ if not isinstance(data, (bytes, bytearray, memoryview)): - raise TypeError("data: expecting a bytes-like instance, got {!r}" - .format(type(data).__name__)) + raise TypeError(f"data: expecting a bytes-like instance, " + f"got {type(data).__name__}") if not data: return self._ssl_protocol._write_appdata(data) @@ -401,6 +403,7 @@ def abort(self): called with None as its argument. """ self._ssl_protocol._abort() + self._closed = True class SSLProtocol(protocols.Protocol): @@ -412,12 +415,21 @@ class SSLProtocol(protocols.Protocol): def __init__(self, loop, app_protocol, sslcontext, waiter, server_side=False, server_hostname=None, - call_connection_made=True): + call_connection_made=True, + ssl_handshake_timeout=None): if ssl is None: raise RuntimeError('stdlib ssl module not available') + if ssl_handshake_timeout is None: + ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT + elif ssl_handshake_timeout <= 0: + raise ValueError( + f"ssl_handshake_timeout should be a positive number, " + f"got {ssl_handshake_timeout}") + if not sslcontext: - sslcontext = _create_transport_context(server_side, server_hostname) + sslcontext = _create_transport_context( + server_side, server_hostname) self._server_side = server_side if server_hostname and not server_side: @@ -435,9 +447,8 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, self._waiter = waiter self._loop = loop - self._app_protocol = app_protocol - self._app_transport = _SSLProtocolTransport(self._loop, - self, self._app_protocol) + self._set_app_protocol(app_protocol) + self._app_transport = _SSLProtocolTransport(self._loop, self) # _SSLPipe instance (None until the connection is made) self._sslpipe = None self._session_established = False @@ -446,6 +457,12 @@ def __init__(self, loop, app_protocol, sslcontext, waiter, # transport, ex: SelectorSocketTransport self._transport = None self._call_connection_made = call_connection_made + self._ssl_handshake_timeout = ssl_handshake_timeout + + def _set_app_protocol(self, app_protocol): + self._app_protocol = app_protocol + self._app_protocol_is_buffer = \ + isinstance(app_protocol, protocols.BufferedProtocol) def _wakeup_waiter(self, exc=None): if self._waiter is None: @@ -478,9 +495,19 @@ def connection_lost(self, exc): if self._session_established: self._session_established = False self._loop.call_soon(self._app_protocol.connection_lost, exc) + else: + # Most likely an exception occurred while in SSL handshake. + # Just mark the app transport as closed so that its __del__ + # doesn't complain. + if self._app_transport is not None: + self._app_transport._closed = True self._transport = None self._app_transport = None + if getattr(self, '_handshake_timeout_handle', None): + self._handshake_timeout_handle.cancel() self._wakeup_waiter(exc) + self._app_protocol = None + self._sslpipe = None def pause_writing(self): """Called when the low-level transport's buffer goes over @@ -499,13 +526,16 @@ def data_received(self, data): The argument is a bytes object. """ + if self._sslpipe is None: + # transport closing, sslpipe is destroyed + return + try: ssldata, appdata = self._sslpipe.feed_ssldata(data) - except ssl.SSLError as e: - if self._loop.get_debug(): - logger.warning('%r: SSL error %s (reason %s)', - self, e.errno, e.reason) - self._abort() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + self._fatal_error(e, 'SSL error in data received') return for chunk in ssldata: @@ -513,7 +543,18 @@ def data_received(self, data): for chunk in appdata: if chunk: - self._app_protocol.data_received(chunk) + try: + if self._app_protocol_is_buffer: + protocols._feed_data_to_buffered_proto( + self._app_protocol, chunk) + else: + self._app_protocol.data_received(chunk) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as ex: + self._fatal_error( + ex, 'application protocol failed to receive SSL data') + return else: self._start_shutdown() break @@ -543,14 +584,19 @@ def eof_received(self): def _get_extra_info(self, name, default=None): if name in self._extra: return self._extra[name] - else: + elif self._transport is not None: return self._transport.get_extra_info(name, default) + else: + return default def _start_shutdown(self): if self._in_shutdown: return - self._in_shutdown = True - self._write_appdata(b'') + if self._in_handshake: + self._abort() + else: + self._in_shutdown = True + self._write_appdata(b'') def _write_appdata(self, data): self._write_backlog.append((data, 0)) @@ -567,10 +613,23 @@ def _start_handshake(self): # (b'', 1) is a special value in _process_write_backlog() to do # the SSL handshake self._write_backlog.append((b'', 1)) - self._loop.call_soon(self._process_write_backlog) + self._handshake_timeout_handle = \ + self._loop.call_later(self._ssl_handshake_timeout, + self._check_handshake_timeout) + self._process_write_backlog() + + def _check_handshake_timeout(self): + if self._in_handshake is True: + msg = ( + f"SSL handshake is taking longer than " + f"{self._ssl_handshake_timeout} seconds: " + f"aborting the connection" + ) + self._fatal_error(ConnectionAbortedError(msg)) def _on_handshake_complete(self, handshake_exc): self._in_handshake = False + self._handshake_timeout_handle.cancel() sslobj = self._sslpipe.ssl_object try: @@ -578,27 +637,15 @@ def _on_handshake_complete(self, handshake_exc): raise handshake_exc peercert = sslobj.getpeercert() - if not hasattr(self._sslcontext, 'check_hostname'): - # Verify hostname if requested, Python 3.4+ uses check_hostname - # and checks the hostname in do_handshake() - if (self._server_hostname - and self._sslcontext.verify_mode != ssl.CERT_NONE): - ssl.match_hostname(peercert, self._server_hostname) + except (SystemExit, KeyboardInterrupt): + raise except BaseException as exc: - if self._loop.get_debug(): - if isinstance(exc, ssl.CertificateError): - logger.warning("%r: SSL handshake failed " - "on verifying the certificate", - self, exc_info=True) - else: - logger.warning("%r: SSL handshake failed", - self, exc_info=True) - self._transport.close() - if isinstance(exc, Exception): - self._wakeup_waiter(exc) - return + if isinstance(exc, ssl.CertificateError): + msg = 'SSL handshake failed on verifying the certificate' else: - raise + msg = 'SSL handshake failed' + self._fatal_error(exc, msg) + return if self._loop.get_debug(): dt = self._loop.time() - self._handshake_start_time @@ -623,7 +670,7 @@ def _on_handshake_complete(self, handshake_exc): def _process_write_backlog(self): # Try to make progress on the write backlog. - if self._transport is None: + if self._transport is None or self._sslpipe is None: return try: @@ -655,19 +702,17 @@ def _process_write_backlog(self): # delete it and reduce the outstanding buffer size. del self._write_backlog[0] self._write_buffer_size -= len(data) + except (SystemExit, KeyboardInterrupt): + raise except BaseException as exc: if self._in_handshake: - # BaseExceptions will be re-raised in _on_handshake_complete. + # Exceptions will be re-raised in _on_handshake_complete. self._on_handshake_complete(exc) else: self._fatal_error(exc, 'Fatal error on SSL transport') - if not isinstance(exc, Exception): - # BaseException - raise def _fatal_error(self, exc, message='Fatal error on transport'): - # Should be called from exception handler only. - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) else: @@ -681,12 +726,14 @@ def _fatal_error(self, exc, message='Fatal error on transport'): self._transport._force_close(exc) def _finalize(self): + self._sslpipe = None + if self._transport is not None: self._transport.close() def _abort(self): - if self._transport is not None: - try: + try: + if self._transport is not None: self._transport.abort() - finally: - self._finalize() + finally: + self._finalize() diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py new file mode 100644 index 0000000000..451a53a16f --- /dev/null +++ b/Lib/asyncio/staggered.py @@ -0,0 +1,149 @@ +"""Support for running coroutines in parallel with staggered start times.""" + +__all__ = 'staggered_race', + +import contextlib +import typing + +from . import events +from . import exceptions as exceptions_mod +from . import locks +from . import tasks + + +async def staggered_race( + coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], + delay: typing.Optional[float], + *, + loop: events.AbstractEventLoop = None, +) -> typing.Tuple[ + typing.Any, + typing.Optional[int], + typing.List[typing.Optional[Exception]] +]: + """Run coroutines with staggered start times and take the first to finish. + + This method takes an iterable of coroutine functions. The first one is + started immediately. From then on, whenever the immediately preceding one + fails (raises an exception), or when *delay* seconds has passed, the next + coroutine is started. This continues until one of the coroutines complete + successfully, in which case all others are cancelled, or until all + coroutines fail. + + The coroutines provided should be well-behaved in the following way: + + * They should only ``return`` if completed successfully. + + * They should always raise an exception if they did not complete + successfully. In particular, if they handle cancellation, they should + probably reraise, like this:: + + try: + # do work + except asyncio.CancelledError: + # undo partially completed work + raise + + Args: + coro_fns: an iterable of coroutine functions, i.e. callables that + return a coroutine object when called. Use ``functools.partial`` or + lambdas to pass arguments. + + delay: amount of time, in seconds, between starting coroutines. If + ``None``, the coroutines will run sequentially. + + loop: the event loop to use. + + Returns: + tuple *(winner_result, winner_index, exceptions)* where + + - *winner_result*: the result of the winning coroutine, or ``None`` + if no coroutines won. + + - *winner_index*: the index of the winning coroutine in + ``coro_fns``, or ``None`` if no coroutines won. If the winning + coroutine may return None on success, *winner_index* can be used + to definitively determine whether any coroutine won. + + - *exceptions*: list of exceptions returned by the coroutines. + ``len(exceptions)`` is equal to the number of coroutines actually + started, and the order is the same as in ``coro_fns``. The winning + coroutine's entry is ``None``. + + """ + # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + loop = loop or events.get_running_loop() + enum_coro_fns = enumerate(coro_fns) + winner_result = None + winner_index = None + exceptions = [] + running_tasks = [] + + async def run_one_coro( + previous_failed: typing.Optional[locks.Event]) -> None: + # Wait for the previous task to finish, or for delay seconds + if previous_failed is not None: + with contextlib.suppress(exceptions_mod.TimeoutError): + # Use asyncio.wait_for() instead of asyncio.wait() here, so + # that if we get cancelled at this point, Event.wait() is also + # cancelled, otherwise there will be a "Task destroyed but it is + # pending" later. + await tasks.wait_for(previous_failed.wait(), delay) + # Get the next coroutine to run + try: + this_index, coro_fn = next(enum_coro_fns) + except StopIteration: + return + # Start task that will run the next coroutine + this_failed = locks.Event() + next_task = loop.create_task(run_one_coro(this_failed)) + running_tasks.append(next_task) + assert len(running_tasks) == this_index + 2 + # Prepare place to put this coroutine's exceptions if not won + exceptions.append(None) + assert len(exceptions) == this_index + 1 + + try: + result = await coro_fn() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as e: + exceptions[this_index] = e + this_failed.set() # Kickstart the next coroutine + else: + # Store winner's results + nonlocal winner_index, winner_result + assert winner_index is None + winner_index = this_index + winner_result = result + # Cancel all other tasks. We take care to not cancel the current + # task as well. If we do so, then since there is no `await` after + # here and CancelledError are usually thrown at one, we will + # encounter a curious corner case where the current task will end + # up as done() == True, cancelled() == False, exception() == + # asyncio.CancelledError. This behavior is specified in + # https://bugs.python.org/issue30048 + for i, t in enumerate(running_tasks): + if i != this_index: + t.cancel() + + first_task = loop.create_task(run_one_coro(None)) + running_tasks.append(first_task) + try: + # Wait for a growing list of tasks to all finish: poor man's version of + # curio's TaskGroup or trio's nursery + done_count = 0 + while done_count != len(running_tasks): + done, _ = await tasks.wait(running_tasks) + done_count = len(done) + # If run_one_coro raises an unhandled exception, it's probably a + # programming error, and I want to see it. + if __debug__: + for d in done: + if d.done() and not d.cancelled() and d.exception(): + raise d.exception() + return winner_result, winner_index, exceptions + finally: + # Make sure no tasks are left running if we leave this function + for t in running_tasks: + t.cancel() diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index a82cc79aca..3c80bb8892 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -1,55 +1,29 @@ -"""Stream-related things.""" - -__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol', - 'open_connection', 'start_server', - 'IncompleteReadError', - 'LimitOverrunError', - ] +__all__ = ( + 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', + 'open_connection', 'start_server') import socket +import sys +import warnings +import weakref if hasattr(socket, 'AF_UNIX'): - __all__.extend(['open_unix_connection', 'start_unix_server']) + __all__ += ('open_unix_connection', 'start_unix_server') from . import coroutines -from . import compat from . import events +from . import exceptions +from . import format_helpers from . import protocols -from .coroutines import coroutine from .log import logger +from .tasks import sleep -_DEFAULT_LIMIT = 2 ** 16 - - -class IncompleteReadError(EOFError): - """ - Incomplete read error. Attributes: - - - partial: read bytes string before the end of stream was reached - - expected: total number of expected bytes (or None if unknown) - """ - def __init__(self, partial, expected): - super().__init__("%d bytes read on a total of %r expected bytes" - % (len(partial), expected)) - self.partial = partial - self.expected = expected - - -class LimitOverrunError(Exception): - """Reached the buffer limit while looking for a separator. - - Attributes: - - consumed: total number of to be consumed bytes. - """ - def __init__(self, message, consumed): - super().__init__(message) - self.consumed = consumed +_DEFAULT_LIMIT = 2 ** 16 # 64 KiB -@coroutine -def open_connection(host=None, port=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): +async def open_connection(host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): """A wrapper for create_connection() returning a (reader, writer) pair. The reader returned is a StreamReader instance; the writer is a @@ -69,17 +43,20 @@ def open_connection(host=None, port=None, *, """ if loop is None: loop = events.get_event_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_connection( + transport, _ = await loop.create_connection( lambda: protocol, host, port, **kwds) writer = StreamWriter(transport, protocol, reader, loop) return reader, writer -@coroutine -def start_server(client_connected_cb, host=None, port=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): +async def start_server(client_connected_cb, host=None, port=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): """Start a socket server, call back for each client connected. The first parameter, `client_connected_cb`, takes two parameters: @@ -103,6 +80,10 @@ def start_server(client_connected_cb, host=None, port=None, *, """ if loop is None: loop = events.get_event_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) def factory(): reader = StreamReader(limit=limit, loop=loop) @@ -110,31 +91,37 @@ def factory(): loop=loop) return protocol - return (yield from loop.create_server(factory, host, port, **kwds)) + return await loop.create_server(factory, host, port, **kwds) if hasattr(socket, 'AF_UNIX'): # UNIX Domain Sockets are supported on this platform - @coroutine - def open_unix_connection(path=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): + async def open_unix_connection(path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `open_connection` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) reader = StreamReader(limit=limit, loop=loop) protocol = StreamReaderProtocol(reader, loop=loop) - transport, _ = yield from loop.create_unix_connection( + transport, _ = await loop.create_unix_connection( lambda: protocol, path, **kwds) writer = StreamWriter(transport, protocol, reader, loop) return reader, writer - @coroutine - def start_unix_server(client_connected_cb, path=None, *, - loop=None, limit=_DEFAULT_LIMIT, **kwds): + async def start_unix_server(client_connected_cb, path=None, *, + loop=None, limit=_DEFAULT_LIMIT, **kwds): """Similar to `start_server` but works with UNIX Domain Sockets.""" if loop is None: loop = events.get_event_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) def factory(): reader = StreamReader(limit=limit, loop=loop) @@ -142,14 +129,14 @@ def factory(): loop=loop) return protocol - return (yield from loop.create_unix_server(factory, path, **kwds)) + return await loop.create_unix_server(factory, path, **kwds) class FlowControlMixin(protocols.Protocol): """Reusable flow control logic for StreamWriter.drain(). This implements the protocol methods pause_writing(), - resume_reading() and connection_lost(). If the subclass overrides + resume_writing() and connection_lost(). If the subclass overrides these it must call the super methods. StreamWriter.drain() must wait for _drain_helper() coroutine. @@ -198,8 +185,7 @@ def connection_lost(self, exc): else: waiter.set_exception(exc) - @coroutine - def _drain_helper(self): + async def _drain_helper(self): if self._connection_lost: raise ConnectionResetError('Connection lost') if not self._paused: @@ -208,7 +194,10 @@ def _drain_helper(self): assert waiter is None or waiter.cancelled() waiter = self._loop.create_future() self._drain_waiter = waiter - yield from waiter + await waiter + + def _get_close_waiter(self, stream): + raise NotImplementedError class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): @@ -220,40 +209,86 @@ class StreamReaderProtocol(FlowControlMixin, protocols.Protocol): call inappropriate methods of the protocol.) """ + _source_traceback = None + def __init__(self, stream_reader, client_connected_cb=None, loop=None): super().__init__(loop=loop) - self._stream_reader = stream_reader + if stream_reader is not None: + self._stream_reader_wr = weakref.ref(stream_reader) + self._source_traceback = stream_reader._source_traceback + else: + self._stream_reader_wr = None + if client_connected_cb is not None: + # This is a stream created by the `create_server()` function. + # Keep a strong reference to the reader until a connection + # is established. + self._strong_reader = stream_reader + self._reject_connection = False self._stream_writer = None + self._transport = None self._client_connected_cb = client_connected_cb self._over_ssl = False + self._closed = self._loop.create_future() + + @property + def _stream_reader(self): + if self._stream_reader_wr is None: + return None + return self._stream_reader_wr() def connection_made(self, transport): - self._stream_reader.set_transport(transport) + if self._reject_connection: + context = { + 'message': ('An open stream was garbage collected prior to ' + 'establishing network connection; ' + 'call "stream.close()" explicitly.') + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + transport.abort() + return + self._transport = transport + reader = self._stream_reader + if reader is not None: + reader.set_transport(transport) self._over_ssl = transport.get_extra_info('sslcontext') is not None if self._client_connected_cb is not None: self._stream_writer = StreamWriter(transport, self, - self._stream_reader, + reader, self._loop) - res = self._client_connected_cb(self._stream_reader, + res = self._client_connected_cb(reader, self._stream_writer) if coroutines.iscoroutine(res): self._loop.create_task(res) + self._strong_reader = None def connection_lost(self, exc): - if self._stream_reader is not None: + reader = self._stream_reader + if reader is not None: + if exc is None: + reader.feed_eof() + else: + reader.set_exception(exc) + if not self._closed.done(): if exc is None: - self._stream_reader.feed_eof() + self._closed.set_result(None) else: - self._stream_reader.set_exception(exc) + self._closed.set_exception(exc) super().connection_lost(exc) - self._stream_reader = None + self._stream_reader_wr = None self._stream_writer = None + self._transport = None def data_received(self, data): - self._stream_reader.feed_data(data) + reader = self._stream_reader + if reader is not None: + reader.feed_data(data) def eof_received(self): - self._stream_reader.feed_eof() + reader = self._stream_reader + if reader is not None: + reader.feed_eof() if self._over_ssl: # Prevent a warning in SSLProtocol.eof_received: # "returning true from eof_received() @@ -261,6 +296,16 @@ def eof_received(self): return False return True + def _get_close_waiter(self, stream): + return self._closed + + def __del__(self): + # Prevent reports about unhandled exceptions. + # Better than self._closed._log_traceback = False hack + closed = self._closed + if closed.done() and not closed.cancelled(): + closed.exception() + class StreamWriter: """Wraps a Transport. @@ -279,12 +324,14 @@ def __init__(self, transport, protocol, reader, loop): assert reader is None or isinstance(reader, StreamReader) self._reader = reader self._loop = loop + self._complete_fut = self._loop.create_future() + self._complete_fut.set_result(None) def __repr__(self): - info = [self.__class__.__name__, 'transport=%r' % self._transport] + info = [self.__class__.__name__, f'transport={self._transport!r}'] if self._reader is not None: - info.append('reader=%r' % self._reader) - return '<%s>' % ' '.join(info) + info.append(f'reader={self._reader!r}') + return '<{}>'.format(' '.join(info)) @property def transport(self): @@ -305,36 +352,45 @@ def can_write_eof(self): def close(self): return self._transport.close() + def is_closing(self): + return self._transport.is_closing() + + async def wait_closed(self): + await self._protocol._get_close_waiter(self) + def get_extra_info(self, name, default=None): return self._transport.get_extra_info(name, default) - @coroutine - def drain(self): + async def drain(self): """Flush the write buffer. The intended use is to write w.write(data) - yield from w.drain() + await w.drain() """ if self._reader is not None: exc = self._reader.exception() if exc is not None: raise exc - if self._transport is not None: - if self._transport.is_closing(): - # Yield to the event loop so connection_lost() may be - # called. Without this, _drain_helper() would return - # immediately, and code that calls - # write(...); yield from drain() - # in a loop would never call connection_lost(), so it - # would not see an error when the socket is closed. - yield - yield from self._protocol._drain_helper() + if self._transport.is_closing(): + # Wait for protocol.connection_lost() call + # Raise connection closing error if any, + # ConnectionResetError otherwise + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); await drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await sleep(0) + await self._protocol._drain_helper() class StreamReader: + _source_traceback = None + def __init__(self, limit=_DEFAULT_LIMIT, loop=None): # The line length limit is a security feature; # it also doubles as half the buffer limit. @@ -353,24 +409,27 @@ def __init__(self, limit=_DEFAULT_LIMIT, loop=None): self._exception = None self._transport = None self._paused = False + if self._loop.get_debug(): + self._source_traceback = format_helpers.extract_stack( + sys._getframe(1)) def __repr__(self): info = ['StreamReader'] if self._buffer: - info.append('%d bytes' % len(self._buffer)) + info.append(f'{len(self._buffer)} bytes') if self._eof: info.append('eof') if self._limit != _DEFAULT_LIMIT: - info.append('l=%d' % self._limit) + info.append(f'limit={self._limit}') if self._waiter: - info.append('w=%r' % self._waiter) + info.append(f'waiter={self._waiter!r}') if self._exception: - info.append('e=%r' % self._exception) + info.append(f'exception={self._exception!r}') if self._transport: - info.append('t=%r' % self._transport) + info.append(f'transport={self._transport!r}') if self._paused: info.append('paused') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def exception(self): return self._exception @@ -431,8 +490,7 @@ def feed_data(self, data): else: self._paused = True - @coroutine - def _wait_for_data(self, func_name): + async def _wait_for_data(self, func_name): """Wait until feed_data() or feed_eof() is called. If stream was paused, automatically resume it. @@ -442,8 +500,9 @@ def _wait_for_data(self, func_name): # would have an unexpected behaviour. It would not possible to know # which coroutine would get the next data. if self._waiter is not None: - raise RuntimeError('%s() called while another coroutine is ' - 'already waiting for incoming data' % func_name) + raise RuntimeError( + f'{func_name}() called while another coroutine is ' + f'already waiting for incoming data') assert not self._eof, '_wait_for_data after EOF' @@ -455,12 +514,11 @@ def _wait_for_data(self, func_name): self._waiter = self._loop.create_future() try: - yield from self._waiter + await self._waiter finally: self._waiter = None - @coroutine - def readline(self): + async def readline(self): """Read chunk of data from the stream until newline (b'\n') is found. On success, return chunk that ends with newline. If only partial @@ -479,10 +537,10 @@ def readline(self): sep = b'\n' seplen = len(sep) try: - line = yield from self.readuntil(sep) - except IncompleteReadError as e: + line = await self.readuntil(sep) + except exceptions.IncompleteReadError as e: return e.partial - except LimitOverrunError as e: + except exceptions.LimitOverrunError as e: if self._buffer.startswith(sep, e.consumed): del self._buffer[:e.consumed + seplen] else: @@ -491,8 +549,7 @@ def readline(self): raise ValueError(e.args[0]) return line - @coroutine - def readuntil(self, separator=b'\n'): + async def readuntil(self, separator=b'\n'): """Read data from the stream until ``separator`` is found. On success, the data and separator will be removed from the @@ -558,7 +615,7 @@ def readuntil(self, separator=b'\n'): # see upper comment for explanation. offset = buflen + 1 - seplen if offset > self._limit: - raise LimitOverrunError( + raise exceptions.LimitOverrunError( 'Separator is not found, and chunk exceed the limit', offset) @@ -569,13 +626,13 @@ def readuntil(self, separator=b'\n'): if self._eof: chunk = bytes(self._buffer) self._buffer.clear() - raise IncompleteReadError(chunk, None) + raise exceptions.IncompleteReadError(chunk, None) # _wait_for_data() will resume reading if stream was paused. - yield from self._wait_for_data('readuntil') + await self._wait_for_data('readuntil') if isep > self._limit: - raise LimitOverrunError( + raise exceptions.LimitOverrunError( 'Separator is found, but chunk is longer than limit', isep) chunk = self._buffer[:isep + seplen] @@ -583,8 +640,7 @@ def readuntil(self, separator=b'\n'): self._maybe_resume_transport() return bytes(chunk) - @coroutine - def read(self, n=-1): + async def read(self, n=-1): """Read up to `n` bytes from the stream. If n is not provided, or set to -1, read until EOF and return all read @@ -618,14 +674,14 @@ def read(self, n=-1): # bytes. So just call self.read(self._limit) until EOF. blocks = [] while True: - block = yield from self.read(self._limit) + block = await self.read(self._limit) if not block: break blocks.append(block) return b''.join(blocks) if not self._buffer and not self._eof: - yield from self._wait_for_data('read') + await self._wait_for_data('read') # This will work right even if buffer is less than n bytes data = bytes(self._buffer[:n]) @@ -634,8 +690,7 @@ def read(self, n=-1): self._maybe_resume_transport() return data - @coroutine - def readexactly(self, n): + async def readexactly(self, n): """Read exactly `n` bytes. Raise an IncompleteReadError if EOF is reached before `n` bytes can be @@ -663,9 +718,9 @@ def readexactly(self, n): if self._eof: incomplete = bytes(self._buffer) self._buffer.clear() - raise IncompleteReadError(incomplete, n) + raise exceptions.IncompleteReadError(incomplete, n) - yield from self._wait_for_data('readexactly') + await self._wait_for_data('readexactly') if len(self._buffer) == n: data = bytes(self._buffer) @@ -676,20 +731,11 @@ def readexactly(self, n): self._maybe_resume_transport() return data - if compat.PY35: - @coroutine - def __aiter__(self): - return self - - @coroutine - def __anext__(self): - val = yield from self.readline() - if val == b'': - raise StopAsyncIteration - return val - - if compat.PY352: - # In Python 3.5.2 and greater, __aiter__ should return - # the asynchronous iterator directly. - def __aiter__(self): - return self + def __aiter__(self): + return self + + async def __anext__(self): + val = await self.readline() + if val == b'': + raise StopAsyncIteration + return val diff --git a/Lib/asyncio/subprocess.py b/Lib/asyncio/subprocess.py index b2f5304f77..820304ecca 100644 --- a/Lib/asyncio/subprocess.py +++ b/Lib/asyncio/subprocess.py @@ -1,12 +1,12 @@ -__all__ = ['create_subprocess_exec', 'create_subprocess_shell'] +__all__ = 'create_subprocess_exec', 'create_subprocess_shell' import subprocess +import warnings from . import events from . import protocols from . import streams from . import tasks -from .coroutines import coroutine from .log import logger @@ -24,16 +24,19 @@ def __init__(self, limit, loop): self._limit = limit self.stdin = self.stdout = self.stderr = None self._transport = None + self._process_exited = False + self._pipe_fds = [] + self._stdin_closed = self._loop.create_future() def __repr__(self): info = [self.__class__.__name__] if self.stdin is not None: - info.append('stdin=%r' % self.stdin) + info.append(f'stdin={self.stdin!r}') if self.stdout is not None: - info.append('stdout=%r' % self.stdout) + info.append(f'stdout={self.stdout!r}') if self.stderr is not None: - info.append('stderr=%r' % self.stderr) - return '<%s>' % ' '.join(info) + info.append(f'stderr={self.stderr!r}') + return '<{}>'.format(' '.join(info)) def connection_made(self, transport): self._transport = transport @@ -43,12 +46,14 @@ def connection_made(self, transport): self.stdout = streams.StreamReader(limit=self._limit, loop=self._loop) self.stdout.set_transport(stdout_transport) + self._pipe_fds.append(1) stderr_transport = transport.get_pipe_transport(2) if stderr_transport is not None: self.stderr = streams.StreamReader(limit=self._limit, loop=self._loop) self.stderr.set_transport(stderr_transport) + self._pipe_fds.append(2) stdin_transport = transport.get_pipe_transport(0) if stdin_transport is not None: @@ -73,6 +78,10 @@ def pipe_connection_lost(self, fd, exc): if pipe is not None: pipe.close() self.connection_lost(exc) + if exc is None: + self._stdin_closed.set_result(None) + else: + self._stdin_closed.set_exception(exc) return if fd == 1: reader = self.stdout @@ -80,15 +89,28 @@ def pipe_connection_lost(self, fd, exc): reader = self.stderr else: reader = None - if reader != None: + if reader is not None: if exc is None: reader.feed_eof() else: reader.set_exception(exc) + if fd in self._pipe_fds: + self._pipe_fds.remove(fd) + self._maybe_close_transport() + def process_exited(self): - self._transport.close() - self._transport = None + self._process_exited = True + self._maybe_close_transport() + + def _maybe_close_transport(self): + if len(self._pipe_fds) == 0 and self._process_exited: + self._transport.close() + self._transport = None + + def _get_close_waiter(self, stream): + if stream is self.stdin: + return self._stdin_closed class Process: @@ -102,18 +124,15 @@ def __init__(self, transport, protocol, loop): self.pid = transport.get_pid() def __repr__(self): - return '<%s %s>' % (self.__class__.__name__, self.pid) + return f'<{self.__class__.__name__} {self.pid}>' @property def returncode(self): return self._transport.get_returncode() - @coroutine - def wait(self): - """Wait until the process exit and return the process return code. - - This method is a coroutine.""" - return (yield from self._transport._wait()) + async def wait(self): + """Wait until the process exit and return the process return code.""" + return await self._transport._wait() def send_signal(self, signal): self._transport.send_signal(signal) @@ -124,15 +143,14 @@ def terminate(self): def kill(self): self._transport.kill() - @coroutine - def _feed_stdin(self, input): + async def _feed_stdin(self, input): debug = self._loop.get_debug() self.stdin.write(input) if debug: - logger.debug('%r communicate: feed stdin (%s bytes)', - self, len(input)) + logger.debug( + '%r communicate: feed stdin (%s bytes)', self, len(input)) try: - yield from self.stdin.drain() + await self.stdin.drain() except (BrokenPipeError, ConnectionResetError) as exc: # communicate() ignores BrokenPipeError and ConnectionResetError if debug: @@ -142,12 +160,10 @@ def _feed_stdin(self, input): logger.debug('%r communicate: close stdin', self) self.stdin.close() - @coroutine - def _noop(self): + async def _noop(self): return None - @coroutine - def _read_stream(self, fd): + async def _read_stream(self, fd): transport = self._transport.get_pipe_transport(fd) if fd == 2: stream = self.stderr @@ -157,15 +173,14 @@ def _read_stream(self, fd): if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: read %s', self, name) - output = yield from stream.read() + output = await stream.read() if self._loop.get_debug(): name = 'stdout' if fd == 1 else 'stderr' logger.debug('%r communicate: close %s', self, name) transport.close() return output - @coroutine - def communicate(self, input=None): + async def communicate(self, input=None): if input is not None: stdin = self._feed_stdin(input) else: @@ -178,36 +193,49 @@ def communicate(self, input=None): stderr = self._read_stream(2) else: stderr = self._noop() - stdin, stdout, stderr = yield from tasks.gather(stdin, stdout, stderr, - loop=self._loop) - yield from self.wait() + stdin, stdout, stderr = await tasks._gather(stdin, stdout, stderr, + loop=self._loop) + await self.wait() return (stdout, stderr) -@coroutine -def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, - loop=None, limit=streams._DEFAULT_LIMIT, **kwds): +async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None, + loop=None, limit=streams._DEFAULT_LIMIT, + **kwds): if loop is None: loop = events.get_event_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8 " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, + stacklevel=2 + ) + protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_shell( - protocol_factory, - cmd, stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) + transport, protocol = await loop.subprocess_shell( + protocol_factory, + cmd, stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) return Process(transport, protocol, loop) -@coroutine -def create_subprocess_exec(program, *args, stdin=None, stdout=None, - stderr=None, loop=None, - limit=streams._DEFAULT_LIMIT, **kwds): + +async def create_subprocess_exec(program, *args, stdin=None, stdout=None, + stderr=None, loop=None, + limit=streams._DEFAULT_LIMIT, **kwds): if loop is None: loop = events.get_event_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8 " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, + stacklevel=2 + ) protocol_factory = lambda: SubprocessStreamProtocol(limit=limit, loop=loop) - transport, protocol = yield from loop.subprocess_exec( - protocol_factory, - program, *args, - stdin=stdin, stdout=stdout, - stderr=stderr, **kwds) + transport, protocol = await loop.subprocess_exec( + protocol_factory, + program, *args, + stdin=stdin, stdout=stdout, + stderr=stderr, **kwds) return Process(transport, protocol, loop) diff --git a/Lib/asyncio/tasks.py b/Lib/asyncio/tasks.py index 8a8427fe68..53252f2079 100644 --- a/Lib/asyncio/tasks.py +++ b/Lib/asyncio/tasks.py @@ -1,24 +1,35 @@ """Support for tasks, coroutines and the scheduler.""" -__all__ = ['Task', 'create_task', - 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', - 'wait', 'wait_for', 'as_completed', 'sleep', 'async', - 'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe', - 'all_tasks' - ] +__all__ = ( + 'Task', 'create_task', + 'FIRST_COMPLETED', 'FIRST_EXCEPTION', 'ALL_COMPLETED', + 'wait', 'wait_for', 'as_completed', 'sleep', + 'gather', 'shield', 'ensure_future', 'run_coroutine_threadsafe', + 'current_task', 'all_tasks', + '_register_task', '_unregister_task', '_enter_task', '_leave_task', +) import concurrent.futures +import contextvars import functools import inspect +import itertools +import types import warnings import weakref +from types import GenericAlias from . import base_tasks -from . import compat from . import coroutines from . import events +from . import exceptions from . import futures -from .coroutines import coroutine +from .coroutines import _is_coroutine + +# Helper to generate new task names +# This uses itertools.count() instead of a "+= 1" operation because the latter +# is not thread safe. See bpo-11866 for a longer explanation. +_task_name_counter = itertools.count(1).__next__ def current_task(loop=None): @@ -32,7 +43,22 @@ def all_tasks(loop=None): """Return a set of all tasks for the loop.""" if loop is None: loop = events.get_running_loop() - return {t for t in _all_tasks + # Looping over a WeakSet (_all_tasks) isn't safe as it can be updated from another + # thread while we do so. Therefore we cast it to list prior to filtering. The list + # cast itself requires iteration, so we repeat it several times ignoring + # RuntimeErrors (which are not very likely to occur). See issues 34970 and 36607 for + # details. + i = 0 + while True: + try: + tasks = list(_all_tasks) + except RuntimeError: + i += 1 + if i >= 1000: + raise + else: + break + return {t for t in tasks if futures._get_loop(t) is loop and not t.done()} @@ -42,7 +68,22 @@ def _all_tasks_compat(loop=None): # method. if loop is None: loop = events.get_event_loop() - return {t for t in _all_tasks if futures._get_loop(t) is loop} + # Looping over a WeakSet (_all_tasks) isn't safe as it can be updated from another + # thread while we do so. Therefore we cast it to list prior to filtering. The list + # cast itself requires iteration, so we repeat it several times ignoring + # RuntimeErrors (which are not very likely to occur). See issues 34970 and 36607 for + # details. + i = 0 + while True: + try: + tasks = list(_all_tasks) + except RuntimeError: + i += 1 + if i >= 1000: + raise + else: + break + return {t for t in tasks if futures._get_loop(t) is loop} def _set_task_name(task, name): @@ -55,7 +96,9 @@ def _set_task_name(task, name): set_name(name) -class Task(futures.Future): +class Task(futures._PyFuture): # Inherit Python Task implementation + # from a Python Future implementation. + """A coroutine wrapped in a Future.""" # An important invariant maintained while a Task not done: @@ -67,68 +110,64 @@ class Task(futures.Future): # _wakeup(). When _fut_waiter is not None, one of its callbacks # must be _wakeup(). - # Weak set containing all tasks alive. - _all_tasks = weakref.WeakSet() - - # Dictionary containing tasks that are currently active in - # all running event loops. {EventLoop: Task} - _current_tasks = {} - # If False, don't log a message if the task is destroyed whereas its # status is still pending _log_destroy_pending = True - @classmethod - def current_task(cls, loop=None): - """Return the currently running task in an event loop or None. - - By default the current task for the current event loop is returned. - - None is returned when called not in the context of a Task. - """ - if loop is None: - loop = events.get_event_loop() - return cls._current_tasks.get(loop) - - @classmethod - def all_tasks(cls, loop=None): - """Return a set of all tasks for an event loop. - - By default all tasks for the current event loop are returned. - """ - if loop is None: - loop = events.get_event_loop() - return {t for t in cls._all_tasks if t._loop is loop} - - def __init__(self, coro, *, loop=None): - assert coroutines.iscoroutine(coro), repr(coro) + def __init__(self, coro, *, loop=None, name=None): super().__init__(loop=loop) if self._source_traceback: del self._source_traceback[-1] - self._coro = coro - self._fut_waiter = None + if not coroutines.iscoroutine(coro): + # raise after Future.__init__(), attrs are required for __del__ + # prevent logging for pending task in __del__ + self._log_destroy_pending = False + raise TypeError(f"a coroutine was expected, got {coro!r}") + + if name is None: + self._name = f'Task-{_task_name_counter()}' + else: + self._name = str(name) + self._must_cancel = False - self._loop.call_soon(self._step) - self.__class__._all_tasks.add(self) - - # On Python 3.3 or older, objects with a destructor that are part of a - # reference cycle are never destroyed. That's not the case any more on - # Python 3.4 thanks to the PEP 442. - if compat.PY34: - def __del__(self): - if self._state == futures._PENDING and self._log_destroy_pending: - context = { - 'task': self, - 'message': 'Task was destroyed but it is pending!', - } - if self._source_traceback: - context['source_traceback'] = self._source_traceback - self._loop.call_exception_handler(context) - futures.Future.__del__(self) + self._fut_waiter = None + self._coro = coro + self._context = contextvars.copy_context() + + self._loop.call_soon(self.__step, context=self._context) + _register_task(self) + + def __del__(self): + if self._state == futures._PENDING and self._log_destroy_pending: + context = { + 'task': self, + 'message': 'Task was destroyed but it is pending!', + } + if self._source_traceback: + context['source_traceback'] = self._source_traceback + self._loop.call_exception_handler(context) + super().__del__() + + __class_getitem__ = classmethod(GenericAlias) def _repr_info(self): return base_tasks._task_repr_info(self) + def get_coro(self): + return self._coro + + def get_name(self): + return self._name + + def set_name(self, value): + self._name = str(value) + + def set_result(self, result): + raise RuntimeError('Task does not support set_result operation') + + def set_exception(self, exception): + raise RuntimeError('Task does not support set_exception operation') + def get_stack(self, *, limit=None): """Return the list of stack frames for this task's coroutine. @@ -163,7 +202,7 @@ def print_stack(self, *, limit=None, file=None): """ return base_tasks._task_print_stack(self, limit, file) - def cancel(self): + def cancel(self, msg=None): """Request that this task cancel itself. This arranges for a CancelledError to be thrown into the @@ -183,29 +222,32 @@ def cancel(self): terminates with a CancelledError exception (even if cancel() was not called). """ + self._log_traceback = False if self.done(): return False if self._fut_waiter is not None: - if self._fut_waiter.cancel(): + if self._fut_waiter.cancel(msg=msg): # Leave self._fut_waiter; it may be a Task that # catches and ignores the cancellation so we may have # to cancel it again later. return True - # It must be the case that self._step is already scheduled. + # It must be the case that self.__step is already scheduled. self._must_cancel = True + self._cancel_message = msg return True - def _step(self, exc=None): - assert not self.done(), \ - '_step(): already done: {!r}, {!r}'.format(self, exc) + def __step(self, exc=None): + if self.done(): + raise exceptions.InvalidStateError( + f'_step(): already done: {self!r}, {exc!r}') if self._must_cancel: - if not isinstance(exc, futures.CancelledError): - exc = futures.CancelledError() + if not isinstance(exc, exceptions.CancelledError): + exc = self._make_cancelled_error() self._must_cancel = False coro = self._coro self._fut_waiter = None - self.__class__._current_tasks[self._loop] = self + _enter_task(self._loop, self) # Call either coro.throw(exc) or coro.send(None). try: if exc is None: @@ -215,71 +257,78 @@ def _step(self, exc=None): else: result = coro.throw(exc) except StopIteration as exc: - self.set_result(exc.value) - except futures.CancelledError: + if self._must_cancel: + # Task is cancelled right before coro stops. + self._must_cancel = False + super().cancel(msg=self._cancel_message) + else: + super().set_result(exc.value) + except exceptions.CancelledError as exc: + # Save the original exception so we can chain it later. + self._cancelled_exc = exc super().cancel() # I.e., Future.cancel(self). - except Exception as exc: - self.set_exception(exc) - except BaseException as exc: - self.set_exception(exc) + except (KeyboardInterrupt, SystemExit) as exc: + super().set_exception(exc) raise + except BaseException as exc: + super().set_exception(exc) else: blocking = getattr(result, '_asyncio_future_blocking', None) if blocking is not None: # Yielded Future must come from Future.__iter__(). - if result._loop is not self._loop: + if futures._get_loop(result) is not self._loop: + new_exc = RuntimeError( + f'Task {self!r} got Future ' + f'{result!r} attached to a different loop') self._loop.call_soon( - self._step, - RuntimeError( - 'Task {!r} got Future {!r} attached to a ' - 'different loop'.format(self, result))) + self.__step, new_exc, context=self._context) elif blocking: if result is self: + new_exc = RuntimeError( + f'Task cannot await on itself: {self!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'Task cannot await on itself: {!r}'.format( - self))) + self.__step, new_exc, context=self._context) else: result._asyncio_future_blocking = False - result.add_done_callback(self._wakeup) + result.add_done_callback( + self.__wakeup, context=self._context) self._fut_waiter = result if self._must_cancel: - if self._fut_waiter.cancel(): + if self._fut_waiter.cancel( + msg=self._cancel_message): self._must_cancel = False else: + new_exc = RuntimeError( + f'yield was used instead of yield from ' + f'in task {self!r} with {result!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'yield was used instead of yield from ' - 'in task {!r} with {!r}'.format(self, result))) + self.__step, new_exc, context=self._context) + elif result is None: # Bare yield relinquishes control for one event loop iteration. - self._loop.call_soon(self._step) + self._loop.call_soon(self.__step, context=self._context) elif inspect.isgenerator(result): # Yielding a generator is just wrong. + new_exc = RuntimeError( + f'yield was used instead of yield from for ' + f'generator in task {self!r} with {result!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'yield was used instead of yield from for ' - 'generator in task {!r} with {}'.format( - self, result))) + self.__step, new_exc, context=self._context) else: # Yielding something else is an error. + new_exc = RuntimeError(f'Task got bad yield: {result!r}') self._loop.call_soon( - self._step, - RuntimeError( - 'Task got bad yield: {!r}'.format(result))) + self.__step, new_exc, context=self._context) finally: - self.__class__._current_tasks.pop(self._loop) + _leave_task(self._loop, self) self = None # Needed to break cycles when an exception occurs. - def _wakeup(self, future): + def __wakeup(self, future): try: future.result() - except Exception as exc: + except BaseException as exc: # This may also be a cancellation. - self._step(exc) + self.__step(exc) else: # Don't pass the value of `future.result()` explicitly, # as `Future.__iter__` and `Future.__await__` don't need it. @@ -287,7 +336,7 @@ def _wakeup(self, future): # Python eval loop would use `.send(value)` method call, # instead of `__next__()`, which is slower for futures # that return non-generator iterators from their `__iter__`. - self._step() + self.__step() self = None # Needed to break cycles when an exception occurs. @@ -321,11 +370,10 @@ def create_task(coro, *, name=None): ALL_COMPLETED = concurrent.futures.ALL_COMPLETED -@coroutine -def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): +async def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): """Wait for the Futures and coroutines given by fs to complete. - The sequence futures must not be empty. + The fs iterable must not be empty. Coroutines will be wrapped in Tasks. @@ -333,24 +381,36 @@ def wait(fs, *, loop=None, timeout=None, return_when=ALL_COMPLETED): Usage: - done, pending = yield from asyncio.wait(fs) + done, pending = await asyncio.wait(fs) Note: This does not raise TimeoutError! Futures that aren't done when the timeout occurs are returned in the second set. """ if futures.isfuture(fs) or coroutines.iscoroutine(fs): - raise TypeError("expect a list of futures, not %s" % type(fs).__name__) + raise TypeError(f"expect a list of futures, not {type(fs).__name__}") if not fs: raise ValueError('Set of coroutines/Futures is empty.') if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): - raise ValueError('Invalid return_when value: {}'.format(return_when)) + raise ValueError(f'Invalid return_when value: {return_when}') if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) + + fs = set(fs) + + if any(coroutines.iscoroutine(f) for f in fs): + warnings.warn("The explicit passing of coroutine objects to " + "asyncio.wait() is deprecated since Python 3.8, and " + "scheduled for removal in Python 3.11.", + DeprecationWarning, stacklevel=2) - fs = {ensure_future(f, loop=loop) for f in set(fs)} + fs = {ensure_future(f, loop=loop) for f in fs} - return (yield from _wait(fs, timeout, return_when, loop)) + return await _wait(fs, timeout, return_when, loop) def _release_waiter(waiter, *args): @@ -358,8 +418,7 @@ def _release_waiter(waiter, *args): waiter.set_result(None) -@coroutine -def wait_for(fut, timeout, *, loop=None): +async def wait_for(fut, timeout, *, loop=None): """Wait for the single Future or coroutine to complete, with timeout. Coroutine will be wrapped in Task. @@ -373,10 +432,26 @@ def wait_for(fut, timeout, *, loop=None): This function is a coroutine. """ if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() + else: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) if timeout is None: - return (yield from fut) + return await fut + + if timeout <= 0: + fut = ensure_future(fut, loop=loop) + + if fut.done(): + return fut.result() + + await _cancel_and_wait(fut, loop=loop) + try: + return fut.result() + except exceptions.CancelledError as exc: + raise exceptions.TimeoutError() from exc waiter = loop.create_future() timeout_handle = loop.call_later(timeout, _release_waiter, waiter) @@ -388,25 +463,39 @@ def wait_for(fut, timeout, *, loop=None): try: # wait until the future completes or the timeout try: - yield from waiter - except futures.CancelledError: - fut.remove_done_callback(cb) - fut.cancel() - raise + await waiter + except exceptions.CancelledError: + if fut.done(): + return fut.result() + else: + fut.remove_done_callback(cb) + # We must ensure that the task is not running + # after wait_for() returns. + # See https://bugs.python.org/issue32751 + await _cancel_and_wait(fut, loop=loop) + raise if fut.done(): return fut.result() else: fut.remove_done_callback(cb) - fut.cancel() - raise futures.TimeoutError() + # We must ensure that the task is not running + # after wait_for() returns. + # See https://bugs.python.org/issue32751 + await _cancel_and_wait(fut, loop=loop) + # In case task cancellation failed with some + # exception, we should re-raise it + # See https://bugs.python.org/issue40607 + try: + return fut.result() + except exceptions.CancelledError as exc: + raise exceptions.TimeoutError() from exc finally: timeout_handle.cancel() -@coroutine -def _wait(fs, timeout, return_when, loop): - """Internal helper for wait() and wait_for(). +async def _wait(fs, timeout, return_when, loop): + """Internal helper for wait(). The fs argument must be a collection of Futures. """ @@ -433,14 +522,15 @@ def _on_completion(f): f.add_done_callback(_on_completion) try: - yield from waiter + await waiter finally: if timeout_handle is not None: timeout_handle.cancel() + for f in fs: + f.remove_done_callback(_on_completion) done, pending = set(), set() for f in fs: - f.remove_done_callback(_on_completion) if f.done(): done.add(f) else: @@ -448,6 +538,22 @@ def _on_completion(f): return done, pending +async def _cancel_and_wait(fut, loop): + """Cancel the *fut* future or task and wait until it completes.""" + + waiter = loop.create_future() + cb = functools.partial(_release_waiter, waiter) + fut.add_done_callback(cb) + + try: + fut.cancel() + # We cannot wait on *fut* directly to make + # sure _cancel_and_wait itself is reliably cancellable. + await waiter + finally: + fut.remove_done_callback(cb) + + # This is *not* a @coroutine! It is just an iterator (yielding Futures). def as_completed(fs, *, loop=None, timeout=None): """Return an iterator whose values are coroutines. @@ -459,20 +565,28 @@ def as_completed(fs, *, loop=None, timeout=None): This differs from PEP 3148; the proper way to use this is: for f in as_completed(fs): - result = yield from f # The 'yield from' may raise. + result = await f # The 'await' may raise. # Use result. - If a timeout is specified, the 'yield from' will raise + If a timeout is specified, the 'await' will raise TimeoutError when the timeout occurs before all Futures are done. Note: The futures 'f' are not necessarily members of fs. """ if futures.isfuture(fs) or coroutines.iscoroutine(fs): - raise TypeError("expect a list of futures, not %s" % type(fs).__name__) - loop = loop if loop is not None else events.get_event_loop() - todo = {ensure_future(f, loop=loop) for f in set(fs)} + raise TypeError(f"expect an iterable of futures, not {type(fs).__name__}") + + if loop is not None: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) + from .queues import Queue # Import here to avoid circular import problem. done = Queue(loop=loop) + + if loop is None: + loop = events.get_event_loop() + todo = {ensure_future(f, loop=loop) for f in set(fs)} timeout_handle = None def _on_timeout(): @@ -489,12 +603,11 @@ def _on_completion(f): if not todo and timeout_handle is not None: timeout_handle.cancel() - @coroutine - def _wait_for_one(): - f = yield from done.get() + async def _wait_for_one(): + f = await done.get() if f is None: # Dummy value from _on_timeout(). - raise futures.TimeoutError + raise exceptions.TimeoutError return f.result() # May raise f.exception(). for f in todo: @@ -505,67 +618,67 @@ def _wait_for_one(): yield _wait_for_one() -@coroutine -def sleep(delay, result=None, *, loop=None): +@types.coroutine +def __sleep0(): + """Skip one event loop run cycle. + + This is a private helper for 'asyncio.sleep()', used + when the 'delay' is set to 0. It uses a bare 'yield' + expression (which Task.__step knows how to handle) + instead of creating a Future object. + """ + yield + + +async def sleep(delay, result=None, *, loop=None): """Coroutine that completes after a given time (in seconds).""" - if delay == 0: - yield + if loop is not None: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) + + if delay <= 0: + await __sleep0() return result if loop is None: - loop = events.get_event_loop() + loop = events.get_running_loop() + future = loop.create_future() - h = future._loop.call_later(delay, - futures._set_result_unless_cancelled, - future, result) + h = loop.call_later(delay, + futures._set_result_unless_cancelled, + future, result) try: - return (yield from future) + return await future finally: h.cancel() -def async_(coro_or_future, *, loop=None): - """Wrap a coroutine in a future. - - If the argument is a Future, it is returned directly. - - This function is deprecated in 3.5. Use asyncio.ensure_future() instead. - """ - - warnings.warn("asyncio.async() function is deprecated, use ensure_future()", - DeprecationWarning) - - return ensure_future(coro_or_future, loop=loop) - -# Silence DeprecationWarning: -globals()['async'] = async_ -async_.__name__ = 'async' -del async_ - - def ensure_future(coro_or_future, *, loop=None): """Wrap a coroutine or an awaitable in a future. If the argument is a Future, it is returned directly. """ - if futures.isfuture(coro_or_future): - if loop is not None and loop is not coro_or_future._loop: - raise ValueError('loop argument must agree with Future') - return coro_or_future - elif coroutines.iscoroutine(coro_or_future): + if coroutines.iscoroutine(coro_or_future): if loop is None: loop = events.get_event_loop() task = loop.create_task(coro_or_future) if task._source_traceback: del task._source_traceback[-1] return task - elif compat.PY35 and inspect.isawaitable(coro_or_future): + elif futures.isfuture(coro_or_future): + if loop is not None and loop is not futures._get_loop(coro_or_future): + raise ValueError('The future belongs to a different loop than ' + 'the one specified as the loop argument') + return coro_or_future + elif inspect.isawaitable(coro_or_future): return ensure_future(_wrap_awaitable(coro_or_future), loop=loop) else: - raise TypeError('A Future, a coroutine or an awaitable is required') + raise TypeError('An asyncio.Future, a coroutine or an awaitable is ' + 'required') -@coroutine +@types.coroutine def _wrap_awaitable(awaitable): """Helper for asyncio.ensure_future(). @@ -574,6 +687,8 @@ def _wrap_awaitable(awaitable): """ return (yield from awaitable.__await__()) +_wrap_awaitable._is_coroutine = _is_coroutine + class _GatheringFuture(futures.Future): """Helper for gather(). @@ -586,20 +701,25 @@ class _GatheringFuture(futures.Future): def __init__(self, children, *, loop=None): super().__init__(loop=loop) self._children = children + self._cancel_requested = False - def cancel(self): + def cancel(self, msg=None): if self.done(): return False ret = False for child in self._children: - if child.cancel(): + if child.cancel(msg=msg): ret = True + if ret: + # If any child tasks were actually cancelled, we should + # propagate the cancellation request regardless of + # *return_exceptions* argument. See issue 32684. + self._cancel_requested = True return ret def gather(*coros_or_futures, loop=None, return_exceptions=False): - """Return a future aggregating results from the given coroutines - or futures. + """Return a future aggregating results from the given coroutines/futures. Coroutines will be wrapped in a future and scheduled in the event loop. They will not necessarily be scheduled in the same order as @@ -620,7 +740,23 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): the outer Future is *not* cancelled in this case. (This is to prevent the cancellation of one child to cause other children to be cancelled.) + + If *return_exceptions* is False, cancelling gather() after it + has been marked done won't cancel any submitted awaitables. + For instance, gather can be marked done after propagating an + exception to the caller, therefore, calling ``gather.cancel()`` + after catching an exception (raised by one of the awaitables) from + gather won't cancel any other awaitables. """ + if loop is not None: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) + + return _gather(*coros_or_futures, loop=loop, return_exceptions=return_exceptions) + + +def _gather(*coros_or_futures, loop=None, return_exceptions=False): if not coros_or_futures: if loop is None: loop = events.get_event_loop() @@ -628,56 +764,89 @@ def gather(*coros_or_futures, loop=None, return_exceptions=False): outer.set_result([]) return outer - arg_to_fut = {} - for arg in set(coros_or_futures): - if not futures.isfuture(arg): - fut = ensure_future(arg, loop=loop) - if loop is None: - loop = fut._loop - # The caller cannot control this future, the "destroy pending task" - # warning should not be emitted. - fut._log_destroy_pending = False - else: - fut = arg - if loop is None: - loop = fut._loop - elif fut._loop is not loop: - raise ValueError("futures are tied to different event loops") - arg_to_fut[arg] = fut - - children = [arg_to_fut[arg] for arg in coros_or_futures] - nchildren = len(children) - outer = _GatheringFuture(children, loop=loop) - nfinished = 0 - results = [None] * nchildren - - def _done_callback(i, fut): + def _done_callback(fut): nonlocal nfinished - if outer.done(): + nfinished += 1 + + if outer is None or outer.done(): if not fut.cancelled(): # Mark exception retrieved. fut.exception() return - if fut.cancelled(): - res = futures.CancelledError() - if not return_exceptions: - outer.set_exception(res) - return - elif fut._exception is not None: - res = fut.exception() # Mark exception retrieved. - if not return_exceptions: - outer.set_exception(res) + if not return_exceptions: + if fut.cancelled(): + # Check if 'fut' is cancelled first, as + # 'fut.exception()' will *raise* a CancelledError + # instead of returning it. + exc = fut._make_cancelled_error() + outer.set_exception(exc) return + else: + exc = fut.exception() + if exc is not None: + outer.set_exception(exc) + return + + if nfinished == nfuts: + # All futures are done; create a list of results + # and set it to the 'outer' future. + results = [] + + for fut in children: + if fut.cancelled(): + # Check if 'fut' is cancelled first, as 'fut.exception()' + # will *raise* a CancelledError instead of returning it. + # Also, since we're adding the exception return value + # to 'results' instead of raising it, don't bother + # setting __context__. This also lets us preserve + # calling '_make_cancelled_error()' at most once. + res = exceptions.CancelledError( + '' if fut._cancel_message is None else + fut._cancel_message) + else: + res = fut.exception() + if res is None: + res = fut.result() + results.append(res) + + if outer._cancel_requested: + # If gather is being cancelled we must propagate the + # cancellation regardless of *return_exceptions* argument. + # See issue 32684. + exc = fut._make_cancelled_error() + outer.set_exception(exc) + else: + outer.set_result(results) + + arg_to_fut = {} + children = [] + nfuts = 0 + nfinished = 0 + outer = None # bpo-46672 + for arg in coros_or_futures: + if arg not in arg_to_fut: + fut = ensure_future(arg, loop=loop) + if loop is None: + loop = futures._get_loop(fut) + if fut is not arg: + # 'arg' was not a Future, therefore, 'fut' is a new + # Future created specifically for 'arg'. Since the caller + # can't control it, disable the "destroy pending task" + # warning. + fut._log_destroy_pending = False + + nfuts += 1 + arg_to_fut[arg] = fut + fut.add_done_callback(_done_callback) + else: - res = fut._result - results[i] = res - nfinished += 1 - if nfinished == nchildren: - outer.set_result(results) + # There's a duplicate Future object in coros_or_futures. + fut = arg_to_fut[arg] - for i, fut in enumerate(children): - fut.add_done_callback(functools.partial(_done_callback, i)) + children.append(fut) + + outer = _GatheringFuture(children, loop=loop) return outer @@ -686,11 +855,11 @@ def shield(arg, *, loop=None): The statement - res = yield from shield(something()) + res = await shield(something()) is exactly equivalent to the statement - res = yield from something() + res = await something() *except* that if the coroutine containing it is cancelled, the task running in something() is not cancelled. From the POV of @@ -703,18 +872,22 @@ def shield(arg, *, loop=None): you can combine shield() with a try/except clause, as follows: try: - res = yield from shield(something()) + res = await shield(something()) except CancelledError: res = None """ + if loop is not None: + warnings.warn("The loop argument is deprecated since Python 3.8, " + "and scheduled for removal in Python 3.10.", + DeprecationWarning, stacklevel=2) inner = ensure_future(arg, loop=loop) if inner.done(): # Shortcut. return inner - loop = inner._loop + loop = futures._get_loop(inner) outer = loop.create_future() - def _done_callback(inner): + def _inner_done_callback(inner): if outer.cancelled(): if not inner.cancelled(): # Mark inner's result as retrieved. @@ -730,7 +903,13 @@ def _done_callback(inner): else: outer.set_result(inner.result()) - inner.add_done_callback(_done_callback) + + def _outer_done_callback(outer): + if not inner.done(): + inner.remove_done_callback(_inner_done_callback) + + inner.add_done_callback(_inner_done_callback) + outer.add_done_callback(_outer_done_callback) return outer @@ -746,7 +925,9 @@ def run_coroutine_threadsafe(coro, loop): def callback(): try: futures._chain_future(ensure_future(coro, loop=loop), future) - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: if future.set_running_or_notify_cancel(): future.set_exception(exc) raise diff --git a/Lib/asyncio/test_utils.py b/Lib/asyncio/test_utils.py deleted file mode 100644 index 99e3839f45..0000000000 --- a/Lib/asyncio/test_utils.py +++ /dev/null @@ -1,503 +0,0 @@ -"""Utilities shared by tests.""" - -import collections -import contextlib -import io -import logging -import os -import re -import socket -import socketserver -import sys -import tempfile -import threading -import time -import unittest -import weakref - -from unittest import mock - -from http.server import HTTPServer -from wsgiref.simple_server import WSGIRequestHandler, WSGIServer - -try: - import ssl -except ImportError: # pragma: no cover - ssl = None - -from . import base_events -from . import compat -from . import events -from . import futures -from . import selectors -from . import tasks -from .coroutines import coroutine -from .log import logger - - -if sys.platform == 'win32': # pragma: no cover - from .windows_utils import socketpair -else: - from socket import socketpair # pragma: no cover - - -def dummy_ssl_context(): - if ssl is None: - return None - else: - return ssl.SSLContext(ssl.PROTOCOL_SSLv23) - - -def run_briefly(loop): - @coroutine - def once(): - pass - gen = once() - t = loop.create_task(gen) - # Don't log a warning if the task is not done after run_until_complete(). - # It occurs if the loop is stopped or if a task raises a BaseException. - t._log_destroy_pending = False - try: - loop.run_until_complete(t) - finally: - gen.close() - - -def run_until(loop, pred, timeout=30): - deadline = time.time() + timeout - while not pred(): - if timeout is not None: - timeout = deadline - time.time() - if timeout <= 0: - raise futures.TimeoutError() - loop.run_until_complete(tasks.sleep(0.001, loop=loop)) - - -def run_once(loop): - """Legacy API to run once through the event loop. - - This is the recommended pattern for test code. It will poll the - selector once and run all callbacks scheduled in response to I/O - events. - """ - loop.call_soon(loop.stop) - loop.run_forever() - - -class SilentWSGIRequestHandler(WSGIRequestHandler): - - def get_stderr(self): - return io.StringIO() - - def log_message(self, format, *args): - pass - - -class SilentWSGIServer(WSGIServer): - - request_timeout = 2 - - def get_request(self): - request, client_addr = super().get_request() - request.settimeout(self.request_timeout) - return request, client_addr - - def handle_error(self, request, client_address): - pass - - -class SSLWSGIServerMixin: - - def finish_request(self, request, client_address): - # The relative location of our test directory (which - # contains the ssl key and certificate files) differs - # between the stdlib and stand-alone asyncio. - # Prefer our own if we can find it. - here = os.path.join(os.path.dirname(__file__), '..', 'tests') - if not os.path.isdir(here): - here = os.path.join(os.path.dirname(os.__file__), - 'test', 'test_asyncio') - keyfile = os.path.join(here, 'ssl_key.pem') - certfile = os.path.join(here, 'ssl_cert.pem') - context = ssl.SSLContext() - context.load_cert_chain(certfile, keyfile) - - ssock = context.wrap_socket(request, server_side=True) - try: - self.RequestHandlerClass(ssock, client_address, self) - ssock.close() - except OSError: - # maybe socket has been closed by peer - pass - - -class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer): - pass - - -def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls): - - def app(environ, start_response): - status = '200 OK' - headers = [('Content-type', 'text/plain')] - start_response(status, headers) - return [b'Test message'] - - # Run the test WSGI server in a separate thread in order not to - # interfere with event handling in the main thread - server_class = server_ssl_cls if use_ssl else server_cls - httpd = server_class(address, SilentWSGIRequestHandler) - httpd.set_app(app) - httpd.address = httpd.server_address - server_thread = threading.Thread( - target=lambda: httpd.serve_forever(poll_interval=0.05)) - server_thread.start() - try: - yield httpd - finally: - httpd.shutdown() - httpd.server_close() - server_thread.join() - - -if hasattr(socket, 'AF_UNIX'): - - class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer): - - def server_bind(self): - socketserver.UnixStreamServer.server_bind(self) - self.server_name = '127.0.0.1' - self.server_port = 80 - - - class UnixWSGIServer(UnixHTTPServer, WSGIServer): - - request_timeout = 2 - - def server_bind(self): - UnixHTTPServer.server_bind(self) - self.setup_environ() - - def get_request(self): - request, client_addr = super().get_request() - request.settimeout(self.request_timeout) - # Code in the stdlib expects that get_request - # will return a socket and a tuple (host, port). - # However, this isn't true for UNIX sockets, - # as the second return value will be a path; - # hence we return some fake data sufficient - # to get the tests going - return request, ('127.0.0.1', '') - - - class SilentUnixWSGIServer(UnixWSGIServer): - - def handle_error(self, request, client_address): - pass - - - class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer): - pass - - - def gen_unix_socket_path(): - with tempfile.NamedTemporaryFile() as file: - return file.name - - - @contextlib.contextmanager - def unix_socket_path(): - path = gen_unix_socket_path() - try: - yield path - finally: - try: - os.unlink(path) - except OSError: - pass - - - @contextlib.contextmanager - def run_test_unix_server(*, use_ssl=False): - with unix_socket_path() as path: - yield from _run_test_server(address=path, use_ssl=use_ssl, - server_cls=SilentUnixWSGIServer, - server_ssl_cls=UnixSSLWSGIServer) - - -@contextlib.contextmanager -def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False): - yield from _run_test_server(address=(host, port), use_ssl=use_ssl, - server_cls=SilentWSGIServer, - server_ssl_cls=SSLWSGIServer) - - -def make_test_protocol(base): - dct = {} - for name in dir(base): - if name.startswith('__') and name.endswith('__'): - # skip magic names - continue - dct[name] = MockCallback(return_value=None) - return type('TestProtocol', (base,) + base.__bases__, dct)() - - -class TestSelector(selectors.BaseSelector): - - def __init__(self): - self.keys = {} - - def register(self, fileobj, events, data=None): - key = selectors.SelectorKey(fileobj, 0, events, data) - self.keys[fileobj] = key - return key - - def unregister(self, fileobj): - return self.keys.pop(fileobj) - - def select(self, timeout): - return [] - - def get_map(self): - return self.keys - - -class TestLoop(base_events.BaseEventLoop): - """Loop for unittests. - - It manages self time directly. - If something scheduled to be executed later then - on next loop iteration after all ready handlers done - generator passed to __init__ is calling. - - Generator should be like this: - - def gen(): - ... - when = yield ... - ... = yield time_advance - - Value returned by yield is absolute time of next scheduled handler. - Value passed to yield is time advance to move loop's time forward. - """ - - def __init__(self, gen=None): - super().__init__() - - if gen is None: - def gen(): - yield - self._check_on_close = False - else: - self._check_on_close = True - - self._gen = gen() - next(self._gen) - self._time = 0 - self._clock_resolution = 1e-9 - self._timers = [] - self._selector = TestSelector() - - self.readers = {} - self.writers = {} - self.reset_counters() - - self._transports = weakref.WeakValueDictionary() - - def time(self): - return self._time - - def advance_time(self, advance): - """Move test time forward.""" - if advance: - self._time += advance - - def close(self): - super().close() - if self._check_on_close: - try: - self._gen.send(0) - except StopIteration: - pass - else: # pragma: no cover - raise AssertionError("Time generator is not finished") - - def _add_reader(self, fd, callback, *args): - self.readers[fd] = events.Handle(callback, args, self) - - def _remove_reader(self, fd): - self.remove_reader_count[fd] += 1 - if fd in self.readers: - del self.readers[fd] - return True - else: - return False - - def assert_reader(self, fd, callback, *args): - assert fd in self.readers, 'fd {} is not registered'.format(fd) - handle = self.readers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( - handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( - handle._args, args) - - def _add_writer(self, fd, callback, *args): - self.writers[fd] = events.Handle(callback, args, self) - - def _remove_writer(self, fd): - self.remove_writer_count[fd] += 1 - if fd in self.writers: - del self.writers[fd] - return True - else: - return False - - def assert_writer(self, fd, callback, *args): - assert fd in self.writers, 'fd {} is not registered'.format(fd) - handle = self.writers[fd] - assert handle._callback == callback, '{!r} != {!r}'.format( - handle._callback, callback) - assert handle._args == args, '{!r} != {!r}'.format( - handle._args, args) - - def _ensure_fd_no_transport(self, fd): - try: - transport = self._transports[fd] - except KeyError: - pass - else: - raise RuntimeError( - 'File descriptor {!r} is used by transport {!r}'.format( - fd, transport)) - - def add_reader(self, fd, callback, *args): - """Add a reader callback.""" - self._ensure_fd_no_transport(fd) - return self._add_reader(fd, callback, *args) - - def remove_reader(self, fd): - """Remove a reader callback.""" - self._ensure_fd_no_transport(fd) - return self._remove_reader(fd) - - def add_writer(self, fd, callback, *args): - """Add a writer callback..""" - self._ensure_fd_no_transport(fd) - return self._add_writer(fd, callback, *args) - - def remove_writer(self, fd): - """Remove a writer callback.""" - self._ensure_fd_no_transport(fd) - return self._remove_writer(fd) - - def reset_counters(self): - self.remove_reader_count = collections.defaultdict(int) - self.remove_writer_count = collections.defaultdict(int) - - def _run_once(self): - super()._run_once() - for when in self._timers: - advance = self._gen.send(when) - self.advance_time(advance) - self._timers = [] - - def call_at(self, when, callback, *args): - self._timers.append(when) - return super().call_at(when, callback, *args) - - def _process_events(self, event_list): - return - - def _write_to_self(self): - pass - - -def MockCallback(**kwargs): - return mock.Mock(spec=['__call__'], **kwargs) - - -class MockPattern(str): - """A regex based str with a fuzzy __eq__. - - Use this helper with 'mock.assert_called_with', or anywhere - where a regex comparison between strings is needed. - - For instance: - mock_call.assert_called_with(MockPattern('spam.*ham')) - """ - def __eq__(self, other): - return bool(re.search(str(self), other, re.S)) - - -def get_function_source(func): - source = events._get_function_source(func) - if source is None: - raise ValueError("unable to get the source of %r" % (func,)) - return source - - -class TestCase(unittest.TestCase): - def set_event_loop(self, loop, *, cleanup=True): - assert loop is not None - # ensure that the event loop is passed explicitly in asyncio - events.set_event_loop(None) - if cleanup: - self.addCleanup(loop.close) - - def new_test_loop(self, gen=None): - loop = TestLoop(gen) - self.set_event_loop(loop) - return loop - - def setUp(self): - self._get_running_loop = events._get_running_loop - events._get_running_loop = lambda: None - - def tearDown(self): - events._get_running_loop = self._get_running_loop - - events.set_event_loop(None) - - # Detect CPython bug #23353: ensure that yield/yield-from is not used - # in an except block of a generator - self.assertEqual(sys.exc_info(), (None, None, None)) - - if not compat.PY34: - # Python 3.3 compatibility - def subTest(self, *args, **kwargs): - class EmptyCM: - def __enter__(self): - pass - def __exit__(self, *exc): - pass - return EmptyCM() - - -@contextlib.contextmanager -def disable_logger(): - """Context manager to disable asyncio logger. - - For example, it can be used to ignore warnings in debug mode. - """ - old_level = logger.level - try: - logger.setLevel(logging.CRITICAL+1) - yield - finally: - logger.setLevel(old_level) - - -def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM, - family=socket.AF_INET): - """Create a mock of a non-blocking socket.""" - sock = mock.MagicMock(socket.socket) - sock.proto = proto - sock.type = type - sock.family = family - sock.gettimeout.return_value = 0.0 - return sock - - -def force_legacy_ssl_support(): - return mock.patch('asyncio.sslproto._is_sslproto_available', - return_value=False) diff --git a/Lib/asyncio/threads.py b/Lib/asyncio/threads.py new file mode 100644 index 0000000000..db048a8231 --- /dev/null +++ b/Lib/asyncio/threads.py @@ -0,0 +1,25 @@ +"""High-level support for working with threads in asyncio""" + +import functools +import contextvars + +from . import events + + +__all__ = "to_thread", + + +async def to_thread(func, /, *args, **kwargs): + """Asynchronously run function *func* in a separate thread. + + Any *args and **kwargs supplied for this function are directly passed + to *func*. Also, the current :class:`contextvars.Context` is propagated, + allowing context variables from the main thread to be accessed in the + separate thread. + + Return a coroutine that can be awaited to get the eventual result of *func*. + """ + loop = events.get_running_loop() + ctx = contextvars.copy_context() + func_call = functools.partial(ctx.run, func, *args, **kwargs) + return await loop.run_in_executor(None, func_call) diff --git a/Lib/asyncio/transports.py b/Lib/asyncio/transports.py index 0db0875715..73b1fa2de4 100644 --- a/Lib/asyncio/transports.py +++ b/Lib/asyncio/transports.py @@ -1,15 +1,16 @@ """Abstract Transport class.""" -from asyncio import compat - -__all__ = ['BaseTransport', 'ReadTransport', 'WriteTransport', - 'Transport', 'DatagramTransport', 'SubprocessTransport', - ] +__all__ = ( + 'BaseTransport', 'ReadTransport', 'WriteTransport', + 'Transport', 'DatagramTransport', 'SubprocessTransport', +) class BaseTransport: """Base class for transports.""" + __slots__ = ('_extra',) + def __init__(self, extra=None): if extra is None: extra = {} @@ -28,8 +29,8 @@ def close(self): Buffered data will be flushed asynchronously. No more data will be received. After all buffered data is flushed, the - protocol's connection_lost() method will (eventually) called - with None as its argument. + protocol's connection_lost() method will (eventually) be + called with None as its argument. """ raise NotImplementedError @@ -45,6 +46,12 @@ def get_protocol(self): class ReadTransport(BaseTransport): """Interface for read-only transports.""" + __slots__ = () + + def is_reading(self): + """Return True if the transport is receiving.""" + raise NotImplementedError + def pause_reading(self): """Pause the receiving end. @@ -65,6 +72,8 @@ def resume_reading(self): class WriteTransport(BaseTransport): """Interface for write-only transports.""" + __slots__ = () + def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -90,6 +99,12 @@ def get_write_buffer_size(self): """Return the current size of the write buffer.""" raise NotImplementedError + def get_write_buffer_limits(self): + """Get the high and low watermarks for write flow control. + Return a tuple (low, high) where low and high are + positive number of bytes.""" + raise NotImplementedError + def write(self, data): """Write some data bytes to the transport. @@ -104,7 +119,7 @@ def writelines(self, list_of_data): The default implementation concatenates the arguments and calls write() on the result. """ - data = compat.flatten_list_bytes(list_of_data) + data = b''.join(list_of_data) self.write(data) def write_eof(self): @@ -151,10 +166,14 @@ class Transport(ReadTransport, WriteTransport): except writelines(), which calls write() in a loop. """ + __slots__ = () + class DatagramTransport(BaseTransport): """Interface for datagram (UDP) transports.""" + __slots__ = () + def sendto(self, data, addr=None): """Send data to the transport. @@ -177,6 +196,8 @@ def abort(self): class SubprocessTransport(BaseTransport): + __slots__ = () + def get_pid(self): """Get subprocess id.""" raise NotImplementedError @@ -244,6 +265,8 @@ class _FlowControlMixin(Transport): resume_writing() may be called. """ + __slots__ = ('_loop', '_protocol_paused', '_high_water', '_low_water') + def __init__(self, extra=None, loop=None): super().__init__(extra) assert loop is not None @@ -259,7 +282,9 @@ def _maybe_pause_protocol(self): self._protocol_paused = True try: self._protocol.pause_writing() - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._loop.call_exception_handler({ 'message': 'protocol.pause_writing() failed', 'exception': exc, @@ -269,11 +294,13 @@ def _maybe_pause_protocol(self): def _maybe_resume_protocol(self): if (self._protocol_paused and - self.get_write_buffer_size() <= self._low_water): + self.get_write_buffer_size() <= self._low_water): self._protocol_paused = False try: self._protocol.resume_writing() - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._loop.call_exception_handler({ 'message': 'protocol.resume_writing() failed', 'exception': exc, @@ -287,14 +314,16 @@ def get_write_buffer_limits(self): def _set_write_buffer_limits(self, high=None, low=None): if high is None: if low is None: - high = 64*1024 + high = 64 * 1024 else: - high = 4*low + high = 4 * low if low is None: low = high // 4 + if not high >= low >= 0: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) + raise ValueError( + f'high ({high!r}) must be >= low ({low!r}) must be >= 0') + self._high_water = high self._low_water = low diff --git a/Lib/asyncio/trsock.py b/Lib/asyncio/trsock.py new file mode 100644 index 0000000000..e9ebcc3261 --- /dev/null +++ b/Lib/asyncio/trsock.py @@ -0,0 +1,206 @@ +import socket +import warnings + + +class TransportSocket: + + """A socket-like wrapper for exposing real transport sockets. + + These objects can be safely returned by APIs like + `transport.get_extra_info('socket')`. All potentially disruptive + operations (like "socket.close()") are banned. + """ + + __slots__ = ('_sock',) + + def __init__(self, sock: socket.socket): + self._sock = sock + + def _na(self, what): + warnings.warn( + f"Using {what} on sockets returned from get_extra_info('socket') " + f"will be prohibited in asyncio 3.9. Please report your use case " + f"to bugs.python.org.", + DeprecationWarning, source=self) + + @property + def family(self): + return self._sock.family + + @property + def type(self): + return self._sock.type + + @property + def proto(self): + return self._sock.proto + + def __repr__(self): + s = ( + f"" + + def __getstate__(self): + raise TypeError("Cannot serialize asyncio.TransportSocket object") + + def fileno(self): + return self._sock.fileno() + + def dup(self): + return self._sock.dup() + + def get_inheritable(self): + return self._sock.get_inheritable() + + def shutdown(self, how): + # asyncio doesn't currently provide a high-level transport API + # to shutdown the connection. + self._sock.shutdown(how) + + def getsockopt(self, *args, **kwargs): + return self._sock.getsockopt(*args, **kwargs) + + def setsockopt(self, *args, **kwargs): + self._sock.setsockopt(*args, **kwargs) + + def getpeername(self): + return self._sock.getpeername() + + def getsockname(self): + return self._sock.getsockname() + + def getsockbyname(self): + return self._sock.getsockbyname() + + def accept(self): + self._na('accept() method') + return self._sock.accept() + + def connect(self, *args, **kwargs): + self._na('connect() method') + return self._sock.connect(*args, **kwargs) + + def connect_ex(self, *args, **kwargs): + self._na('connect_ex() method') + return self._sock.connect_ex(*args, **kwargs) + + def bind(self, *args, **kwargs): + self._na('bind() method') + return self._sock.bind(*args, **kwargs) + + def ioctl(self, *args, **kwargs): + self._na('ioctl() method') + return self._sock.ioctl(*args, **kwargs) + + def listen(self, *args, **kwargs): + self._na('listen() method') + return self._sock.listen(*args, **kwargs) + + def makefile(self): + self._na('makefile() method') + return self._sock.makefile() + + def sendfile(self, *args, **kwargs): + self._na('sendfile() method') + return self._sock.sendfile(*args, **kwargs) + + def close(self): + self._na('close() method') + return self._sock.close() + + def detach(self): + self._na('detach() method') + return self._sock.detach() + + def sendmsg_afalg(self, *args, **kwargs): + self._na('sendmsg_afalg() method') + return self._sock.sendmsg_afalg(*args, **kwargs) + + def sendmsg(self, *args, **kwargs): + self._na('sendmsg() method') + return self._sock.sendmsg(*args, **kwargs) + + def sendto(self, *args, **kwargs): + self._na('sendto() method') + return self._sock.sendto(*args, **kwargs) + + def send(self, *args, **kwargs): + self._na('send() method') + return self._sock.send(*args, **kwargs) + + def sendall(self, *args, **kwargs): + self._na('sendall() method') + return self._sock.sendall(*args, **kwargs) + + def set_inheritable(self, *args, **kwargs): + self._na('set_inheritable() method') + return self._sock.set_inheritable(*args, **kwargs) + + def share(self, process_id): + self._na('share() method') + return self._sock.share(process_id) + + def recv_into(self, *args, **kwargs): + self._na('recv_into() method') + return self._sock.recv_into(*args, **kwargs) + + def recvfrom_into(self, *args, **kwargs): + self._na('recvfrom_into() method') + return self._sock.recvfrom_into(*args, **kwargs) + + def recvmsg_into(self, *args, **kwargs): + self._na('recvmsg_into() method') + return self._sock.recvmsg_into(*args, **kwargs) + + def recvmsg(self, *args, **kwargs): + self._na('recvmsg() method') + return self._sock.recvmsg(*args, **kwargs) + + def recvfrom(self, *args, **kwargs): + self._na('recvfrom() method') + return self._sock.recvfrom(*args, **kwargs) + + def recv(self, *args, **kwargs): + self._na('recv() method') + return self._sock.recv(*args, **kwargs) + + def settimeout(self, value): + if value == 0: + return + raise ValueError( + 'settimeout(): only 0 timeout is allowed on transport sockets') + + def gettimeout(self): + return 0 + + def setblocking(self, flag): + if not flag: + return + raise ValueError( + 'setblocking(): transport sockets cannot be blocking') + + def __enter__(self): + self._na('context manager protocol') + return self._sock.__enter__() + + def __exit__(self, *err): + self._na('context manager protocol') + return self._sock.__exit__(*err) diff --git a/Lib/asyncio/unix_events.py b/Lib/asyncio/unix_events.py index 9db09b9d9b..eecbc101ee 100644 --- a/Lib/asyncio/unix_events.py +++ b/Lib/asyncio/unix_events.py @@ -1,7 +1,10 @@ """Selector event loop for Unix with signal handling.""" import errno +import io +import itertools import os +import selectors import signal import socket import stat @@ -10,25 +13,27 @@ import threading import warnings - from . import base_events from . import base_subprocess -from . import compat from . import constants from . import coroutines from . import events +from . import exceptions from . import futures from . import selector_events -from . import selectors +from . import tasks from . import transports -from .coroutines import coroutine from .log import logger -__all__ = ['SelectorEventLoop', - 'AbstractChildWatcher', 'SafeChildWatcher', - 'FastChildWatcher', 'DefaultEventLoopPolicy', - ] +__all__ = ( + 'SelectorEventLoop', + 'AbstractChildWatcher', 'SafeChildWatcher', + 'FastChildWatcher', 'PidfdChildWatcher', + 'MultiLoopChildWatcher', 'ThreadedChildWatcher', + 'DefaultEventLoopPolicy', +) + if sys.platform == 'win32': # pragma: no cover raise ImportError('Signals are not really supported on Windows') @@ -39,13 +44,6 @@ def _sighandler_noop(signum, frame): pass -try: - _fspath = os.fspath -except AttributeError: - # Python 3.5 or earlier - _fspath = lambda path: path - - class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop): """Unix event loop. @@ -56,13 +54,19 @@ def __init__(self, selector=None): super().__init__(selector) self._signal_handlers = {} - def _socketpair(self): - return socket.socketpair() - def close(self): super().close() - for sig in list(self._signal_handlers): - self.remove_signal_handler(sig) + if not sys.is_finalizing(): + for sig in list(self._signal_handlers): + self.remove_signal_handler(sig) + else: + if self._signal_handlers: + warnings.warn(f"Closing the loop {self!r} " + f"on interpreter shutdown " + f"stage, skipping signal handlers removal", + ResourceWarning, + source=self) + self._signal_handlers.clear() def _process_self_data(self, data): for signum in data: @@ -77,8 +81,8 @@ def add_signal_handler(self, sig, callback, *args): Raise ValueError if the signal number is invalid or uncatchable. Raise RuntimeError if there is a problem setting up the handler. """ - if (coroutines.iscoroutine(callback) - or coroutines.iscoroutinefunction(callback)): + if (coroutines.iscoroutine(callback) or + coroutines.iscoroutinefunction(callback)): raise TypeError("coroutines cannot be used " "with add_signal_handler()") self._check_signal(sig) @@ -92,12 +96,12 @@ def add_signal_handler(self, sig, callback, *args): except (ValueError, OSError) as exc: raise RuntimeError(str(exc)) - handle = events.Handle(callback, args, self) + handle = events.Handle(callback, args, self, None) self._signal_handlers[sig] = handle try: # Register a dummy signal handler to ask Python to write the signal - # number in the wakup file descriptor. _process_self_data() will + # number in the wakeup file descriptor. _process_self_data() will # read signal numbers from this file descriptor to handle signals. signal.signal(sig, _sighandler_noop) @@ -112,7 +116,7 @@ def add_signal_handler(self, sig, callback, *args): logger.info('set_wakeup_fd(-1) failed: %s', nexc) if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + raise RuntimeError(f'sig {sig} cannot be caught') else: raise @@ -146,7 +150,7 @@ def remove_signal_handler(self, sig): signal.signal(sig, handler) except OSError as exc: if exc.errno == errno.EINVAL: - raise RuntimeError('sig {} cannot be caught'.format(sig)) + raise RuntimeError(f'sig {sig} cannot be caught') else: raise @@ -165,11 +169,10 @@ def _check_signal(self, sig): Raise RuntimeError if there is a problem setting up the handler. """ if not isinstance(sig, int): - raise TypeError('sig must be an int, not {!r}'.format(sig)) + raise TypeError(f'sig must be an int, not {sig!r}') - if not (1 <= sig < signal.NSIG): - raise ValueError( - 'sig {} out of range(1, {})'.format(sig, signal.NSIG)) + if sig not in signal.valid_signals(): + raise ValueError(f'invalid signal number {sig}') def _make_read_pipe_transport(self, pipe, protocol, waiter=None, extra=None): @@ -179,11 +182,17 @@ def _make_write_pipe_transport(self, pipe, protocol, waiter=None, extra=None): return _UnixWritePipeTransport(self, pipe, protocol, waiter, extra) - @coroutine - def _make_subprocess_transport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): + async def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): with events.get_child_watcher() as watcher: + if not watcher.is_active(): + # Check early. + # Raising exception before process creation + # prevents subprocess execution if the watcher + # is not ready to handle it. + raise RuntimeError("asyncio.get_child_watcher() is not activated, " + "subprocess support is not installed.") waiter = self.create_future() transp = _UnixSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, @@ -193,29 +202,24 @@ def _make_subprocess_transport(self, protocol, args, shell, watcher.add_child_handler(transp.get_pid(), self._child_watcher_callback, transp) try: - yield from waiter - except Exception as exc: - # Workaround CPython bug #23353: using yield/yield-from in an - # except block of a generator doesn't clear properly - # sys.exc_info() - err = exc - else: - err = None - - if err is not None: + await waiter + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: transp.close() - yield from transp._wait() - raise err + await transp._wait() + raise return transp def _child_watcher_callback(self, pid, returncode, transp): self.call_soon_threadsafe(transp._process_exited, returncode) - @coroutine - def create_unix_connection(self, protocol_factory, path, *, - ssl=None, sock=None, - server_hostname=None): + async def create_unix_connection( + self, protocol_factory, path=None, *, + ssl=None, sock=None, + server_hostname=None, + ssl_handshake_timeout=None): assert server_hostname is None or isinstance(server_hostname, str) if ssl: if server_hostname is None: @@ -224,16 +228,20 @@ def create_unix_connection(self, protocol_factory, path, *, else: if server_hostname is not None: raise ValueError('server_hostname is only meaningful with ssl') + if ssl_handshake_timeout is not None: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') if path is not None: if sock is not None: raise ValueError( 'path and sock can not be specified at the same time') + path = os.fspath(path) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) try: sock.setblocking(False) - yield from self.sock_connect(sock, path) + await self.sock_connect(sock, path) except: sock.close() raise @@ -242,28 +250,34 @@ def create_unix_connection(self, protocol_factory, path, *, if sock is None: raise ValueError('no path and sock were specified') if (sock.family != socket.AF_UNIX or - not base_events._is_stream_socket(sock)): + sock.type != socket.SOCK_STREAM): raise ValueError( - 'A UNIX Domain Stream Socket was expected, got {!r}' - .format(sock)) + f'A UNIX Domain Stream Socket was expected, got {sock!r}') sock.setblocking(False) - transport, protocol = yield from self._create_connection_transport( - sock, protocol_factory, ssl, server_hostname) + transport, protocol = await self._create_connection_transport( + sock, protocol_factory, ssl, server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout) return transport, protocol - @coroutine - def create_unix_server(self, protocol_factory, path=None, *, - sock=None, backlog=100, ssl=None): + async def create_unix_server( + self, protocol_factory, path=None, *, + sock=None, backlog=100, ssl=None, + ssl_handshake_timeout=None, + start_serving=True): if isinstance(ssl, bool): raise TypeError('ssl argument must be an SSLContext or None') + if ssl_handshake_timeout is not None and not ssl: + raise ValueError( + 'ssl_handshake_timeout is only meaningful with ssl') + if path is not None: if sock is not None: raise ValueError( 'path and sock can not be specified at the same time') - path = _fspath(path) + path = os.fspath(path) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) # Check for abstract socket. `str` and `bytes` paths are supported. @@ -275,7 +289,8 @@ def create_unix_server(self, protocol_factory, path=None, *, pass except OSError as err: # Directory may have permissions only to create socket. - logger.error('Unable to check or remove stale UNIX socket %r: %r', path, err) + logger.error('Unable to check or remove stale UNIX socket ' + '%r: %r', path, err) try: sock.bind(path) @@ -284,7 +299,7 @@ def create_unix_server(self, protocol_factory, path=None, *, if exc.errno == errno.EADDRINUSE: # Let's improve the error message by adding # with what exact address it occurs. - msg = 'Address {!r} is already in use'.format(path) + msg = f'Address {path!r} is already in use' raise OSError(errno.EADDRINUSE, msg) from None else: raise @@ -297,28 +312,125 @@ def create_unix_server(self, protocol_factory, path=None, *, 'path was not specified, and no sock specified') if (sock.family != socket.AF_UNIX or - not base_events._is_stream_socket(sock)): + sock.type != socket.SOCK_STREAM): raise ValueError( - 'A UNIX Domain Stream Socket was expected, got {!r}' - .format(sock)) + f'A UNIX Domain Stream Socket was expected, got {sock!r}') - server = base_events.Server(self, [sock]) - sock.listen(backlog) sock.setblocking(False) - self._start_serving(protocol_factory, sock, ssl, server) - return server + server = base_events.Server(self, [sock], protocol_factory, + ssl, backlog, ssl_handshake_timeout) + if start_serving: + server._start_serving() + # Skip one loop iteration so that all 'loop.add_reader' + # go through. + await tasks.sleep(0) + return server -#if hasattr(os, 'set_blocking'): -# def _set_nonblocking(fd): -# os.set_blocking(fd, False) -#else: -# import fcntl + async def _sock_sendfile_native(self, sock, file, offset, count): + try: + os.sendfile + except AttributeError: + raise exceptions.SendfileNotAvailableError( + "os.sendfile() is not available") + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise exceptions.SendfileNotAvailableError("not a regular file") + try: + fsize = os.fstat(fileno).st_size + except OSError: + raise exceptions.SendfileNotAvailableError("not a regular file") + blocksize = count if count else fsize + if not blocksize: + return 0 # empty file + + fut = self.create_future() + self._sock_sendfile_native_impl(fut, None, sock, fileno, + offset, count, blocksize, 0) + return await fut + + def _sock_sendfile_native_impl(self, fut, registered_fd, sock, fileno, + offset, count, blocksize, total_sent): + fd = sock.fileno() + if registered_fd is not None: + # Remove the callback early. It should be rare that the + # selector says the fd is ready but the call still returns + # EAGAIN, and I am willing to take a hit in that case in + # order to simplify the common case. + self.remove_writer(registered_fd) + if fut.cancelled(): + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + return + if count: + blocksize = count - total_sent + if blocksize <= 0: + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_result(total_sent) + return -# def _set_nonblocking(fd): -# flags = fcntl.fcntl(fd, fcntl.F_GETFL) -# flags = flags | os.O_NONBLOCK -# fcntl.fcntl(fd, fcntl.F_SETFL, flags) + try: + sent = os.sendfile(fd, fileno, offset, blocksize) + except (BlockingIOError, InterruptedError): + if registered_fd is None: + self._sock_add_cancellation_callback(fut, sock) + self.add_writer(fd, self._sock_sendfile_native_impl, fut, + fd, sock, fileno, + offset, count, blocksize, total_sent) + except OSError as exc: + if (registered_fd is not None and + exc.errno == errno.ENOTCONN and + type(exc) is not ConnectionError): + # If we have an ENOTCONN and this isn't a first call to + # sendfile(), i.e. the connection was closed in the middle + # of the operation, normalize the error to ConnectionError + # to make it consistent across all Posix systems. + new_exc = ConnectionError( + "socket is not connected", errno.ENOTCONN) + new_exc.__cause__ = exc + exc = new_exc + if total_sent == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + err = exceptions.SendfileNotAvailableError( + "os.sendfile call failed") + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_exception(err) + else: + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_exception(exc) + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_exception(exc) + else: + if sent == 0: + # EOF + self._sock_sendfile_update_filepos(fileno, offset, total_sent) + fut.set_result(total_sent) + else: + offset += sent + total_sent += sent + if registered_fd is None: + self._sock_add_cancellation_callback(fut, sock) + self.add_writer(fd, self._sock_sendfile_native_impl, fut, + fd, sock, fileno, + offset, count, blocksize, total_sent) + + def _sock_sendfile_update_filepos(self, fileno, offset, total_sent): + if total_sent > 0: + os.lseek(fileno, offset, os.SEEK_SET) + + def _sock_add_cancellation_callback(self, fut, sock): + def cb(fut): + if fut.cancelled(): + fd = sock.fileno() + if fd != -1: + self.remove_writer(fd) + fut.add_done_callback(cb) class _UnixReadPipeTransport(transports.ReadTransport): @@ -333,6 +445,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._fileno = pipe.fileno() self._protocol = protocol self._closing = False + self._paused = False mode = os.fstat(self._fileno).st_mode if not (stat.S_ISFIFO(mode) or @@ -343,7 +456,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): self._protocol = None raise ValueError("Pipe transport is for pipes/sockets only.") - _set_nonblocking(self._fileno) + os.set_blocking(self._fileno, False) self._loop.call_soon(self._protocol.connection_made, self) # only start reading when connection_made() has been called @@ -360,12 +473,11 @@ def __repr__(self): info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % self._fileno) + info.append(f'fd={self._fileno}') selector = getattr(self._loop, '_selector', None) if self._pipe is not None and selector is not None: polling = selector_events._test_selector_event( - selector, - self._fileno, selectors.EVENT_READ) + selector, self._fileno, selectors.EVENT_READ) if polling: info.append('polling') else: @@ -374,7 +486,7 @@ def __repr__(self): info.append('open') else: info.append('closed') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def _read_ready(self): try: @@ -395,10 +507,20 @@ def _read_ready(self): self._loop.call_soon(self._call_connection_lost, None) def pause_reading(self): + if self._closing or self._paused: + return + self._paused = True self._loop._remove_reader(self._fileno) + if self._loop.get_debug(): + logger.debug("%r pauses reading", self) def resume_reading(self): + if self._closing or not self._paused: + return + self._paused = False self._loop._add_reader(self._fileno, self._read_ready) + if self._loop.get_debug(): + logger.debug("%r resumes reading", self) def set_protocol(self, protocol): self._protocol = protocol @@ -413,15 +535,10 @@ def close(self): if not self._closing: self._close(None) - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._pipe is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self._pipe.close() + def __del__(self, _warn=warnings.warn): + if self._pipe is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._pipe.close() def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only @@ -476,7 +593,7 @@ def __init__(self, loop, pipe, protocol, waiter=None, extra=None): raise ValueError("Pipe transport is only for " "pipes, sockets and character devices") - _set_nonblocking(self._fileno) + os.set_blocking(self._fileno, False) self._loop.call_soon(self._protocol.connection_made, self) # On AIX, the reader trick (to be notified when the read end of the @@ -498,24 +615,23 @@ def __repr__(self): info.append('closed') elif self._closing: info.append('closing') - info.append('fd=%s' % self._fileno) + info.append(f'fd={self._fileno}') selector = getattr(self._loop, '_selector', None) if self._pipe is not None and selector is not None: polling = selector_events._test_selector_event( - selector, - self._fileno, selectors.EVENT_WRITE) + selector, self._fileno, selectors.EVENT_WRITE) if polling: info.append('polling') else: info.append('idle') bufsize = self.get_write_buffer_size() - info.append('bufsize=%s' % bufsize) + info.append(f'bufsize={bufsize}') elif self._pipe is not None: info.append('open') else: info.append('closed') - return '<%s>' % ' '.join(info) + return '<{}>'.format(' '.join(info)) def get_write_buffer_size(self): return len(self._buffer) @@ -549,7 +665,9 @@ def write(self, data): n = os.write(self._fileno, data) except (BlockingIOError, InterruptedError): n = 0 - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._conn_lost += 1 self._fatal_error(exc, 'Fatal write error on pipe transport') return @@ -569,7 +687,9 @@ def _write_ready(self): n = os.write(self._fileno, self._buffer) except (BlockingIOError, InterruptedError): pass - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: self._buffer.clear() self._conn_lost += 1 # Remove writer here, _fatal_error() doesn't it @@ -614,22 +734,17 @@ def close(self): # write_eof is all what we needed to close the write pipe self.write_eof() - # On Python 3.3 and older, objects with a destructor part of a reference - # cycle are never destroyed. It's not more the case on Python 3.4 thanks - # to the PEP 442. - if compat.PY34: - def __del__(self): - if self._pipe is not None: - warnings.warn("unclosed transport %r" % self, ResourceWarning, - source=self) - self._pipe.close() + def __del__(self, _warn=warnings.warn): + if self._pipe is not None: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + self._pipe.close() def abort(self): self._close(None) def _fatal_error(self, exc, message='Fatal error on pipe transport'): # should be called by exception handler only - if isinstance(exc, base_events._FATAL_ERROR_IGNORE): + if isinstance(exc, OSError): if self._loop.get_debug(): logger.debug("%r: %s", self, message, exc_info=True) else: @@ -659,22 +774,6 @@ def _call_connection_lost(self, exc): self._loop = None -#if hasattr(os, 'set_inheritable'): -# # Python 3.4 and newer -# _set_inheritable = os.set_inheritable -#else: -# import fcntl -# -# def _set_inheritable(fd, inheritable): -# cloexec_flag = getattr(fcntl, 'FD_CLOEXEC', 1) -# -# old = fcntl.fcntl(fd, fcntl.F_GETFD) -# if not inheritable: -# fcntl.fcntl(fd, fcntl.F_SETFD, old | cloexec_flag) -# else: -# fcntl.fcntl(fd, fcntl.F_SETFD, old & ~cloexec_flag) - - class _UnixSubprocessTransport(base_subprocess.BaseSubprocessTransport): def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): @@ -685,19 +784,19 @@ def _start(self, args, shell, stdin, stdout, stderr, bufsize, **kwargs): # socket (which we use in order to detect closing of the # other end). Notably this is needed on AIX, and works # just fine on other platforms. - stdin, stdin_w = self._loop._socketpair() - - # Mark the write end of the stdin pipe as non-inheritable, - # needed by close_fds=False on Python 3.3 and older - # (Python 3.4 implements the PEP 446, socketpair returns - # non-inheritable sockets) - _set_inheritable(stdin_w.fileno(), False) - self._proc = subprocess.Popen( - args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, - universal_newlines=False, bufsize=bufsize, **kwargs) - if stdin_w is not None: - stdin.close() - self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + stdin, stdin_w = socket.socketpair() + try: + self._proc = subprocess.Popen( + args, shell=shell, stdin=stdin, stdout=stdout, stderr=stderr, + universal_newlines=False, bufsize=bufsize, **kwargs) + if stdin_w is not None: + stdin.close() + self._proc.stdin = open(stdin_w.detach(), 'wb', buffering=bufsize) + stdin_w = None + finally: + if stdin_w is not None: + stdin.close() + stdin_w.close() class AbstractChildWatcher: @@ -759,6 +858,15 @@ def close(self): """ raise NotImplementedError() + def is_active(self): + """Return ``True`` if the watcher is active and is used by the event loop. + + Return True if the watcher is installed and ready to handle process exit + notifications. + + """ + raise NotImplementedError() + def __enter__(self): """Enter the watcher's context and allow starting new processes @@ -770,6 +878,98 @@ def __exit__(self, a, b, c): raise NotImplementedError() +class PidfdChildWatcher(AbstractChildWatcher): + """Child watcher implementation using Linux's pid file descriptors. + + This child watcher polls process file descriptors (pidfds) to await child + process termination. In some respects, PidfdChildWatcher is a "Goldilocks" + child watcher implementation. It doesn't require signals or threads, doesn't + interfere with any processes launched outside the event loop, and scales + linearly with the number of subprocesses launched by the event loop. The + main disadvantage is that pidfds are specific to Linux, and only work on + recent (5.3+) kernels. + """ + + def __init__(self): + self._loop = None + self._callbacks = {} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + pass + + def is_active(self): + return self._loop is not None and self._loop.is_running() + + def close(self): + self.attach_loop(None) + + def attach_loop(self, loop): + if self._loop is not None and loop is None and self._callbacks: + warnings.warn( + 'A loop is being detached ' + 'from a child watcher with pending handlers', + RuntimeWarning) + for pidfd, _, _ in self._callbacks.values(): + self._loop._remove_reader(pidfd) + os.close(pidfd) + self._callbacks.clear() + self._loop = loop + + def add_child_handler(self, pid, callback, *args): + existing = self._callbacks.get(pid) + if existing is not None: + self._callbacks[pid] = existing[0], callback, args + else: + pidfd = os.pidfd_open(pid) + self._loop._add_reader(pidfd, self._do_wait, pid) + self._callbacks[pid] = pidfd, callback, args + + def _do_wait(self, pid): + pidfd, callback, args = self._callbacks.pop(pid) + self._loop._remove_reader(pidfd) + try: + _, status = os.waitpid(pid, 0) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + returncode = 255 + logger.warning( + "child process pid %d exit status already read: " + " will report returncode 255", + pid) + else: + returncode = _compute_returncode(status) + + os.close(pidfd) + callback(pid, returncode, *args) + + def remove_child_handler(self, pid): + try: + pidfd, _, _ = self._callbacks.pop(pid) + except KeyError: + return False + self._loop._remove_reader(pidfd) + os.close(pidfd) + return True + + +def _compute_returncode(status): + if os.WIFSIGNALED(status): + # The child process died because of a signal. + return -os.WTERMSIG(status) + elif os.WIFEXITED(status): + # The child process exited (e.g sys.exit()). + return os.WEXITSTATUS(status) + else: + # The child exited, but we don't understand its status. + # This shouldn't happen, but if it does, let's just + # return that status; perhaps that helps debug it. + return status + + class BaseChildWatcher(AbstractChildWatcher): def __init__(self): @@ -779,6 +979,9 @@ def __init__(self): def close(self): self.attach_loop(None) + def is_active(self): + return self._loop is not None and self._loop.is_running() + def _do_waitpid(self, expected_pid): raise NotImplementedError() @@ -808,7 +1011,9 @@ def attach_loop(self, loop): def _sig_chld(self): try: self._do_waitpid_all() - except Exception as exc: + except (SystemExit, KeyboardInterrupt): + raise + except BaseException as exc: # self._loop should always be available here # as '_sig_chld' is added as a signal handler # in 'attach_loop' @@ -817,19 +1022,6 @@ def _sig_chld(self): 'exception': exc, }) - def _compute_returncode(self, status): - if os.WIFSIGNALED(status): - # The child process died because of a signal. - return -os.WTERMSIG(status) - elif os.WIFEXITED(status): - # The child process exited (e.g sys.exit()). - return os.WEXITSTATUS(status) - else: - # The child exited, but we don't understand its status. - # This shouldn't happen, but if it does, let's just - # return that status; perhaps that helps debug it. - return status - class SafeChildWatcher(BaseChildWatcher): """'Safe' child watcher implementation. @@ -853,11 +1045,6 @@ def __exit__(self, a, b, c): pass def add_child_handler(self, pid, callback, *args): - if self._loop is None: - raise RuntimeError( - "Cannot add child handler, " - "the child watcher does not have a loop attached") - self._callbacks[pid] = (callback, args) # Prevent a race condition in case the child is already terminated. @@ -893,7 +1080,7 @@ def _do_waitpid(self, expected_pid): # The child process is still alive. return - returncode = self._compute_returncode(status) + returncode = _compute_returncode(status) if self._loop.get_debug(): logger.debug('process %s exited with returncode %s', expected_pid, returncode) @@ -954,11 +1141,6 @@ def __exit__(self, a, b, c): def add_child_handler(self, pid, callback, *args): assert self._forks, "Must use the context manager" - if self._loop is None: - raise RuntimeError( - "Cannot add child handler, " - "the child watcher does not have a loop attached") - with self._lock: try: returncode = self._zombies.pop(pid) @@ -991,7 +1173,7 @@ def _do_waitpid_all(self): # A child process is still alive. return - returncode = self._compute_returncode(status) + returncode = _compute_returncode(status) with self._lock: try: @@ -1020,6 +1202,220 @@ def _do_waitpid_all(self): callback(pid, returncode, *args) +class MultiLoopChildWatcher(AbstractChildWatcher): + """A watcher that doesn't require running loop in the main thread. + + This implementation registers a SIGCHLD signal handler on + instantiation (which may conflict with other code that + install own handler for this signal). + + The solution is safe but it has a significant overhead when + handling a big number of processes (*O(n)* each time a + SIGCHLD is received). + """ + + # Implementation note: + # The class keeps compatibility with AbstractChildWatcher ABC + # To achieve this it has empty attach_loop() method + # and doesn't accept explicit loop argument + # for add_child_handler()/remove_child_handler() + # but retrieves the current loop by get_running_loop() + + def __init__(self): + self._callbacks = {} + self._saved_sighandler = None + + def is_active(self): + return self._saved_sighandler is not None + + def close(self): + self._callbacks.clear() + if self._saved_sighandler is None: + return + + handler = signal.getsignal(signal.SIGCHLD) + if handler != self._sig_chld: + logger.warning("SIGCHLD handler was changed by outside code") + else: + signal.signal(signal.SIGCHLD, self._saved_sighandler) + self._saved_sighandler = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def add_child_handler(self, pid, callback, *args): + loop = events.get_running_loop() + self._callbacks[pid] = (loop, callback, args) + + # Prevent a race condition in case the child is already terminated. + self._do_waitpid(pid) + + def remove_child_handler(self, pid): + try: + del self._callbacks[pid] + return True + except KeyError: + return False + + def attach_loop(self, loop): + # Don't save the loop but initialize itself if called first time + # The reason to do it here is that attach_loop() is called from + # unix policy only for the main thread. + # Main thread is required for subscription on SIGCHLD signal + if self._saved_sighandler is not None: + return + + self._saved_sighandler = signal.signal(signal.SIGCHLD, self._sig_chld) + if self._saved_sighandler is None: + logger.warning("Previous SIGCHLD handler was set by non-Python code, " + "restore to default handler on watcher close.") + self._saved_sighandler = signal.SIG_DFL + + # Set SA_RESTART to limit EINTR occurrences. + signal.siginterrupt(signal.SIGCHLD, False) + + def _do_waitpid_all(self): + for pid in list(self._callbacks): + self._do_waitpid(pid) + + def _do_waitpid(self, expected_pid): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, os.WNOHANG) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + debug_log = False + else: + if pid == 0: + # The child process is still alive. + return + + returncode = _compute_returncode(status) + debug_log = True + try: + loop, callback, args = self._callbacks.pop(pid) + except KeyError: # pragma: no cover + # May happen if .remove_child_handler() is called + # after os.waitpid() returns. + logger.warning("Child watcher got an unexpected pid: %r", + pid, exc_info=True) + else: + if loop.is_closed(): + logger.warning("Loop %r that handles pid %r is closed", loop, pid) + else: + if debug_log and loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + loop.call_soon_threadsafe(callback, pid, returncode, *args) + + def _sig_chld(self, signum, frame): + try: + self._do_waitpid_all() + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: + logger.warning('Unknown exception in SIGCHLD handler', exc_info=True) + + +class ThreadedChildWatcher(AbstractChildWatcher): + """Threaded child watcher implementation. + + The watcher uses a thread per process + for waiting for the process finish. + + It doesn't require subscription on POSIX signal + but a thread creation is not free. + + The watcher has O(1) complexity, its performance doesn't depend + on amount of spawn processes. + """ + + def __init__(self): + self._pid_counter = itertools.count(0) + self._threads = {} + + def is_active(self): + return True + + def close(self): + self._join_threads() + + def _join_threads(self): + """Internal: Join all non-daemon threads""" + threads = [thread for thread in list(self._threads.values()) + if thread.is_alive() and not thread.daemon] + for thread in threads: + thread.join() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def __del__(self, _warn=warnings.warn): + threads = [thread for thread in list(self._threads.values()) + if thread.is_alive()] + if threads: + _warn(f"{self.__class__} has registered but not finished child processes", + ResourceWarning, + source=self) + + def add_child_handler(self, pid, callback, *args): + loop = events.get_running_loop() + thread = threading.Thread(target=self._do_waitpid, + name=f"waitpid-{next(self._pid_counter)}", + args=(loop, pid, callback, args), + daemon=True) + self._threads[pid] = thread + thread.start() + + def remove_child_handler(self, pid): + # asyncio never calls remove_child_handler() !!! + # The method is no-op but is implemented because + # abstract base classes require it. + return True + + def attach_loop(self, loop): + pass + + def _do_waitpid(self, loop, expected_pid, callback, args): + assert expected_pid > 0 + + try: + pid, status = os.waitpid(expected_pid, 0) + except ChildProcessError: + # The child process is already reaped + # (may happen if waitpid() is called elsewhere). + pid = expected_pid + returncode = 255 + logger.warning( + "Unknown child process pid %d, will report returncode 255", + pid) + else: + returncode = _compute_returncode(status) + if loop.get_debug(): + logger.debug('process %s exited with returncode %s', + expected_pid, returncode) + + if loop.is_closed(): + logger.warning("Loop %r that handles pid %r is closed", loop, pid) + else: + loop.call_soon_threadsafe(callback, pid, returncode, *args) + + self._threads.pop(expected_pid) + + class _UnixDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): """UNIX event loop policy with a watcher for child processes.""" _loop_factory = _UnixSelectorEventLoop @@ -1031,9 +1427,8 @@ def __init__(self): def _init_watcher(self): with events._lock: if self._watcher is None: # pragma: no branch - self._watcher = SafeChildWatcher() - if isinstance(threading.current_thread(), - threading._MainThread): + self._watcher = ThreadedChildWatcher() + if threading.current_thread() is threading.main_thread(): self._watcher.attach_loop(self._local._loop) def set_event_loop(self, loop): @@ -1046,14 +1441,14 @@ def set_event_loop(self, loop): super().set_event_loop(loop) - if self._watcher is not None and \ - isinstance(threading.current_thread(), threading._MainThread): + if (self._watcher is not None and + threading.current_thread() is threading.main_thread()): self._watcher.attach_loop(loop) def get_child_watcher(self): """Get the watcher for child processes. - If not yet set, a SafeChildWatcher object is automatically created. + If not yet set, a ThreadedChildWatcher object is automatically created. """ if self._watcher is None: self._init_watcher() @@ -1070,5 +1465,6 @@ def set_child_watcher(self, watcher): self._watcher = watcher + SelectorEventLoop = _UnixSelectorEventLoop DefaultEventLoopPolicy = _UnixDefaultEventLoopPolicy diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index 2c68bc526a..da81ab435b 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -1,28 +1,36 @@ """Selector and proactor event loops for Windows.""" +import sys + +if sys.platform != 'win32': # pragma: no cover + raise ImportError('win32 only') + +import _overlapped import _winapi import errno import math +import msvcrt import socket import struct +import time import weakref from . import events from . import base_subprocess from . import futures +from . import exceptions from . import proactor_events from . import selector_events from . import tasks from . import windows_utils -# XXX RustPython TODO: _overlapped -# from . import _overlapped -from .coroutines import coroutine from .log import logger -__all__ = ['SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', - 'DefaultEventLoopPolicy', - ] +__all__ = ( + 'SelectorEventLoop', 'ProactorEventLoop', 'IocpProactor', + 'DefaultEventLoopPolicy', 'WindowsSelectorEventLoopPolicy', + 'WindowsProactorEventLoopPolicy', +) NULL = 0 @@ -53,7 +61,7 @@ def _repr_info(self): info = super()._repr_info() if self._ov is not None: state = 'pending' if self._ov.pending else 'completed' - info.insert(1, 'overlapped=<%s, %#x>' % (state, self._ov.address)) + info.insert(1, f'overlapped=<{state}, {self._ov.address:#x}>') return info def _cancel_overlapped(self): @@ -72,9 +80,9 @@ def _cancel_overlapped(self): self._loop.call_exception_handler(context) self._ov = None - def cancel(self): + def cancel(self, msg=None): self._cancel_overlapped() - return super().cancel() + return super().cancel(msg=msg) def set_exception(self, exception): super().set_exception(exception) @@ -109,12 +117,12 @@ def _poll(self): def _repr_info(self): info = super()._repr_info() - info.append('handle=%#x' % self._handle) + info.append(f'handle={self._handle:#x}') if self._handle is not None: state = 'signaled' if self._poll() else 'waiting' info.append(state) if self._wait_handle is not None: - info.append('wait_handle=%#x' % self._wait_handle) + info.append(f'wait_handle={self._wait_handle:#x}') return info def _unregister_wait_cb(self, fut): @@ -146,9 +154,9 @@ def _unregister_wait(self): self._unregister_wait_cb(None) - def cancel(self): + def cancel(self, msg=None): self._unregister_wait() - return super().cancel() + return super().cancel(msg=msg) def set_exception(self, exception): self._unregister_wait() @@ -297,9 +305,6 @@ def close(self): class _WindowsSelectorEventLoop(selector_events.BaseSelectorEventLoop): """Windows version of selector event loop.""" - def _socketpair(self): - return windows_utils.socketpair() - class ProactorEventLoop(proactor_events.BaseProactorEventLoop): """Windows version of proactor event loop using IOCP.""" @@ -309,20 +314,34 @@ def __init__(self, proactor=None): proactor = IocpProactor() super().__init__(proactor) - def _socketpair(self): - return windows_utils.socketpair() - - @coroutine - def create_pipe_connection(self, protocol_factory, address): + def run_forever(self): + try: + assert self._self_reading_future is None + self.call_soon(self._loop_self_reading) + super().run_forever() + finally: + if self._self_reading_future is not None: + ov = self._self_reading_future._ov + self._self_reading_future.cancel() + # self_reading_future was just cancelled so if it hasn't been + # finished yet, it never will be (it's possible that it has + # already finished and its callback is waiting in the queue, + # where it could still happen if the event loop is restarted). + # Unregister it otherwise IocpProactor.close will wait for it + # forever + if ov is not None: + self._proactor._unregister(ov) + self._self_reading_future = None + + async def create_pipe_connection(self, protocol_factory, address): f = self._proactor.connect_pipe(address) - pipe = yield from f + pipe = await f protocol = protocol_factory() trans = self._make_duplex_pipe_transport(pipe, protocol, extra={'addr': address}) return trans, protocol - @coroutine - def start_serving_pipe(self, protocol_factory, address): + async def start_serving_pipe(self, protocol_factory, address): server = PipeServer(address) def loop_accept_pipe(f=None): @@ -358,7 +377,7 @@ def loop_accept_pipe(f=None): elif self._debug: logger.warning("Accept pipe failed on pipe %r", pipe, exc_info=True) - except futures.CancelledError: + except exceptions.CancelledError: if pipe: pipe.close() else: @@ -368,28 +387,22 @@ def loop_accept_pipe(f=None): self.call_soon(loop_accept_pipe) return [server] - @coroutine - def _make_subprocess_transport(self, protocol, args, shell, - stdin, stdout, stderr, bufsize, - extra=None, **kwargs): + async def _make_subprocess_transport(self, protocol, args, shell, + stdin, stdout, stderr, bufsize, + extra=None, **kwargs): waiter = self.create_future() transp = _WindowsSubprocessTransport(self, protocol, args, shell, stdin, stdout, stderr, bufsize, waiter=waiter, extra=extra, **kwargs) try: - yield from waiter - except Exception as exc: - # Workaround CPython bug #23353: using yield/yield-from in an - # except block of a generator doesn't clear properly sys.exc_info() - err = exc - else: - err = None - - if err is not None: + await waiter + except (SystemExit, KeyboardInterrupt): + raise + except BaseException: transp.close() - yield from transp._wait() - raise err + await transp._wait() + raise return transp @@ -407,10 +420,16 @@ def __init__(self, concurrency=0xffffffff): self._unregistered = [] self._stopped_serving = weakref.WeakSet() + def _check_closed(self): + if self._iocp is None: + raise RuntimeError('IocpProactor is closed') + def __repr__(self): - return ('<%s overlapped#=%s result#=%s>' - % (self.__class__.__name__, len(self._cache), - len(self._results))) + info = ['overlapped#=%s' % len(self._cache), + 'result#=%s' % len(self._results)] + if self._iocp is None: + info.append('closed') + return '<%s %s>' % (self.__class__.__name__, " ".join(info)) def set_loop(self, loop): self._loop = loop @@ -442,13 +461,75 @@ def finish_recv(trans, key, ov): try: return ov.getresult() except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): raise ConnectionResetError(*exc.args) else: raise return self._register(ov, conn, finish_recv) + def recv_into(self, conn, buf, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + if isinstance(conn, socket.socket): + ov.WSARecvInto(conn.fileno(), buf, flags) + else: + ov.ReadFileInto(conn.fileno(), buf) + except BrokenPipeError: + return self._result(0) + + def finish_recv(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_recv) + + def recvfrom(self, conn, nbytes, flags=0): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + try: + ov.WSARecvFrom(conn.fileno(), nbytes, flags) + except BrokenPipeError: + return self._result((b'', None)) + + def finish_recv(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_recv) + + def sendto(self, conn, buf, flags=0, addr=None): + self._register_with_iocp(conn) + ov = _overlapped.Overlapped(NULL) + + ov.WSASendTo(conn.fileno(), buf, flags, addr) + + def finish_send(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + + return self._register(ov, conn, finish_send) + def send(self, conn, buf, flags=0): self._register_with_iocp(conn) ov = _overlapped.Overlapped(NULL) @@ -461,7 +542,8 @@ def finish_send(trans, key, ov): try: return ov.getresult() except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): raise ConnectionResetError(*exc.args) else: raise @@ -483,12 +565,11 @@ def finish_accept(trans, key, ov): conn.settimeout(listener.gettimeout()) return conn, conn.getpeername() - @coroutine - def accept_coro(future, conn): + async def accept_coro(future, conn): # Coroutine closing the accept socket if the future is cancelled try: - yield from future - except futures.CancelledError: + await future + except exceptions.CancelledError: conn.close() raise @@ -498,6 +579,14 @@ def accept_coro(future, conn): return future def connect(self, conn, address): + if conn.type == socket.SOCK_DGRAM: + # WSAConnect will complete immediately for UDP sockets so we don't + # need to register any IOCP operation + _overlapped.WSAConnect(conn.fileno(), address) + fut = self._loop.create_future() + fut.set_result(None) + return fut + self._register_with_iocp(conn) # The socket needs to be locally bound before we call ConnectEx(). try: @@ -520,6 +609,27 @@ def finish_connect(trans, key, ov): return self._register(ov, conn, finish_connect) + def sendfile(self, sock, file, offset, count): + self._register_with_iocp(sock) + ov = _overlapped.Overlapped(NULL) + offset_low = offset & 0xffff_ffff + offset_high = (offset >> 32) & 0xffff_ffff + ov.TransmitFile(sock.fileno(), + msvcrt.get_osfhandle(file.fileno()), + offset_low, offset_high, + count, 0, 0) + + def finish_sendfile(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, sock, finish_sendfile) + def accept_pipe(self, pipe): self._register_with_iocp(pipe) ov = _overlapped.Overlapped(NULL) @@ -537,13 +647,12 @@ def finish_accept_pipe(trans, key, ov): return self._register(ov, pipe, finish_accept_pipe) - @coroutine - def connect_pipe(self, address): + async def connect_pipe(self, address): delay = CONNECT_PIPE_INIT_DELAY while True: - # Unfortunately there is no way to do an overlapped connect to a pipe. - # Call CreateFile() in a loop until it doesn't fail with - # ERROR_PIPE_BUSY + # Unfortunately there is no way to do an overlapped connect to + # a pipe. Call CreateFile() in a loop until it doesn't fail with + # ERROR_PIPE_BUSY. try: handle = _overlapped.ConnectPipe(address) break @@ -553,7 +662,7 @@ def connect_pipe(self, address): # ConnectPipe() failed with ERROR_PIPE_BUSY: retry later delay = min(delay * 2, CONNECT_PIPE_MAX_DELAY) - yield from tasks.sleep(delay, loop=self._loop) + await tasks.sleep(delay) return windows_utils.PipeHandle(handle) @@ -573,6 +682,8 @@ def _wait_cancel(self, event, done_callback): return fut def _wait_for_handle(self, handle, timeout, _is_cancel): + self._check_closed() + if timeout is None: ms = _winapi.INFINITE else: @@ -615,6 +726,8 @@ def _register_with_iocp(self, obj): # that succeed immediately. def _register(self, ov, obj, callback): + self._check_closed() + # Return a future which will be set with the result of the # operation when it completes. The future's value is actually # the value returned by callback(). @@ -651,6 +764,7 @@ def _unregister(self, ov): already be signalled (pending in the proactor event queue). It is also safe if the event is never signalled (because it was cancelled). """ + self._check_closed() self._unregistered.append(ov) def _get_accept_socket(self, family): @@ -708,7 +822,7 @@ def _poll(self, timeout=None): f.set_result(value) self._results.append(f) - # Remove unregisted futures + # Remove unregistered futures for ov in self._unregistered: self._cache.pop(ov.address, None) self._unregistered.clear() @@ -720,6 +834,10 @@ def _stop_serving(self, obj): self._stopped_serving.add(obj) def close(self): + if self._iocp is None: + # already closed + return + # Cancel remaining registered operations. for address, (fut, ov, obj, callback) in list(self._cache.items()): if fut.cancelled(): @@ -742,14 +860,25 @@ def close(self): context['source_traceback'] = fut._source_traceback self._loop.call_exception_handler(context) + # Wait until all cancelled overlapped complete: don't exit with running + # overlapped to prevent a crash. Display progress every second if the + # loop is still running. + msg_update = 1.0 + start_time = time.monotonic() + next_msg = start_time + msg_update while self._cache: - if not self._poll(1): - logger.debug('taking long time to close proactor') + if next_msg <= time.monotonic(): + logger.debug('%r is running after closing for %.1f seconds', + self, time.monotonic() - start_time) + next_msg = time.monotonic() + msg_update + + # handle a few events, or timeout + self._poll(msg_update) self._results = [] - if self._iocp is not None: - _winapi.CloseHandle(self._iocp) - self._iocp = None + + _winapi.CloseHandle(self._iocp) + self._iocp = None def __del__(self): self.close() @@ -773,8 +902,12 @@ def callback(f): SelectorEventLoop = _WindowsSelectorEventLoop -class _WindowsDefaultEventLoopPolicy(events.BaseDefaultEventLoopPolicy): +class WindowsSelectorEventLoopPolicy(events.BaseDefaultEventLoopPolicy): _loop_factory = SelectorEventLoop -DefaultEventLoopPolicy = _WindowsDefaultEventLoopPolicy +class WindowsProactorEventLoopPolicy(events.BaseDefaultEventLoopPolicy): + _loop_factory = ProactorEventLoop + + +DefaultEventLoopPolicy = WindowsProactorEventLoopPolicy diff --git a/Lib/asyncio/windows_utils.py b/Lib/asyncio/windows_utils.py index 7c63fb904b..ef277fac3e 100644 --- a/Lib/asyncio/windows_utils.py +++ b/Lib/asyncio/windows_utils.py @@ -1,6 +1,4 @@ -""" -Various Windows specific bits and pieces -""" +"""Various Windows specific bits and pieces.""" import sys @@ -11,13 +9,12 @@ import itertools import msvcrt import os -import socket import subprocess import tempfile import warnings -__all__ = ['socketpair', 'pipe', 'Popen', 'PIPE', 'PipeHandle'] +__all__ = 'pipe', 'Popen', 'PIPE', 'PipeHandle' # Constants/globals @@ -29,61 +26,14 @@ _mmap_counter = itertools.count() -if hasattr(socket, 'socketpair'): - # Since Python 3.5, socket.socketpair() is now also available on Windows - socketpair = socket.socketpair -else: - # Replacement for socket.socketpair() - def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - """A socket pair usable as a self-pipe, for Windows. - - Origin: https://gist.github.com/4325783, by Geert Jansen. - Public domain. - """ - if family == socket.AF_INET: - host = '127.0.0.1' - elif family == socket.AF_INET6: - host = '::1' - else: - raise ValueError("Only AF_INET and AF_INET6 socket address " - "families are supported") - if type != socket.SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with setblocking(0) - # that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen(1) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) - try: - csock.setblocking(False) - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - csock.setblocking(True) - ssock, _ = lsock.accept() - except: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) - - # Replacement for os.pipe() using handles instead of fds def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE): """Like os.pipe() but with overlapped support and using handles not fds.""" - address = tempfile.mktemp(prefix=r'\\.\pipe\python-pipe-%d-%d-' % - (os.getpid(), next(_mmap_counter))) + address = tempfile.mktemp( + prefix=r'\\.\pipe\python-pipe-{:d}-{:d}-'.format( + os.getpid(), next(_mmap_counter))) if duplex: openmode = _winapi.PIPE_ACCESS_DUPLEX @@ -138,10 +88,10 @@ def __init__(self, handle): def __repr__(self): if self._handle is not None: - handle = 'handle=%r' % self._handle + handle = f'handle={self._handle!r}' else: handle = 'closed' - return '<%s %s>' % (self.__class__.__name__, handle) + return f'<{self.__class__.__name__} {handle}>' @property def handle(self): @@ -149,7 +99,7 @@ def handle(self): def fileno(self): if self._handle is None: - raise ValueError("I/O operatioon on closed pipe") + raise ValueError("I/O operation on closed pipe") return self._handle def close(self, *, CloseHandle=_winapi.CloseHandle): @@ -157,10 +107,9 @@ def close(self, *, CloseHandle=_winapi.CloseHandle): CloseHandle(self._handle) self._handle = None - def __del__(self): + def __del__(self, _warn=warnings.warn): if self._handle is not None: - warnings.warn("unclosed %r" % self, ResourceWarning, - source=self) + _warn(f"unclosed {self!r}", ResourceWarning, source=self) self.close() def __enter__(self): diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index e3d14d0d9a..5b82853102 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -2881,7 +2881,6 @@ def _testSendmsgAncillaryGenerator(self): self.assertEqual(self.sendmsgToServer([MSG], (o for o in [])), len(MSG)) - @unittest.skipIf(sys.platform == "darwin", "flaky on macOS") def testSendmsgArray(self): # Send data from an array instead of the usual bytes object. self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) @@ -2890,7 +2889,6 @@ def _testSendmsgArray(self): self.assertEqual(self.sendmsgToServer([array.array("B", MSG)]), len(MSG)) - @unittest.skipIf(sys.platform == "darwin", "flaky on macOS") def testSendmsgGather(self): # Send message data from more than one buffer (gather write). self.assertEqual(self.serv_sock.recv(len(MSG)), MSG) @@ -2953,7 +2951,6 @@ def _testSendmsgBadMultiCmsg(self): [MSG], [(0, 0, b""), object()]) self.sendToServer(b"done") - @unittest.skipIf(sys.platform == "darwin", "flaky on macOS") def testSendmsgExcessCmsgReject(self): # Check that sendmsg() rejects excess ancillary data items # when the number that can be sent is limited. @@ -4413,6 +4410,7 @@ class SendrecvmsgUnixStreamTestBase(SendrecvmsgConnectedBase, ConnectedStreamTestMixin, UnixStreamBase): pass +@unittest.skipIf(sys.platform == "darwin", "flaky on macOS") @requireAttrs(socket.socket, "sendmsg") @requireAttrs(socket, "AF_UNIX") class SendmsgUnixStreamTest(SendmsgStreamTests, SendrecvmsgUnixStreamTestBase):