bpo-27456: Ensure TCP_NODELAY is set on linux (#4231) (#4898)

(cherry picked from commit e796b2fe26)
This commit is contained in:
Yury Selivanov 2017-12-15 21:53:08 -05:00 committed by GitHub
parent dab4cf210c
commit 572636d425
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 53 additions and 31 deletions

View File

@ -84,18 +84,24 @@ def _set_reuseport(sock):
'SO_REUSEPORT defined but not implemented.') 'SO_REUSEPORT defined but not implemented.')
def _is_stream_socket(sock): def _is_stream_socket(sock_type):
# Linux's socket.type is a bitmask that can include extra info if hasattr(socket, 'SOCK_NONBLOCK'):
# about socket, therefore we can't do simple # Linux's socket.type is a bitmask that can include extra info
# `sock_type == socket.SOCK_STREAM`. # about socket (like SOCK_NONBLOCK bit), therefore we can't do simple
return (sock.type & socket.SOCK_STREAM) == socket.SOCK_STREAM # `sock_type == socket.SOCK_STREAM`, see
# https://github.com/torvalds/linux/blob/v4.13/include/linux/net.h#L77
# for more details.
return (sock_type & 0xF) == socket.SOCK_STREAM
else:
return sock_type == socket.SOCK_STREAM
def _is_dgram_socket(sock): def _is_dgram_socket(sock_type):
# Linux's socket.type is a bitmask that can include extra info if hasattr(socket, 'SOCK_NONBLOCK'):
# about socket, therefore we can't do simple # See the comment in `_is_stream_socket`.
# `sock_type == socket.SOCK_DGRAM`. return (sock_type & 0xF) == socket.SOCK_DGRAM
return (sock.type & socket.SOCK_DGRAM) == socket.SOCK_DGRAM else:
return sock_type == socket.SOCK_DGRAM
def _ipaddr_info(host, port, family, type, proto): def _ipaddr_info(host, port, family, type, proto):
@ -108,14 +114,9 @@ def _ipaddr_info(host, port, family, type, proto):
host is None: host is None:
return None return None
if type == socket.SOCK_STREAM: if _is_stream_socket(type):
# Linux only:
# getaddrinfo() can raise when socket.type is a bit mask.
# So if socket.type is a bit mask of SOCK_STREAM, and say
# SOCK_NONBLOCK, we simply return None, which will trigger
# a call to getaddrinfo() letting it process this request.
proto = socket.IPPROTO_TCP proto = socket.IPPROTO_TCP
elif type == socket.SOCK_DGRAM: elif _is_dgram_socket(type):
proto = socket.IPPROTO_UDP proto = socket.IPPROTO_UDP
else: else:
return None return None
@ -789,7 +790,7 @@ class BaseEventLoop(events.AbstractEventLoop):
if sock is None: if sock is None:
raise ValueError( raise ValueError(
'host and port was not specified and no sock specified') 'host and port was not specified and no sock specified')
if not _is_stream_socket(sock): if not _is_stream_socket(sock.type):
# We allow AF_INET, AF_INET6, AF_UNIX as long as they # We allow AF_INET, AF_INET6, AF_UNIX as long as they
# are SOCK_STREAM. # are SOCK_STREAM.
# We support passing AF_UNIX sockets even though we have # We support passing AF_UNIX sockets even though we have
@ -841,7 +842,7 @@ class BaseEventLoop(events.AbstractEventLoop):
allow_broadcast=None, sock=None): allow_broadcast=None, sock=None):
"""Create datagram connection.""" """Create datagram connection."""
if sock is not None: if sock is not None:
if not _is_dgram_socket(sock): if not _is_dgram_socket(sock.type):
raise ValueError( raise ValueError(
'A UDP Socket was expected, got {!r}'.format(sock)) 'A UDP Socket was expected, got {!r}'.format(sock))
if (local_addr or remote_addr or if (local_addr or remote_addr or
@ -1054,7 +1055,7 @@ class BaseEventLoop(events.AbstractEventLoop):
else: else:
if sock is None: if sock is None:
raise ValueError('Neither host/port nor sock were specified') raise ValueError('Neither host/port nor sock were specified')
if not _is_stream_socket(sock): if not _is_stream_socket(sock.type):
raise ValueError( raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock)) 'A Stream Socket was expected, got {!r}'.format(sock))
sockets = [sock] sockets = [sock]
@ -1078,7 +1079,7 @@ class BaseEventLoop(events.AbstractEventLoop):
This method is a coroutine. When completed, the coroutine This method is a coroutine. When completed, the coroutine
returns a (transport, protocol) pair. returns a (transport, protocol) pair.
""" """
if not _is_stream_socket(sock): if not _is_stream_socket(sock.type):
raise ValueError( raise ValueError(
'A Stream Socket was expected, got {!r}'.format(sock)) 'A Stream Socket was expected, got {!r}'.format(sock))

