diff --git a/Lib/socket.py b/Lib/socket.py --- a/Lib/socket.py +++ b/Lib/socket.py @@ -47,7 +47,7 @@ import _socket from _socket import * -import os, sys, io +import os, sys, io, select try: import errno @@ -83,6 +83,8 @@ errorTab[10065] = "The host is unreachable." __all__.append("errorTab") +class _GiveupOnSendfile(Exception): pass + class socket(_socket.socket): @@ -184,6 +186,109 @@ text.mode = mode return text + def _sendfile_use_sendfile(self, file, blocksize, offset): + sockno = self.fileno() + try: + fileno = file.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise _GiveupOnSendfile(err) # not a regular mmap-like file + + timeout = self.gettimeout() + if timeout == 0: + raise ValueError("non-blocking sockets are not supported") + if timeout: + # whenever possible use poll() instead of select() in + # order to avoid running out of fds + if hasattr(select, 'poll'): + if timeout is not None: + timeout *= 1000 + pollster = select.poll() + pollster.register(sockno, select.POLLOUT) + def wait_for_fd(): + if pollster.poll(timeout) == []: + raise _socket.timeout('timed out') + else: + # call select() once in order to solicit ValueError in + # case we run out of fds + try: + select.select([], [sockno], [], 0) + except ValueError: + raise _GiveupOnSendfile(err) + + def wait_for_fd(): + fds = select.select([], [sockno], [], timeout) + if fds == ([], [], []): + raise _socket.timeout('timed out') + + while True: + if timeout: + wait_for_fd() # block until socket is writable + try: + sent = os.sendfile(sockno, fileno, offset, blocksize) + except BlockingIOError: + continue + except OSError as err: + if offset == 0: + # We can get here for different reasons, the main + # one being 'file' is not a regular mmap(2)-like + # file, in which case we'll fall back on using + # plain send(). + raise _GiveupOnSendfile(err) + else: + file.seek(offset) + raise + except Exception: + file.seek(offset) + raise + else: + if sent == 0: + break + offset += sent + file.seek(offset) + + def _sendfile_use_send(self, file, blocksize, offset): + if self.gettimeout() == 0: + raise ValueError("non-blocking sockets are not supported") + if offset: + file.seek(offset) + while True: + chunk = file.read(blocksize) + if not chunk: + break + self.sendall(chunk) + + def sendfile(self, file, blocksize=262144, offset=0): + """sendfile(file[, blocksize[, offset]]) -> (succeded, exc) + + Send a file attempting to use high-performance os.sendfile(), + in which case 'file' must be a regular file object. + If not socket.send() will be used as fallback. + File position is updated on return or in case of error so + tell() can be used to know how many bytes were transmitted. + + - file: a file object opened for reading. + - blocksize: the maximum number of bytes to transmit at one + time (default 262144) + - offset: from where to start reading the file (default 0) + + Return a tuple of 2 elements including: + - a bool indicating whether os.sendfile() was used + - an exception instance in case it wasn't on account of an + internal error, if any + + Raises ValueError if socket is non-blocking. + """ + if hasattr(os, 'sendfile'): + try: + self._sendfile_use_sendfile(file, blocksize, offset) + return (True, None) + except _GiveupOnSendfile as exc: + self._sendfile_use_send(file, blocksize, offset) + return (False, exc.args[0]) + else: + self._sendfile_use_send(file, blocksize, offset) + return (False, None) + def _decref_socketios(self): if self._io_refs > 0: self._io_refs -= 1 diff --git a/Lib/ssl.py b/Lib/ssl.py --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -52,8 +52,6 @@ PROTOCOL_SSLv3 PROTOCOL_SSLv23 PROTOCOL_TLSv1 -PROTOCOL_TLSv1_1 -PROTOCOL_TLSv1_2 The following constants identify various SSL alert message descriptions as per http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 @@ -112,7 +110,8 @@ from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN -from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +from _ssl import (PROTOCOL_SSLv3, PROTOCOL_SSLv23, + PROTOCOL_TLSv1) from _ssl import _OPENSSL_API_VERSION @@ -129,14 +128,6 @@ else: _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" -try: - from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2 -except ImportError: - pass -else: - _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1" - _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2" - from socket import getnameinfo as _getnameinfo from socket import socket, AF_INET, SOCK_STREAM, create_connection import base64 # for DER-to-PEM translation @@ -420,17 +411,18 @@ raise ValueError( "non-zero flags not allowed in calls to send() on %s" % self.__class__) - try: - v = self._sslobj.write(data) - except SSLError as x: - if x.args[0] == SSL_ERROR_WANT_READ: - return 0 - elif x.args[0] == SSL_ERROR_WANT_WRITE: - return 0 + while True: + try: + v = self._sslobj.write(data) + except SSLError as x: + if x.args[0] == SSL_ERROR_WANT_READ: + return 0 + elif x.args[0] == SSL_ERROR_WANT_WRITE: + return 0 + else: + raise else: - raise - else: - return v + return v else: return socket.send(self, data, flags) @@ -612,6 +604,11 @@ return None return self._sslobj.tls_unique_cb() + def sendfile(self, file, blocksize=262144): + # sendfile() works with plain sockets only + self._sendfile_use_send(file, blocksize) + return (False, None) + def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -4733,6 +4733,157 @@ self.assertRaises(OSError, sock.sendall, b'foo') +@unittest.skipUnless(thread, 'Threading required for this test.') +class SendfileTest(ThreadedTCPSocketTest): + + FILESIZE = (10 * 1024 * 1024) # 10MB + + @classmethod + def setUpClass(cls): + def chunks(total, step): + assert total >= step + while total > step: + yield step; + total -= step; + if total: + yield total + + with open(support.TESTFN, 'wb') as f: + for csize in chunks(cls.FILESIZE, 262144): + f.write(b'x' * csize) + + @classmethod + def tearDownClass(cls): + support.unlink(support.TESTFN) + + # regular transfer + + def _testRegular(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + self.assertEqual(ret[1], None) + + def testRegular(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, self.FILESIZE) + + # non regular file + + def _testNonRegularFile(self): + address = self.serv.getsockname() + file = io.BytesIO(b'x' * 1*1024*1024) # 1MB + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], False) + if hasattr(os, 'sendfile'): + self.assertIsInstance(ret[1], io.UnsupportedOperation) + else: + self.assertIs(ret[1], None) + + def testNonRegularFile(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, 1*1024*1024) + + # non blocking sockets are not supposed to work + + def _testNonBlocking(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + sock.setblocking(False) + self.assertRaises(ValueError, sock.sendfile, file) + + def testNonBlocking(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + if conn.recv(8192): + self.fail('was not supposed to receive any data') + + # offset + + def _testOffset(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address) as sock, file as file: + ret = sock.sendfile(file, offset=self.FILESIZE // 2) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + + def testOffset(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, self.FILESIZE // 2) + + # timeout (non-triggered) + + def _testWithTimeout(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=2) as sock, file as file: + ret = sock.sendfile(file) + self.assertEqual(ret[0], hasattr(os, 'sendfile')) + + def _testWithTimeout(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + self.assertEqual(received, self.FILESIZE) + + # timeout (triggered) + + def _testWithTimeoutTriggered(self): + address = self.serv.getsockname() + file = open(support.TESTFN, 'rb') + with socket.create_connection(address, timeout=0.01) as sock, \ + file as file: + self.assertRaises(socket.timeout, sock.sendfile, file, blocksize=512) + + def testWithTimeoutTriggered(self): + conn, addr = self.serv.accept() + self.addCleanup(conn.close) + received = 0 + while True: + chunk = conn.recv(8192) + if not chunk: + break + received += len(chunk) + self.assertEqual(chunk, b'x' * len(chunk)) + # interrupt transfer halfway; this will cause a timeout + if received >= self.FILESIZE // 2: + break + + @unittest.skipUnless(hasattr(socket, "SOCK_CLOEXEC"), "SOCK_CLOEXEC not defined") @unittest.skipUnless(fcntl, "module fcntl not available")