bpo-32622: Implement loop.sendfile() (#5271)

This commit is contained in:
Andrew Svetlov 2018-01-27 21:22:47 +02:00 committed by GitHub
parent f13f12d8da
commit 7c684073f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 560 additions and 8 deletions

View File

@ -543,6 +543,37 @@ Creating listening connections
.. versionadded:: 3.5.3 .. versionadded:: 3.5.3
File Transferring
-----------------
.. coroutinemethod:: AbstractEventLoop.sendfile(sock, transport, \
offset=0, count=None, \
*, fallback=True)
Send a *file* to *transport*, return the total number of bytes
which were sent.
The method uses high-performance :meth:`os.sendfile` if available.
*file* must be a regular file object opened in binary mode.
*offset* tells from where to start reading the file. If specified,
*count* is the total number of bytes to transmit as opposed to
sending the file until EOF is reached. File position is updated on
return or also in case of error in which case :meth:`file.tell()
<io.IOBase.tell>` can be used to figure out the number of bytes
which were sent.
*fallback* set to ``True`` makes asyncio to manually read and send
the file when the platform does not support the sendfile syscall
(e.g. Windows or SSL socket on Unix).
Raise :exc:`SendfileNotAvailableError` if the system does not support
*sendfile* syscall and *fallback* is ``False``.
.. versionadded:: 3.7
TLS Upgrade TLS Upgrade
----------- -----------

View File

@ -38,8 +38,10 @@ from . import constants
from . import coroutines from . import coroutines
from . import events from . import events
from . import futures from . import futures
from . import protocols
from . import sslproto from . import sslproto
from . import tasks from . import tasks
from . import transports
from .log import logger from .log import logger
@ -155,6 +157,75 @@ def _run_until_complete_cb(fut):
futures._get_loop(fut).stop() futures._get_loop(fut).stop()
class _SendfileFallbackProtocol(protocols.Protocol):
def __init__(self, transp):
if not isinstance(transp, transports._FlowControlMixin):
raise TypeError("transport should be _FlowControlMixin instance")
self._transport = transp
self._proto = transp.get_protocol()
self._should_resume_reading = transp.is_reading()
self._should_resume_writing = transp._protocol_paused
transp.pause_reading()
transp.set_protocol(self)
if self._should_resume_writing:
self._write_ready_fut = self._transport._loop.create_future()
else:
self._write_ready_fut = None
async def drain(self):
if self._transport.is_closing():
raise ConnectionError("Connection closed by peer")
fut = self._write_ready_fut
if fut is None:
return
await fut
def connection_made(self, transport):
raise RuntimeError("Invalid state: "
"connection should have been established already.")
def connection_lost(self, exc):
if self._write_ready_fut is not None:
# Never happens if peer disconnects after sending the whole content
# Thus disconnection is always an exception from user perspective
if exc is None:
self._write_ready_fut.set_exception(
ConnectionError("Connection is closed by peer"))
else:
self._write_ready_fut.set_exception(exc)
self._proto.connection_lost(exc)
def pause_writing(self):
if self._write_ready_fut is not None:
return
self._write_ready_fut = self._transport._loop.create_future()
def resume_writing(self):
if self._write_ready_fut is None:
return
self._write_ready_fut.set_result(False)
self._write_ready_fut = None
def data_received(self, data):
raise RuntimeError("Invalid state: reading should be paused")
def eof_received(self):
raise RuntimeError("Invalid state: reading should be paused")
async def restore(self):
self._transport.set_protocol(self._proto)
if self._should_resume_reading:
self._transport.resume_reading()
if self._write_ready_fut is not None:
# Cancel the future.
# Basically it has no effect because protocol is switched back,
# no code should wait for it anymore.
self._write_ready_fut.cancel()
if self._should_resume_writing:
self._proto.resume_writing()
class Server(events.AbstractServer): class Server(events.AbstractServer):
def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog,
@ -926,6 +997,77 @@ class BaseEventLoop(events.AbstractEventLoop):
return transport, protocol return transport, protocol
async def sendfile(self, transport, file, offset=0, count=None,
*, fallback=True):
"""Send a file to transport.
Return the total number of bytes which were sent.
The method uses high-performance os.sendfile if available.
file must be a regular file object opened in binary mode.
offset tells from where to start reading the file. If specified,
count is the total number of bytes to transmit as opposed to
sending the file until EOF is reached. File position is updated on
return or also in case of error in which case file.tell()
can be used to figure out the number of bytes
which were sent.
fallback set to True makes asyncio to manually read and send
the file when the platform does not support the sendfile syscall
(e.g. Windows or SSL socket on Unix).
Raise SendfileNotAvailableError if the system does not support
sendfile syscall and fallback is False.
"""
if transport.is_closing():
raise RuntimeError("Transport is closing")
mode = getattr(transport, '_sendfile_compatible',
constants._SendfileMode.UNSUPPORTED)
if mode is constants._SendfileMode.UNSUPPORTED:
raise RuntimeError(
f"sendfile is not supported for transport {transport!r}")
if mode is constants._SendfileMode.TRY_NATIVE:
try:
return await self._sendfile_native(transport, file,
offset, count)
except events.SendfileNotAvailableError as exc:
if not fallback:
raise
# the mode is FALLBACK or fallback is True
return await self._sendfile_fallback(transport, file,
offset, count)
async def _sendfile_native(self, transp, file, offset, count):
raise events.SendfileNotAvailableError(
"sendfile syscall is not supported")
async def _sendfile_fallback(self, transp, file, offset, count):
if offset:
file.seek(offset)
blocksize = min(count, 16384) if count else 16384
buf = bytearray(blocksize)
total_sent = 0
proto = _SendfileFallbackProtocol(transp)
try:
while True:
if count:
blocksize = min(count - total_sent, blocksize)
if blocksize <= 0:
return total_sent
view = memoryview(buf)[:blocksize]
read = file.readinto(view)
if not read:
return total_sent # EOF
await proto.drain()
transp.write(view)
total_sent += read
finally:
if total_sent > 0 and hasattr(file, 'seek'):
file.seek(offset + total_sent)
await proto.restore()
async def start_tls(self, transport, protocol, sslcontext, *, async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False, server_side=False,
server_hostname=None, server_hostname=None,

View File

@ -1,3 +1,5 @@
import enum
# After the connection is lost, log warnings after this many write()s. # After the connection is lost, log warnings after this many write()s.
LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5
@ -11,3 +13,10 @@ DEBUG_STACK_DEPTH = 10
# Number of seconds to wait for SSL handshake to complete # Number of seconds to wait for SSL handshake to complete
SSL_HANDSHAKE_TIMEOUT = 10.0 SSL_HANDSHAKE_TIMEOUT = 10.0
# The enum should be here to break circular dependencies between
# base_events and sslproto
class _SendfileMode(enum.Enum):
UNSUPPORTED = enum.auto()
TRY_NATIVE = enum.auto()
FALLBACK = enum.auto()

View File

@ -354,6 +354,14 @@ class AbstractEventLoop:
""" """
raise NotImplementedError raise NotImplementedError
async def sendfile(self, transport, file, offset=0, count=None,
*, fallback=True):
"""Send a file through a transport.
Return an amount of sent bytes.
"""
raise NotImplementedError
async def start_tls(self, transport, protocol, sslcontext, *, async def start_tls(self, transport, protocol, sslcontext, *,
server_side=False, server_side=False,
server_hostname=None, server_hostname=None,

View File

@ -180,7 +180,12 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport,
assert self._read_fut is fut or (self._read_fut is None and assert self._read_fut is fut or (self._read_fut is None and
self._closing) self._closing)
self._read_fut = None self._read_fut = None
data = fut.result() # deliver data later in "finally" clause if fut.done():
# deliver data later in "finally" clause
data = fut.result()
else:
# the future will be replaced by next proactor.recv call
fut.cancel()
if self._closing: if self._closing:
# since close() has been called we ignore any read data # since close() has been called we ignore any read data
@ -345,6 +350,8 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport,
transports.Transport): transports.Transport):
"""Transport for connected sockets.""" """Transport for connected sockets."""
_sendfile_compatible = constants._SendfileMode.FALLBACK
def _set_extra(self, sock): def _set_extra(self, sock):
self._extra['socket'] = sock self._extra['socket'] = sock

View File

@ -540,6 +540,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop):
else: else:
fut.set_result((conn, address)) fut.set_result((conn, address))
async def _sendfile_native(self, transp, file, offset, count):
del self._transports[transp._sock_fd]
resume_reading = transp.is_reading()
transp.pause_reading()
await transp._make_empty_waiter()
try:
return await self.sock_sendfile(transp._sock, file, offset, count,
fallback=False)
finally:
transp._reset_empty_waiter()
if resume_reading:
transp.resume_reading()
self._transports[transp._sock_fd] = transp
def _process_events(self, event_list): def _process_events(self, event_list):
for key, mask in event_list: for key, mask in event_list:
fileobj, (reader, writer) = key.fileobj, key.data fileobj, (reader, writer) = key.fileobj, key.data
@ -695,12 +709,14 @@ class _SelectorTransport(transports._FlowControlMixin,
class _SelectorSocketTransport(_SelectorTransport): class _SelectorSocketTransport(_SelectorTransport):
_start_tls_compatible = True _start_tls_compatible = True
_sendfile_compatible = constants._SendfileMode.TRY_NATIVE
def __init__(self, loop, sock, protocol, waiter=None, def __init__(self, loop, sock, protocol, waiter=None,
extra=None, server=None): extra=None, server=None):
super().__init__(loop, sock, protocol, extra, server) super().__init__(loop, sock, protocol, extra, server)
self._eof = False self._eof = False
self._paused = False self._paused = False
self._empty_waiter = None
# Disable the Nagle algorithm -- small writes will be # Disable the Nagle algorithm -- small writes will be
# sent without waiting for the TCP ACK. This generally # sent without waiting for the TCP ACK. This generally
@ -765,6 +781,8 @@ class _SelectorSocketTransport(_SelectorTransport):
f'not {type(data).__name__!r}') f'not {type(data).__name__!r}')
if self._eof: if self._eof:
raise RuntimeError('Cannot call write() after write_eof()') raise RuntimeError('Cannot call write() after write_eof()')
if self._empty_waiter is not None:
raise RuntimeError('unable to write; sendfile is in progress')
if not data: if not data:
return return
@ -807,12 +825,16 @@ class _SelectorSocketTransport(_SelectorTransport):
self._loop._remove_writer(self._sock_fd) self._loop._remove_writer(self._sock_fd)
self._buffer.clear() self._buffer.clear()
self._fatal_error(exc, 'Fatal write error on socket transport') self._fatal_error(exc, 'Fatal write error on socket transport')
if self._empty_waiter is not None:
self._empty_waiter.set_exception(exc)
else: else:
if n: if n:
del self._buffer[:n] del self._buffer[:n]
self._maybe_resume_protocol() # May append to buffer. self._maybe_resume_protocol() # May append to buffer.
if not self._buffer: if not self._buffer:
self._loop._remove_writer(self._sock_fd) self._loop._remove_writer(self._sock_fd)
if self._empty_waiter is not None:
self._empty_waiter.set_result(None)
if self._closing: if self._closing:
self._call_connection_lost(None) self._call_connection_lost(None)
elif self._eof: elif self._eof:
@ -828,6 +850,23 @@ class _SelectorSocketTransport(_SelectorTransport):
def can_write_eof(self): def can_write_eof(self):
return True return True
def _call_connection_lost(self, exc):
super()._call_connection_lost(exc)
if self._empty_waiter is not None:
self._empty_waiter.set_exception(
ConnectionError("Connection is closed by peer"))
def _make_empty_waiter(self):
if self._empty_waiter is not None:
raise RuntimeError("Empty waiter is already set")
self._empty_waiter = self._loop.create_future()
if not self._buffer:
self._empty_waiter.set_result(None)
return self._empty_waiter
def _reset_empty_waiter(self):
self._empty_waiter = None
class _SelectorDatagramTransport(_SelectorTransport): class _SelectorDatagramTransport(_SelectorTransport):

