bpo-33530: Implement Happy Eyeballs in asyncio, v2 (GH-7237)
Added two keyword arguments, `delay` and `interleave`, to `BaseEventLoop.create_connection`. Happy eyeballs is activated if `delay` is specified. We now have documentation for the new arguments. `staggered_race()` is in its own module, but not exported to the main asyncio package. https://bugs.python.org/issue33530
This commit is contained in:
parent
c4d92c8ada
commit
88f07a804a
|
@ -397,9 +397,27 @@ Opening network connections
|
||||||
If given, these should all be integers from the corresponding
|
If given, these should all be integers from the corresponding
|
||||||
:mod:`socket` module constants.
|
:mod:`socket` module constants.
|
||||||
|
|
||||||
|
* *happy_eyeballs_delay*, if given, enables Happy Eyeballs for this
|
||||||
|
connection. It should
|
||||||
|
be a floating-point number representing the amount of time in seconds
|
||||||
|
to wait for a connection attempt to complete, before starting the next
|
||||||
|
attempt in parallel. This is the "Connection Attempt Delay" as defined
|
||||||
|
in :rfc:`8305`. A sensible default value recommended by the RFC is ``0.25``
|
||||||
|
(250 milliseconds).
|
||||||
|
|
||||||
|
* *interleave* controls address reordering when a host name resolves to
|
||||||
|
multiple IP addresses.
|
||||||
|
If ``0`` or unspecified, no reordering is done, and addresses are
|
||||||
|
tried in the order returned by :meth:`getaddrinfo`. If a positive integer
|
||||||
|
is specified, the addresses are interleaved by address family, and the
|
||||||
|
given integer is interpreted as "First Address Family Count" as defined
|
||||||
|
in :rfc:`8305`. The default is ``0`` if *happy_eyeballs_delay* is not
|
||||||
|
specified, and ``1`` if it is.
|
||||||
|
|
||||||
* *sock*, if given, should be an existing, already connected
|
* *sock*, if given, should be an existing, already connected
|
||||||
:class:`socket.socket` object to be used by the transport.
|
:class:`socket.socket` object to be used by the transport.
|
||||||
If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags*
|
If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags*,
|
||||||
|
*happy_eyeballs_delay*, *interleave*
|
||||||
and *local_addr* should be specified.
|
and *local_addr* should be specified.
|
||||||
|
|
||||||
* *local_addr*, if given, is a ``(local_host, local_port)`` tuple used
|
* *local_addr*, if given, is a ``(local_host, local_port)`` tuple used
|
||||||
|
@ -410,6 +428,10 @@ Opening network connections
|
||||||
to wait for the TLS handshake to complete before aborting the connection.
|
to wait for the TLS handshake to complete before aborting the connection.
|
||||||
``60.0`` seconds if ``None`` (default).
|
``60.0`` seconds if ``None`` (default).
|
||||||
|
|
||||||
|
.. versionadded:: 3.8
|
||||||
|
|
||||||
|
The *happy_eyeballs_delay* and *interleave* parameters.
|
||||||
|
|
||||||
.. versionadded:: 3.7
|
.. versionadded:: 3.7
|
||||||
|
|
||||||
The *ssl_handshake_timeout* parameter.
|
The *ssl_handshake_timeout* parameter.
|
||||||
|
|
|
@ -16,6 +16,7 @@ to modify the meaning of the API call itself.
|
||||||
import collections
|
import collections
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import functools
|
||||||
import heapq
|
import heapq
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
|
@ -41,6 +42,7 @@ from . import exceptions
|
||||||
from . import futures
|
from . import futures
|
||||||
from . import protocols
|
from . import protocols
|
||||||
from . import sslproto
|
from . import sslproto
|
||||||
|
from . import staggered
|
||||||
from . import tasks
|
from . import tasks
|
||||||
from . import transports
|
from . import transports
|
||||||
from .log import logger
|
from .log import logger
|
||||||
|
@ -159,6 +161,28 @@ def _ipaddr_info(host, port, family, type, proto):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _interleave_addrinfos(addrinfos, first_address_family_count=1):
|
||||||
|
"""Interleave list of addrinfo tuples by family."""
|
||||||
|
# Group addresses by family
|
||||||
|
addrinfos_by_family = collections.OrderedDict()
|
||||||
|
for addr in addrinfos:
|
||||||
|
family = addr[0]
|
||||||
|
if family not in addrinfos_by_family:
|
||||||
|
addrinfos_by_family[family] = []
|
||||||
|
addrinfos_by_family[family].append(addr)
|
||||||
|
addrinfos_lists = list(addrinfos_by_family.values())
|
||||||
|
|
||||||
|
reordered = []
|
||||||
|
if first_address_family_count > 1:
|
||||||
|
reordered.extend(addrinfos_lists[0][:first_address_family_count - 1])
|
||||||
|
del addrinfos_lists[0][:first_address_family_count - 1]
|
||||||
|
reordered.extend(
|
||||||
|
a for a in itertools.chain.from_iterable(
|
||||||
|
itertools.zip_longest(*addrinfos_lists)
|
||||||
|
) if a is not None)
|
||||||
|
return reordered
|
||||||
|
|
||||||
|
|
||||||
def _run_until_complete_cb(fut):
|
def _run_until_complete_cb(fut):
|
||||||
if not fut.cancelled():
|
if not fut.cancelled():
|
||||||
exc = fut.exception()
|
exc = fut.exception()
|
||||||
|
@ -871,12 +895,49 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
"offset must be a non-negative integer (got {!r})".format(
|
"offset must be a non-negative integer (got {!r})".format(
|
||||||
offset))
|
offset))
|
||||||
|
|
||||||
|
async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None):
|
||||||
|
"""Create, bind and connect one socket."""
|
||||||
|
my_exceptions = []
|
||||||
|
exceptions.append(my_exceptions)
|
||||||
|
family, type_, proto, _, address = addr_info
|
||||||
|
sock = None
|
||||||
|
try:
|
||||||
|
sock = socket.socket(family=family, type=type_, proto=proto)
|
||||||
|
sock.setblocking(False)
|
||||||
|
if local_addr_infos is not None:
|
||||||
|
for _, _, _, _, laddr in local_addr_infos:
|
||||||
|
try:
|
||||||
|
sock.bind(laddr)
|
||||||
|
break
|
||||||
|
except OSError as exc:
|
||||||
|
msg = (
|
||||||
|
f'error while attempting to bind on '
|
||||||
|
f'address {laddr!r}: '
|
||||||
|
f'{exc.strerror.lower()}'
|
||||||
|
)
|
||||||
|
exc = OSError(exc.errno, msg)
|
||||||
|
my_exceptions.append(exc)
|
||||||
|
else: # all bind attempts failed
|
||||||
|
raise my_exceptions.pop()
|
||||||
|
await self.sock_connect(sock, address)
|
||||||
|
return sock
|
||||||
|
except OSError as exc:
|
||||||
|
my_exceptions.append(exc)
|
||||||
|
if sock is not None:
|
||||||
|
sock.close()
|
||||||
|
raise
|
||||||
|
except:
|
||||||
|
if sock is not None:
|
||||||
|
sock.close()
|
||||||
|
raise
|
||||||
|
|
||||||
async def create_connection(
|
async def create_connection(
|
||||||
self, protocol_factory, host=None, port=None,
|
self, protocol_factory, host=None, port=None,
|
||||||
*, ssl=None, family=0,
|
*, ssl=None, family=0,
|
||||||
proto=0, flags=0, sock=None,
|
proto=0, flags=0, sock=None,
|
||||||
local_addr=None, server_hostname=None,
|
local_addr=None, server_hostname=None,
|
||||||
ssl_handshake_timeout=None):
|
ssl_handshake_timeout=None,
|
||||||
|
happy_eyeballs_delay=None, interleave=None):
|
||||||
"""Connect to a TCP server.
|
"""Connect to a TCP server.
|
||||||
|
|
||||||
Create a streaming transport connection to a given Internet host and
|
Create a streaming transport connection to a given Internet host and
|
||||||
|
@ -911,6 +972,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'ssl_handshake_timeout is only meaningful with ssl')
|
'ssl_handshake_timeout is only meaningful with ssl')
|
||||||
|
|
||||||
|
if happy_eyeballs_delay is not None and interleave is None:
|
||||||
|
# If using happy eyeballs, default to interleave addresses by family
|
||||||
|
interleave = 1
|
||||||
|
|
||||||
if host is not None or port is not None:
|
if host is not None or port is not None:
|
||||||
if sock is not None:
|
if sock is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -929,43 +994,31 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
flags=flags, loop=self)
|
flags=flags, loop=self)
|
||||||
if not laddr_infos:
|
if not laddr_infos:
|
||||||
raise OSError('getaddrinfo() returned empty list')
|
raise OSError('getaddrinfo() returned empty list')
|
||||||
|
else:
|
||||||
|
laddr_infos = None
|
||||||
|
|
||||||
|
if interleave:
|
||||||
|
infos = _interleave_addrinfos(infos, interleave)
|
||||||
|
|
||||||
exceptions = []
|
exceptions = []
|
||||||
for family, type, proto, cname, address in infos:
|
if happy_eyeballs_delay is None:
|
||||||
try:
|
# not using happy eyeballs
|
||||||
sock = socket.socket(family=family, type=type, proto=proto)
|
for addrinfo in infos:
|
||||||
sock.setblocking(False)
|
try:
|
||||||
if local_addr is not None:
|
sock = await self._connect_sock(
|
||||||
for _, _, _, _, laddr in laddr_infos:
|
exceptions, addrinfo, laddr_infos)
|
||||||
try:
|
break
|
||||||
sock.bind(laddr)
|
except OSError:
|
||||||
break
|
continue
|
||||||
except OSError as exc:
|
else: # using happy eyeballs
|
||||||
msg = (
|
sock, _, _ = await staggered.staggered_race(
|
||||||
f'error while attempting to bind on '
|
(functools.partial(self._connect_sock,
|
||||||
f'address {laddr!r}: '
|
exceptions, addrinfo, laddr_infos)
|
||||||
f'{exc.strerror.lower()}'
|
for addrinfo in infos),
|
||||||
)
|
happy_eyeballs_delay, loop=self)
|
||||||
exc = OSError(exc.errno, msg)
|
|
||||||
exceptions.append(exc)
|
if sock is None:
|
||||||
else:
|
exceptions = [exc for sub in exceptions for exc in sub]
|
||||||
sock.close()
|
|
||||||
sock = None
|
|
||||||
continue
|
|
||||||
if self._debug:
|
|
||||||
logger.debug("connect %r to %r", sock, address)
|
|
||||||
await self.sock_connect(sock, address)
|
|
||||||
except OSError as exc:
|
|
||||||
if sock is not None:
|
|
||||||
sock.close()
|
|
||||||
exceptions.append(exc)
|
|
||||||
except:
|
|
||||||
if sock is not None:
|
|
||||||
sock.close()
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if len(exceptions) == 1:
|
if len(exceptions) == 1:
|
||||||
raise exceptions[0]
|
raise exceptions[0]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -298,7 +298,8 @@ class AbstractEventLoop:
|
||||||
*, ssl=None, family=0, proto=0,
|
*, ssl=None, family=0, proto=0,
|
||||||
flags=0, sock=None, local_addr=None,
|
flags=0, sock=None, local_addr=None,
|
||||||
server_hostname=None,
|
server_hostname=None,
|
||||||
ssl_handshake_timeout=None):
|
ssl_handshake_timeout=None,
|
||||||
|
happy_eyeballs_delay=None, interleave=None):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def create_server(
|
async def create_server(
|
||||||
|
|
|
@ -0,0 +1,147 @@
|
||||||
|
"""Support for running coroutines in parallel with staggered start times."""
|
||||||
|
|
||||||
|
__all__ = 'staggered_race',
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import typing
|
||||||
|
|
||||||
|
from . import events
|
||||||
|
from . import futures
|
||||||
|
from . import locks
|
||||||
|
from . import tasks
|
||||||
|
|
||||||
|
|
||||||
|
async def staggered_race(
|
||||||
|
coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]],
|
||||||
|
delay: typing.Optional[float],
|
||||||
|
*,
|
||||||
|
loop: events.AbstractEventLoop = None,
|
||||||
|
) -> typing.Tuple[
|
||||||
|
typing.Any,
|
||||||
|
typing.Optional[int],
|
||||||
|
typing.List[typing.Optional[Exception]]
|
||||||
|
]:
|
||||||
|
"""Run coroutines with staggered start times and take the first to finish.
|
||||||
|
|
||||||
|
This method takes an iterable of coroutine functions. The first one is
|
||||||
|
started immediately. From then on, whenever the immediately preceding one
|
||||||
|
fails (raises an exception), or when *delay* seconds has passed, the next
|
||||||
|
coroutine is started. This continues until one of the coroutines complete
|
||||||
|
successfully, in which case all others are cancelled, or until all
|
||||||
|
coroutines fail.
|
||||||
|
|
||||||
|
The coroutines provided should be well-behaved in the following way:
|
||||||
|
|
||||||
|
* They should only ``return`` if completed successfully.
|
||||||
|
|
||||||
|
* They should always raise an exception if they did not complete
|
||||||
|
successfully. In particular, if they handle cancellation, they should
|
||||||
|
probably reraise, like this::
|
||||||
|
|
||||||
|
try:
|
||||||
|
# do work
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# undo partially completed work
|
||||||
|
raise
|
||||||
|
|
||||||
|
Args:
|
||||||
|
coro_fns: an iterable of coroutine functions, i.e. callables that
|
||||||
|
return a coroutine object when called. Use ``functools.partial`` or
|
||||||
|
lambdas to pass arguments.
|
||||||
|
|
||||||
|
delay: amount of time, in seconds, between starting coroutines. If
|
||||||
|
``None``, the coroutines will run sequentially.
|
||||||
|
|
||||||
|
loop: the event loop to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple *(winner_result, winner_index, exceptions)* where
|
||||||
|
|
||||||
|
- *winner_result*: the result of the winning coroutine, or ``None``
|
||||||
|
if no coroutines won.
|
||||||
|
|
||||||
|
- *winner_index*: the index of the winning coroutine in
|
||||||
|
``coro_fns``, or ``None`` if no coroutines won. If the winning
|
||||||
|
coroutine may return None on success, *winner_index* can be used
|
||||||
|
to definitively determine whether any coroutine won.
|
||||||
|
|
||||||
|
- *exceptions*: list of exceptions returned by the coroutines.
|
||||||
|
``len(exceptions)`` is equal to the number of coroutines actually
|
||||||
|
started, and the order is the same as in ``coro_fns``. The winning
|
||||||
|
coroutine's entry is ``None``.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
|
||||||
|
loop = loop or events.get_running_loop()
|
||||||
|
enum_coro_fns = enumerate(coro_fns)
|
||||||
|
winner_result = None
|
||||||
|
winner_index = None
|
||||||
|
exceptions = []
|
||||||
|
running_tasks = []
|
||||||
|
|
||||||
|
async def run_one_coro(
|
||||||
|
previous_failed: typing.Optional[locks.Event]) -> None:
|
||||||
|
# Wait for the previous task to finish, or for delay seconds
|
||||||
|
if previous_failed is not None:
|
||||||
|
with contextlib.suppress(futures.TimeoutError):
|
||||||
|
# Use asyncio.wait_for() instead of asyncio.wait() here, so
|
||||||
|
# that if we get cancelled at this point, Event.wait() is also
|
||||||
|
# cancelled, otherwise there will be a "Task destroyed but it is
|
||||||
|
# pending" later.
|
||||||
|
await tasks.wait_for(previous_failed.wait(), delay)
|
||||||
|
# Get the next coroutine to run
|
||||||
|
try:
|
||||||
|
this_index, coro_fn = next(enum_coro_fns)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
# Start task that will run the next coroutine
|
||||||
|
this_failed = locks.Event()
|
||||||
|
next_task = loop.create_task(run_one_coro(this_failed))
|
||||||
|
running_tasks.append(next_task)
|
||||||
|
assert len(running_tasks) == this_index + 2
|
||||||
|
# Prepare place to put this coroutine's exceptions if not won
|
||||||
|
exceptions.append(None)
|
||||||
|
assert len(exceptions) == this_index + 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await coro_fn()
|
||||||
|
except Exception as e:
|
||||||
|
exceptions[this_index] = e
|
||||||
|
this_failed.set() # Kickstart the next coroutine
|
||||||
|
else:
|
||||||
|
# Store winner's results
|
||||||
|
nonlocal winner_index, winner_result
|
||||||
|
assert winner_index is None
|
||||||
|
winner_index = this_index
|
||||||
|
winner_result = result
|
||||||
|
# Cancel all other tasks. We take care to not cancel the current
|
||||||
|
# task as well. If we do so, then since there is no `await` after
|
||||||
|
# here and CancelledError are usually thrown at one, we will
|
||||||
|
# encounter a curious corner case where the current task will end
|
||||||
|
# up as done() == True, cancelled() == False, exception() ==
|
||||||
|
# asyncio.CancelledError. This behavior is specified in
|
||||||
|
# https://bugs.python.org/issue30048
|
||||||
|
for i, t in enumerate(running_tasks):
|
||||||
|
if i != this_index:
|
||||||
|
t.cancel()
|
||||||
|
|
||||||
|
first_task = loop.create_task(run_one_coro(None))
|
||||||
|
running_tasks.append(first_task)
|
||||||
|
try:
|
||||||
|
# Wait for a growing list of tasks to all finish: poor man's version of
|
||||||
|
# curio's TaskGroup or trio's nursery
|
||||||
|
done_count = 0
|
||||||
|
while done_count != len(running_tasks):
|
||||||
|
done, _ = await tasks.wait(running_tasks)
|
||||||
|
done_count = len(done)
|
||||||
|
# If run_one_coro raises an unhandled exception, it's probably a
|
||||||
|
# programming error, and I want to see it.
|
||||||
|
if __debug__:
|
||||||
|
for d in done:
|
||||||
|
if d.done() and not d.cancelled() and d.exception():
|
||||||
|
raise d.exception()
|
||||||
|
return winner_result, winner_index, exceptions
|
||||||
|
finally:
|
||||||
|
# Make sure no tasks are left running if we leave this function
|
||||||
|
for t in running_tasks:
|
||||||
|
t.cancel()
|
|
@ -0,0 +1,3 @@
|
||||||
|
Implemented Happy Eyeballs in `asyncio.create_connection()`. Added two new
|
||||||
|
arguments, *happy_eyeballs_delay* and *interleave*,
|
||||||
|
to specify Happy Eyeballs behavior.
|
Loading…
Reference in New Issue