fix issue #17552: add socket.sendfile() method allowing to send a file over a socket by using high-performance os.sendfile() on UNIX. Patch by Giampaolo Rodola'·

This commit is contained in:
Giampaolo Rodola' 2014-06-11 03:54:30 +02:00
parent b398d33c65
commit 915d14190e
9 changed files with 483 additions and 2 deletions

View File

@ -1092,6 +1092,10 @@ or `the MSDN <http://msdn.microsoft.com/en-us/library/z0kc8e3z.aspx>`_ on Window
Availability: Unix. Availability: Unix.
.. note::
For a higher-level version of this see :mod:`socket.socket.sendfile`.
.. versionadded:: 3.3 .. versionadded:: 3.3

View File

@ -1148,6 +1148,21 @@ to sockets.
.. versionadded:: 3.3 .. versionadded:: 3.3
.. method:: socket.sendfile(file, offset=0, count=None)
Send a file until EOF is reached by using high-performance
:mod:`os.sendfile` and return the total number of bytes which were sent.
*file* must be a regular file object opened in binary mode. If
:mod:`os.sendfile` is not available (e.g. Windows) or *file* is not a
regular file :meth:`send` will be used instead. *offset* tells from where to
start reading the file. If specified, *count* is the total number of bytes
to transmit as opposed to sending the file until EOF is reached. File
position is updated on return or also in case of error in which case
:meth:`file.tell() <io.IOBase.tell>` can be used to figure out the number of
bytes which were sent. The socket must be of :const:`SOCK_STREAM` type. Non-
blocking sockets are not supported.
.. versionadded:: 3.5
.. method:: socket.set_inheritable(inheritable) .. method:: socket.set_inheritable(inheritable)

View File

@ -789,6 +789,9 @@ SSL sockets provide the following methods of :ref:`socket-objects`:
(but passing a non-zero ``flags`` argument is not allowed) (but passing a non-zero ``flags`` argument is not allowed)
- :meth:`~socket.socket.send()`, :meth:`~socket.socket.sendall()` (with - :meth:`~socket.socket.send()`, :meth:`~socket.socket.sendall()` (with
the same limitation) the same limitation)
- :meth:`~socket.socket.sendfile()` (but :mod:`os.sendfile` will be used
for plain-text sockets only, else :meth:`~socket.socket.send()` will be used)
.. versionadded:: 3.5
- :meth:`~socket.socket.shutdown()` - :meth:`~socket.socket.shutdown()`
However, since the SSL (and TLS) protocol has its own framing atop However, since the SSL (and TLS) protocol has its own framing atop

View File