View File

@ -282,6 +282,8 @@ class _SSLPipe(object):
class _SSLProtocolTransport(transports._FlowControlMixin, class _SSLProtocolTransport(transports._FlowControlMixin,
transports.Transport): transports.Transport):
_sendfile_compatible = constants._SendfileMode.FALLBACK
def __init__(self, loop, ssl_protocol): def __init__(self, loop, ssl_protocol):
self._loop = loop self._loop = loop
# SSLProtocol instance # SSLProtocol instance
@ -365,6 +367,11 @@ class _SSLProtocolTransport(transports._FlowControlMixin,
"""Return the current size of the write buffer.""" """Return the current size of the write buffer."""
return self._ssl_protocol._transport.get_write_buffer_size() return self._ssl_protocol._transport.get_write_buffer_size()
@property
def _protocol_paused(self):
# Required for sendfile fallback pause_writing/resume_writing logic
return self._ssl_protocol._transport._protocol_paused
def write(self, data): def write(self, data):
"""Write some data bytes to the transport. """Write some data bytes to the transport.

View File

@ -425,7 +425,8 @@ class IocpProactor:
try: try:
return ov.getresult() return ov.getresult()
except OSError as exc: except OSError as exc:
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
_overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args) raise ConnectionResetError(*exc.args)
else: else:
raise raise
@ -447,7 +448,8 @@ class IocpProactor:
try: try:
return ov.getresult() return ov.getresult()
except OSError as exc: except OSError as exc:
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
_overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args) raise ConnectionResetError(*exc.args)
else: else:
raise raise
@ -466,7 +468,8 @@ class IocpProactor:
try: try:
return ov.getresult() return ov.getresult()
except OSError as exc: except OSError as exc:
if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
_overlapped.ERROR_OPERATION_ABORTED):
raise ConnectionResetError(*exc.args) raise ConnectionResetError(*exc.args)
else: else:
raise raise

