mirror of https://github.com/python/cpython
Issue #28176: Fix callbacks race in asyncio.SelectorLoop.sock_connect.
This commit is contained in:
parent
4c5bf3bc52
commit
d6c6771fc9
|
@ -400,6 +400,7 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
data = data[n:]
|
||||
self.add_writer(fd, self._sock_sendall, fut, True, sock, data)
|
||||
|
||||
@coroutine
|
||||
def sock_connect(self, sock, address):
|
||||
"""Connect to a remote socket at address.
|
||||
|
||||
|
@ -408,24 +409,16 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
if self._debug and sock.gettimeout() != 0:
|
||||
raise ValueError("the socket must be non-blocking")
|
||||
|
||||
fut = self.create_future()
|
||||
if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX:
|
||||
self._sock_connect(fut, sock, address)
|
||||
else:
|
||||
if not hasattr(socket, 'AF_UNIX') or sock.family != socket.AF_UNIX:
|
||||
resolved = base_events._ensure_resolved(
|
||||
address, family=sock.family, proto=sock.proto, loop=self)
|
||||
resolved.add_done_callback(
|
||||
lambda resolved: self._on_resolved(fut, sock, resolved))
|
||||
|
||||
return fut
|
||||
|
||||
def _on_resolved(self, fut, sock, resolved):
|
||||
try:
|
||||
if not resolved.done():
|
||||
yield from resolved
|
||||
_, _, _, _, address = resolved.result()[0]
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
else:
|
||||
self._sock_connect(fut, sock, address)
|
||||
|
||||
fut = self.create_future()
|
||||
self._sock_connect(fut, sock, address)
|
||||
return (yield from fut)
|
||||
|
||||
def _sock_connect(self, fut, sock, address):
|
||||
fd = sock.fileno()
|
||||
|
@ -436,8 +429,8 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
|
|||
# connection runs in background. We have to wait until the socket
|
||||
# becomes writable to be notified when the connection succeed or
|
||||
# fails.
|
||||
fut.add_done_callback(functools.partial(self._sock_connect_done,
|
||||
fd))
|
||||
fut.add_done_callback(
|
||||
functools.partial(self._sock_connect_done, fd))
|
||||
self.add_writer(fd, self._sock_connect_cb, fut, sock, address)
|
||||
except Exception as exc:
|
||||
fut.set_exception(exc)
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
import errno
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from unittest import mock
|
||||
try:
|
||||
|
@ -337,18 +339,6 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
|
|||
(10, self.loop._sock_sendall, f, True, sock, b'data'),
|
||||
self.loop.add_writer.call_args[0])
|
||||
|
||||
def test_sock_connect(self):
|
||||
sock = test_utils.mock_nonblocking_socket()
|
||||
self.loop._sock_connect = mock.Mock()
|
||||
|
||||
f = self.loop.sock_connect(sock, ('127.0.0.1', 8080))
|
||||
self.assertIsInstance(f, asyncio.Future)
|
||||
self.loop._run_once()
|
||||
future_in, sock_in, address_in = self.loop._sock_connect.call_args[0]
|
||||
self.assertEqual(future_in, f)
|
||||
self.assertEqual(sock_in, sock)
|
||||
self.assertEqual(address_in, ('127.0.0.1', 8080))
|
||||
|
||||
def test_sock_connect_timeout(self):
|
||||
# asyncio issue #205: sock_connect() must unregister the socket on
|
||||
# timeout error
|
||||
|
@ -360,29 +350,34 @@ class BaseSelectorEventLoopTests(test_utils.TestCase):
|
|||
sock.connect.side_effect = BlockingIOError
|
||||
|
||||
# first call to sock_connect() registers the socket
|
||||
fut = self.loop.sock_connect(sock, ('127.0.0.1', 80))
|
||||
fut = self.loop.create_task(
|
||||
self.loop.sock_connect(sock, ('127.0.0.1', 80)))
|
||||
self.loop._run_once()
|
||||
self.assertTrue(sock.connect.called)
|
||||
self.assertTrue(self.loop.add_writer.called)
|
||||
self.assertEqual(len(fut._callbacks), 1)
|
||||
|
||||
# on timeout, the socket must be unregistered
|
||||
sock.connect.reset_mock()
|
||||
fut.set_exception(asyncio.TimeoutError)
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
fut.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
self.loop.run_until_complete(fut)
|
||||
self.assertTrue(self.loop.remove_writer.called)
|
||||
|
||||
def test_sock_connect_resolve_using_socket_params(self):
|
||||
@mock.patch('socket.getaddrinfo')
|
||||
def test_sock_connect_resolve_using_socket_params(self, m_gai):
|
||||
addr = ('need-resolution.com', 8080)
|
||||
sock = test_utils.mock_nonblocking_socket()
|
||||
self.loop.getaddrinfo = mock.Mock()
|
||||
self.loop.sock_connect(sock, addr)
|
||||
while not self.loop.getaddrinfo.called:
|
||||
m_gai.side_effect = (None, None, None, None, ('127.0.0.1', 0))
|
||||
m_gai._is_coroutine = False
|
||||
con = self.loop.create_task(self.loop.sock_connect(sock, addr))
|
||||
while not m_gai.called:
|
||||
self.loop._run_once()
|
||||
self.loop.getaddrinfo.assert_called_with(
|
||||
*addr, type=sock.type, family=sock.family, proto=sock.proto,
|
||||
flags=0)
|
||||
m_gai.assert_called_with(
|
||||
addr[0], addr[1], sock.family, sock.type, sock.proto, 0)
|
||||
|
||||
con.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
self.loop.run_until_complete(con)
|
||||
|
||||
def test__sock_connect(self):
|
||||
f = asyncio.Future(loop=self.loop)
|
||||
|
@ -1792,5 +1787,88 @@ class SelectorDatagramTransportTests(test_utils.TestCase):
|
|||
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
|
||||
|
||||
|
||||
class SelectorLoopFunctionalTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
|
||||
@asyncio.coroutine
|
||||
def recv_all(self, sock, nbytes):
|
||||
buf = b''
|
||||
while len(buf) < nbytes:
|
||||
buf += yield from self.loop.sock_recv(sock, nbytes - len(buf))
|
||||
return buf
|
||||
|
||||
def test_sock_connect_sock_write_race(self):
|
||||
TIMEOUT = 3.0
|
||||
PAYLOAD = b'DATA' * 1024 * 1024
|
||||
|
||||
class Server(threading.Thread):
|
||||
def __init__(self, *args, srv_sock, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.srv_sock = srv_sock
|
||||
|
||||
def run(self):
|
||||
with self.srv_sock:
|
||||
srv_sock.listen(100)
|
||||
|
||||
sock, addr = self.srv_sock.accept()
|
||||
sock.settimeout(TIMEOUT)
|
||||
|
||||
with sock:
|
||||
sock.sendall(b'helo')
|
||||
|
||||
buf = bytearray()
|
||||
while len(buf) < len(PAYLOAD):
|
||||
pack = sock.recv(1024 * 65)
|
||||
if not pack:
|
||||
break
|
||||
buf.extend(pack)
|
||||
|
||||
@asyncio.coroutine
|
||||
def client(addr):
|
||||
sock = socket.socket()
|
||||
with sock:
|
||||
sock.setblocking(False)
|
||||
|
||||
started = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - started > TIMEOUT:
|
||||
self.fail('unable to connect to the socket')
|
||||
return
|
||||
try:
|
||||
yield from self.loop.sock_connect(sock, addr)
|
||||
except OSError:
|
||||
yield from asyncio.sleep(0.05, loop=self.loop)
|
||||
else:
|
||||
break
|
||||
|
||||
# Give 'Server' thread a chance to accept and send b'helo'
|
||||
time.sleep(0.1)
|
||||
|
||||
data = yield from self.recv_all(sock, 4)
|
||||
self.assertEqual(data, b'helo')
|
||||
yield from self.loop.sock_sendall(sock, PAYLOAD)
|
||||
|
||||
srv_sock = socket.socket()
|
||||
srv_sock.settimeout(TIMEOUT)
|
||||
srv_sock.bind(('127.0.0.1', 0))
|
||||
srv_addr = srv_sock.getsockname()
|
||||
|
||||
srv = Server(srv_sock=srv_sock, daemon=True)
|
||||
srv.start()
|
||||
|
||||
try:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv_addr), loop=self.loop,
|
||||
timeout=TIMEOUT))
|
||||
finally:
|
||||
srv.join()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue