diff --git a/Doc/library/socket.rst b/Doc/library/socket.rst --- a/Doc/library/socket.rst +++ b/Doc/library/socket.rst @@ -1145,6 +1145,24 @@ .. versionadded:: 3.3 +.. method:: socket.sendfile(file, blocksize=262144, offset=0) + + Send a file attempting to use high-performance :mod:`os.sendfile`, in which + case *file* must be a regular file object; if not :meth:`send` will be used + as fallback. + File position is updated on return or also in case of error, therefore + :meth:`file.tell() ` can be used to figure out the number + of bytes which were transmitted. + *file* is a binary file object opened for reading, *blocksize* is the + maximum number of bytes to transmit at one time, *offset* tells from where + to start reading the file. + When the whole file has been sent return a tuple of 2 elements including: + a bool indicating whether :mod:`os.sendfile` was used and an exception + instance in case it couldn't on account of an internal error, if any, else + *None*. + Non blocking sockets are not supported. + + .. versionadded:: 3.5 .. method:: socket.set_inheritable(inheritable) diff --git a/Lib/socket.py b/Lib/socket.py --- a/Lib/socket.py +++ b/Lib/socket.py @@ -47,7 +47,7 @@ import _socket from _socket import * -import os, sys, io +import os, sys, io, selectors from enum import IntEnum try: @@ -109,6 +109,9 @@ __all__.append("errorTab") +class _GiveupOnSendfile(Exception): pass + + class socket(_socket.socket): """A subclass of _socket.socket adding the makefile() method.""" @@ -233,6 +236,118 @@ text.mode = mode return text + if hasattr(os, 'sendfile'): + + def _sendfile_use_sendfile(self, file, blocksize, offset): + sockno = self.fileno() + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise _GiveupOnSendfile(err) # not a regular file + + timeout = self.gettimeout() + if timeout == 0: + raise ValueError("non-blocking sockets are not supported") + if timeout: + # poll/select have the advantage of not requiring any + # extra file descriptor, contrarily to epoll/kqueue + # (also, they require a single syscall). + if hasattr(selectors, 'PollSelector'): + selector = selectors.PollSelector() + else: + selector = selectors.SelectSelector() + selector.register(sockno, selectors.EVENT_WRITE) + + original_offset = offset + try: + while True: + if timeout and not selector.select(timeout): + raise _socket.timeout('timed out') + try: + sent = os.sendfile(sockno, fileno, offset, blocksize) + except BlockingIOError: + continue + except OSError as err: + if offset == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + raise _GiveupOnSendfile(err) + raise + else: + if sent == 0: + break + offset += sent + finally: + # Moving the file offset is also necessary to behave + # like when using plain send(). + if original_offset != offset and hasattr(file, 'seek'): + file.seek(offset) + + def _sendfile_use_send(self, file, blocksize, offset): + if self.gettimeout() == 0: + raise ValueError("non-blocking sockets are not supported") + if offset: + file.seek(offset) + original_offset = offset + try: + while True: + data = memoryview(file.read(blocksize)) + if not data: + break + while True: + try: + sent = self.send(data) + except BlockingIOError: + continue + else: + offset += sent + if sent < len(data): + data = data[sent:] + else: + break + finally: + if original_offset != offset and hasattr(file, 'seek'): + file.seek(offset) + + def sendfile(self, file, blocksize=262144, offset=0): + """sendfile(file[, blocksize[, offset]]) -> (succeded, exc) + + Send a file attempting to use high-performance os.sendfile(), + in which case 'file' must be a regular file object; if not + socket.send() will be used as fallback. + File position is updated on return or also in case of error, + therefore file.tell() can be used to figure out the number + of bytes which were transmitted. + Arguments: + + - file: a binary file object opened for reading. + - blocksize: the maximum number of bytes to transmit at one + time (default 262144) + - offset: from where to start reading the file (default 0) + + When the whole file has been sent return a tuple of 2 + elements including: + - a bool indicating whether os.sendfile() was used + - an exception instance in case os.sendfile() couldn't be + used on account of an internal error, if any, else None. + + Non blocking sockets are not supported. + """ + if 'b' not in getattr(file, 'mode', 'b'): + raise ValueError("file should be opened in binary mode") + if hasattr(self, '_sendfile_use_sendfile'): + try: + self._sendfile_use_sendfile(file, blocksize, offset) + return (True, None) + except _GiveupOnSendfile as exc: + self._sendfile_use_send(file, blocksize, offset) + return (False, exc.args[0]) + else: + self._sendfile_use_send(file, blocksize, offset) + return (False, None) + def _decref_socketios(self): if self._io_refs > 0: self._io_refs -= 1 diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -710,6 +710,17 @@ else: return socket.sendall(self, data, flags) + def sendfile(self, file, blocksize=262144, offset=0): + """Sends a file, possibly by using os.sendfile() if this is a + clear-text socket. + """ + if self._sslobj is None: + # os.sendfile() works with plain sockets only + return super().sendfile(file, blocksize, offset) + else: + self._sendfile_use_send(file, blocksize, offset) + return (False, None) + def recv(self, buflen=1024, flags=0): self._checkClosed() if self._sslobj: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -19,6 +19,8 @@ import math import pickle import struct +import random +import string try: import multiprocessing except ImportError: @@ -5074,6 +5076,155 @@ source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileTest(ThreadedTCPSocketTest): + + FILESIZE = (10 * 1024 * 1024) # 10MB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() \ + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + # regular transfer + + def _testRegular(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + self.assertEqual(ret[1], None) + + def testRegular(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], False) + if hasattr(os, 'sendfile'): + self.assertIsInstance(ret[1], io.UnsupportedOperation) + else: + self.assertIs(ret[1], None) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file, offset=self.FILESIZE // 2) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE // 2) + self.assertEqual(data, self.FILEDATA[self.FILESIZE // 2:]) + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + + def _testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggered(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + self.assertRaises(socket.timeout, sock.sendfile, file, + blocksize=512) + + def testWithTimeoutTriggered(self): + conn = self.accept_conn() + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + # interrupt transfer halfway; this will cause a timeout + if received >= self.FILESIZE // 2: + break + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5126,6 +5277,7 @@ InterruptedRecvTimeoutTest, InterruptedSendTimeoutTest, TestSocketSharing, + SendfileTest, ]) thread_info = support.threading_setup() diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -2856,6 +2856,24 @@ self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + with open(support.TESTFN, 'rb') as file: + ret = s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + self.assertEqual(ret[0], False) + def test_main(verbose=False): if support.verbose: