bpo-30064: Fix asyncio loop.sock_* race condition issue (#20369)
This commit is contained in:
parent
526e23f153
commit
210a137396
|
@ -266,6 +266,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
(handle, writer))
|
(handle, writer))
|
||||||
if reader is not None:
|
if reader is not None:
|
||||||
reader.cancel()
|
reader.cancel()
|
||||||
|
return handle
|
||||||
|
|
||||||
def _remove_reader(self, fd):
|
def _remove_reader(self, fd):
|
||||||
if self.is_closed():
|
if self.is_closed():
|
||||||
|
@ -302,6 +303,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
(reader, handle))
|
(reader, handle))
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
writer.cancel()
|
writer.cancel()
|
||||||
|
return handle
|
||||||
|
|
||||||
def _remove_writer(self, fd):
|
def _remove_writer(self, fd):
|
||||||
"""Remove a writer callback."""
|
"""Remove a writer callback."""
|
||||||
|
@ -329,7 +331,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
def add_reader(self, fd, callback, *args):
|
def add_reader(self, fd, callback, *args):
|
||||||
"""Add a reader callback."""
|
"""Add a reader callback."""
|
||||||
self._ensure_fd_no_transport(fd)
|
self._ensure_fd_no_transport(fd)
|
||||||
return self._add_reader(fd, callback, *args)
|
self._add_reader(fd, callback, *args)
|
||||||
|
|
||||||
def remove_reader(self, fd):
|
def remove_reader(self, fd):
|
||||||
"""Remove a reader callback."""
|
"""Remove a reader callback."""
|
||||||
|
@ -339,7 +341,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
def add_writer(self, fd, callback, *args):
|
def add_writer(self, fd, callback, *args):
|
||||||
"""Add a writer callback.."""
|
"""Add a writer callback.."""
|
||||||
self._ensure_fd_no_transport(fd)
|
self._ensure_fd_no_transport(fd)
|
||||||
return self._add_writer(fd, callback, *args)
|
self._add_writer(fd, callback, *args)
|
||||||
|
|
||||||
def remove_writer(self, fd):
|
def remove_writer(self, fd):
|
||||||
"""Remove a writer callback."""
|
"""Remove a writer callback."""
|
||||||
|
@ -362,12 +364,14 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
pass
|
pass
|
||||||
fut = self.create_future()
|
fut = self.create_future()
|
||||||
fd = sock.fileno()
|
fd = sock.fileno()
|
||||||
self.add_reader(fd, self._sock_recv, fut, sock, n)
|
self._ensure_fd_no_transport(fd)
|
||||||
|
handle = self._add_reader(fd, self._sock_recv, fut, sock, n)
|
||||||
fut.add_done_callback(
|
fut.add_done_callback(
|
||||||
functools.partial(self._sock_read_done, fd))
|
functools.partial(self._sock_read_done, fd, handle=handle))
|
||||||
return await fut
|
return await fut
|
||||||
|
|
||||||
def _sock_read_done(self, fd, fut):
|
def _sock_read_done(self, fd, fut, handle=None):
|
||||||
|
if handle is None or not handle.cancelled():
|
||||||
self.remove_reader(fd)
|
self.remove_reader(fd)
|
||||||
|
|
||||||
def _sock_recv(self, fut, sock, n):
|
def _sock_recv(self, fut, sock, n):
|
||||||
|
@ -401,9 +405,10 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
pass
|
pass
|
||||||
fut = self.create_future()
|
fut = self.create_future()
|
||||||
fd = sock.fileno()
|
fd = sock.fileno()
|
||||||
self.add_reader(fd, self._sock_recv_into, fut, sock, buf)
|
self._ensure_fd_no_transport(fd)
|
||||||
|
handle = self._add_reader(fd, self._sock_recv_into, fut, sock, buf)
|
||||||
fut.add_done_callback(
|
fut.add_done_callback(
|
||||||
functools.partial(self._sock_read_done, fd))
|
functools.partial(self._sock_read_done, fd, handle=handle))
|
||||||
return await fut
|
return await fut
|
||||||
|
|
||||||
def _sock_recv_into(self, fut, sock, buf):
|
def _sock_recv_into(self, fut, sock, buf):
|
||||||
|
@ -446,11 +451,12 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
|
|
||||||
fut = self.create_future()
|
fut = self.create_future()
|
||||||
fd = sock.fileno()
|
fd = sock.fileno()
|
||||||
fut.add_done_callback(
|
self._ensure_fd_no_transport(fd)
|
||||||
functools.partial(self._sock_write_done, fd))
|
|
||||||
# use a trick with a list in closure to store a mutable state
|
# use a trick with a list in closure to store a mutable state
|
||||||
self.add_writer(fd, self._sock_sendall, fut, sock,
|
handle = self._add_writer(fd, self._sock_sendall, fut, sock,
|
||||||
memoryview(data), [n])
|
memoryview(data), [n])
|
||||||
|
fut.add_done_callback(
|
||||||
|
functools.partial(self._sock_write_done, fd, handle=handle))
|
||||||
return await fut
|
return await fut
|
||||||
|
|
||||||
def _sock_sendall(self, fut, sock, view, pos):
|
def _sock_sendall(self, fut, sock, view, pos):
|
||||||
|
@ -502,9 +508,11 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
# connection runs in background. We have to wait until the socket
|
# connection runs in background. We have to wait until the socket
|
||||||
# becomes writable to be notified when the connection succeed or
|
# becomes writable to be notified when the connection succeed or
|
||||||
# fails.
|
# fails.
|
||||||
|
self._ensure_fd_no_transport(fd)
|
||||||
|
handle = self._add_writer(
|
||||||
|
fd, self._sock_connect_cb, fut, sock, address)
|
||||||
fut.add_done_callback(
|
fut.add_done_callback(
|
||||||
functools.partial(self._sock_write_done, fd))
|
functools.partial(self._sock_write_done, fd, handle=handle))
|
||||||
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
|
|
||||||
except (SystemExit, KeyboardInterrupt):
|
except (SystemExit, KeyboardInterrupt):
|
||||||
raise
|
raise
|
||||||
except BaseException as exc:
|
except BaseException as exc:
|
||||||
|
@ -512,7 +520,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
||||||
else:
|
else:
|
||||||
fut.set_result(None)
|
fut.set_result(None)
|
||||||
|
|
||||||
def _sock_write_done(self, fd, fut):
|
def _sock_write_done(self, fd, fut, handle=None):
|
||||||
|
if handle is None or not handle.cancelled():
|
||||||
self.remove_writer(fd)
|
self.remove_writer(fd)
|
||||||
|
|
||||||
def _sock_connect_cb(self, fut, sock, address):
|
def _sock_connect_cb(self, fut, sock, address):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import socket
|
import socket
|
||||||
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from asyncio import proactor_events
|
from asyncio import proactor_events
|
||||||
|
@ -122,6 +123,136 @@ class BaseSockTestsMixin:
|
||||||
sock = socket.socket()
|
sock = socket.socket()
|
||||||
self._basetest_sock_recv_into(httpd, sock)
|
self._basetest_sock_recv_into(httpd, sock)
|
||||||
|
|
||||||
|
async def _basetest_sock_recv_racing(self, httpd, sock):
|
||||||
|
sock.setblocking(False)
|
||||||
|
await self.loop.sock_connect(sock, httpd.address)
|
||||||
|
|
||||||
|
task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
asyncio.create_task(
|
||||||
|
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||||
|
data = await self.loop.sock_recv(sock, 1024)
|
||||||
|
# consume data
|
||||||
|
await self.loop.sock_recv(sock, 1024)
|
||||||
|
|
||||||
|
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||||
|
|
||||||
|
async def _basetest_sock_recv_into_racing(self, httpd, sock):
|
||||||
|
sock.setblocking(False)
|
||||||
|
await self.loop.sock_connect(sock, httpd.address)
|
||||||
|
|
||||||
|
data = bytearray(1024)
|
||||||
|
with memoryview(data) as buf:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self.loop.sock_recv_into(sock, buf[:1024]))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||||
|
nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
|
||||||
|
# consume data
|
||||||
|
await self.loop.sock_recv_into(sock, buf[nbytes:])
|
||||||
|
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||||
|
|
||||||
|
await task
|
||||||
|
|
||||||
|
async def _basetest_sock_send_racing(self, listener, sock):
|
||||||
|
listener.bind(('127.0.0.1', 0))
|
||||||
|
listener.listen(1)
|
||||||
|
|
||||||
|
# make connection
|
||||||
|
sock.setblocking(False)
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self.loop.sock_connect(sock, listener.getsockname()))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
server = listener.accept()[0]
|
||||||
|
server.setblocking(False)
|
||||||
|
|
||||||
|
with server:
|
||||||
|
await task
|
||||||
|
|
||||||
|
# fill the buffer
|
||||||
|
with self.assertRaises(BlockingIOError):
|
||||||
|
while True:
|
||||||
|
sock.send(b' ' * 5)
|
||||||
|
|
||||||
|
# cancel a blocked sock_sendall
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self.loop.sock_sendall(sock, b'hello'))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# clear the buffer
|
||||||
|
async def recv_until():
|
||||||
|
data = b''
|
||||||
|
while not data:
|
||||||
|
data = await self.loop.sock_recv(server, 1024)
|
||||||
|
data = data.strip()
|
||||||
|
return data
|
||||||
|
task = asyncio.create_task(recv_until())
|
||||||
|
|
||||||
|
# immediately register another sock_sendall
|
||||||
|
await self.loop.sock_sendall(sock, b'world')
|
||||||
|
data = await task
|
||||||
|
# ProactorEventLoop could deliver hello
|
||||||
|
self.assertTrue(data.endswith(b'world'))
|
||||||
|
|
||||||
|
async def _basetest_sock_connect_racing(self, listener, sock):
|
||||||
|
listener.bind(('127.0.0.1', 0))
|
||||||
|
addr = listener.getsockname()
|
||||||
|
sock.setblocking(False)
|
||||||
|
|
||||||
|
task = asyncio.create_task(self.loop.sock_connect(sock, addr))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
listener.listen(1)
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await self.loop.sock_connect(sock, addr)
|
||||||
|
break
|
||||||
|
except ConnectionRefusedError: # on Linux we need another retry
|
||||||
|
await self.loop.sock_connect(sock, addr)
|
||||||
|
break
|
||||||
|
except OSError as e: # on Windows we need more retries
|
||||||
|
# A connect request was made on an already connected socket
|
||||||
|
if getattr(e, 'winerror', 0) == 10056:
|
||||||
|
break
|
||||||
|
|
||||||
|
# https://stackoverflow.com/a/54437602/3316267
|
||||||
|
if getattr(e, 'winerror', 0) != 10022:
|
||||||
|
raise
|
||||||
|
i += 1
|
||||||
|
if i >= 128:
|
||||||
|
raise # too many retries
|
||||||
|
# avoid touching event loop to maintain race condition
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def test_sock_client_racing(self):
|
||||||
|
with test_utils.run_test_server() as httpd:
|
||||||
|
sock = socket.socket()
|
||||||
|
with sock:
|
||||||
|
self.loop.run_until_complete(asyncio.wait_for(
|
||||||
|
self._basetest_sock_recv_racing(httpd, sock), 10))
|
||||||
|
sock = socket.socket()
|
||||||
|
with sock:
|
||||||
|
self.loop.run_until_complete(asyncio.wait_for(
|
||||||
|
self._basetest_sock_recv_into_racing(httpd, sock), 10))
|
||||||
|
listener = socket.socket()
|
||||||
|
sock = socket.socket()
|
||||||
|
with listener, sock:
|
||||||
|
self.loop.run_until_complete(asyncio.wait_for(
|
||||||
|
self._basetest_sock_send_racing(listener, sock), 10))
|
||||||
|
listener = socket.socket()
|
||||||
|
sock = socket.socket()
|
||||||
|
with listener, sock:
|
||||||
|
self.loop.run_until_complete(asyncio.wait_for(
|
||||||
|
self._basetest_sock_connect_racing(listener, sock), 10))
|
||||||
|
|
||||||
async def _basetest_huge_content(self, address):
|
async def _basetest_huge_content(self, address):
|
||||||
sock = socket.socket()
|
sock = socket.socket()
|
||||||
sock.setblocking(False)
|
sock.setblocking(False)
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Fix asyncio ``loop.sock_*`` race condition issue
|
Loading…
Reference in New Issue