From a19fb3c6aaa7632410d1d9dcb395d7101d124da4 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 25 Feb 2018 19:32:14 +0300 Subject: [PATCH] bpo-32622: Native sendfile on windows (#5565) * Support sendfile on Windows Proactor event loop naively. --- Lib/asyncio/proactor_events.py | 70 ++++++- Lib/asyncio/windows_events.py | 22 +++ Lib/test/test_asyncio/test_events.py | 187 ++++++++++++++---- Lib/test/test_asyncio/test_proactor_events.py | 116 +++++++++++ Lib/test/test_asyncio/test_unix_events.py | 64 +----- .../2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst | 1 + Modules/overlapped.c | 64 +++++- 7 files changed, 431 insertions(+), 93 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index 10ca6f8967f..b675c8200ce 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -6,11 +6,14 @@ proactor is only implemented on Windows with IOCP. __all__ = 'BaseProactorEventLoop', +import io +import os import socket import warnings from . import base_events from . import constants +from . import events from . import futures from . import protocols from . import sslproto @@ -107,6 +110,11 @@ class _ProactorBasePipeTransport(transports._FlowControlMixin, self._force_close(exc) def _force_close(self, exc): + if self._empty_waiter is not None: + if exc is None: + self._empty_waiter.set_result(None) + else: + self._empty_waiter.set_exception(exc) if self._closing: return self._closing = True @@ -327,6 +335,10 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, _start_tls_compatible = True + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._empty_waiter = None + def write(self, data): if not isinstance(data, (bytes, bytearray, memoryview)): raise TypeError( @@ -334,6 +346,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, f"not {type(data).__name__}") if self._eof_written: raise RuntimeError('write_eof() already called') + if self._empty_waiter is not None: + raise RuntimeError('unable to write; sendfile is in progress') if not data: return @@ -393,6 +407,8 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, self._maybe_pause_protocol() else: self._write_fut.add_done_callback(self._loop_writing) + if self._empty_waiter is not None and self._write_fut is None: + self._empty_waiter.set_result(None) except ConnectionResetError as exc: self._force_close(exc) except OSError as exc: @@ -407,6 +423,17 @@ class _ProactorBaseWritePipeTransport(_ProactorBasePipeTransport, def abort(self): self._force_close(None) + def _make_empty_waiter(self): + if self._empty_waiter is not None: + raise RuntimeError("Empty waiter is already set") + self._empty_waiter = self._loop.create_future() + if self._write_fut is None: + self._empty_waiter.set_result(None) + return self._empty_waiter + + def _reset_empty_waiter(self): + self._empty_waiter = None + class _ProactorWritePipeTransport(_ProactorBaseWritePipeTransport): def __init__(self, *args, **kw): @@ -447,7 +474,7 @@ class _ProactorSocketTransport(_ProactorReadPipeTransport, transports.Transport): """Transport for connected sockets.""" - _sendfile_compatible = constants._SendfileMode.FALLBACK + _sendfile_compatible = constants._SendfileMode.TRY_NATIVE def _set_extra(self, sock): self._extra['socket'] = sock @@ -556,6 +583,47 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): async def sock_accept(self, sock): return await self._proactor.accept(sock) + async def _sock_sendfile_native(self, sock, file, offset, count): + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise events.SendfileNotAvailableError("not a regular file") + try: + fsize = os.fstat(fileno).st_size + except OSError as err: + raise events.SendfileNotAvailableError("not a regular file") + blocksize = count if count else fsize + if not blocksize: + return 0 # empty file + + blocksize = min(blocksize, 0xffff_ffff) + end_pos = min(offset + count, fsize) if count else fsize + offset = min(offset, fsize) + total_sent = 0 + try: + while True: + blocksize = min(end_pos - offset, blocksize) + if blocksize <= 0: + return total_sent + await self._proactor.sendfile(sock, file, offset, blocksize) + offset += blocksize + total_sent += blocksize + finally: + if total_sent > 0: + file.seek(offset) + + async def _sendfile_native(self, transp, file, offset, count): + resume_reading = transp.is_reading() + transp.pause_reading() + await transp._make_empty_waiter() + try: + return await self.sock_sendfile(transp._sock, file, offset, count, + fallback=False) + finally: + transp._reset_empty_waiter() + if resume_reading: + transp.resume_reading() + def _close_self_pipe(self): if self._self_reading_future is not None: self._self_reading_future.cancel() diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py index f91fcddb2aa..d22edec51ef 100644 --- a/Lib/asyncio/windows_events.py +++ b/Lib/asyncio/windows_events.py @@ -4,6 +4,7 @@ import _overlapped import _winapi import errno import math +import msvcrt import socket import struct import weakref @@ -527,6 +528,27 @@ class IocpProactor: return self._register(ov, conn, finish_connect) + def sendfile(self, sock, file, offset, count): + self._register_with_iocp(sock) + ov = _overlapped.Overlapped(NULL) + offset_low = offset & 0xffff_ffff + offset_high = (offset >> 32) & 0xffff_ffff + ov.TransmitFile(sock.fileno(), + msvcrt.get_osfhandle(file.fileno()), + offset_low, offset_high, + count, 0, 0) + + def finish_sendfile(trans, key, ov): + try: + return ov.getresult() + except OSError as exc: + if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED, + _overlapped.ERROR_OPERATION_ABORTED): + raise ConnectionResetError(*exc.args) + else: + raise + return self._register(ov, sock, finish_sendfile) + def accept_pipe(self, pipe): self._register_with_iocp(pipe) ov = _overlapped.Overlapped(NULL) diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index f5995974c68..6accbdae8b3 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -15,6 +15,7 @@ except ImportError: ssl = None import subprocess import sys +import tempfile import threading import time import errno @@ -2092,22 +2093,7 @@ class SubprocessTestsMixin: self.loop.run_until_complete(connect(shell=False)) -class MySendfileProto(MyBaseProto): - - def __init__(self, loop=None, close_after=0): - super().__init__(loop) - self.data = bytearray() - self.close_after = close_after - - def data_received(self, data): - self.data.extend(data) - super().data_received(data) - if self.close_after and self.nbytes >= self.close_after: - self.transport.close() - - -class SendfileMixin: - # Note: sendfile via SSL transport is equal to sendfile fallback +class SendfileBase: DATA = b"12345abcde" * 160 * 1024 # 160 KiB @@ -2130,9 +2116,134 @@ class SendfileMixin: def run_loop(self, coro): return self.loop.run_until_complete(coro) - def prepare(self, *, is_ssl=False, close_after=0): + +class SockSendfileMixin(SendfileBase): + + class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + def make_socket(self, cleanup=True): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + if cleanup: + self.addCleanup(sock.close) + return sock + + def prepare_socksendfile(self): + sock = self.make_socket() + proto = self.MyProto(self.loop) port = support.find_unused_port() - srv_proto = MySendfileProto(loop=self.loop, close_after=close_after) + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind((support.HOST, port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) + + def cleanup(): + if proto.transport is not None: + # can be None if the task was cancelled before + # connection_made callback + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_success(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + self.assertEqual(proto.data, self.DATA) + self.assertEqual(self.file.tell(), len(self.DATA)) + + def test_sock_sendfile_with_offset_and_count(self): + sock, proto = self.prepare_socksendfile() + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, + 1000, 2000)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(proto.data, self.DATA[1000:3000]) + self.assertEqual(self.file.tell(), 3000) + self.assertEqual(ret, 2000) + + def test_sock_sendfile_zero_size(self): + sock, proto = self.prepare_socksendfile() + with tempfile.TemporaryFile() as f: + ret = self.run_loop(self.loop.sock_sendfile(sock, f, + 0, None)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, 0) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_mix_with_regular_send(self): + buf = b'1234567890' * 1024 * 1024 # 10 MB + sock, proto = self.prepare_socksendfile() + self.run_loop(self.loop.sock_sendall(sock, buf)) + ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) + self.run_loop(self.loop.sock_sendall(sock, buf)) + sock.close() + self.run_loop(proto.wait_closed()) + + self.assertEqual(ret, len(self.DATA)) + expected = buf + self.DATA + buf + self.assertEqual(proto.data, expected) + self.assertEqual(self.file.tell(), len(self.DATA)) + + +class SendfileMixin(SendfileBase): + + class MySendfileProto(MyBaseProto): + + def __init__(self, loop=None, close_after=0): + super().__init__(loop) + self.data = bytearray() + self.close_after = close_after + + def data_received(self, data): + self.data.extend(data) + super().data_received(data) + if self.close_after and self.nbytes >= self.close_after: + self.transport.close() + + + # Note: sendfile via SSL transport is equal to sendfile fallback + + def prepare_sendfile(self, *, is_ssl=False, close_after=0): + port = support.find_unused_port() + srv_proto = self.MySendfileProto(loop=self.loop, + close_after=close_after) if is_ssl: if not ssl: self.skipTest("No ssl module") @@ -2156,7 +2267,7 @@ class SendfileMixin: # reduce send socket buffer size to test on relative small data sets cli_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) cli_sock.connect((support.HOST, port)) - cli_proto = MySendfileProto(loop=self.loop) + cli_proto = self.MySendfileProto(loop=self.loop) tr, pr = self.run_loop(self.loop.create_connection( lambda: cli_proto, sock=cli_sock, ssl=cli_ctx, server_hostname=server_hostname)) @@ -2189,7 +2300,7 @@ class SendfileMixin: tr.close() def test_sendfile(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() @@ -2200,7 +2311,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_force_fallback(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError @@ -2222,7 +2333,7 @@ class SendfileMixin: if sys.platform == 'win32': if isinstance(self.loop, asyncio.ProactorEventLoop): self.skipTest("Fails on proactor event loop") - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError @@ -2243,7 +2354,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 0) def test_sendfile_ssl(self): - srv_proto, cli_proto = self.prepare(is_ssl=True) + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() @@ -2254,7 +2365,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_for_closing_transp(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() cli_proto.transport.close() with self.assertRaisesRegex(RuntimeError, "is closing"): self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) @@ -2263,7 +2374,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 0) def test_sendfile_pre_and_post_data(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() PREFIX = b'zxcvbnm' * 1024 SUFFIX = b'0987654321' * 1024 cli_proto.transport.write(PREFIX) @@ -2277,7 +2388,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_ssl_pre_and_post_data(self): - srv_proto, cli_proto = self.prepare(is_ssl=True) + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) PREFIX = b'zxcvbnm' * 1024 SUFFIX = b'0987654321' * 1024 cli_proto.transport.write(PREFIX) @@ -2291,7 +2402,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_partial(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) cli_proto.transport.close() @@ -2302,7 +2413,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 1100) def test_sendfile_ssl_partial(self): - srv_proto, cli_proto = self.prepare(is_ssl=True) + srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) cli_proto.transport.close() @@ -2313,7 +2424,8 @@ class SendfileMixin: self.assertEqual(self.file.tell(), 1100) def test_sendfile_close_peer_after_receiving(self): - srv_proto, cli_proto = self.prepare(close_after=len(self.DATA)) + srv_proto, cli_proto = self.prepare_sendfile( + close_after=len(self.DATA)) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() @@ -2324,8 +2436,8 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_ssl_close_peer_after_receiving(self): - srv_proto, cli_proto = self.prepare(is_ssl=True, - close_after=len(self.DATA)) + srv_proto, cli_proto = self.prepare_sendfile( + is_ssl=True, close_after=len(self.DATA)) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) @@ -2335,7 +2447,7 @@ class SendfileMixin: self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_close_peer_in_middle_of_receiving(self): - srv_proto, cli_proto = self.prepare(close_after=1024) + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) @@ -2345,6 +2457,7 @@ class SendfileMixin: srv_proto.nbytes) self.assertTrue(1024 <= self.file.tell() < len(self.DATA), self.file.tell()) + self.assertTrue(cli_proto.transport.is_closing()) def test_sendfile_fallback_close_peer_in_middle_of_receiving(self): @@ -2355,7 +2468,7 @@ class SendfileMixin: self.loop._sendfile_native = sendfile_native - srv_proto, cli_proto = self.prepare(close_after=1024) + srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) @@ -2369,7 +2482,7 @@ class SendfileMixin: @unittest.skipIf(not hasattr(os, 'sendfile'), "Don't have native sendfile support") def test_sendfile_prevents_bare_write(self): - srv_proto, cli_proto = self.prepare() + srv_proto, cli_proto = self.prepare_sendfile() fut = self.loop.create_future() async def coro(): @@ -2397,6 +2510,7 @@ if sys.platform == 'win32': class SelectEventLoopTests(EventLoopTestsMixin, SendfileMixin, + SockSendfileMixin, test_utils.TestCase): def create_event_loop(self): @@ -2404,6 +2518,7 @@ if sys.platform == 'win32': class ProactorEventLoopTests(EventLoopTestsMixin, SendfileMixin, + SockSendfileMixin, SubprocessTestsMixin, test_utils.TestCase): @@ -2431,7 +2546,9 @@ if sys.platform == 'win32': else: import selectors - class UnixEventLoopTestsMixin(EventLoopTestsMixin, SendfileMixin): + class UnixEventLoopTestsMixin(EventLoopTestsMixin, + SendfileMixin, + SockSendfileMixin): def setUp(self): super().setUp() watcher = asyncio.SafeChildWatcher() diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py index f627dfce0e1..98e698983ea 100644 --- a/Lib/test/test_asyncio/test_proactor_events.py +++ b/Lib/test/test_asyncio/test_proactor_events.py @@ -1,14 +1,18 @@ """Tests for proactor_events.py""" +import io import socket import unittest +import sys from unittest import mock import asyncio +from asyncio import events from asyncio.proactor_events import BaseProactorEventLoop from asyncio.proactor_events import _ProactorSocketTransport from asyncio.proactor_events import _ProactorWritePipeTransport from asyncio.proactor_events import _ProactorDuplexPipeTransport +from test import support from test.test_asyncio import utils as test_utils @@ -775,5 +779,117 @@ class BaseProactorEventLoopTests(test_utils.TestCase): self.assertFalse(future2.cancel.called) +@unittest.skipIf(sys.platform != 'win32', + 'Proactor is supported on Windows only') +class ProactorEventLoopUnixSockSendfileTests(test_utils.TestCase): + DATA = b"12345abcde" * 16 * 1024 # 160 KiB + + class MyProto(asyncio.Protocol): + + def __init__(self, loop): + self.started = False + self.closed = False + self.data = bytearray() + self.fut = loop.create_future() + self.transport = None + + def connection_made(self, transport): + self.started = True + self.transport = transport + + def data_received(self, data): + self.data.extend(data) + + def connection_lost(self, exc): + self.closed = True + self.fut.set_result(None) + + async def wait_closed(self): + await self.fut + + @classmethod + def setUpClass(cls): + with open(support.TESTFN, 'wb') as fp: + fp.write(cls.DATA) + super().setUpClass() + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + super().tearDownClass() + + def setUp(self): + self.loop = asyncio.ProactorEventLoop() + self.set_event_loop(self.loop) + self.addCleanup(self.loop.close) + self.file = open(support.TESTFN, 'rb') + self.addCleanup(self.file.close) + super().setUp() + + def make_socket(self, cleanup=True): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + if cleanup: + self.addCleanup(sock.close) + return sock + + def run_loop(self, coro): + return self.loop.run_until_complete(coro) + + def prepare(self): + sock = self.make_socket() + proto = self.MyProto(self.loop) + port = support.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind(('127.0.0.1', port)) + server = self.run_loop(self.loop.create_server( + lambda: proto, sock=srv_sock)) + self.run_loop(self.loop.sock_connect(sock, srv_sock.getsockname())) + + def cleanup(): + if proto.transport is not None: + # can be None if the task was cancelled before + # connection_made callback + proto.transport.close() + self.run_loop(proto.wait_closed()) + + server.close() + self.run_loop(server.wait_closed()) + + self.addCleanup(cleanup) + + return sock, proto + + def test_sock_sendfile_not_a_file(self): + sock, proto = self.prepare() + f = object() + with self.assertRaisesRegex(events.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_iobuffer(self): + sock, proto = self.prepare() + f = io.BytesIO() + with self.assertRaisesRegex(events.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + def test_sock_sendfile_not_regular_file(self): + sock, proto = self.prepare() + f = mock.Mock() + f.fileno.return_value = -1 + with self.assertRaisesRegex(events.SendfileNotAvailableError, + "not a regular file"): + self.run_loop(self.loop._sock_sendfile_native(sock, f, + 0, None)) + self.assertEqual(self.file.tell(), 0) + + if __name__ == '__main__': unittest.main() diff --git a/Lib/test/test_asyncio/test_unix_events.py b/Lib/test/test_asyncio/test_unix_events.py index 5bd76d30d2d..104f9959379 100644 --- a/Lib/test/test_asyncio/test_unix_events.py +++ b/Lib/test/test_asyncio/test_unix_events.py @@ -466,10 +466,13 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): self.addCleanup(self.file.close) super().setUp() - def make_socket(self, blocking=False): + def make_socket(self, cleanup=True): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setblocking(blocking) - self.addCleanup(sock.close) + sock.setblocking(False) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024) + if cleanup: + self.addCleanup(sock.close) return sock def run_loop(self, coro): @@ -479,8 +482,10 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): sock = self.make_socket() proto = self.MyProto(self.loop) port = support.find_unused_port() + srv_sock = self.make_socket(cleanup=False) + srv_sock.bind((support.HOST, port)) server = self.run_loop(self.loop.create_server( - lambda: proto, support.HOST, port)) + lambda: proto, sock=srv_sock)) self.run_loop(self.loop.sock_connect(sock, (support.HOST, port))) def cleanup(): @@ -497,27 +502,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): return sock, proto - def test_sock_sendfile_success(self): - sock, proto = self.prepare() - ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(ret, len(self.DATA)) - self.assertEqual(proto.data, self.DATA) - self.assertEqual(self.file.tell(), len(self.DATA)) - - def test_sock_sendfile_with_offset_and_count(self): - sock, proto = self.prepare() - ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, - 1000, 2000)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(proto.data, self.DATA[1000:3000]) - self.assertEqual(self.file.tell(), 3000) - self.assertEqual(ret, 2000) - def test_sock_sendfile_not_available(self): sock, proto = self.prepare() with mock.patch('asyncio.unix_events.os', spec=[]): @@ -555,36 +539,6 @@ class SelectorEventLoopUnixSockSendfileTests(test_utils.TestCase): 0, None)) self.assertEqual(self.file.tell(), 0) - def test_sock_sendfile_zero_size(self): - sock, proto = self.prepare() - fname = support.TESTFN + '.suffix' - with open(fname, 'wb') as f: - pass # make zero sized file - f = open(fname, 'rb') - self.addCleanup(f.close) - self.addCleanup(support.unlink, fname) - ret = self.run_loop(self.loop._sock_sendfile_native(sock, f, - 0, None)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(ret, 0) - self.assertEqual(self.file.tell(), 0) - - def test_sock_sendfile_mix_with_regular_send(self): - buf = b'1234567890' * 1024 * 1024 # 10 MB - sock, proto = self.prepare() - self.run_loop(self.loop.sock_sendall(sock, buf)) - ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) - self.run_loop(self.loop.sock_sendall(sock, buf)) - sock.close() - self.run_loop(proto.wait_closed()) - - self.assertEqual(ret, len(self.DATA)) - expected = buf + self.DATA + buf - self.assertEqual(proto.data, expected) - self.assertEqual(self.file.tell(), len(self.DATA)) - def test_sock_sendfile_cancel1(self): sock, proto = self.prepare() diff --git a/Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst b/Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst new file mode 100644 index 00000000000..456a6dc5595 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2018-02-06-17-58-15.bpo-32622.AE0Jz7.rst @@ -0,0 +1 @@ +Implement native fast sendfile for Windows proactor event loop. diff --git a/Modules/overlapped.c b/Modules/overlapped.c index 447a337fdd1..ae7cddadd02 100644 --- a/Modules/overlapped.c +++ b/Modules/overlapped.c @@ -39,7 +39,7 @@ enum {TYPE_NONE, TYPE_NOT_STARTED, TYPE_READ, TYPE_READINTO, TYPE_WRITE, TYPE_ACCEPT, TYPE_CONNECT, TYPE_DISCONNECT, TYPE_CONNECT_NAMED_PIPE, - TYPE_WAIT_NAMED_PIPE_AND_CONNECT}; + TYPE_WAIT_NAMED_PIPE_AND_CONNECT, TYPE_TRANSMIT_FILE}; typedef struct { PyObject_HEAD @@ -89,6 +89,7 @@ SetFromWindowsErr(DWORD err) static LPFN_ACCEPTEX Py_AcceptEx = NULL; static LPFN_CONNECTEX Py_ConnectEx = NULL; static LPFN_DISCONNECTEX Py_DisconnectEx = NULL; +static LPFN_TRANSMITFILE Py_TransmitFile = NULL; static BOOL (CALLBACK *Py_CancelIoEx)(HANDLE, LPOVERLAPPED) = NULL; #define GET_WSA_POINTER(s, x) \ @@ -102,6 +103,7 @@ initialize_function_pointers(void) GUID GuidAcceptEx = WSAID_ACCEPTEX; GUID GuidConnectEx = WSAID_CONNECTEX; GUID GuidDisconnectEx = WSAID_DISCONNECTEX; + GUID GuidTransmitFile = WSAID_TRANSMITFILE; HINSTANCE hKernel32; SOCKET s; DWORD dwBytes; @@ -114,7 +116,8 @@ initialize_function_pointers(void) if (!GET_WSA_POINTER(s, AcceptEx) || !GET_WSA_POINTER(s, ConnectEx) || - !GET_WSA_POINTER(s, DisconnectEx)) + !GET_WSA_POINTER(s, DisconnectEx) || + !GET_WSA_POINTER(s, TransmitFile)) { closesocket(s); SetFromWindowsErr(WSAGetLastError()); @@ -1194,6 +1197,61 @@ Overlapped_DisconnectEx(OverlappedObject *self, PyObject *args) } } +PyDoc_STRVAR( + Overlapped_TransmitFile_doc, + "TransmitFile(socket, file, offset, offset_high, " + "count_to_write, count_per_send, flags) " + "-> Overlapped[None]\n\n" + "Transmit file data over a connected socket."); + +static PyObject * +Overlapped_TransmitFile(OverlappedObject *self, PyObject *args) +{ + SOCKET Socket; + HANDLE File; + DWORD offset; + DWORD offset_high; + DWORD count_to_write; + DWORD count_per_send; + DWORD flags; + BOOL ret; + DWORD err; + + if (!PyArg_ParseTuple(args, + F_HANDLE F_HANDLE F_DWORD F_DWORD + F_DWORD F_DWORD F_DWORD, + &Socket, &File, &offset, &offset_high, + &count_to_write, &count_per_send, + &flags)) + return NULL; + + if (self->type != TYPE_NONE) { + PyErr_SetString(PyExc_ValueError, "operation already attempted"); + return NULL; + } + + self->type = TYPE_TRANSMIT_FILE; + self->handle = (HANDLE)Socket; + self->overlapped.Offset = offset; + self->overlapped.OffsetHigh = offset_high; + + Py_BEGIN_ALLOW_THREADS + ret = Py_TransmitFile(Socket, File, count_to_write, count_per_send, + &self->overlapped, + NULL, flags); + Py_END_ALLOW_THREADS + + self->error = err = ret ? ERROR_SUCCESS : WSAGetLastError(); + switch (err) { + case ERROR_SUCCESS: + case ERROR_IO_PENDING: + Py_RETURN_NONE; + default: + self->type = TYPE_NOT_STARTED; + return SetFromWindowsErr(err); + } +} + PyDoc_STRVAR( Overlapped_ConnectNamedPipe_doc, "ConnectNamedPipe(handle) -> Overlapped[None]\n\n" @@ -1303,6 +1361,8 @@ static PyMethodDef Overlapped_methods[] = { METH_VARARGS, Overlapped_ConnectEx_doc}, {"DisconnectEx", (PyCFunction) Overlapped_DisconnectEx, METH_VARARGS, Overlapped_DisconnectEx_doc}, + {"TransmitFile", (PyCFunction) Overlapped_TransmitFile, + METH_VARARGS, Overlapped_TransmitFile_doc}, {"ConnectNamedPipe", (PyCFunction) Overlapped_ConnectNamedPipe, METH_VARARGS, Overlapped_ConnectNamedPipe_doc}, {NULL}