diff --git a/Doc/library/asyncio-eventloop.rst b/Doc/library/asyncio-eventloop.rst index 834a4e85c2f..fe162236e01 100644 --- a/Doc/library/asyncio-eventloop.rst +++ b/Doc/library/asyncio-eventloop.rst @@ -543,6 +543,37 @@ Creating listening connections .. versionadded:: 3.5.3 +File Transferring +----------------- + +.. coroutinemethod:: AbstractEventLoop.sendfile(sock, transport, \ + offset=0, count=None, \ + *, fallback=True) + + Send a *file* to *transport*, return the total number of bytes + which were sent. + + The method uses high-performance :meth:`os.sendfile` if available. + + *file* must be a regular file object opened in binary mode. + + *offset* tells from where to start reading the file. If specified, + *count* is the total number of bytes to transmit as opposed to + sending the file until EOF is reached. File position is updated on + return or also in case of error in which case :meth:`file.tell() + ` can be used to figure out the number of bytes + which were sent. + + *fallback* set to ``True`` makes asyncio to manually read and send + the file when the platform does not support the sendfile syscall + (e.g. Windows or SSL socket on Unix). + + Raise :exc:`SendfileNotAvailableError` if the system does not support + *sendfile* syscall and *fallback* is ``False``. + + .. versionadded:: 3.7 + + TLS Upgrade ----------- diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 94eb3089e93..f532dc42132 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -38,8 +38,10 @@ from . import constants from . import coroutines from . import events from . import futures +from . import protocols from . import sslproto from . import tasks +from . import transports from .log import logger @@ -155,6 +157,75 @@ def _run_until_complete_cb(fut): futures._get_loop(fut).stop() + +class _SendfileFallbackProtocol(protocols.Protocol): + def __init__(self, transp): + if not isinstance(transp, transports._FlowControlMixin): + raise TypeError("transport should be _FlowControlMixin instance") + self._transport = transp + self._proto = transp.get_protocol() + self._should_resume_reading = transp.is_reading() + self._should_resume_writing = transp._protocol_paused + transp.pause_reading() + transp.set_protocol(self) + if self._should_resume_writing: + self._write_ready_fut = self._transport._loop.create_future() + else: + self._write_ready_fut = None + + async def drain(self): + if self._transport.is_closing(): + raise ConnectionError("Connection closed by peer") + fut = self._write_ready_fut + if fut is None: + return + await fut + + def connection_made(self, transport): + raise RuntimeError("Invalid state: " + "connection should have been established already.") + + def connection_lost(self, exc): + if self._write_ready_fut is not None: + # Never happens if peer disconnects after sending the whole content + # Thus disconnection is always an exception from user perspective + if exc is None: + self._write_ready_fut.set_exception( + ConnectionError("Connection is closed by peer")) + else: + self._write_ready_fut.set_exception(exc) + self._proto.connection_lost(exc) + + def pause_writing(self): + if self._write_ready_fut is not None: + return + self._write_ready_fut = self._transport._loop.create_future() + + def resume_writing(self): + if self._write_ready_fut is None: + return + self._write_ready_fut.set_result(False) + self._write_ready_fut = None + + def data_received(self, data): + raise RuntimeError("Invalid state: reading should be paused") + + def eof_received(self): + raise RuntimeError("Invalid state: reading should be paused") + + async def restore(self): + self._transport.set_protocol(self._proto) + if self._should_resume_reading: + self._transport.resume_reading() + if self._write_ready_fut is not None: + # Cancel the future. + # Basically it has no effect because protocol is switched back, + # no code should wait for it anymore. + self._write_ready_fut.cancel() + if self._should_resume_writing: + self._proto.resume_writing() + + class Server(events.AbstractServer): def __init__(self, loop, sockets, protocol_factory, ssl_context, backlog, @@ -926,6 +997,77 @@ class BaseEventLoop(events.AbstractEventLoop): return transport, protocol + async def sendfile(self, transport, file, offset=0, count=None, + *, fallback=True): + """Send a file to transport. + + Return the total number of bytes which were sent. + + The method uses high-performance os.sendfile if available. + + file must be a regular file object opened in binary mode. + + offset tells from where to start reading the file. If specified, + count is the total number of bytes to transmit as opposed to + sending the file until EOF is reached. File position is updated on + return or also in case of error in which case file.tell() + can be used to figure out the number of bytes + which were sent. + + fallback set to True makes asyncio to manually read and send + the file when the platform does not support the sendfile syscall + (e.g. Windows or SSL socket on Unix). + + Raise SendfileNotAvailableError if the system does not support + sendfile syscall and fallback is False. + """ + if transport.is_closing(): + raise RuntimeError("Transport is closing") + mode = getattr(transport, '_sendfile_compatible', + constants._SendfileMode.UNSUPPORTED) + if mode is constants._SendfileMode.UNSUPPORTED: + raise RuntimeError( + f"sendfile is not supported for transport {transport!r}") + if mode is constants._SendfileMode.TRY_NATIVE: + try: + return await self._sendfile_native(transport, file, + offset, count) + except events.SendfileNotAvailableError as exc: + if not fallback: + raise + # the mode is FALLBACK or fallback is True + return await self._sendfile_fallback(transport, file, + offset, count) + + async def _sendfile_native(self, transp, file, offset, count): + raise events.SendfileNotAvailableError( + "sendfile syscall is not supported") + + async def _sendfile_fallback(self, transp, file, offset, count): + if offset: + file.seek(offset) + blocksize = min(count, 16384) if count else 16384 + buf = bytearray(blocksize) + total_sent = 0 + proto = _SendfileFallbackProtocol(transp) + try: + while True: + if count: + blocksize = min(count - total_sent, blocksize) + if blocksize <= 0: + return total_sent + view = memoryview(buf)[:blocksize] + read = file.readinto(view) + if not read: + return total_sent # EOF + await proto.drain() + transp.write(view) + total_sent += read + finally: + if total_sent > 0 and hasattr(file, 'seek'): + file.seek(offset + total_sent) + await proto.restore() + async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, diff --git a/Lib/asyncio/constants.py b/Lib/asyncio/constants.py index 0ad974ff2fb..739b0a70c13 100644 --- a/Lib/asyncio/constants.py +++ b/Lib/asyncio/constants.py @@ -1,3 +1,5 @@ +import enum + # After the connection is lost, log warnings after this many write()s. LOG_THRESHOLD_FOR_CONNLOST_WRITES = 5 @@ -11,3 +13,10 @@ DEBUG_STACK_DEPTH = 10 # Number of seconds to wait for SSL handshake to complete SSL_HANDSHAKE_TIMEOUT = 10.0 + +# The enum should be here to break circular dependencies between +# base_events and sslproto +class _SendfileMode(enum.Enum): + UNSUPPORTED = enum.auto() + TRY_NATIVE = enum.auto() + FALLBACK = enum.auto() diff --git a/Lib/asyncio/events.py b/Lib/asyncio/events.py index 7aa3de02c95..bdefcf62a05 100644 --- a/Lib/asyncio/events.py +++ b/Lib/asyncio/events.py @@ -354,6 +354,14 @@ class AbstractEventLoop: """ raise NotImplementedError + async def sendfile(self, transport, file, offset=0, count=None, + *, fallback=True): + """Send a file through a transport. + + Return an amount of sent bytes. + """ + raise NotImplementedError + async def start_tls(self, transport, protocol, sslcontext, *, server_side=False, server_hostname=None, diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index ab1285b7999..6d27e532387 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -180,7 +180,12 @@ class _ProactorReadPipeTransport(_ProactorBasePipeTransport, assert self._read_fut is fut or (self._read_fut is None and self._closing) self._read_fut = None - data = fut.result() # deliver data later in "finally" clause + if fut.done(): + # deliver data later in "finally" clause + data = fut.result() + else: + # the future will be replaced by next proactor.recv call + fut.cancel() if self._closing: # since close() has been called we ignore any read data @@ -345,6 +350,8 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, transports.Transport): """Transport for connected sockets.""" + _sendfile_compatible = constants._SendfileMode.FALLBACK + def _set_extra(self, sock): self._extra['socket'] = sock diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index 9446ae6a3bc..5956f2d993e 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -540,6 +540,20 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): else: fut.set_result((conn, address)) + async def _sendfile_native(self, transp, file, offset, count): + del self._transports[transp._sock_fd] + resume_reading = transp.is_reading() + transp.pause_reading() + await transp._make_empty_waiter() + try: + return await self.sock_sendfile(transp._sock, file, offset, count, + fallback=False) + finally: + transp._reset_empty_waiter() + if resume_reading: + transp.resume_reading() + self._transports[transp._sock_fd] = transp + def _process_events(self, event_list): for key, mask in event_list: fileobj, (reader, writer) = key.fileobj, key.data @@ -695,12 +709,14 @@ class _SelectorTransport(transports._FlowControlMixin, class _SelectorSocketTransport(_SelectorTransport): _start_tls_compatible = True + _sendfile_compatible = constants._SendfileMode.TRY_NATIVE def __init__(self, loop, sock, protocol, waiter=None, extra=None, server=None): super().__init__(loop, sock, protocol, extra, server) self._eof = False self._paused = False + self._empty_waiter = None # Disable the Nagle algorithm -- small writes will be # sent without waiting for the TCP ACK. This generally @@ -765,6 +781,8 @@ class _SelectorSocketTransport(_SelectorTransport): f'not {type(data).__name__!r}') if self._eof: raise RuntimeError('Cannot call write() after write_eof()') + if self._empty_waiter is not None: + raise RuntimeError('unable to write; sendfile is in progress') if not data: return @@ -807,12 +825,16 @@ class _SelectorSocketTransport(_SelectorTransport): self._loop._remove_writer(self._sock_fd) self._buffer.clear() self._fatal_error(exc, 'Fatal write error on socket transport') + if self._empty_waiter is not None: + self._empty_waiter.set_exception(exc) else: if n: del self._buffer[:n] self._maybe_resume_protocol() # May append to buffer. if not self._buffer: self._loop._remove_writer(self._sock_fd) + if self._empty_waiter is not None: + self._empty_waiter.set_result(None) if self._closing: self._call_connection_lost(None) elif self._eof: @@ -828,6 +850,23 @@ class _SelectorSocketTransport(_SelectorTransport): def can_write_eof(self): return True + def _call_connection_lost(self, exc): + super()._call_connection_lost(exc) + if self._empty_waiter is not None: + self._empty_waiter.set_exception( + ConnectionError("Connection is closed by peer")) + + def _make_empty_waiter(self): + if self._empty_waiter is not None: + raise RuntimeError("Empty waiter is already set") + self._empty_waiter = self._loop.create_future() + if not self._buffer: + self._empty_waiter.set_result(None) + return self._empty_waiter + + def _reset_empty_waiter(self): + self._empty_waiter = None + class _SelectorDatagramTransport(_SelectorTransport): diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py index 1130bced8ae..863b54313cc 100644 --- a/Lib/asyncio/sslproto.py +++ b/Lib/asyncio/sslproto.py @@ -282,6 +282,8 @@ class _SSLPipe(object): class _SSLProtocolTransport(transports._FlowControlMixin, transports.Transport): + _sendfile_compatible = constants._SendfileMode.FALLBACK + def __init__(self, loop, ssl_protocol): self._loop = loop # SSLProtocol instance @@ -365,6 +367,11 @@ class _SSLProtocolTransport(transports._FlowControlMixin, """Return the current size of the write buffer.""" return self._ssl_protocol._transport.get_write_buffer_size() + @property + def _protocol_paused(self): + # Required for sendfile fallback pause_writing/resume_writing logic + return self._ssl_protocol._transport._protocol_paused + def write(self, data): """Write some data bytes to the transport. diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index 890fce8b405..f91fcddb2aa 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -425,7 +425,8 @@ class IocpProactor: try: return ov.getresult() except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): raise ConnectionResetError(*exc.args) else: raise @@ -447,7 +448,8 @@ class IocpProactor: try: return ov.getresult() except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): raise ConnectionResetError(*exc.args) else: raise @@ -466,7 +468,8 @@ class IocpProactor: try: return ov.getresult() except OSError as exc: - if exc.winerror == _overlapped.ERROR_NETNAME_DELETED: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): raise ConnectionResetError(*exc.args) else: raise diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index 6489f50f272..ab6560c70b9 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -1788,7 +1788,7 @@ class RunningLoopTests(unittest.TestCase): outer_loop.close() -class BaseLoopSendfileTests(test_utils.TestCase): +class BaseLoopSockSendfileTests(test_utils.TestCase): DATA = b"12345abcde" * 16 * 1024 # 160 KiB @@ -1799,9 +1799,11 @@ class BaseLoopSendfileTests(test_utils.TestCase): self.closed = False self.data = bytearray() self.fut = loop.create_future() + self.transport = None def connection_made(self, transport): self.started = True + self.transport = transport def data_received(self, data): self.data.extend(data) @@ -1809,6 +1811,7 @@ class BaseLoopSendfileTests(test_utils.TestCase): def connection_lost(self, exc): self.closed = True self.fut.set_result(None) + self.transport = None async def wait_closed(self): await self.fut @@ -1853,6 +1856,10 @@ class BaseLoopSendfileTests(test_utils.TestCase): def cleanup(): server.close() self.run_loop(server.wait_closed()) + sock.close() + if proto.transport is not None: + proto.transport.close() + self.run_loop(proto.wait_closed()) self.addCleanup(cleanup) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index cf217538a06..0981bd6ac91 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -26,6 +26,7 @@ if sys.platform != 'win32': import tty import asyncio +from asyncio import base_events from asyncio import coroutines from asyncio import events from asyncio import proactor_events @@ -2090,14 +2091,308 @@ class SubprocessTestsMixin: self.loop.run_until_complete(connect(shell=False)) +class MySendfileProto(MyBaseProto): + + def __init__(self, loop=None, close_after=0): + super().__init__(loop) + self.data = bytearray() + self.close_after = close_after + + def data_received(self, data): + self.data.extend(data) + super().data_received(data) + if self.close_after and self.nbytes >= self.close_after: + self.transport.close() + + +class SendfileMixin: + # Note: sendfile via SSL transport is equal to sendfile fallback + + DATA = b"12345abcde" * 160 * 1024 # 160 KiB + + @classmethod + def setUpClass(cls): + with open(support.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + super().tearDownClass() + + def setUp(self): + self.file = open(support.TESTFN, 'rb') + self.addCleanup(self.file.close) + super().setUp() + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + def prepare(self, *, is_ssl=False, close_after=0): + port = support.find_unused_port() + srv_proto = MySendfileProto(loop=self.loop, close_after=close_after) + if is_ssl: + srv_ctx = test_utils.simple_server_sslcontext() + cli_ctx = test_utils.simple_client_sslcontext() + else: + srv_ctx = None + cli_ctx = None + srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # reduce recv socket buffer size to test on relative small data sets + srv_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + srv_sock.bind((support.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) + + if is_ssl: + server_hostname = support.HOST + else: + server_hostname = None + cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # reduce send socket buffer size to test on relative small data sets + cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + cli_sock.connect((support.HOST, port)) + cli_proto = MySendfileProto(loop=self.loop) + tr, pr = self.run_loop(self.loop.create_connection( + lambda: cli_proto, sock=cli_sock, + ssl=cli_ctx, server_hostname=server_hostname)) + + def cleanup(): + srv_proto.transport.close() + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.run_loop(cli_proto.done) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + return srv_proto, cli_proto + + @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") + def test_sendfile_not_supported(self): + tr, pr = self.run_loop( + self.loop.create_datagram_endpoint( + lambda: MyDatagramProto(loop=self.loop), + family=socket.AF_INET)) + try: + with self.assertRaisesRegex(RuntimeError, "not supported"): + self.run_loop( + self.loop.sendfile(tr, self.file)) + self.assertEqual(0, self.file.tell()) + finally: + # don't use self.addCleanup because it produces resource warning + tr.close() + + def test_sendfile(self): + srv_proto, cli_proto = self.prepare() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_fallback(self): + srv_proto, cli_proto = self.prepare() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_force_unsupported_native(self): + if sys.platform == 'win32': + if isinstance(self.loop, asyncio.ProactorEventLoop): + self.skipTest("Fails on proactor event loop") + srv_proto, cli_proto = self.prepare() + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + with self.assertRaisesRegex(events.SendfileNotAvailableError, + "not supported"): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, + fallback=False)) + + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_ssl(self): + srv_proto, cli_proto = self.prepare(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_for_closing_transp(self): + srv_proto, cli_proto = self.prepare() + cli_proto.transport.close() + with self.assertRaisesRegex(RuntimeError, "is closing"): + self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(srv_proto.nbytes, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sendfile_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare() + PREFIX = b'zxcvbnm' * 1024 + SUFFIX = b'0987654321' * 1024 + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_pre_and_post_data(self): + srv_proto, cli_proto = self.prepare(is_ssl=True) + PREFIX = b'zxcvbnm' * 1024 + SUFFIX = b'0987654321' * 1024 + cli_proto.transport.write(PREFIX) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.write(SUFFIX) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_partial(self): + srv_proto, cli_proto = self.prepare() + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_ssl_partial(self): + srv_proto, cli_proto = self.prepare(is_ssl=True) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, 100) + self.assertEqual(srv_proto.nbytes, 100) + self.assertEqual(srv_proto.data, self.DATA[1000:1100]) + self.assertEqual(self.file.tell(), 1100) + + def test_sendfile_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare(close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + cli_proto.transport.close() + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_ssl_close_peer_after_receiving(self): + srv_proto, cli_proto = self.prepare(is_ssl=True, + close_after=len(self.DATA)) + ret = self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(srv_proto.nbytes, len(self.DATA)) + self.assertEqual(srv_proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sendfile_close_peer_in_middle_of_receiving(self): + srv_proto, cli_proto = self.prepare(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + + def test_sendfile_fallback_close_peer_in_middle_of_receiving(self): + + def sendfile_native(transp, file, offset, count): + # to raise SendfileNotAvailableError + return base_events.BaseEventLoop._sendfile_native( + self.loop, transp, file, offset, count) + + self.loop._sendfile_native = sendfile_native + + srv_proto, cli_proto = self.prepare(close_after=1024) + with self.assertRaises(ConnectionError): + self.run_loop( + self.loop.sendfile(cli_proto.transport, self.file)) + self.run_loop(srv_proto.done) + + self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), + srv_proto.nbytes) + self.assertTrue(1024 <= self.file.tell() < len(self.DATA), + self.file.tell()) + + @unittest.skipIf(not hasattr(os, 'sendfile'), + "Don't have native sendfile support") + def test_sendfile_prevents_bare_write(self): + srv_proto, cli_proto = self.prepare() + fut = self.loop.create_future() + + async def coro(): + fut.set_result(None) + return await self.loop.sendfile(cli_proto.transport, self.file) + + t = self.loop.create_task(coro()) + self.run_loop(fut) + with self.assertRaisesRegex(RuntimeError, + "sendfile is in progress"): + cli_proto.transport.write(b'data') + ret = self.run_loop(t) + self.assertEqual(ret, len(self.DATA)) + + if sys.platform == 'win32': - class SelectEventLoopTests(EventLoopTestsMixin, test_utils.TestCase): + class SelectEventLoopTests(EventLoopTestsMixin, + SendfileMixin, + test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop() class ProactorEventLoopTests(EventLoopTestsMixin, + SendfileMixin, SubprocessTestsMixin, test_utils.TestCase): @@ -2125,7 +2420,7 @@ if sys.platform == 'win32': else: import selectors - class UnixEventLoopTestsMixin(EventLoopTestsMixin): + class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin): def setUp(self): super().setUp() watcher = asyncio.SafeChildWatcher() @@ -2556,7 +2851,9 @@ class AbstractEventLoopTests(unittest.TestCase): with self.assertRaises(NotImplementedError): await loop.sock_accept(f) with self.assertRaises(NotImplementedError): - await loop.sock_sendfile(f, mock.Mock()) + await loop.sock_sendfile(f, f) + with self.assertRaises(NotImplementedError): + await loop.sendfile(f, f) with self.assertRaises(NotImplementedError): await loop.connect_read_pipe(f, mock.sentinel.pipe) with self.assertRaises(NotImplementedError): diff --git a/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst b/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst new file mode 100644 index 00000000000..d7433fa3cb1 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-01-22-18-18-44.bpo-32622.A1D6FP.rst @@ -0,0 +1 @@ +Add :meth:`asyncio.AbstractEventLoop.sendfile` method. diff --git a/Modules/overlapped.c b/Modules/overlapped.c index e66e8566840..447a337fdd1 100644 --- a/Modules/overlapped.c +++ b/Modules/overlapped.c @@ -1436,6 +1436,7 @@ PyInit__overlapped(void) WINAPI_CONSTANT(F_DWORD, ERROR_IO_PENDING); WINAPI_CONSTANT(F_DWORD, ERROR_NETNAME_DELETED); + WINAPI_CONSTANT(F_DWORD, ERROR_OPERATION_ABORTED); WINAPI_CONSTANT(F_DWORD, ERROR_SEM_TIMEOUT); WINAPI_CONSTANT(F_DWORD, ERROR_PIPE_BUSY); WINAPI_CONSTANT(F_DWORD, INFINITE);