diff --git a/Lib/ftplib.py b/Lib/ftplib.py --- a/Lib/ftplib.py +++ b/Lib/ftplib.py @@ -36,10 +36,12 @@ # Modified by Giampaolo Rodola' to add TLS support. # +import io import os import sys import socket import warnings +import select from socket import _GLOBAL_DEFAULT_TIMEOUT __all__ = ["FTP", "Netrc"] @@ -58,6 +60,8 @@ class error_temp(Error): pass # 4xx errors class error_perm(Error): pass # 5xx errors class error_proto(Error): pass # response does not begin with [1-5] +class _GiveupOnSendfile(Error): pass # internal, used by storbinary() + # All exceptions (hopefully) that may be raised here and that aren't @@ -102,6 +106,7 @@ welcome = None passiveserver = 1 encoding = "latin-1" + use_sendfile = hasattr(os, 'sendfile') # Initialization method (called by class instantiation). # Initialize host to localhost, port to standard ftp port @@ -477,6 +482,68 @@ conn.unwrap() return self.voidresp() + def _storbinary_send(self, fp, conn, blocksize, callback): + while True: + buf = fp.read(blocksize) + if not buf: + break + conn.sendall(buf) + if callback: + callback(buf) + + def _storbinary_sendfile(self, fp, conn, blocksize): + offset = 0 + sockno = conn.fileno() + try: + fileno = fp.fileno() + except (AttributeError, io.UnsupportedOperation) as err: + raise _GiveupOnSendfile(err) # not a regular mmap-like file + + timeout = self.timeout if self.timeout != \ + _GLOBAL_DEFAULT_TIMEOUT else None + # whenever possible use poll() instead of select() in + # order to avoid running out of fds + if hasattr(select, 'poll'): + if timeout != None: + timeout *= 1000 + pollster = select.poll() + pollster.register(sockno, select.POLLOUT) + def wait_for_fd(): + if pollster.poll(timeout) == []: + raise socket.timeout('timed out') + else: + # call select() once in order to solicit EMFILE, in + # which case we'll exit immediately + try: + select.select([], [sockno], [], 0) + except OSError as err: + raise _GiveupOnSendfile(err) + + def wait_for_fd(): + fds = select.select([], [sockno], [], timeout) + if fds == ([], [], []): + raise socket.timeout('timed out') + + while True: + wait_for_fd() # block until socket is writable + try: + sent = os.sendfile(sockno, fileno, offset, blocksize) + except BlockingIOError: + continue + except OSError as err: + if offset == 0: + # We can get here for different reasons, the main + # one being fp is not a regular mmap(2)-like file, + # in which case we'll fall back on using plain + # send(). + raise _GiveupOnSendfile(err) + else: + raise + else: + if sent == 0: + break + offset += sent + def storbinary(self, cmd, fp, blocksize=8192, callback=None, rest=None): """Store a file in binary mode. A new port is created for you. @@ -494,13 +561,16 @@ """ self.voidcmd('TYPE I') with self.transfercmd(cmd, rest) as conn: - while 1: - buf = fp.read(blocksize) - if not buf: - break - conn.sendall(buf) - if callback: - callback(buf) + if callback or not self.use_sendfile or isinstance(conn, _SSLSocket): + self._storbinary_send(fp, conn, blocksize, callback) + else: + try: + self._storbinary_sendfile(fp, conn, blocksize) + except _GiveupOnSendfile as err: + if self.debugging: + print("couldn't use sendfile() %r; falling back on " \ + "using send" % err) + self._storbinary_send(fp, conn, blocksize, callback) # shutdown ssl layer if isinstance(conn, _SSLSocket): conn.unwrap() diff --git a/Lib/test/test_ftplib.py b/Lib/test/test_ftplib.py --- a/Lib/test/test_ftplib.py +++ b/Lib/test/test_ftplib.py @@ -11,6 +11,7 @@ import errno import os import time +import unittest try: import ssl except ImportError: @@ -18,7 +19,7 @@ from unittest import TestCase from test import support -from test.support import HOST +from test.support import HOST, TESTFN threading = support.import_module('threading') # the dummy data returned by server over the data channel when @@ -578,6 +579,21 @@ self.client.storbinary('stor', f, rest=r) self.assertEqual(self.server.handler_instance.rest, str(r)) + @unittest.skipUnless(hasattr(os, 'sendfile'), 'os.sendfile() not available') + def test_storbinary_sendfile(self): + with open(TESTFN, 'wb+') as f: + test_data = 'abcde12345\r\n' * 100000 + f.write(test_data.encode('ascii')) + f = open(TESTFN, 'rb') + f.seek(0) + self.client.debugging = 2 + with support.captured_stdout() as output: + self.client.storbinary('stor', f) + output.seek(0) + self.assertNotIn('sendfile', output.read()) + self.assertEqual(self.server.handler_instance.last_received_data, + test_data) + def test_storlines(self): f = io.BytesIO(RETR_DATA.replace('\r\n', '\n').encode('ascii')) self.client.storlines('stor', f) @@ -1009,6 +1025,7 @@ support.run_unittest(*tests) finally: support.threading_cleanup(*thread_info) + support.unlink(TESTFN) if __name__ == '__main__':