diff --git a/Lib/socket.py b/Lib/socket.py --- a/Lib/socket.py +++ b/Lib/socket.py @@ -47,7 +47,7 @@ import _socket from _socket import * -import os, sys, io +import os, sys, io, select try: import errno @@ -83,6 +83,8 @@ errorTab[10065] = "The host is unreachable." __all__.append("errorTab") +class _GiveupOnSendfile(Exception): pass + class socket(_socket.socket): @@ -184,6 +186,92 @@ text.mode = mode return text + def _sendfile_use_sendfile(self, file, blocksize): + offset = 0 + sockno = self.fileno() + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise _GiveupOnSendfile(err) # not a regular mmap-like file + + timeout = self.gettimeout() + if timeout == 0: + raise ValueError("non-blocking sockets are not supported") + # whenever possible use poll() instead of select() in + # order to avoid running out of fds + if hasattr(select, 'poll'): + if timeout != None: + timeout *= 1000 + pollster = select.poll() + pollster.register(sockno, select.POLLOUT) + def wait_for_fd(): + if pollster.poll(timeout) == []: + raise _socket.timeout('timed out') + else: + # call select() once in order to solicit EMFILE, in + # which case we'll exit immediately + try: + select.select([], [sockno], [], 0) + except OSError as err: + raise _GiveupOnSendfile(err) + + def wait_for_fd(): + fds = select.select([], [sockno], [], timeout) + if fds == ([], [], []): + raise _socket.timeout('timed out') + + while True: + wait_for_fd() # block until socket is writable + try: + sent = os.sendfile(sockno, fileno, offset, blocksize) + except BlockingIOError: + continue + except OSError as err: + if offset == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + raise _GiveupOnSendfile(err) + else: + raise + else: + if sent == 0: + break + offset += sent + + def _sendfile_use_send(self, file, blocksize): + if self.gettimeout() == 0: + raise ValueError("non-blocking sockets are not supported") + while True: + buf = file.read(blocksize) + if not buf: + break + self.sendall(buf) + + def sendfile(self, file, blocksize=262144): + """Send a file attempting to use high-performance os.sendfile(), + in which case 'file' must be a mmap-like ('regular') file object. + If not socket.send() will be used as fallback. + + Return a tuple of 2 elements including: + - a bool indicating whether os.sendfile() was used + - an exception instance in case it wasn't on account of an + internal error, if any + + Raises ValueError if socket is non-blocking. + """ + if hasattr(os, 'sendfile'): + try: + self._sendfile_use_sendfile(file, blocksize) + return (True, None) + except _GiveupOnSendfile as exc: + self._sendfile_use_send(file, blocksize) + return (False, exc.args[0]) + else: + self._sendfile_use_send(file, blocksize) + return (False, None) + def _decref_socketios(self): if self._io_refs > 0: self._io_refs -= 1 diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -604,6 +604,11 @@ return None return self._sslobj.tls_unique_cb() + def sendfile(self, file, blocksize=262144): + # sendfile() works with plain sockets only + self._sendfile_use_send(file, blocksize) + return (False, None) + def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -4733,6 +4733,132 @@ self.assertRaises(OSError, sock.sendall, b'foo') +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileTest(ThreadedTCPSocketTest): + + FILESIZE = (10 * 1024 * 1024) # 10MB + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step; + total -= step; + if total: + yield total + + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, 262144): + f.write(b'x' * csize) + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + def _testSendfile(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + self.assertEqual(ret[1], None) + + def testSendfile(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, self.FILESIZE) + + # non regular file + + def _testSendfileNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(b'x' * 1*1024*1024) # 1MB + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], False) + self.assertIsInstance(ret[1], io.UnsupportedOperation) + + def testSendfileNonRegularFile(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, 1*1024*1024) + + # non blocking sockets are not supposed to work + + def _testSendfileNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + self.assertRaises(ValueError, sock.sendfile, file) + + def testSendfileNonBlocking(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # timeout (non-triggered) + + def _testSendfileWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + self.assertEqual(ret[1], None) + + def testSendfileWithTimeout(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, self.FILESIZE) + + # timeout (triggered) + + def _testSendfileWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + self.assertRaises(socket.timeout, sock.sendfile, file, blocksize=512) + + def testSendfileWithTimeout(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + # interrupt transfer halfway; this will cause a timeout + if received >= self.FILESIZE // 2: + break + + @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), "SOCK_CLOEXEC not defined") @unittest.skipUnless(fcntl, "module fcntl not available")