View File

@ -1788,7 +1788,7 @@ class RunningLoopTests(unittest.TestCase):
outer_loop.close() outer_loop.close()
class BaseLoopSendfileTests(test_utils.TestCase): class BaseLoopSockSendfileTests(test_utils.TestCase):
DATA = b"12345abcde" * 16 * 1024 # 160 KiB DATA = b"12345abcde" * 16 * 1024 # 160 KiB
@ -1799,9 +1799,11 @@ class BaseLoopSendfileTests(test_utils.TestCase):
self.closed = False self.closed = False
self.data = bytearray() self.data = bytearray()
self.fut = loop.create_future() self.fut = loop.create_future()
self.transport = None
def connection_made(self, transport): def connection_made(self, transport):
self.started = True self.started = True
self.transport = transport
def data_received(self, data): def data_received(self, data):
self.data.extend(data) self.data.extend(data)
@ -1809,6 +1811,7 @@ class BaseLoopSendfileTests(test_utils.TestCase):
def connection_lost(self, exc): def connection_lost(self, exc):
self.closed = True self.closed = True
self.fut.set_result(None) self.fut.set_result(None)
self.transport = None
async def wait_closed(self): async def wait_closed(self):
await self.fut await self.fut
@ -1853,6 +1856,10 @@ class BaseLoopSendfileTests(test_utils.TestCase):
def cleanup(): def cleanup():
server.close() server.close()
self.run_loop(server.wait_closed()) self.run_loop(server.wait_closed())
sock.close()
if proto.transport is not None:
proto.transport.close()
self.run_loop(proto.wait_closed())
self.addCleanup(cleanup) self.addCleanup(cleanup)