@ -181,9 +181,18 @@ signal
* Different constants of :mod:`signal` module are now enumeration values using * Different constants of :mod:`signal` module are now enumeration values using
the :mod:`enum` module. This allows meaningful names to be printed during the :mod:`enum` module. This allows meaningful names to be printed during
debugging, instead of integer “magic numbers”. (contribute by Giampaolo debugging, instead of integer “magic numbers”. (contributed by Giampaolo
Rodola' in :issue:`21076`) Rodola' in :issue:`21076`)
socket
------
* New :meth:`socket.socket.sendfile` method allows to send a file over a socket
by using high-performance :func:`os.sendfile` function on UNIX resulting in
uploads being from 2x to 3x faster than when using plain
:meth:`socket.socket.send`.
(contributed by Giampaolo Rodola' in :issue:`17552`)
xmlrpc xmlrpc
------ ------

View File

@ -47,7 +47,7 @@ the setsockopt() and getsockopt() methods.
import _socket import _socket
from _socket import * from _socket import *
import os, sys, io import os, sys, io, selectors
from enum import IntEnum from enum import IntEnum
try: try:
@ -109,6 +109,9 @@ if sys.platform.lower().startswith("win"):
__all__.append("errorTab") __all__.append("errorTab")
class _GiveupOnSendfile(Exception): pass
class socket(_socket.socket): class socket(_socket.socket):
"""A subclass of _socket.socket adding the makefile() method.""" """A subclass of _socket.socket adding the makefile() method."""
@ -233,6 +236,149 @@ class socket(_socket.socket):
text.mode = mode text.mode = mode
return text return text
if hasattr(os, 'sendfile'):
def _sendfile_use_sendfile(self, file, offset=0, count=None):
self._check_sendfile_params(file, offset, count)
sockno = self.fileno()
try:
fileno = file.fileno()
except (AttributeError, io.UnsupportedOperation) as err:
raise _GiveupOnSendfile(err) # not a regular file
try:
fsize = os.fstat(fileno).st_size
except OSError:
raise _GiveupOnSendfile(err) # not a regular file
if not fsize:
return 0 # empty file
blocksize = fsize if not count else count
timeout = self.gettimeout()
if timeout == 0:
raise ValueError("non-blocking sockets are not supported")
# poll/select have the advantage of not requiring any
# extra file descriptor, contrarily to epoll/kqueue
# (also, they require a single syscall).
if hasattr(selectors, 'PollSelector'):
selector = selectors.PollSelector()
else:
selector = selectors.SelectSelector()
selector.register(sockno, selectors.EVENT_WRITE)
total_sent = 0
# localize variable access to minimize overhead
selector_select = selector.select
os_sendfile = os.sendfile
try:
while True:
if timeout and not selector_select(timeout):
raise _socket.timeout('timed out')
if count:
blocksize = count - total_sent
if blocksize <= 0:
break
try:
sent = os_sendfile(sockno, fileno, offset, blocksize)
except BlockingIOError:
if not timeout:
# Block until the socket is ready to send some
# data; avoids hogging CPU resources.
selector_select()
continue
except OSError as err:
if total_sent == 0:
# We can get here for different reasons, the main
# one being 'file' is not a regular mmap(2)-like
# file, in which case we'll fall back on using
# plain send().
raise _GiveupOnSendfile(err)
raise err from None
else:
if sent == 0:
break # EOF
offset += sent
total_sent += sent
return total_sent
finally:
if total_sent > 0 and hasattr(file, 'seek'):
file.seek(offset)
else:
def _sendfile_use_sendfile(self, file, offset=0, count=None):
raise _GiveupOnSendfile(
"os.sendfile() not available on this platform")
def _sendfile_use_send(self, file, offset=0, count=None):
self._check_sendfile_params(file, offset, count)
if self.gettimeout() == 0:
raise ValueError("non-blocking sockets are not supported")
if offset:
file.seek(offset)
blocksize = min(count, 8192) if count else 8192
total_sent = 0
# localize variable access to minimize overhead
file_read = file.read
sock_send = self.send
try:
while True:
if count:
blocksize = min(count - total_sent, blocksize)
if blocksize <= 0:
break
data = memoryview(file_read(blocksize))
if not data:
break # EOF
while True:
try:
sent = sock_send(data)
except BlockingIOError:
continue
else:
total_sent += sent
if sent < len(data):
data = data[sent:]
else:
break
return total_sent
finally:
if total_sent > 0 and hasattr(file, 'seek'):
file.seek(offset + total_sent)
def _check_sendfile_params(self, file, offset, count):
if 'b' not in getattr(file, 'mode', 'b'):
raise ValueError("file should be opened in binary mode")
if not self.type & SOCK_STREAM:
raise ValueError("only SOCK_STREAM type sockets are supported")
if count is not None:
if not isinstance(count, int):
raise TypeError(
"count must be a positive integer (got {!r})".format(count))
if count <= 0:
raise ValueError(
"count must be a positive integer (got {!r})".format(count))
def sendfile(self, file, offset=0, count=None):
"""sendfile(file[, offset[, count]]) -> sent
Send a file until EOF is reached by using high-performance
os.sendfile() and return the total number of bytes which
were sent.
*file* must be a regular file object opened in binary mode.
If os.sendfile() is not available (e.g. Windows) or file is
not a regular file socket.send() will be used instead.
*offset* tells from where to start reading the file.
If specified, *count* is the total number of bytes to transmit
as opposed to sending the file until EOF is reached.
File position is updated on return or also in case of error in
which case file.tell() can be used to figure out the number of
bytes which were sent.
The socket must be of SOCK_STREAM type.
Non-blocking sockets are not supported.
"""
try:
return self._sendfile_use_sendfile(file, offset, count)
except _GiveupOnSendfile:
return self._sendfile_use_send(file, offset, count)
def _decref_socketios(self): def _decref_socketios(self):
if self._io_refs > 0: if self._io_refs > 0:
self._io_refs -= 1 self._io_refs -= 1

View File

@ -700,6 +700,16 @@ class SSLSocket(socket):
else: else:
return socket.sendall(self, data, flags) return socket.sendall(self, data, flags)
def sendfile(self, file, offset=0, count=None):
"""Send a file, possibly by using os.sendfile() if this is a
clear-text socket. Return the total number of bytes sent.
"""
if self._sslobj is None:
# os.sendfile() works with plain sockets only
return super().sendfile(file, offset, count)
else:
return self._sendfile_use_send(file, offset, count)
def recv(self, buflen=1024, flags=0): def recv(self, buflen=1024, flags=0):
self._checkClosed() self._checkClosed()
if self._sslobj: if self._sslobj:

View File

@ -19,6 +19,8 @@ import signal
import math import math
import pickle import pickle
import struct import struct
import random
import string
try: try:
import multiprocessing import multiprocessing
except ImportError: except ImportError:
@ -5077,6 +5079,275 @@ class TestSocketSharing(SocketTCPTest):
source.close() source.close()
@unittest.skipUnless(thread, 'Threading required for this test.')
class SendfileUsingSendTest(ThreadedTCPSocketTest):
"""
Test the send() implementation of socket.sendfile().
"""
FILESIZE = (10 * 1024 * 1024) # 10MB
BUFSIZE = 8192
FILEDATA = b""
TIMEOUT = 2
@classmethod
def setUpClass(cls):
def chunks(total, step):
assert total >= step
while total > step:
yield step
total -= step
if total:
yield total
chunk = b"".join([random.choice(string.ascii_letters).encode()
for i in range(cls.BUFSIZE)])
with open(support.TESTFN, 'wb') as f:
for csize in chunks(cls.FILESIZE, cls.BUFSIZE):
f.write(chunk)
with open(support.TESTFN, 'rb') as f:
cls.FILEDATA = f.read()
assert len(cls.FILEDATA) == cls.FILESIZE
@classmethod
def tearDownClass(cls):
support.unlink(support.TESTFN)
def accept_conn(self):
self.serv.settimeout(self.TIMEOUT)
conn, addr = self.serv.accept()
conn.settimeout(self.TIMEOUT)
self.addCleanup(conn.close)
return conn
def recv_data(self, conn):
received = []
while True:
chunk = conn.recv(self.BUFSIZE)
if not chunk:
break
received.append(chunk)
return b''.join(received)
def meth_from_sock(self, sock):
# Depending on the mixin class being run return either send()
# or sendfile() method implementation.
return getattr(sock, "_sendfile_use_send")
# regular file
def _testRegularFile(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, self.FILESIZE)
self.assertEqual(file.tell(), self.FILESIZE)
def testRegularFile(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE)
self.assertEqual(data, self.FILEDATA)
# non regular file
def _testNonRegularFile(self):
address = self.serv.getsockname()
file = io.BytesIO(self.FILEDATA)
with socket.create_connection(address) as sock, file as file:
sent = sock.sendfile(file)
self.assertEqual(sent, self.FILESIZE)
self.assertEqual(file.tell(), self.FILESIZE)
self.assertRaises(socket._GiveupOnSendfile,
sock._sendfile_use_sendfile, file)
def testNonRegularFile(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE)
self.assertEqual(data, self.FILEDATA)
# empty file
def _testEmptyFileSend(self):
address = self.serv.getsockname()
filename = support.TESTFN + "2"
with open(filename, 'wb'):
self.addCleanup(support.unlink, filename)
file = open(filename, 'rb')
with socket.create_connection(address) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, 0)
self.assertEqual(file.tell(), 0)
def testEmptyFileSend(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(data, b"")
# offset
def _testOffset(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file, offset=5000)
self.assertEqual(sent, self.FILESIZE - 5000)
self.assertEqual(file.tell(), self.FILESIZE)
def testOffset(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE - 5000)
self.assertEqual(data, self.FILEDATA[5000:])
# count
def _testCount(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
count = 5000007
meth = self.meth_from_sock(sock)
sent = meth(file, count=count)
self.assertEqual(sent, count)
self.assertEqual(file.tell(), count)
def testCount(self):
count = 5000007
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), count)
self.assertEqual(data, self.FILEDATA[:count])
# count small
def _testCountSmall(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
count = 1
meth = self.meth_from_sock(sock)
sent = meth(file, count=count)
self.assertEqual(sent, count)
self.assertEqual(file.tell(), count)
def testCountSmall(self):
count = 1
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), count)
self.assertEqual(data, self.FILEDATA[:count])
# count + offset
def _testCountWithOffset(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
count = 100007
meth = self.meth_from_sock(sock)
sent = meth(file, offset=2007, count=count)
self.assertEqual(sent, count)
self.assertEqual(file.tell(), count + 2007)
def testCountWithOffset(self):
count = 100007
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), count)
self.assertEqual(data, self.FILEDATA[2007:count+2007])
# non blocking sockets are not supposed to work
def _testNonBlocking(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address) as sock, file as file:
sock.setblocking(False)
meth = self.meth_from_sock(sock)
self.assertRaises(ValueError, meth, file)
self.assertRaises(ValueError, sock.sendfile, file)
def testNonBlocking(self):
conn = self.accept_conn()
if conn.recv(8192):
self.fail('was not supposed to receive any data')
# timeout (non-triggered)
def _testWithTimeout(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, self.FILESIZE)
def testWithTimeout(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE)
self.assertEqual(data, self.FILEDATA)
# timeout (triggered)
def _testWithTimeoutTriggeredSend(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=0.01) as sock, \
file as file:
meth = self.meth_from_sock(sock)
self.assertRaises(socket.timeout, meth, file)
def testWithTimeoutTriggeredSend(self):
conn = self.accept_conn()
conn.recv(88192)
# errors
def _test_errors(self):
pass
def test_errors(self):
with open(support.TESTFN, 'rb') as file:
with socket.socket(type=socket.SOCK_DGRAM) as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(
ValueError, "SOCK_STREAM", meth, file)
with open(support.TESTFN, 'rt') as file:
with socket.socket() as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(
ValueError, "binary mode", meth, file)
with open(support.TESTFN, 'rb') as file:
with socket.socket() as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(TypeError, "positive integer",
meth, file, count='2')
self.assertRaisesRegex(TypeError, "positive integer",
meth, file, count=0.1)
self.assertRaisesRegex(ValueError, "positive integer",
meth, file, count=0)
self.assertRaisesRegex(ValueError, "positive integer",
meth, file, count=-1)
@unittest.skipUnless(thread, 'Threading required for this test.')
@unittest.skipUnless(hasattr(os, "sendfile"),
'os.sendfile() required for this test.')
class SendfileUsingSendfileTest(SendfileUsingSendTest):
"""
Test the sendfile() implementation of socket.sendfile().
"""
def meth_from_sock(self, sock):
return getattr(sock, "_sendfile_use_sendfile")
def test_main(): def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@ -5129,6 +5400,8 @@ def test_main():
InterruptedRecvTimeoutTest, InterruptedRecvTimeoutTest,
InterruptedSendTimeoutTest, InterruptedSendTimeoutTest,
TestSocketSharing, TestSocketSharing,
SendfileUsingSendTest,
SendfileUsingSendfileTest,
]) ])
thread_info = support.threading_setup() thread_info = support.threading_setup()

