bpo-40007: Make asyncio.transport.writelines on selector use sendmsg

This commit is contained in:
tzickel 2020-03-18 17:39:24 +02:00
parent 56bfdebfb1
commit 8ea34d2bc7
2 changed files with 100 additions and 3 deletions

View File

@ -29,6 +29,9 @@ from . import trsock
from .log import logger from .log import logger
sendmsg = getattr(socket.socket, "sendmsg", False)
def _test_selector_event(selector, fd, event): def _test_selector_event(selector, fd, event):
# Test if the selector is monitoring 'event' events # Test if the selector is monitoring 'event' events
# for the file descriptor 'fd'. # for the file descriptor 'fd'.
@ -746,6 +749,7 @@ class _SelectorTransport(transports._FlowControlMixin,
class _SelectorSocketTransport(_SelectorTransport): class _SelectorSocketTransport(_SelectorTransport):
_buffer_factory = list
_start_tls_compatible = True _start_tls_compatible = True
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE _sendfile_compatible = constants._SendfileMode.TRY_NATIVE
@ -921,7 +925,62 @@ class _SelectorSocketTransport(_SelectorTransport):
self._loop._add_writer(self._sock_fd, self._write_ready) self._loop._add_writer(self._sock_fd, self._write_ready)
# Add it to the buffer. # Add it to the buffer.
self._buffer.extend(data) self._buffer.append(data)
self._maybe_pause_protocol()
@staticmethod
def _calculate_leftovers(n, items):
leftovers = []
whole = False
for item in items:
if whole:
leftovers.append(item)
continue
n -= len(item)
if n >= 0:
continue
leftovers.append(memoryview(item)[n:])
whole = True
return leftovers
def writelines(self, lines):
if not sendmsg:
return self.write(b''.join(lines))
if self._eof:
raise RuntimeError('Cannot call write() after write_eof()')
if self._empty_waiter is not None:
raise RuntimeError('unable to write; sendfile is in progress')
if not lines:
return
if self._conn_lost:
if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
logger.warning('socket.send() raised exception.')
self._conn_lost += 1
return
if not self._buffer:
# Optimization: try to send now.
try:
n = self._sock.sendmsg(lines)
except OSError:
return self.write(b''.join(lines))
except (BlockingIOError, InterruptedError):
pass
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._fatal_error(exc, 'Fatal write error on socket transport')
return
else:
lines = self._calculate_leftovers(n, lines)
if not lines:
return
# Not all was written; register write handler.
self._loop._add_writer(self._sock_fd, self._write_ready)
# Add it to the buffer.
self._buffer.extend(lines)
self._maybe_pause_protocol() self._maybe_pause_protocol()
def _write_ready(self): def _write_ready(self):
@ -929,8 +988,13 @@ class _SelectorSocketTransport(_SelectorTransport):
if self._conn_lost: if self._conn_lost:
return return
if sendmsg:
return self._write_vectored_self()
try: try:
n = self._sock.send(self._buffer) tmp = b''.join(self._buffer)
n = self._sock.send(tmp)
except (BlockingIOError, InterruptedError): except (BlockingIOError, InterruptedError):
pass pass
except (SystemExit, KeyboardInterrupt): except (SystemExit, KeyboardInterrupt):
@ -943,7 +1007,7 @@ class _SelectorSocketTransport(_SelectorTransport):
self._empty_waiter.set_exception(exc) self._empty_waiter.set_exception(exc)
else: else:
if n: if n:
del self._buffer[:n] self._buffer = [tmp[:n]]
self._maybe_resume_protocol() # May append to buffer. self._maybe_resume_protocol() # May append to buffer.
if not self._buffer: if not self._buffer:
self._loop._remove_writer(self._sock_fd) self._loop._remove_writer(self._sock_fd)
@ -954,6 +1018,38 @@ class _SelectorSocketTransport(_SelectorTransport):
elif self._eof: elif self._eof:
self._sock.shutdown(socket.SHUT_WR) self._sock.shutdown(socket.SHUT_WR)
def _write_vectored_self(self):
try:
try:
n = self._sock.sendmsg(self._buffer)
except OSError:
self._buffer = [b''.join(self._buffer)]
n = self._sock.sendmsg(self._buffer)
except (BlockingIOError, InterruptedError):
pass
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as exc:
self._loop._remove_writer(self._sock_fd)
self._buffer.clear()
self._fatal_error(exc, 'Fatal write error on socket transport')
if self._empty_waiter is not None:
self._empty_waiter.set_exception(exc)
else:
self._buffer = self._calculate_leftovers(n, self._buffer)
self._maybe_resume_protocol() # May append to buffer.
if not self._buffer:
self._loop._remove_writer(self._sock_fd)
if self._empty_waiter is not None:
self._empty_waiter.set_result(None)
if self._closing:
self._call_connection_lost(None)
elif self._eof:
self._sock.shutdown(socket.SHUT_WR)
def get_write_buffer_size(self):
return sum(len(data) for data in self._buffer)
def write_eof(self): def write_eof(self):
if self._closing or self._eof: if self._closing or self._eof:
return return

View File

@ -0,0 +1 @@
Make asyncio.transport.writelines on selector use sendmsg