diff -r 8c8315bac6a8 Doc/library/socket.rst --- a/Doc/library/socket.rst Sun Apr 20 09:45:00 2014 -0700 +++ b/Doc/library/socket.rst Tue Apr 22 19:38:36 2014 +0200 @@ -1145,6 +1145,22 @@ .. versionadded:: 3.3 +.. method:: socket.sendfile(file, blocksize=262144, offset=0, use_fallback=True) + + Send a file until EOF is reached attempting to use high-performance + :mod:`os.sendfile` in which case *file* must be a regular file object + opened in binary mode; if not and *use_fallback* is ``True`` :meth:`send` + will be used instead. + File position is updated on return or also in case of error in which case + :meth:`file.tell() ` can be used to figure out the number of + bytes which were transmitted. + *blocksize* is the maximum number of bytes to transmit at one time, *offset* + tells from where to start reading the file. + The socket must be of :const:`SOCK_STREAM` type. + Non-blocking sockets are not supported. + Return the total number of bytes which were transmitted. + + .. versionadded:: 3.5 .. method:: socket.set_inheritable(inheritable) diff -r 8c8315bac6a8 Lib/socket.py --- a/Lib/socket.py Sun Apr 20 09:45:00 2014 -0700 +++ b/Lib/socket.py Tue Apr 22 19:38:36 2014 +0200 @@ -47,7 +47,7 @@ import _socket from _socket import * -import os, sys, io +import os, sys, io, selectors from enum import IntEnum try: @@ -109,6 +109,9 @@ __all__.append("errorTab") +class _GiveupOnSendfile(Exception): pass + + class socket(_socket.socket): """A subclass of _socket.socket adding the makefile() method.""" @@ -233,6 +236,121 @@ text.mode = mode return text + if hasattr(os, 'sendfile'): + + def _sendfile_use_sendfile(self, file, blocksize=262144, offset=0, + use_fallback=False): + sockno = self.fileno() + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + if use_fallback: + raise _GiveupOnSendfile(err) # not a regular file + raise err from None + + timeout = self.gettimeout() + if timeout == 0: + raise ValueError("non-blocking sockets are not supported") + # poll/select have the advantage of not requiring any + # extra file descriptor, contrarily to epoll/kqueue + # (also, they require a single syscall). + if hasattr(selectors, 'PollSelector'): + selector = selectors.PollSelector() + else: + selector = selectors.SelectSelector() + selector.register(sockno, selectors.EVENT_WRITE) + + original_offset = offset + try: + while True: + if timeout and not selector.select(timeout): + raise _socket.timeout('timed out') + try: + sent = os.sendfile(sockno, fileno, offset, blocksize) + except BlockingIOError: + if not timeout: + # Block until the socket is ready to send some + # data; avoids hogging CPU resources. + selector.select() + continue + except OSError as err: + if offset == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + if use_fallback: + raise _GiveupOnSendfile(err) + raise err from None + raise + else: + if sent == 0: + break + offset += sent + return offset + finally: + if original_offset != offset and hasattr(file, 'seek'): + file.seek(offset) + else: + def _sendfile_use_sendfile(self, file, blocksize=262144, offset=0, + use_fallback=False): + if use_fallback: + raise _GiveupOnSendfile + raise NotImplementedError("os.sendfile() is not available") + + def _sendfile_use_send(self, file, blocksize=262144, offset=0): + if self.gettimeout() == 0: + raise ValueError("non-blocking sockets are not supported") + if offset: + file.seek(offset) + original_offset = offset + try: + while True: + data = memoryview(file.read(blocksize)) + if not data: + break + while True: + try: + sent = self.send(data) + except BlockingIOError: + continue + else: + offset += sent + if sent < len(data): + data = data[sent:] + else: + break + return offset + finally: + if original_offset != offset and hasattr(file, 'seek'): + file.seek(offset) + + def sendfile(self, file, blocksize=262144, offset=0, use_fallback=True): + """sendfile(file[, blocksize[, offset[, use_fallback]]]) -> sent + + Send a file until EOF is reached attempting to use high-performance + os.sendfile() in which case *file* must be a regular file object + opened in binary mode; if not and *use_fallback* is True send() + will be used instead. + File position is updated on return or also in case of error in + which case file.tell() can be used to figure out the number of + bytes which were transmitted. + *blocksize* is the maximum number of bytes to transmit at one time, + *offset* tells from where to start reading the file. + The socket must be of SOCK_STREAM type. + Non-blocking sockets are not supported. + Return the total number of bytes which were transmitted. + """ + if 'b' not in getattr(file, 'mode', 'b'): + raise ValueError("file should be opened in binary mode") + if not self.type & SOCK_STREAM: + raise ValueError("only SOCK_STREAM type sockets are supported") + try: + return self._sendfile_use_sendfile(file, blocksize, offset, + use_fallback) + except _GiveupOnSendfile as exc: + return self._sendfile_use_send(file, blocksize, offset) + def _decref_socketios(self): if self._io_refs > 0: self._io_refs -= 1 diff -r 8c8315bac6a8 Lib/ssl.py --- a/Lib/ssl.py Sun Apr 20 09:45:00 2014 -0700 +++ b/Lib/ssl.py Tue Apr 22 19:38:36 2014 +0200 @@ -710,6 +710,16 @@ else: return socket.sendall(self, data, flags) + def sendfile(self, file, blocksize=262144, offset=0): + """Sends a file, possibly by using os.sendfile() if this is a + clear-text socket. + """ + if self._sslobj is None: + # os.sendfile() works with plain sockets only + return super().sendfile(file, blocksize, offset) + else: + return self._sendfile_use_send(file, blocksize, offset) + def recv(self, buflen=1024, flags=0): self._checkClosed() if self._sslobj: diff -r 8c8315bac6a8 Lib/test/test_socket.py --- a/Lib/test/test_socket.py Sun Apr 20 09:45:00 2014 -0700 +++ b/Lib/test/test_socket.py Tue Apr 22 19:38:36 2014 +0200 @@ -19,6 +19,8 @@ import math import pickle import struct +import random +import string try: import multiprocessing except ImportError: @@ -5074,6 +5076,183 @@ source.close() +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileTest(ThreadedTCPSocketTest): + + FILESIZE = (10 * 1024 * 1024) # 10MB + BUFSIZE = 8192 + FILEDATA = b"" + TIMEOUT = 2 + SUPPORT_SENDFILE = hasattr(os, "sendfile") + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step + total -= step + if total: + yield total + + chunk = b"".join([random.choice(string.ascii_letters).encode() \ + for i in range(cls.BUFSIZE)]) + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, cls.BUFSIZE): + f.write(chunk) + with open(support.TESTFN, 'rb') as f: + cls.FILEDATA = f.read() + assert len(cls.FILEDATA) == cls.FILESIZE + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def accept_conn(self): + self.serv.settimeout(self.TIMEOUT) + conn, addr = self.serv.accept() + conn.settimeout(self.TIMEOUT) + self.addCleanup(conn.close) + return conn + + def recv_data(self, conn): + received = [] + while True: + chunk = conn.recv(self.BUFSIZE) + if not chunk: + break + received.append(chunk) + return b''.join(received) + + # regular transfer + + def _testRegular(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file, use_fallback=not self.SUPPORT_SENDFILE) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + + def testRegular(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(self.FILEDATA) + with socket.create_connection(address) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + self.assertEqual(file.tell(), self.FILESIZE) + self.assertRaises((io.UnsupportedOperation, NotImplementedError), + sock.sendfile, file, use_fallback=False) + + def testNonRegularFile(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + self.assertRaises(ValueError, sock.sendfile, file) + self.assertRaises(ValueError, sock._sendfile_use_send, file) + + def testNonBlocking(self): + conn = self.accept_conn() + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file, offset=5000) + self.assertEqual(file.tell(), self.FILESIZE) + + def testOffset(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # test offset also for the send() implementation + + if SUPPORT_SENDFILE: + + def _testOffsetSend(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock._sendfile_use_send(file, offset=5000) + + def testOffsetSend(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE - 5000) + self.assertEqual(data, self.FILEDATA[5000:]) + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + sent = sock.sendfile(file) + self.assertEqual(sent, self.FILESIZE) + + def testWithTimeout(self): + conn = self.accept_conn() + data = self.recv_data(conn) + self.assertEqual(len(data), self.FILESIZE) + self.assertEqual(data, self.FILEDATA) + + # timeout (triggered) + + def _testWithTimeoutTriggered(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + self.assertRaises(socket.timeout, sock.sendfile, file, + blocksize=512) + + def testWithTimeoutTriggered(self): + conn = self.accept_conn() + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + # interrupt transfer halfway; this will cause a timeout + if received >= self.FILESIZE // 2: + break + + # misc + + def _test_not_stream_type(self): + pass + + def test_not_stream_type(self): + file = open(support.TESTFN, 'rb') + with file: + with socket.socket(type=socket.SOCK_DGRAM) as s: + self.assertRaises(ValueError, s.sendfile, file) + + def test_main(): tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest, TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ] @@ -5126,6 +5305,7 @@ InterruptedRecvTimeoutTest, InterruptedSendTimeoutTest, TestSocketSharing, + SendfileTest, ]) thread_info = support.threading_setup() diff -r 8c8315bac6a8 Lib/test/test_ssl.py --- a/Lib/test/test_ssl.py Sun Apr 20 09:45:00 2014 -0700 +++ b/Lib/test/test_ssl.py Tue Apr 22 19:38:36 2014 +0200 @@ -2856,6 +2856,23 @@ self.assertRaises(ValueError, s.read, 1024) self.assertRaises(ValueError, s.write, b'hello') + def test_sendfile(self): + TEST_DATA = b"x" * 512 + with open(support.TESTFN, 'wb') as f: + f.write(TEST_DATA) + self.addCleanup(support.unlink, support.TESTFN) + context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(CERTFILE) + context.load_cert_chain(CERTFILE) + server = ThreadedEchoServer(context=context, chatty=False) + with server: + with context.wrap_socket(socket.socket()) as s: + s.connect((HOST, server.port)) + with open(support.TESTFN, 'rb') as file: + s.sendfile(file) + self.assertEqual(s.recv(1024), TEST_DATA) + def test_main(verbose=False): if support.verbose: