From b9bf913ab32d27d221fb765fd90d64d07e926000 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Mon, 5 Oct 2015 09:15:28 -0700 Subject: [PATCH] Issue #23972: updates to asyncio datagram API. By Chris Laws. --- Doc/library/asyncio-eventloop.rst | 46 +++++- Lib/asyncio/base_events.py | 174 ++++++++++++++-------- Lib/asyncio/events.py | 40 ++++- Lib/test/test_asyncio/test_base_events.py | 140 ++++++++++++++++- Lib/test/test_asyncio/test_events.py | 52 +++++++ Misc/ACKS | 1 + Misc/NEWS | 6 + 7 files changed, 385 insertions(+), 74 deletions(-) diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 656eca636e4..e867c80d4ed 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -283,17 +283,50 @@ Creating connections (:class:`StreamReader`, :class:`StreamWriter`) instead of a protocol. -.. coroutinemethod:: BaseEventLoop.create_datagram_endpoint(protocol_factory, local_addr=None, remote_addr=None, \*, family=0, proto=0, flags=0) +.. coroutinemethod:: BaseEventLoop.create_datagram_endpoint(protocol_factory, local_addr=None, remote_addr=None, \*, family=0, proto=0, flags=0, reuse_address=None, reuse_port=None, allow_broadcast=None, sock=None) Create datagram connection: socket family :py:data:`~socket.AF_INET` or :py:data:`~socket.AF_INET6` depending on *host* (or *family* if specified), - socket type :py:data:`~socket.SOCK_DGRAM`. + socket type :py:data:`~socket.SOCK_DGRAM`. *protocol_factory* must be a + callable returning a :ref:`protocol ` instance. This method is a :ref:`coroutine ` which will try to establish the connection in the background. When successful, the coroutine returns a ``(transport, protocol)`` pair. - See the :meth:`BaseEventLoop.create_connection` method for parameters. + Options changing how the connection is created: + + * *local_addr*, if given, is a ``(local_host, local_port)`` tuple used + to bind the socket to locally. The *local_host* and *local_port* + are looked up using :meth:`getaddrinfo`. + + * *remote_addr*, if given, is a ``(remote_host, remote_port)`` tuple used + to connect the socket to a remote address. The *remote_host* and + *remote_port* are looked up using :meth:`getaddrinfo`. + + * *family*, *proto*, *flags* are the optional address family, protocol + and flags to be passed through to :meth:`getaddrinfo` for *host* + resolution. If given, these should all be integers from the + corresponding :mod:`socket` module constants. + + * *reuse_address* tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified will automatically be set to True on + UNIX. + + * *reuse_port* tells the kernel to allow this endpoint to be bound to the + same port as other existing endpoints are bound to, so long as they all + set this flag when being created. This option is not supported on Windows + and some UNIX's. If the :py:data:`~socket.SO_REUSEPORT` constant is not + defined then this capability is unsupported. + + * *allow_broadcast* tells the kernel to allow this endpoint to send + messages to the broadcast address. + + * *sock* can optionally be specified in order to use a preexisting, + already connected, :class:`socket.socket` object to be used by the + transport. If specified, *local_addr* and *remote_addr* should be omitted + (must be :const:`None`). On Windows with :class:`ProactorEventLoop`, this method is not supported. @@ -320,7 +353,7 @@ Creating connections Creating listening connections ------------------------------ -.. coroutinemethod:: BaseEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None) +.. coroutinemethod:: BaseEventLoop.create_server(protocol_factory, host=None, port=None, \*, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, sock=None, backlog=100, ssl=None, reuse_address=None, reuse_port=None) Create a TCP server (socket type :data:`~socket.SOCK_STREAM`) bound to *host* and *port*. @@ -359,6 +392,11 @@ Creating listening connections expire. If not specified will automatically be set to True on UNIX. + * *reuse_port* tells the kernel to allow this endpoint to be bound to the + same port as other existing endpoints are bound to, so long as they all + set this flag when being created. This option is not supported on + Windows. + This method is a :ref:`coroutine `. On Windows with :class:`ProactorEventLoop`, SSL/TLS is not supported. diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index a50e00522cc..af9c8811bb4 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -700,75 +700,109 @@ class BaseEventLoop(events.AbstractEventLoop): @coroutine def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0): + family=0, proto=0, flags=0, + reuse_address=None, reuse_port=None, + allow_broadcast=None, sock=None): """Create datagram connection.""" - if not (local_addr or remote_addr): - if family == 0: - raise ValueError('unexpected address family') - addr_pairs_info = (((family, proto), (None, None)),) - else: - # join address by (family, protocol) - addr_infos = collections.OrderedDict() - for idx, addr in ((0, local_addr), (1, remote_addr)): - if addr is not None: - assert isinstance(addr, tuple) and len(addr) == 2, ( - '2-tuple is expected') - - infos = yield from self.getaddrinfo( - *addr, family=family, type=socket.SOCK_DGRAM, - proto=proto, flags=flags) - if not infos: - raise OSError('getaddrinfo() returned empty list') - - for fam, _, pro, _, address in infos: - key = (fam, pro) - if key not in addr_infos: - addr_infos[key] = [None, None] - addr_infos[key][idx] = address - - # each addr has to have info for each (family, proto) pair - addr_pairs_info = [ - (key, addr_pair) for key, addr_pair in addr_infos.items() - if not ((local_addr and addr_pair[0] is None) or - (remote_addr and addr_pair[1] is None))] - - if not addr_pairs_info: - raise ValueError('can not get address information') - - exceptions = [] - - for ((family, proto), - (local_address, remote_address)) in addr_pairs_info: - sock = None + if sock is not None: + if (local_addr or remote_addr or + family or proto or flags or + reuse_address or reuse_port or allow_broadcast): + # show the problematic kwargs in exception msg + opts = dict(local_addr=local_addr, remote_addr=remote_addr, + family=family, proto=proto, flags=flags, + reuse_address=reuse_address, reuse_port=reuse_port, + allow_broadcast=allow_broadcast) + problems = ', '.join( + '{}={}'.format(k, v) for k, v in opts.items() if v) + raise ValueError( + 'socket modifier keyword arguments can not be used ' + 'when sock is specified. ({})'.format(problems)) + sock.setblocking(False) r_addr = None - try: - sock = socket.socket( - family=family, type=socket.SOCK_DGRAM, proto=proto) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.setblocking(False) - - if local_addr: - sock.bind(local_address) - if remote_addr: - yield from self.sock_connect(sock, remote_address) - r_addr = remote_address - except OSError as exc: - if sock is not None: - sock.close() - exceptions.append(exc) - except: - if sock is not None: - sock.close() - raise - else: - break else: - raise exceptions[0] + if not (local_addr or remote_addr): + if family == 0: + raise ValueError('unexpected address family') + addr_pairs_info = (((family, proto), (None, None)),) + else: + # join address by (family, protocol) + addr_infos = collections.OrderedDict() + for idx, addr in ((0, local_addr), (1, remote_addr)): + if addr is not None: + assert isinstance(addr, tuple) and len(addr) == 2, ( + '2-tuple is expected') + + infos = yield from self.getaddrinfo( + *addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags) + if not infos: + raise OSError('getaddrinfo() returned empty list') + + for fam, _, pro, _, address in infos: + key = (fam, pro) + if key not in addr_infos: + addr_infos[key] = [None, None] + addr_infos[key][idx] = address + + # each addr has to have info for each (family, proto) pair + addr_pairs_info = [ + (key, addr_pair) for key, addr_pair in addr_infos.items() + if not ((local_addr and addr_pair[0] is None) or + (remote_addr and addr_pair[1] is None))] + + if not addr_pairs_info: + raise ValueError('can not get address information') + + exceptions = [] + + if reuse_address is None: + reuse_address = os.name == 'posix' and sys.platform != 'cygwin' + + for ((family, proto), + (local_address, remote_address)) in addr_pairs_info: + sock = None + r_addr = None + try: + sock = socket.socket( + family=family, type=socket.SOCK_DGRAM, proto=proto) + if reuse_address: + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + if reuse_port: + if not hasattr(socket, 'SO_REUSEPORT'): + raise ValueError( + 'reuse_port not supported by socket module') + else: + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + if allow_broadcast: + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + sock.setblocking(False) + + if local_addr: + sock.bind(local_address) + if remote_addr: + yield from self.sock_connect(sock, remote_address) + r_addr = remote_address + except OSError as exc: + if sock is not None: + sock.close() + exceptions.append(exc) + except: + if sock is not None: + sock.close() + raise + else: + break + else: + raise exceptions[0] protocol = protocol_factory() waiter = futures.Future(loop=self) - transport = self._make_datagram_transport(sock, protocol, r_addr, - waiter) + transport = self._make_datagram_transport( + sock, protocol, r_addr, waiter) if self._debug: if local_addr: logger.info("Datagram endpoint local_addr=%r remote_addr=%r " @@ -804,7 +838,8 @@ class BaseEventLoop(events.AbstractEventLoop): sock=None, backlog=100, ssl=None, - reuse_address=None): + reuse_address=None, + reuse_port=None): """Create a TCP server. The host parameter can be a string, in that case the TCP server is bound @@ -857,8 +892,15 @@ class BaseEventLoop(events.AbstractEventLoop): continue sockets.append(sock) if reuse_address: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, - True) + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR, True) + if reuse_port: + if not hasattr(socket, 'SO_REUSEPORT'): + raise ValueError( + 'reuse_port not supported by socket module') + else: + sock.setsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT, True) # Disable IPv4/IPv6 dual stack support (enabled by # default on Linux) which makes a single socket # listen on both address families. diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 1e42ddd03b5..176a8466984 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -297,7 +297,8 @@ class AbstractEventLoop: def create_server(self, protocol_factory, host=None, port=None, *, family=socket.AF_UNSPEC, flags=socket.AI_PASSIVE, - sock=None, backlog=100, ssl=None, reuse_address=None): + sock=None, backlog=100, ssl=None, reuse_address=None, + reuse_port=None): """A coroutine which creates a TCP server bound to host and port. The return value is a Server object which can be used to stop @@ -327,6 +328,11 @@ class AbstractEventLoop: TIME_WAIT state, without waiting for its natural timeout to expire. If not specified will automatically be set to True on UNIX. + + reuse_port tells the kernel to allow this endpoint to be bound to + the same port as other existing endpoints are bound to, so long as + they all set this flag when being created. This option is not + supported on Windows. """ raise NotImplementedError @@ -358,7 +364,37 @@ class AbstractEventLoop: def create_datagram_endpoint(self, protocol_factory, local_addr=None, remote_addr=None, *, - family=0, proto=0, flags=0): + family=0, proto=0, flags=0, + reuse_address=None, reuse_port=None, + allow_broadcast=None, sock=None): + """A coroutine which creates a datagram endpoint. + + This method will try to establish the endpoint in the background. + When successful, the coroutine returns a (transport, protocol) pair. + + protocol_factory must be a callable returning a protocol instance. + + socket family AF_INET or socket.AF_INET6 depending on host (or + family if specified), socket type SOCK_DGRAM. + + reuse_address tells the kernel to reuse a local socket in + TIME_WAIT state, without waiting for its natural timeout to + expire. If not specified it will automatically be set to True on + UNIX. + + reuse_port tells the kernel to allow this endpoint to be bound to + the same port as other existing endpoints are bound to, so long as + they all set this flag when being created. This option is not + supported on Windows and some UNIX's. If the + :py:data:`~socket.SO_REUSEPORT` constant is not defined then this + capability is unsupported. + + allow_broadcast tells the kernel to allow this endpoint to send + messages to the broadcast address. + + sock can optionally be specified in order to use a preexisting + socket object. + """ raise NotImplementedError # Pipes and subprocesses. diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index b1f1e56c2cf..156844001d0 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -3,6 +3,7 @@ import errno import logging import math +import os import socket import sys import threading @@ -790,11 +791,11 @@ class MyProto(asyncio.Protocol): class MyDatagramProto(asyncio.DatagramProtocol): done = None - def __init__(self, create_future=False): + def __init__(self, create_future=False, loop=None): self.state = 'INITIAL' self.nbytes = 0 if create_future: - self.done = asyncio.Future() + self.done = asyncio.Future(loop=loop) def connection_made(self, transport): self.transport = transport @@ -1099,6 +1100,19 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): f = self.loop.create_server(MyProto, '0.0.0.0', 0) self.assertRaises(OSError, self.loop.run_until_complete, f) + @mock.patch('asyncio.base_events.socket') + def test_create_server_nosoreuseport(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.SOCK_STREAM = socket.SOCK_STREAM + m_socket.SOL_SOCKET = socket.SOL_SOCKET + del m_socket.SO_REUSEPORT + m_socket.socket.return_value = mock.Mock() + + f = self.loop.create_server( + MyProto, '0.0.0.0', 0, reuse_port=True) + + self.assertRaises(ValueError, self.loop.run_until_complete, f) + @mock.patch('asyncio.base_events.socket') def test_create_server_cant_bind(self, m_socket): @@ -1199,6 +1213,128 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): self.assertRaises(Err, self.loop.run_until_complete, fut) self.assertTrue(m_sock.close.called) + def test_create_datagram_endpoint_sock(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + fut = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + sock=sock) + transport, protocol = self.loop.run_until_complete(fut) + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + def test_create_datagram_endpoint_sock_sockopts(self): + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, local_addr=('127.0.0.1', 0), sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, remote_addr=('127.0.0.1', 0), sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, family=1, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, proto=1, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, flags=1, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, reuse_address=True, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, reuse_port=True, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + fut = self.loop.create_datagram_endpoint( + MyDatagramProto, allow_broadcast=True, sock=object()) + self.assertRaises(ValueError, self.loop.run_until_complete, fut) + + def test_create_datagram_endpoint_sockopts(self): + # Socket options should not be applied unless asked for. + # SO_REUSEADDR defaults to on for UNIX. + # SO_REUSEPORT is not available on all platforms. + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + local_addr=('127.0.0.1', 0)) + transport, protocol = self.loop.run_until_complete(coro) + sock = transport.get_extra_info('socket') + + reuse_address_default_on = ( + os.name == 'posix' and sys.platform != 'cygwin') + reuseport_supported = hasattr(socket, 'SO_REUSEPORT') + + if reuse_address_default_on: + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR)) + else: + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR)) + if reuseport_supported: + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_BROADCAST)) + + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(create_future=True, loop=self.loop), + local_addr=('127.0.0.1', 0), + reuse_address=True, + reuse_port=reuseport_supported, + allow_broadcast=True) + transport, protocol = self.loop.run_until_complete(coro) + sock = transport.get_extra_info('socket') + + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEADDR)) + if reuseport_supported: + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + else: + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_BROADCAST)) + + transport.close() + self.loop.run_until_complete(protocol.done) + self.assertEqual('CLOSED', protocol.state) + + @mock.patch('asyncio.base_events.socket') + def test_create_datagram_endpoint_nosoreuseport(self, m_socket): + m_socket.getaddrinfo = socket.getaddrinfo + m_socket.SOCK_DGRAM = socket.SOCK_DGRAM + m_socket.SOL_SOCKET = socket.SOL_SOCKET + del m_socket.SO_REUSEPORT + m_socket.socket.return_value = mock.Mock() + + coro = self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + local_addr=('127.0.0.1', 0), + reuse_address=False, + reuse_port=True) + + self.assertRaises(ValueError, self.loop.run_until_complete, coro) + def test_accept_connection_retry(self): sock = mock.Mock() sock.accept.side_effect = BlockingIOError() diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index 9801d22223b..141fde74e6b 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -814,6 +814,32 @@ class EventLoopTestsMixin: # close server server.close() + @unittest.skipUnless(hasattr(socket, 'SO_REUSEPORT'), 'No SO_REUSEPORT') + def test_create_server_reuse_port(self): + proto = MyProto(self.loop) + f = self.loop.create_server( + lambda: proto, '0.0.0.0', 0) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + self.assertFalse( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + server.close() + + test_utils.run_briefly(self.loop) + + proto = MyProto(self.loop) + f = self.loop.create_server( + lambda: proto, '0.0.0.0', 0, reuse_port=True) + server = self.loop.run_until_complete(f) + self.assertEqual(len(server.sockets), 1) + sock = server.sockets[0] + self.assertTrue( + sock.getsockopt( + socket.SOL_SOCKET, socket.SO_REUSEPORT)) + server.close() + def _make_unix_server(self, factory, **kwargs): path = test_utils.gen_unix_socket_path() self.addCleanup(lambda: os.path.exists(path) and os.unlink(path)) @@ -1264,6 +1290,32 @@ class EventLoopTestsMixin: self.assertEqual('CLOSED', client.state) server.transport.close() + def test_create_datagram_endpoint_sock(self): + sock = None + local_address = ('127.0.0.1', 0) + infos = self.loop.run_until_complete( + self.loop.getaddrinfo( + *local_address, type=socket.SOCK_DGRAM)) + for family, type, proto, cname, address in infos: + try: + sock = socket.socket(family=family, type=type, proto=proto) + sock.setblocking(False) + sock.bind(address) + except: + pass + else: + break + else: + assert False, 'Can not create socket.' + + f = self.loop.create_connection( + lambda: MyDatagramProto(loop=self.loop), sock=sock) + tr, pr = self.loop.run_until_complete(f) + self.assertIsInstance(tr, asyncio.Transport) + self.assertIsInstance(pr, MyDatagramProto) + tr.close() + self.loop.run_until_complete(pr.done) + def test_internal_fds(self): loop = self.create_event_loop() if not isinstance(loop, selector_events.BaseSelectorEventLoop): diff --git a/Misc/ACKS b/Misc/ACKS index cae34e6b369..a40545a89b9 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -789,6 +789,7 @@ Ben Laurie Simon Law Julia Lawall Chris Lawrence +Chris Laws Brian Leair Mathieu Leduc-Hamel Amandine Lee diff --git a/Misc/NEWS b/Misc/NEWS index 80edd05bcfa..d7dd962b8b5 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -90,6 +90,12 @@ Core and Builtins Library ------- +- Issue #23972: Updates asyncio datagram create method allowing reuseport + and reuseaddr socket options to be set prior to binding the socket. + Mirroring the existing asyncio create_server method the reuseaddr option + for datagram sockets defaults to True if the O/S is 'posix' (except if the + platform is Cygwin). Patch by Chris Laws. + - Issue #25304: Add asyncio.run_coroutine_threadsafe(). This lets you submit a coroutine to a loop from another thread, returning a concurrent.futures.Future. By Vincent Michel.