View File

@ -43,7 +43,7 @@ def _test_selector_event(selector, fd, event):
if hasattr(socket, 'TCP_NODELAY'): if hasattr(socket, 'TCP_NODELAY'):
def _set_nodelay(sock): def _set_nodelay(sock):
if (sock.family in {socket.AF_INET, socket.AF_INET6} and if (sock.family in {socket.AF_INET, socket.AF_INET6} and
sock.type == socket.SOCK_STREAM and base_events._is_stream_socket(sock.type) and
sock.proto == socket.IPPROTO_TCP): sock.proto == socket.IPPROTO_TCP):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
else: else:

View File

@ -242,7 +242,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
if sock is None: if sock is None:
raise ValueError('no path and sock were specified') raise ValueError('no path and sock were specified')
if (sock.family != socket.AF_UNIX or if (sock.family != socket.AF_UNIX or
not base_events._is_stream_socket(sock)): not base_events._is_stream_socket(sock.type)):
raise ValueError( raise ValueError(
'A UNIX Domain Stream Socket was expected, got {!r}' 'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock)) .format(sock))
@ -297,7 +297,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
'path was not specified, and no sock specified') 'path was not specified, and no sock specified')
if (sock.family != socket.AF_UNIX or if (sock.family != socket.AF_UNIX or
not base_events._is_stream_socket(sock)): not base_events._is_stream_socket(sock.type)):
raise ValueError( raise ValueError(
'A UNIX Domain Stream Socket was expected, got {!r}' 'A UNIX Domain Stream Socket was expected, got {!r}'
.format(sock)) .format(sock))

View File

@ -116,13 +116,6 @@ class BaseEventTests(test_utils.TestCase):
self.assertIsNone( self.assertIsNone(
base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP)) base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP))
if hasattr(socket, 'SOCK_NONBLOCK'):
self.assertEqual(
None,
base_events._ipaddr_info(
'1.2.3.4', 1, INET, STREAM | socket.SOCK_NONBLOCK, TCP))
def test_port_parameter_types(self): def test_port_parameter_types(self):
# Test obscure kinds of arguments for "port". # Test obscure kinds of arguments for "port".
INET = socket.AF_INET INET = socket.AF_INET

View File

@ -17,6 +17,7 @@ from asyncio.selector_events import _SelectorTransport
from asyncio.selector_events import _SelectorSslTransport from asyncio.selector_events import _SelectorSslTransport
from asyncio.selector_events import _SelectorSocketTransport from asyncio.selector_events import _SelectorSocketTransport
from asyncio.selector_events import _SelectorDatagramTransport from asyncio.selector_events import _SelectorDatagramTransport
from asyncio.selector_events import _set_nodelay
MOCK_ANY = mock.ANY MOCK_ANY = mock.ANY
@ -1829,5 +1830,31 @@ class SelectorDatagramTransportTests(test_utils.TestCase):
'Fatal error on transport\nprotocol:.*\ntransport:.*'), 'Fatal error on transport\nprotocol:.*\ntransport:.*'),
exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY)) exc_info=(ConnectionRefusedError, MOCK_ANY, MOCK_ANY))
class TestSelectorUtils(test_utils.TestCase):
def check_set_nodelay(self, sock):
opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
self.assertFalse(opt)
_set_nodelay(sock)
opt = sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
self.assertTrue(opt)
@unittest.skipUnless(hasattr(socket, 'TCP_NODELAY'),
'need socket.TCP_NODELAY')
def test_set_nodelay(self):
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP)
with sock:
self.check_set_nodelay(sock)
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP)
with sock:
sock.setblocking(False)
self.check_set_nodelay(sock)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -0,0 +1 @@
Ensure TCP_NODELAY is set on Linux. Tests by Victor Stinner.