diff -r 19b2c54e5f09 Lib/test/test_zipfile.py --- a/Lib/test/test_zipfile.py Wed Nov 12 10:23:44 2014 -0500 +++ b/Lib/test/test_zipfile.py Fri Nov 14 21:08:54 2014 +0200 @@ -1,3 +1,4 @@ +import contextlib import io import os import sys @@ -25,6 +26,9 @@ SMALL_TEST_DATA = [('_ziptest1', '1q2w3e ('ziptest2dir/ziptest3dir/_ziptest3', 'azsxdcfvgb'), ('ziptest2dir/ziptest3dir/ziptest4dir/_ziptest3', '6y7u8i9o0p')] +def getrandbytes(size): + return getrandbits(8 * size).to_bytes(size, 'little') + def get_files(test): yield TESTFN2 with TemporaryFile() as f: @@ -289,7 +293,7 @@ class AbstractTestsWithSourceFile: # than requested. for test_size in (1, 4095, 4096, 4097, 16384): file_size = test_size + 1 - junk = getrandbits(8 * file_size).to_bytes(file_size, 'little') + junk = getrandbytes(file_size) with zipfile.ZipFile(io.BytesIO(), "w", self.compression) as zipf: zipf.writestr('foo', junk) with zipf.open('foo', 'r') as fp: @@ -1666,46 +1670,93 @@ class LzmaTestsWithRandomBinaryFiles(Abs @requires_zlib class TestsWithMultipleOpens(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): + cls.data1 = b'111' + getrandbytes(10000) + cls.data2 = b'222' + getrandbytes(10000) + + def make_test_archive(self): # Create the ZIP archive with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_DEFLATED) as zipfp: - zipfp.writestr('ones', '1'*FIXEDTEST_SIZE) - zipfp.writestr('twos', '2'*FIXEDTEST_SIZE) + zipfp.writestr('ones', self.data1) + zipfp.writestr('twos', self.data2) def test_same_file(self): # Verify that (when the ZipFile is in control of creating file objects) # multiple open() calls can be made without interfering with each other. + self.make_test_archive() with zipfile.ZipFile(TESTFN2, mode="r") as zipf: with zipf.open('ones') as zopen1, zipf.open('ones') as zopen2: data1 = zopen1.read(500) data2 = zopen2.read(500) - data1 += zopen1.read(500) - data2 += zopen2.read(500) + data1 += zopen1.read() + data2 += zopen2.read() self.assertEqual(data1, data2) + self.assertEqual(data1, self.data1) def test_different_file(self): # Verify that (when the ZipFile is in control of creating file objects) # multiple open() calls can be made without interfering with each other. + self.make_test_archive() with zipfile.ZipFile(TESTFN2, mode="r") as zipf: with zipf.open('ones') as zopen1, zipf.open('twos') as zopen2: data1 = zopen1.read(500) data2 = zopen2.read(500) - data1 += zopen1.read(500) - data2 += zopen2.read(500) - self.assertEqual(data1, b'1'*FIXEDTEST_SIZE) - self.assertEqual(data2, b'2'*FIXEDTEST_SIZE) + data1 += zopen1.read() + data2 += zopen2.read() + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) def test_interleaved(self): # Verify that (when the ZipFile is in control of creating file objects) # multiple open() calls can be made without interfering with each other. + self.make_test_archive() with zipfile.ZipFile(TESTFN2, mode="r") as zipf: with zipf.open('ones') as zopen1, zipf.open('twos') as zopen2: data1 = zopen1.read(500) data2 = zopen2.read(500) - data1 += zopen1.read(500) - data2 += zopen2.read(500) - self.assertEqual(data1, b'1'*FIXEDTEST_SIZE) - self.assertEqual(data2, b'2'*FIXEDTEST_SIZE) + data1 += zopen1.read() + data2 += zopen2.read() + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) + + def test_read_after_close(self): + self.make_test_archive() + with contextlib.ExitStack() as stack: + with zipfile.ZipFile(TESTFN2, 'r') as zipf: + zopen1 = stack.enter_context(zipf.open('ones')) + zopen2 = stack.enter_context(zipf.open('twos')) + data1 = zopen1.read(500) + data2 = zopen2.read(500) + data1 += zopen1.read() + data2 += zopen2.read() + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) + + def test_read_after_write(self): + with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_DEFLATED) as zipf: + zipf.writestr('ones', self.data1) + zipf.writestr('twos', self.data2) + with zipf.open('ones') as zopen1: + data1 = zopen1.read(500) + self.assertEqual(data1, self.data1[:500]) + with zipfile.ZipFile(TESTFN2, 'r') as zipf: + data1 = zipf.read('ones') + data2 = zipf.read('twos') + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) + + def test_write_after_read(self): + with zipfile.ZipFile(TESTFN2, "w", zipfile.ZIP_DEFLATED) as zipf: + zipf.writestr('ones', self.data1) + with zipf.open('ones') as zopen1: + zopen1.read(500) + zipf.writestr('twos', self.data2) + with zipfile.ZipFile(TESTFN2, 'r') as zipf: + data1 = zipf.read('ones') + data2 = zipf.read('twos') + self.assertEqual(data1, self.data1) + self.assertEqual(data2, self.data2) def tearDown(self): unlink(TESTFN2) diff -r 19b2c54e5f09 Lib/zipfile.py --- a/Lib/zipfile.py Wed Nov 12 10:23:44 2014 -0500 +++ b/Lib/zipfile.py Fri Nov 14 21:08:54 2014 +0200 @@ -646,6 +646,25 @@ def _get_decompressor(compress_type): raise NotImplementedError("compression type %d" % (compress_type,)) +class _SharedFile: + def __init__(self, file, pos, close): + self._file = file + self._pos = pos + self._close = close + + def read(self, n=-1): + self._file.seek(self._pos) + data = self._file.read(n) + self._pos = self._file.tell() + return data + + def close(self): + if self._file is not None: + fileobj = self._file + self._file = None + self._close(fileobj) + + class ZipExtFile(io.BufferedIOBase): """File-like object for reading an archive member. Is returned by ZipFile.open(). @@ -945,7 +964,7 @@ class ZipFile: self.NameToInfo = {} # Find file info given name self.filelist = [] # List of ZipInfo instances for archive self.compression = compression # Method of compression - self.mode = key = mode.replace('b', '')[0] + self.mode = mode self.pwd = None self._comment = b'' @@ -954,28 +973,33 @@ class ZipFile: # No, it's a filename self._filePassed = 0 self.filename = file - modeDict = {'r' : 'rb', 'w': 'wb', 'a' : 'r+b'} - try: - self.fp = io.open(file, modeDict[mode]) - except OSError: - if mode == 'a': - mode = key = 'w' - self.fp = io.open(file, modeDict[mode]) - else: + modeDict = {'r' : 'rb', 'w': 'w+b', 'a' : 'r+b', + 'r+b': 'w+b', 'w+b': 'wb'} + filemode = modeDict[mode] + while True: + try: + self.fp = io.open(file, filemode) + except OSError: + if filemode in modeDict: + filemode = modeDict[filemode] + continue raise + break else: self._filePassed = 1 self.fp = file self.filename = getattr(file, 'name', None) + self._fileRefCnt = 1 try: - if key == 'r': + if mode == 'r': self._RealGetContents() - elif key == 'w': + elif mode == 'w': # set the modified flag so central directory gets written # even if no files are added to the archive self._didModify = True - elif key == 'a': + self.start_dir = 0 + elif mode == 'a': try: # See if file is a zip file self._RealGetContents() @@ -988,13 +1012,13 @@ class ZipFile: # set the modified flag so central directory gets written # even if no files are added to the archive self._didModify = True + self.start_dir = self.fp.tell() else: raise RuntimeError('Mode must be "r", "w" or "a"') except: fp = self.fp self.fp = None - if not self._filePassed: - fp.close() + self._fpclose(fp) raise def __enter__(self): @@ -1181,23 +1205,17 @@ class ZipFile: raise RuntimeError( "Attempt to read ZIP archive that was already closed") - # Only open a new file for instances where we were not - # given a file object in the constructor - if self._filePassed: - zef_file = self.fp + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name else: - zef_file = io.open(self.filename, 'rb') + # Get info object for name + zinfo = self.getinfo(name) + self._fileRefCnt += 1 + zef_file = _SharedFile(self.fp, zinfo.header_offset, self._fpclose) try: - # Make sure we have an info object - if isinstance(name, ZipInfo): - # 'name' is already an info object - zinfo = name - else: - # Get info object for name - zinfo = self.getinfo(name) - zef_file.seek(zinfo.header_offset, 0) - # Skip the file header: fheader = zef_file.read(sizeFileHeader) if len(fheader) != sizeFileHeader: @@ -1256,11 +1274,9 @@ class ZipFile: if h[11] != check_byte: raise RuntimeError("Bad password for file", name) - return ZipExtFile(zef_file, mode, zinfo, zd, - close_fileobj=not self._filePassed) + return ZipExtFile(zef_file, mode, zinfo, zd, True) except: - if not self._filePassed: - zef_file.close() + zef_file.close() raise def extract(self, member, path=None, pwd=None): @@ -1394,6 +1410,7 @@ class ZipFile: zinfo.file_size = st.st_size zinfo.flag_bits = 0x00 + self.fp.seek(self.start_dir, 0) zinfo.header_offset = self.fp.tell() # Start of header bytes if zinfo.compress_type == ZIP_LZMA: # Compressed data includes an end-of-stream (EOS) marker @@ -1410,6 +1427,7 @@ class ZipFile: self.filelist.append(zinfo) self.NameToInfo[zinfo.filename] = zinfo self.fp.write(zinfo.FileHeader(False)) + self.start_dir = self.fp.tell() return cmpr = _get_compressor(zinfo.compress_type) @@ -1448,10 +1466,10 @@ class ZipFile: raise RuntimeError('Compressed size larger than uncompressed size') # Seek backwards and write file header (which will now include # correct CRC and file sizes) - position = self.fp.tell() # Preserve current position in file + self.start_dir = self.fp.tell() # Preserve current position in file self.fp.seek(zinfo.header_offset, 0) self.fp.write(zinfo.FileHeader(zip64)) - self.fp.seek(position, 0) + self.fp.seek(self.start_dir, 0) self.filelist.append(zinfo) self.NameToInfo[zinfo.filename] = zinfo @@ -1480,6 +1498,7 @@ class ZipFile: "Attempt to write to ZIP archive that was already closed") zinfo.file_size = len(data) # Uncompressed size + self.fp.seek(self.start_dir, 0) zinfo.header_offset = self.fp.tell() # Start of header data if compress_type is not None: zinfo.compress_type = compress_type @@ -1508,6 +1527,7 @@ class ZipFile: self.fp.write(struct.pack(fmt, zinfo.CRC, zinfo.compress_size, zinfo.file_size)) self.fp.flush() + self.start_dir = self.fp.tell() self.filelist.append(zinfo) self.NameToInfo[zinfo.filename] = zinfo @@ -1523,7 +1543,7 @@ class ZipFile: try: if self.mode in ("w", "a") and self._didModify: # write ending records - pos1 = self.fp.tell() + self.fp.seek(self.start_dir, 0) for zinfo in self.filelist: # write central directory dt = zinfo.date_time dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2] @@ -1589,8 +1609,8 @@ class ZipFile: pos2 = self.fp.tell() # Write end-of-zip-archive record centDirCount = len(self.filelist) - centDirSize = pos2 - pos1 - centDirOffset = pos1 + centDirSize = pos2 - self.start_dir + centDirOffset = self.start_dir requires_zip64 = None if centDirCount > ZIP_FILECOUNT_LIMIT: requires_zip64 = "Files count" @@ -1626,8 +1646,13 @@ class ZipFile: finally: fp = self.fp self.fp = None - if not self._filePassed: - fp.close() + self._fpclose(fp) + + def _fpclose(self, fp): + assert self._fileRefCnt > 0 + self._fileRefCnt -= 1 + if not self._fileRefCnt and not self._filePassed: + fp.close() class PyZipFile(ZipFile):