diff --git a/Lib/test/test_zipfile.py b/Lib/test/test_zipfile.py index 0b3a694..b58bccf 100644 --- a/Lib/test/test_zipfile.py +++ b/Lib/test/test_zipfile.py @@ -368,6 +368,7 @@ class TestsWithSourceFile(unittest.TestCase): produces the expected result.""" with zipfile.ZipFile(TESTFN2, "w") as zipfp: zipfp.write(TESTFN) + with zipfile.ZipFile(TESTFN2, "r") as zipfp: with open(TESTFN, "rb") as f: self.assertEqual(zipfp.read(TESTFN), f.read()) @@ -902,6 +903,7 @@ class OtherTests(unittest.TestCase): string and doesn't advance file pointer.""" with zipfile.ZipFile(TESTFN, mode="w") as zipf: zipf.writestr("foo.txt", "O, for a Muse of Fire!") + with zipfile.ZipFile(TESTFN, mode="r") as zipf: # read the data to make sure the file is there with zipf.open("foo.txt") as f: for i in range(FIXEDTEST_SIZE): diff --git a/Lib/zipfile.py b/Lib/zipfile.py index 6ca269f..70465ed 100644 --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -476,10 +476,17 @@ class ZipExtFile(io.BufferedIOBase): PATTERN = re.compile(br'^(?P[^\r\n]+)|(?P\n|\r\n?)') def __init__(self, fileobj, mode, zipinfo, decrypter=None, - close_fileobj=False): - self._fileobj = fileobj + close_fileobj=False, whole_zipfile=None): + self._zipfile = whole_zipfile + if whole_zipfile is not None: + self._fileobj = whole_zipfile.fp + else: + # deprecated + self._fileobj = fileobj + self._close_fileobj = close_fileobj + + self._fileobj_pos = self._fileobj.tell() self._decrypter = decrypter - self._close_fileobj = close_fileobj self._compress_type = zipinfo.compress_type self._compress_size = zipinfo.compress_size @@ -611,6 +618,7 @@ class ZipExtFile(io.BufferedIOBase): # Read from file. if self._compress_left > 0 and n > len_readbuffer + len(self._unconsumed): + self._fileobj.seek(self._fileobj_pos) nbytes = n - len_readbuffer - len(self._unconsumed) nbytes = max(nbytes, self.MIN_READ_SIZE) nbytes = min(nbytes, self._compress_left) @@ -628,6 +636,7 @@ class ZipExtFile(io.BufferedIOBase): else: # Prepare deflated bytes for decompression. self._unconsumed += data + self._fileobj_pos = self._fileobj.tell() # Handle unconsumed data. if (len(self._unconsumed) > 0 and n > len_readbuffer and @@ -653,8 +662,14 @@ class ZipExtFile(io.BufferedIOBase): def close(self): try: - if self._close_fileobj: - self._fileobj.close() + if self._zipfile is not None: + if self._fileobj is not None: + self._zipfile._childclose(self._fileobj) + self._fileobj = None + else: + # deprecated + if self._close_fileobj: + self._fileobj.close() finally: super().close() @@ -691,6 +706,7 @@ class ZipFile: raise RuntimeError("That compression method is not supported") self._allowZip64 = allowZip64 + self._children = 0 # Number of ZipExtFiles open from this zip file. self._didModify = False self.debug = 0 # Level of printing: 0 through 3 self.NameToInfo = {} # Find file info given name @@ -901,13 +917,6 @@ class ZipFile: raise RuntimeError( "Attempt to read ZIP archive that was already closed") - # Only open a new file for instances where we were not - # given a file object in the constructor - if self._filePassed: - zef_file = self.fp - else: - zef_file = io.open(self.filename, 'rb') - # Make sure we have an info object if isinstance(name, ZipInfo): # 'name' is already an info object @@ -917,20 +926,18 @@ class ZipFile: try: zinfo = self.getinfo(name) except KeyError: - if not self._filePassed: - zef_file.close() raise - zef_file.seek(zinfo.header_offset, 0) + self.fp.seek(zinfo.header_offset, 0) # Skip the file header: - fheader = zef_file.read(sizeFileHeader) + fheader = self.fp.read(sizeFileHeader) if fheader[0:4] != stringFileHeader: raise BadZipFile("Bad magic number for file header") fheader = struct.unpack(structFileHeader, fheader) - fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + fname = self.fp.read(fheader[_FH_FILENAME_LENGTH]) if fheader[_FH_EXTRA_FIELD_LENGTH]: - zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) + self.fp.read(fheader[_FH_EXTRA_FIELD_LENGTH]) if zinfo.flag_bits & 0x800: # UTF-8 filename @@ -939,8 +946,6 @@ class ZipFile: fname_str = fname.decode("cp437") if fname_str != zinfo.orig_filename: - if not self._filePassed: - zef_file.close() raise BadZipFile( 'File name in directory %r and header %r differ.' % (zinfo.orig_filename, fname)) @@ -952,8 +957,6 @@ class ZipFile: if not pwd: pwd = self.pwd if not pwd: - if not self._filePassed: - zef_file.close() raise RuntimeError("File %s is encrypted, " "password required for extraction" % name) @@ -963,7 +966,7 @@ class ZipFile: # completely random, while the 12th contains the MSB of the CRC, # or the MSB of the file time depending on the header type # and is used to check the correctness of the password. - header = zef_file.read(12) + header = self.fp.read(12) h = list(map(zd, header[0:12])) if zinfo.flag_bits & 0x8: # compare against the file type from extended local headers @@ -972,12 +975,10 @@ class ZipFile: # compare against the CRC otherwise check_byte = (zinfo.CRC >> 24) & 0xff if h[11] != check_byte: - if not self._filePassed: - zef_file.close() raise RuntimeError("Bad password for file", name) - return ZipExtFile(zef_file, mode, zinfo, zd, - close_fileobj=not self._filePassed) + self._children += 1 + return ZipExtFile(None, mode, zinfo, zd, whole_zipfile=self) def extract(self, member, path=None, pwd=None): """Extract a member from the archive to the current working directory, @@ -1301,10 +1302,21 @@ class ZipFile: self.fp.write(self.comment) self.fp.flush() - if not self._filePassed: - self.fp.close() + self._fpclose(self.fp) self.fp = None + def _childclose(self, fp): + """ A child ZipExtFile is being closed. """ + if self._children <= 0: + raise RuntimeError("This can't happen.") + self._children -= 1 + if self.fp is None: + self._fpclose(fp) + + def _fpclose(self, fp): + if not self._filePassed and self._children == 0: + fp.close() + class PyZipFile(ZipFile): """Class to create ZIP archives with Python library files and packages."""