"""Tests for sendfile functionality.""" import asyncio import os import socket import sys import tempfile import unittest from asyncio import base_events from asyncio import constants from unittest import mock from test import support from test.support import socket_helper from test.test_asyncio import utils as test_utils try: import ssl except ImportError: ssl = None def tearDownModule(): asyncio.set_event_loop_policy(None) class MySendfileProto(asyncio.Protocol): def __init__(self, loop=None, close_after=0): self.transport = None self.state = 'INITIAL' self.nbytes = 0 if loop is not None: self.connected = loop.create_future() self.done = loop.create_future() self.data = bytearray() self.close_after = close_after def connection_made(self, transport): self.transport = transport assert self.state == 'INITIAL', self.state self.state = 'CONNECTED' if self.connected: self.connected.set_result(None) def eof_received(self): assert self.state == 'CONNECTED', self.state self.state = 'EOF' def connection_lost(self, exc): assert self.state in ('CONNECTED', 'EOF'), self.state self.state = 'CLOSED' if self.done: self.done.set_result(None) def data_received(self, data): assert self.state == 'CONNECTED', self.state self.nbytes += len(data) self.data.extend(data) super().data_received(data) if self.close_after and self.nbytes >= self.close_after: self.transport.close() class MyProto(asyncio.Protocol): def __init__(self, loop): self.started = False self.closed = False self.data = bytearray() self.fut = loop.create_future() self.transport = None def connection_made(self, transport): self.started = True self.transport = transport def data_received(self, data): self.data.extend(data) def connection_lost(self, exc): self.closed = True self.fut.set_result(None) async def wait_closed(self): await self.fut class SendfileBase: # 128 KiB plus small unaligned to buffer chunk DATA = b"SendfileBaseData" * (1024 * 8 + 1) # Reduce socket buffer size to test on relative small data sets. BUF_SIZE = 4 * 1024 # 4 KiB def create_event_loop(self): raise NotImplementedError @classmethod def setUpClass(cls): with open(support.TESTFN, 'wb') as fp: fp.write(cls.DATA) super().setUpClass() @classmethod def tearDownClass(cls): support.unlink(support.TESTFN) super().tearDownClass() def setUp(self): self.file = open(support.TESTFN, 'rb') self.addCleanup(self.file.close) self.loop = self.create_event_loop() self.set_event_loop(self.loop) super().setUp() def tearDown(self): # just in case if we have transport close callbacks if not self.loop.is_closed(): test_utils.run_briefly(self.loop) self.doCleanups() support.gc_collect() super().tearDown() def run_loop(self, coro): return self.loop.run_until_complete(coro) class SockSendfileMixin(SendfileBase): @classmethod def setUpClass(cls): cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16 super().setUpClass() @classmethod def tearDownClass(cls): constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize super().tearDownClass() def make_socket(self, cleanup=True): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setblocking(False) if cleanup: self.addCleanup(sock.close) return sock def reduce_receive_buffer_size(self, sock): # Reduce receive socket buffer size to test on relative # small data sets. sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE) def reduce_send_buffer_size(self, sock, transport=None): # Reduce send socket buffer size to test on relative small data sets. # On macOS, SO_SNDBUF is reset by connect(). So this method # should be called after the socket is connected. sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE) if transport is not None: transport.set_write_buffer_limits(high=self.BUF_SIZE) def prepare_socksendfile(self): proto = MyProto(self.loop) port = socket_helper.find_unused_port() srv_sock = self.make_socket(cleanup=False) srv_sock.bind((socket_helper.HOST, port)) server = self.run_loop(self.loop.create_server( lambda: proto, sock=srv_sock)) self.reduce_receive_buffer_size(srv_sock) sock = self.make_socket() self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port))) self.reduce_send_buffer_size(sock) def cleanup(): if proto.transport is not None: # can be None if the task was cancelled before # connection_made callback proto.transport.close() self.run_loop(proto.wait_closed()) server.close() self.run_loop(server.wait_closed()) self.addCleanup(cleanup) return sock, proto def test_sock_sendfile_success(self): sock, proto = self.prepare_socksendfile() ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) sock.close() self.run_loop(proto.wait_closed()) self.assertEqual(ret, len(self.DATA)) self.assertEqual(proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sock_sendfile_with_offset_and_count(self): sock, proto = self.prepare_socksendfile() ret = self.run_loop(self.loop.sock_sendfile(sock, self.file, 1000, 2000)) sock.close() self.run_loop(proto.wait_closed()) self.assertEqual(proto.data, self.DATA[1000:3000]) self.assertEqual(self.file.tell(), 3000) self.assertEqual(ret, 2000) def test_sock_sendfile_zero_size(self): sock, proto = self.prepare_socksendfile() with tempfile.TemporaryFile() as f: ret = self.run_loop(self.loop.sock_sendfile(sock, f, 0, None)) sock.close() self.run_loop(proto.wait_closed()) self.assertEqual(ret, 0) self.assertEqual(self.file.tell(), 0) def test_sock_sendfile_mix_with_regular_send(self): buf = b"mix_regular_send" * (4 * 1024) # 64 KiB sock, proto = self.prepare_socksendfile() self.run_loop(self.loop.sock_sendall(sock, buf)) ret = self.run_loop(self.loop.sock_sendfile(sock, self.file)) self.run_loop(self.loop.sock_sendall(sock, buf)) sock.close() self.run_loop(proto.wait_closed()) self.assertEqual(ret, len(self.DATA)) expected = buf + self.DATA + buf self.assertEqual(proto.data, expected) self.assertEqual(self.file.tell(), len(self.DATA)) class SendfileMixin(SendfileBase): # Note: sendfile via SSL transport is equal to sendfile fallback def prepare_sendfile(self, *, is_ssl=False, close_after=0): port = socket_helper.find_unused_port() srv_proto = MySendfileProto(loop=self.loop, close_after=close_after) if is_ssl: if not ssl: self.skipTest("No ssl module") srv_ctx = test_utils.simple_server_sslcontext() cli_ctx = test_utils.simple_client_sslcontext() else: srv_ctx = None cli_ctx = None srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) srv_sock.bind((socket_helper.HOST, port)) server = self.run_loop(self.loop.create_server( lambda: srv_proto, sock=srv_sock, ssl=srv_ctx)) self.reduce_receive_buffer_size(srv_sock) if is_ssl: server_hostname = socket_helper.HOST else: server_hostname = None cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) cli_sock.connect((socket_helper.HOST, port)) cli_proto = MySendfileProto(loop=self.loop) tr, pr = self.run_loop(self.loop.create_connection( lambda: cli_proto, sock=cli_sock, ssl=cli_ctx, server_hostname=server_hostname)) self.reduce_send_buffer_size(cli_sock, transport=tr) def cleanup(): srv_proto.transport.close() cli_proto.transport.close() self.run_loop(srv_proto.done) self.run_loop(cli_proto.done) server.close() self.run_loop(server.wait_closed()) self.addCleanup(cleanup) return srv_proto, cli_proto @unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported") def test_sendfile_not_supported(self): tr, pr = self.run_loop( self.loop.create_datagram_endpoint( asyncio.DatagramProtocol, family=socket.AF_INET)) try: with self.assertRaisesRegex(RuntimeError, "not supported"): self.run_loop( self.loop.sendfile(tr, self.file)) self.assertEqual(0, self.file.tell()) finally: # don't use self.addCleanup because it produces resource warning tr.close() def test_sendfile(self): srv_proto, cli_proto = self.prepare_sendfile() ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.nbytes, len(self.DATA)) self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_force_fallback(self): srv_proto, cli_proto = self.prepare_sendfile() def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError return base_events.BaseEventLoop._sendfile_native( self.loop, transp, file, offset, count) self.loop._sendfile_native = sendfile_native ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.nbytes, len(self.DATA)) self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_force_unsupported_native(self): if sys.platform == 'win32': if isinstance(self.loop, asyncio.ProactorEventLoop): self.skipTest("Fails on proactor event loop") srv_proto, cli_proto = self.prepare_sendfile() def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError return base_events.BaseEventLoop._sendfile_native( self.loop, transp, file, offset, count) self.loop._sendfile_native = sendfile_native with self.assertRaisesRegex(asyncio.SendfileNotAvailableError, "not supported"): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, fallback=False)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(srv_proto.nbytes, 0) self.assertEqual(self.file.tell(), 0) def test_sendfile_ssl(self): srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.nbytes, len(self.DATA)) self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_for_closing_transp(self): srv_proto, cli_proto = self.prepare_sendfile() cli_proto.transport.close() with self.assertRaisesRegex(RuntimeError, "is closing"): self.run_loop(self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) self.assertEqual(srv_proto.nbytes, 0) self.assertEqual(self.file.tell(), 0) def test_sendfile_pre_and_post_data(self): srv_proto, cli_proto = self.prepare_sendfile() PREFIX = b'PREFIX__' * 1024 # 8 KiB SUFFIX = b'--SUFFIX' * 1024 # 8 KiB cli_proto.transport.write(PREFIX) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.write(SUFFIX) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_ssl_pre_and_post_data(self): srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) PREFIX = b'zxcvbnm' * 1024 SUFFIX = b'0987654321' * 1024 cli_proto.transport.write(PREFIX) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.write(SUFFIX) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_partial(self): srv_proto, cli_proto = self.prepare_sendfile() ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, 100) self.assertEqual(srv_proto.nbytes, 100) self.assertEqual(srv_proto.data, self.DATA[1000:1100]) self.assertEqual(self.file.tell(), 1100) def test_sendfile_ssl_partial(self): srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file, 1000, 100)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, 100) self.assertEqual(srv_proto.nbytes, 100) self.assertEqual(srv_proto.data, self.DATA[1000:1100]) self.assertEqual(self.file.tell(), 1100) def test_sendfile_close_peer_after_receiving(self): srv_proto, cli_proto = self.prepare_sendfile( close_after=len(self.DATA)) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) cli_proto.transport.close() self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.nbytes, len(self.DATA)) self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_ssl_close_peer_after_receiving(self): srv_proto, cli_proto = self.prepare_sendfile( is_ssl=True, close_after=len(self.DATA)) ret = self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) self.assertEqual(ret, len(self.DATA)) self.assertEqual(srv_proto.nbytes, len(self.DATA)) self.assertEqual(srv_proto.data, self.DATA) self.assertEqual(self.file.tell(), len(self.DATA)) def test_sendfile_close_peer_in_the_middle_of_receiving(self): srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), srv_proto.nbytes) self.assertTrue(1024 <= self.file.tell() < len(self.DATA), self.file.tell()) self.assertTrue(cli_proto.transport.is_closing()) def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self): def sendfile_native(transp, file, offset, count): # to raise SendfileNotAvailableError return base_events.BaseEventLoop._sendfile_native( self.loop, transp, file, offset, count) self.loop._sendfile_native = sendfile_native srv_proto, cli_proto = self.prepare_sendfile(close_after=1024) with self.assertRaises(ConnectionError): self.run_loop( self.loop.sendfile(cli_proto.transport, self.file)) self.run_loop(srv_proto.done) self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA), srv_proto.nbytes) self.assertTrue(1024 <= self.file.tell() < len(self.DATA), self.file.tell()) @unittest.skipIf(not hasattr(os, 'sendfile'), "Don't have native sendfile support") def test_sendfile_prevents_bare_write(self): srv_proto, cli_proto = self.prepare_sendfile() fut = self.loop.create_future() async def coro(): fut.set_result(None) return await self.loop.sendfile(cli_proto.transport, self.file) t = self.loop.create_task(coro()) self.run_loop(fut) with self.assertRaisesRegex(RuntimeError, "sendfile is in progress"): cli_proto.transport.write(b'data') ret = self.run_loop(t) self.assertEqual(ret, len(self.DATA)) def test_sendfile_no_fallback_for_fallback_transport(self): transport = mock.Mock() transport.is_closing.side_effect = lambda: False transport._sendfile_compatible = constants._SendfileMode.FALLBACK with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'): self.loop.run_until_complete( self.loop.sendfile(transport, None, fallback=False)) class SendfileTestsBase(SendfileMixin, SockSendfileMixin): pass if sys.platform == 'win32': class SelectEventLoopTests(SendfileTestsBase, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop() class ProactorEventLoopTests(SendfileTestsBase, test_utils.TestCase): def create_event_loop(self): return asyncio.ProactorEventLoop() else: import selectors if hasattr(selectors, 'KqueueSelector'): class KqueueEventLoopTests(SendfileTestsBase, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop( selectors.KqueueSelector()) if hasattr(selectors, 'EpollSelector'): class EPollEventLoopTests(SendfileTestsBase, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.EpollSelector()) if hasattr(selectors, 'PollSelector'): class PollEventLoopTests(SendfileTestsBase, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.PollSelector()) # Should always exist. class SelectEventLoopTests(SendfileTestsBase, test_utils.TestCase): def create_event_loop(self): return asyncio.SelectorEventLoop(selectors.SelectSelector())