From 88f07a804a0adc0b6ee87687b59d8416113c7331 Mon Sep 17 00:00:00 2001 From: twisteroid ambassador Date: Sun, 5 May 2019 19:14:35 +0800 Subject: [PATCH] bpo-33530: Implement Happy Eyeballs in asyncio, v2 (GH-7237) Added two keyword arguments, `delay` and `interleave`, to `BaseEventLoop.create_connection`. Happy eyeballs is activated if `delay` is specified. We now have documentation for the new arguments. `staggered_race()` is in its own module, but not exported to the main asyncio package. https://bugs.python.org/issue33530 --- Doc/library/asyncio-eventloop.rst | 24 ++- Lib/asyncio/base_events.py | 125 ++++++++++----- Lib/asyncio/events.py | 3 +- Lib/asyncio/staggered.py | 147 ++++++++++++++++++ .../2018-05-29-18-34-53.bpo-33530._4Q_bi.rst | 3 + 5 files changed, 264 insertions(+), 38 deletions(-) create mode 100644 Lib/asyncio/staggered.py create mode 100644 Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index e2b31245392..06f673be790 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -397,9 +397,27 @@ Opening network connections If given, these should all be integers from the corresponding :mod:`socket` module constants. + * *happy_eyeballs_delay*, if given, enables Happy Eyeballs for this + connection. It should + be a floating-point number representing the amount of time in seconds + to wait for a connection attempt to complete, before starting the next + attempt in parallel. This is the "Connection Attempt Delay" as defined + in :rfc:`8305`. A sensible default value recommended by the RFC is ``0.25`` + (250 milliseconds). + + * *interleave* controls address reordering when a host name resolves to + multiple IP addresses. + If ``0`` or unspecified, no reordering is done, and addresses are + tried in the order returned by :meth:`getaddrinfo`. If a positive integer + is specified, the addresses are interleaved by address family, and the + given integer is interpreted as "First Address Family Count" as defined + in :rfc:`8305`. The default is ``0`` if *happy_eyeballs_delay* is not + specified, and ``1`` if it is. + * *sock*, if given, should be an existing, already connected :class:`socket.socket` object to be used by the transport. - If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags* + If *sock* is given, none of *host*, *port*, *family*, *proto*, *flags*, + *happy_eyeballs_delay*, *interleave* and *local_addr* should be specified. * *local_addr*, if given, is a ``(local_host, local_port)`` tuple used @@ -410,6 +428,10 @@ Opening network connections to wait for the TLS handshake to complete before aborting the connection. ``60.0`` seconds if ``None`` (default). + .. versionadded:: 3.8 + + The *happy_eyeballs_delay* and *interleave* parameters. + .. versionadded:: 3.7 The *ssl_handshake_timeout* parameter. diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 9b4b846131d..c58906f8b48 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,6 +16,7 @@ to modify the meaning of the API call itself. import collections import collections.abc import concurrent.futures +import functools import heapq import itertools import os @@ -41,6 +42,7 @@ from . import exceptions from . import futures from . import protocols from . import sslproto +from . import staggered from . import tasks from . import transports from .log import logger @@ -159,6 +161,28 @@ def _ipaddr_info(host, port, family, type, proto): return None +def _interleave_addrinfos(addrinfos, first_address_family_count=1): + """Interleave list of addrinfo tuples by family.""" + # Group addresses by family + addrinfos_by_family = collections.OrderedDict() + for addr in addrinfos: + family = addr[0] + if family not in addrinfos_by_family: + addrinfos_by_family[family] = [] + addrinfos_by_family[family].append(addr) + addrinfos_lists = list(addrinfos_by_family.values()) + + reordered = [] + if first_address_family_count > 1: + reordered.extend(addrinfos_lists[0][:first_address_family_count - 1]) + del addrinfos_lists[0][:first_address_family_count - 1] + reordered.extend( + a for a in itertools.chain.from_iterable( + itertools.zip_longest(*addrinfos_lists) + ) if a is not None) + return reordered + + def _run_until_complete_cb(fut): if not fut.cancelled(): exc = fut.exception() @@ -871,12 +895,49 @@ class BaseEventLoop(events.AbstractEventLoop): "offset must be a non-negative integer (got {!r})".format( offset)) + async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None): + """Create, bind and connect one socket.""" + my_exceptions = [] + exceptions.append(my_exceptions) + family, type_, proto, _, address = addr_info + sock = None + try: + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setblocking(False) + if local_addr_infos is not None: + for _, _, _, _, laddr in local_addr_infos: + try: + sock.bind(laddr) + break + except OSError as exc: + msg = ( + f'error while attempting to bind on ' + f'address {laddr!r}: ' + f'{exc.strerror.lower()}' + ) + exc = OSError(exc.errno, msg) + my_exceptions.append(exc) + else: # all bind attempts failed + raise my_exceptions.pop() + await self.sock_connect(sock, address) + return sock + except OSError as exc: + my_exceptions.append(exc) + if sock is not None: + sock.close() + raise + except: + if sock is not None: + sock.close() + raise + async def create_connection( self, protocol_factory, host=None, port=None, *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): """Connect to a TCP server. Create a streaming transport connection to a given Internet host and @@ -911,6 +972,10 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'ssl_handshake_timeout is only meaningful with ssl') + if happy_eyeballs_delay is not None and interleave is None: + # If using happy eyeballs, default to interleave addresses by family + interleave = 1 + if host is not None or port is not None: if sock is not None: raise ValueError( @@ -929,43 +994,31 @@ class BaseEventLoop(events.AbstractEventLoop): flags=flags, loop=self) if not laddr_infos: raise OSError('getaddrinfo() returned empty list') + else: + laddr_infos = None + + if interleave: + infos = _interleave_addrinfos(infos, interleave) exceptions = [] - for family, type, proto, cname, address in infos: - try: - sock = socket.socket(family=family, type=type, proto=proto) - sock.setblocking(False) - if local_addr is not None: - for _, _, _, _, laddr in laddr_infos: - try: - sock.bind(laddr) - break - except OSError as exc: - msg = ( - f'error while attempting to bind on ' - f'address {laddr!r}: ' - f'{exc.strerror.lower()}' - ) - exc = OSError(exc.errno, msg) - exceptions.append(exc) - else: - sock.close() - sock = None - continue - if self._debug: - logger.debug("connect %r to %r", sock, address) - await self.sock_connect(sock, address) - except OSError as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - except: - if sock is not None: - sock.close() - raise - else: - break - else: + if happy_eyeballs_delay is None: + # not using happy eyeballs + for addrinfo in infos: + try: + sock = await self._connect_sock( + exceptions, addrinfo, laddr_infos) + break + except OSError: + continue + else: # using happy eyeballs + sock, _, _ = await staggered.staggered_race( + (functools.partial(self._connect_sock, + exceptions, addrinfo, laddr_infos) + for addrinfo in infos), + happy_eyeballs_delay, loop=self) + + if sock is None: + exceptions = [exc for sub in exceptions for exc in sub] if len(exceptions) == 1: raise exceptions[0] else: diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 163b868afee..9a923514db0 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -298,7 +298,8 @@ class AbstractEventLoop: *, ssl=None, family=0, proto=0, flags=0, sock=None, local_addr=None, server_hostname=None, - ssl_handshake_timeout=None): + ssl_handshake_timeout=None, + happy_eyeballs_delay=None, interleave=None): raise NotImplementedError async def create_server( diff --git a/Lib/asyncio/staggered.py b/Lib/asyncio/staggered.py new file mode 100644 index 00000000000..feec681b437 --- /dev/null +++ b/Lib/asyncio/staggered.py @@ -0,0 +1,147 @@ +"""Support for running coroutines in parallel with staggered start times.""" + +__all__ = 'staggered_race', + +import contextlib +import typing + +from . import events +from . import futures +from . import locks +from . import tasks + + +async def staggered_race( + coro_fns: typing.Iterable[typing.Callable[[], typing.Awaitable]], + delay: typing.Optional[float], + *, + loop: events.AbstractEventLoop = None, +) -> typing.Tuple[ + typing.Any, + typing.Optional[int], + typing.List[typing.Optional[Exception]] +]: + """Run coroutines with staggered start times and take the first to finish. + + This method takes an iterable of coroutine functions. The first one is + started immediately. From then on, whenever the immediately preceding one + fails (raises an exception), or when *delay* seconds has passed, the next + coroutine is started. This continues until one of the coroutines complete + successfully, in which case all others are cancelled, or until all + coroutines fail. + + The coroutines provided should be well-behaved in the following way: + + * They should only ``return`` if completed successfully. + + * They should always raise an exception if they did not complete + successfully. In particular, if they handle cancellation, they should + probably reraise, like this:: + + try: + # do work + except asyncio.CancelledError: + # undo partially completed work + raise + + Args: + coro_fns: an iterable of coroutine functions, i.e. callables that + return a coroutine object when called. Use ``functools.partial`` or + lambdas to pass arguments. + + delay: amount of time, in seconds, between starting coroutines. If + ``None``, the coroutines will run sequentially. + + loop: the event loop to use. + + Returns: + tuple *(winner_result, winner_index, exceptions)* where + + - *winner_result*: the result of the winning coroutine, or ``None`` + if no coroutines won. + + - *winner_index*: the index of the winning coroutine in + ``coro_fns``, or ``None`` if no coroutines won. If the winning + coroutine may return None on success, *winner_index* can be used + to definitively determine whether any coroutine won. + + - *exceptions*: list of exceptions returned by the coroutines. + ``len(exceptions)`` is equal to the number of coroutines actually + started, and the order is the same as in ``coro_fns``. The winning + coroutine's entry is ``None``. + + """ + # TODO: when we have aiter() and anext(), allow async iterables in coro_fns. + loop = loop or events.get_running_loop() + enum_coro_fns = enumerate(coro_fns) + winner_result = None + winner_index = None + exceptions = [] + running_tasks = [] + + async def run_one_coro( + previous_failed: typing.Optional[locks.Event]) -> None: + # Wait for the previous task to finish, or for delay seconds + if previous_failed is not None: + with contextlib.suppress(futures.TimeoutError): + # Use asyncio.wait_for() instead of asyncio.wait() here, so + # that if we get cancelled at this point, Event.wait() is also + # cancelled, otherwise there will be a "Task destroyed but it is + # pending" later. + await tasks.wait_for(previous_failed.wait(), delay) + # Get the next coroutine to run + try: + this_index, coro_fn = next(enum_coro_fns) + except StopIteration: + return + # Start task that will run the next coroutine + this_failed = locks.Event() + next_task = loop.create_task(run_one_coro(this_failed)) + running_tasks.append(next_task) + assert len(running_tasks) == this_index + 2 + # Prepare place to put this coroutine's exceptions if not won + exceptions.append(None) + assert len(exceptions) == this_index + 1 + + try: + result = await coro_fn() + except Exception as e: + exceptions[this_index] = e + this_failed.set() # Kickstart the next coroutine + else: + # Store winner's results + nonlocal winner_index, winner_result + assert winner_index is None + winner_index = this_index + winner_result = result + # Cancel all other tasks. We take care to not cancel the current + # task as well. If we do so, then since there is no `await` after + # here and CancelledError are usually thrown at one, we will + # encounter a curious corner case where the current task will end + # up as done() == True, cancelled() == False, exception() == + # asyncio.CancelledError. This behavior is specified in + # https://bugs.python.org/issue30048 + for i, t in enumerate(running_tasks): + if i != this_index: + t.cancel() + + first_task = loop.create_task(run_one_coro(None)) + running_tasks.append(first_task) + try: + # Wait for a growing list of tasks to all finish: poor man's version of + # curio's TaskGroup or trio's nursery + done_count = 0 + while done_count != len(running_tasks): + done, _ = await tasks.wait(running_tasks) + done_count = len(done) + # If run_one_coro raises an unhandled exception, it's probably a + # programming error, and I want to see it. + if __debug__: + for d in done: + if d.done() and not d.cancelled() and d.exception(): + raise d.exception() + return winner_result, winner_index, exceptions + finally: + # Make sure no tasks are left running if we leave this function + for t in running_tasks: + t.cancel() diff --git a/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst b/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst new file mode 100644 index 00000000000..747219b1bfb --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-05-29-18-34-53.bpo-33530._4Q_bi.rst @@ -0,0 +1,3 @@ +Implemented Happy Eyeballs in `asyncio.create_connection()`. Added two new +arguments, *happy_eyeballs_delay* and *interleave*, +to specify Happy Eyeballs behavior.