View File

@ -2957,6 +2957,23 @@ else:
self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.read, 1024)
self.assertRaises(ValueError, s.write, b'hello') self.assertRaises(ValueError, s.write, b'hello')
def test_sendfile(self):
TEST_DATA = b"x" * 512
with open(support.TESTFN, 'wb') as f:
f.write(TEST_DATA)
self.addCleanup(support.unlink, support.TESTFN)
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
context.verify_mode = ssl.CERT_REQUIRED
context.load_verify_locations(CERTFILE)
context.load_cert_chain(CERTFILE)
server = ThreadedEchoServer(context=context, chatty=False)
with server:
with context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
with open(support.TESTFN, 'rb') as file:
s.sendfile(file)
self.assertEqual(s.recv(1024), TEST_DATA)
def test_main(verbose=False): def test_main(verbose=False):
if support.verbose: if support.verbose:

View File

@ -92,6 +92,10 @@ Core and Builtins
Library Library
------- -------
- Issue 17552: new socket.sendfile() method allowing to send a file over a
socket by using high-performance os.sendfile() on UNIX.
Patch by Giampaolo Rodola'.
- Issue #18039: dbm.dump.open() now always creates a new database when the - Issue #18039: dbm.dump.open() now always creates a new database when the
flag has the value 'n'. Patch by Claudiu Popa. flag has the value 'n'. Patch by Claudiu Popa.