Skip to content

[3.9] bpo-30064: Fix asyncio loop.sock_* race condition issue (GH-20369) #20460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions Lib/asyncio/selector_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _add_reader(self, fd, callback, *args):
(handle, writer))
if reader is not None:
reader.cancel()
return handle

def _remove_reader(self, fd):
if self.is_closed():
Expand Down Expand Up @@ -302,6 +303,7 @@ def _add_writer(self, fd, callback, *args):
(reader, handle))
if writer is not None:
writer.cancel()
return handle

def _remove_writer(self, fd):
"""Remove a writer callback."""
Expand Down Expand Up @@ -329,7 +331,7 @@ def _remove_writer(self, fd):
def add_reader(self, fd, callback, *args):
"""Add a reader callback."""
self._ensure_fd_no_transport(fd)
return self._add_reader(fd, callback, *args)
self._add_reader(fd, callback, *args)

def remove_reader(self, fd):
"""Remove a reader callback."""
Expand All @@ -339,7 +341,7 @@ def remove_reader(self, fd):
def add_writer(self, fd, callback, *args):
"""Add a writer callback.."""
self._ensure_fd_no_transport(fd)
return self._add_writer(fd, callback, *args)
self._add_writer(fd, callback, *args)

def remove_writer(self, fd):
"""Remove a writer callback."""
Expand All @@ -362,13 +364,15 @@ async def sock_recv(self, sock, n):
pass
fut = self.create_future()
fd = sock.fileno()
self.add_reader(fd, self._sock_recv, fut, sock, n)
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recv, fut, sock, n)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd))
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut

def _sock_read_done(self, fd, fut):
self.remove_reader(fd)
def _sock_read_done(self, fd, fut, handle=None):
if handle is None or not handle.cancelled():
self.remove_reader(fd)

def _sock_recv(self, fut, sock, n):
# _sock_recv() can add itself as an I/O callback if the operation can't
Expand Down Expand Up @@ -401,9 +405,10 @@ async def sock_recv_into(self, sock, buf):
pass
fut = self.create_future()
fd = sock.fileno()
self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
self._ensure_fd_no_transport(fd)
handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf)
fut.add_done_callback(
functools.partial(self._sock_read_done, fd))
functools.partial(self._sock_read_done, fd, handle=handle))
return await fut

def _sock_recv_into(self, fut, sock, buf):
Expand Down Expand Up @@ -446,11 +451,12 @@ async def sock_sendall(self, sock, data):

fut = self.create_future()
fd = sock.fileno()
fut.add_done_callback(
functools.partial(self._sock_write_done, fd))
self._ensure_fd_no_transport(fd)
# use a trick with a list in closure to store a mutable state
self.add_writer(fd, self._sock_sendall, fut, sock,
memoryview(data), [n])
handle = self._add_writer(fd, self._sock_sendall, fut, sock,
memoryview(data), [n])
fut.add_done_callback(
functools.partial(self._sock_write_done, fd, handle=handle))
return await fut

def _sock_sendall(self, fut, sock, view, pos):
Expand Down Expand Up @@ -502,18 +508,21 @@ def _sock_connect(self, fut, sock, address):
# connection runs in background. We have to wait until the socket
# becomes writable to be notified when the connection succeed or
# fails.
self._ensure_fd_no_transport(fd)
handle = self._add_writer(
fd, self._sock_connect_cb, fut, sock, address)
fut.add_done_callback(
functools.partial(self._sock_write_done, fd))
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
functools.partial(self._sock_write_done, fd, handle=handle))
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
fut.set_exception(exc)
else:
fut.set_result(None)

def _sock_write_done(self, fd, fut):
self.remove_writer(fd)
def _sock_write_done(self, fd, fut, handle=None):
if handle is None or not handle.cancelled():
self.remove_writer(fd)

def _sock_connect_cb(self, fut, sock, address):
if fut.done():
Expand Down
131 changes: 131 additions & 0 deletions Lib/test/test_asyncio/test_sock_lowlevel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import socket
import time
import asyncio
import sys
from asyncio import proactor_events
Expand Down Expand Up @@ -122,6 +123,136 @@ def test_sock_client_ops(self):
sock = socket.socket()
self._basetest_sock_recv_into(httpd, sock)

async def _basetest_sock_recv_racing(self, httpd, sock):
sock.setblocking(False)
await self.loop.sock_connect(sock, httpd.address)

task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
await asyncio.sleep(0)
task.cancel()

asyncio.create_task(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = await self.loop.sock_recv(sock, 1024)
# consume data
await self.loop.sock_recv(sock, 1024)

self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))

async def _basetest_sock_recv_into_racing(self, httpd, sock):
sock.setblocking(False)
await self.loop.sock_connect(sock, httpd.address)

data = bytearray(1024)
with memoryview(data) as buf:
task = asyncio.create_task(
self.loop.sock_recv_into(sock, buf[:1024]))
await asyncio.sleep(0)
task.cancel()

task = asyncio.create_task(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
# consume data
await self.loop.sock_recv_into(sock, buf[nbytes:])
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))

await task

async def _basetest_sock_send_racing(self, listener, sock):
listener.bind(('127.0.0.1', 0))
listener.listen(1)

# make connection
sock.setblocking(False)
task = asyncio.create_task(
self.loop.sock_connect(sock, listener.getsockname()))
await asyncio.sleep(0)
server = listener.accept()[0]
server.setblocking(False)

with server:
await task

# fill the buffer
with self.assertRaises(BlockingIOError):
while True:
sock.send(b' ' * 5)

# cancel a blocked sock_sendall
task = asyncio.create_task(
self.loop.sock_sendall(sock, b'hello'))
await asyncio.sleep(0)
task.cancel()

# clear the buffer
async def recv_until():
data = b''
while not data:
data = await self.loop.sock_recv(server, 1024)
data = data.strip()
return data
task = asyncio.create_task(recv_until())

# immediately register another sock_sendall
await self.loop.sock_sendall(sock, b'world')
data = await task
# ProactorEventLoop could deliver hello
self.assertTrue(data.endswith(b'world'))

async def _basetest_sock_connect_racing(self, listener, sock):
listener.bind(('127.0.0.1', 0))
addr = listener.getsockname()
sock.setblocking(False)

task = asyncio.create_task(self.loop.sock_connect(sock, addr))
await asyncio.sleep(0)
task.cancel()

listener.listen(1)
i = 0
while True:
try:
await self.loop.sock_connect(sock, addr)
break
except ConnectionRefusedError: # on Linux we need another retry
await self.loop.sock_connect(sock, addr)
break
except OSError as e: # on Windows we need more retries
# A connect request was made on an already connected socket
if getattr(e, 'winerror', 0) == 10056:
break

# https://stackoverflow.com/a/54437602/3316267
if getattr(e, 'winerror', 0) != 10022:
raise
i += 1
if i >= 128:
raise # too many retries
# avoid touching event loop to maintain race condition
time.sleep(0.01)

def test_sock_client_racing(self):
with test_utils.run_test_server() as httpd:
sock = socket.socket()
with sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_recv_racing(httpd, sock), 10))
sock = socket.socket()
with sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_recv_into_racing(httpd, sock), 10))
listener = socket.socket()
sock = socket.socket()
with listener, sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_send_racing(listener, sock), 10))
listener = socket.socket()
sock = socket.socket()
with listener, sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_connect_racing(listener, sock), 10))

async def _basetest_huge_content(self, address):
sock = socket.socket()
sock.setblocking(False)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix asyncio ``loop.sock_*`` race condition issue