View File

@ -26,6 +26,7 @@ if sys.platform != 'win32':
import tty import tty
import asyncio import asyncio
from asyncio import base_events
from asyncio import coroutines from asyncio import coroutines
from asyncio import events from asyncio import events
from asyncio import proactor_events from asyncio import proactor_events
@ -2090,14 +2091,308 @@ class SubprocessTestsMixin:
self.loop.run_until_complete(connect(shell=False)) self.loop.run_until_complete(connect(shell=False))
class MySendfileProto(MyBaseProto):
def __init__(self, loop=None, close_after=0):
super().__init__(loop)
self.data = bytearray()
self.close_after = close_after
def data_received(self, data):
self.data.extend(data)
super().data_received(data)
if self.close_after and self.nbytes >= self.close_after:
self.transport.close()
class SendfileMixin:
# Note: sendfile via SSL transport is equal to sendfile fallback
DATA = b"12345abcde" * 160 * 1024 # 160 KiB
@classmethod
def setUpClass(cls):
with open(support.TESTFN, 'wb') as fp:
fp.write(cls.DATA)
super().setUpClass()
@classmethod
def tearDownClass(cls):
support.unlink(support.TESTFN)
super().tearDownClass()
def setUp(self):
self.file = open(support.TESTFN, 'rb')
self.addCleanup(self.file.close)
super().setUp()
def run_loop(self, coro):
return self.loop.run_until_complete(coro)
def prepare(self, *, is_ssl=False, close_after=0):
port = support.find_unused_port()
srv_proto = MySendfileProto(loop=self.loop, close_after=close_after)
if is_ssl:
srv_ctx = test_utils.simple_server_sslcontext()
cli_ctx = test_utils.simple_client_sslcontext()
else:
srv_ctx = None
cli_ctx = None
srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# reduce recv socket buffer size to test on relative small data sets
srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024)
srv_sock.bind((support.HOST, port))
server = self.run_loop(self.loop.create_server(
lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
if is_ssl:
server_hostname = support.HOST
else:
server_hostname = None
cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# reduce send socket buffer size to test on relative small data sets
cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
cli_sock.connect((support.HOST, port))
cli_proto = MySendfileProto(loop=self.loop)
tr, pr = self.run_loop(self.loop.create_connection(
lambda: cli_proto, sock=cli_sock,
ssl=cli_ctx, server_hostname=server_hostname))
def cleanup():
srv_proto.transport.close()
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.run_loop(cli_proto.done)
server.close()
self.run_loop(server.wait_closed())
self.addCleanup(cleanup)
return srv_proto, cli_proto
@unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
def test_sendfile_not_supported(self):
tr, pr = self.run_loop(
self.loop.create_datagram_endpoint(
lambda: MyDatagramProto(loop=self.loop),
family=socket.AF_INET))
try:
with self.assertRaisesRegex(RuntimeError, "not supported"):
self.run_loop(
self.loop.sendfile(tr, self.file))
self.assertEqual(0, self.file.tell())
finally:
# don't use self.addCleanup because it produces resource warning
tr.close()
def test_sendfile(self):
srv_proto, cli_proto = self.prepare()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_force_fallback(self):
srv_proto, cli_proto = self.prepare()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
return base_events.BaseEventLoop._sendfile_native(
self.loop, transp, file, offset, count)
self.loop._sendfile_native = sendfile_native
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_force_unsupported_native(self):
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
self.skipTest("Fails on proactor event loop")
srv_proto, cli_proto = self.prepare()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
return base_events.BaseEventLoop._sendfile_native(
self.loop, transp, file, offset, count)
self.loop._sendfile_native = sendfile_native
with self.assertRaisesRegex(events.SendfileNotAvailableError,
"not supported"):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file,
fallback=False))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(srv_proto.nbytes, 0)
self.assertEqual(self.file.tell(), 0)
def test_sendfile_ssl(self):
srv_proto, cli_proto = self.prepare(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_for_closing_transp(self):
srv_proto, cli_proto = self.prepare()
cli_proto.transport.close()
with self.assertRaisesRegex(RuntimeError, "is closing"):
self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertEqual(srv_proto.nbytes, 0)
self.assertEqual(self.file.tell(), 0)
def test_sendfile_pre_and_post_data(self):
srv_proto, cli_proto = self.prepare()
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.write(SUFFIX)
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_pre_and_post_data(self):
srv_proto, cli_proto = self.prepare(is_ssl=True)
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.write(SUFFIX)
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_partial(self):
srv_proto, cli_proto = self.prepare()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, 100)
self.assertEqual(srv_proto.nbytes, 100)
self.assertEqual(srv_proto.data, self.DATA[1000:1100])
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_ssl_partial(self):
srv_proto, cli_proto = self.prepare(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, 100)
self.assertEqual(srv_proto.nbytes, 100)
self.assertEqual(srv_proto.data, self.DATA[1000:1100])
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_close_peer_after_receiving(self):
srv_proto, cli_proto = self.prepare(close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_close_peer_after_receiving(self):
srv_proto, cli_proto = self.prepare(is_ssl=True,
close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_close_peer_in_middle_of_receiving(self):
srv_proto, cli_proto = self.prepare(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
srv_proto.nbytes)
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
self.file.tell())
def test_sendfile_fallback_close_peer_in_middle_of_receiving(self):
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
return base_events.BaseEventLoop._sendfile_native(
self.loop, transp, file, offset, count)
self.loop._sendfile_native = sendfile_native
srv_proto, cli_proto = self.prepare(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
srv_proto.nbytes)
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
self.file.tell())
@unittest.skipIf(not hasattr(os, 'sendfile'),
"Don't have native sendfile support")
def test_sendfile_prevents_bare_write(self):
srv_proto, cli_proto = self.prepare()
fut = self.loop.create_future()
async def coro():
fut.set_result(None)
return await self.loop.sendfile(cli_proto.transport, self.file)
t = self.loop.create_task(coro())
self.run_loop(fut)
with self.assertRaisesRegex(RuntimeError,
"sendfile is in progress"):
cli_proto.transport.write(b'data')
ret = self.run_loop(t)
self.assertEqual(ret, len(self.DATA))
if sys.platform == 'win32': if sys.platform == 'win32':
class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase): class SelectEventLoopTests(EventLoopTestsMixin,
SendfileMixin,
test_utils.TestCase):
def create_event_loop(self): def create_event_loop(self):
return asyncio.SelectorEventLoop() return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(EventLoopTestsMixin, class ProactorEventLoopTests(EventLoopTestsMixin,
SendfileMixin,
SubprocessTestsMixin, SubprocessTestsMixin,
test_utils.TestCase): test_utils.TestCase):
@ -2125,7 +2420,7 @@ if sys.platform == 'win32':
else: else:
import selectors import selectors
class UnixEventLoopTestsMixin(EventLoopTestsMixin): class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
watcher = asyncio.SafeChildWatcher() watcher = asyncio.SafeChildWatcher()
@ -2556,7 +2851,9 @@ class AbstractEventLoopTests(unittest.TestCase):
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
await loop.sock_accept(f) await loop.sock_accept(f)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
await loop.sock_sendfile(f, mock.Mock()) await loop.sock_sendfile(f, f)
with self.assertRaises(NotImplementedError):
await loop.sendfile(f, f)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):
await loop.connect_read_pipe(f, mock.sentinel.pipe) await loop.connect_read_pipe(f, mock.sentinel.pipe)
with self.assertRaises(NotImplementedError): with self.assertRaises(NotImplementedError):

View File

@ -0,0 +1 @@
Add :meth:`asyncio.AbstractEventLoop.sendfile` method.

View File

@ -1436,6 +1436,7 @@ PyInit__overlapped(void)
WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING);
WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED);
WINAPI_CONSTANT(F_DWORD, ERROR_OPERATION_ABORTED);
WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT);
WINAPI_CONSTANT(F_DWORD, ERROR_PIPE_BUSY); WINAPI_CONSTANT(F_DWORD, ERROR_PIPE_BUSY);
WINAPI_CONSTANT(F_DWORD, INFINITE); WINAPI_CONSTANT(F_DWORD, INFINITE);