diff --git a/Doc/library/zipfile.rst b/Doc/library/zipfile.rst --- a/Doc/library/zipfile.rst +++ b/Doc/library/zipfile.rst @@ -219,14 +219,11 @@ ZipFile Objects .. note:: - If the ZipFile was created by passing in a file-like object as the first - argument to the constructor, then the object returned by :meth:`.open` shares the - ZipFile's file pointer. Under these circumstances, the object returned by - :meth:`.open` should not be used after any additional operations are performed - on the ZipFile object. If the ZipFile was created by passing in a string (the - filename) as the first argument to the constructor, then :meth:`.open` will - create a new file object that will be held by the ZipExtFile, allowing it to - operate independently of the ZipFile. + The returned ZipExtFile reuses the ZipFile's file object, and guarantees + concurrent reads from multiple members on a single thread will succeed. + If the ZipFile is managing its own file object (i.e. it was constructed + by passing a string ZIP file path), then it is guaranteed the file object + will remain open while any associated ZipExtFile exists. .. note:: diff --git a/Lib/zipfile.py b/Lib/zipfile.py --- a/Lib/zipfile.py +++ b/Lib/zipfile.py @@ -660,11 +660,11 @@ class ZipExtFile(io.BufferedIOBase): # Search for universal newlines or line chunks. PATTERN = re.compile(br'^(?P[^\r\n]+)|(?P\n|\r\n?)') - def __init__(self, fileobj, mode, zipinfo, decrypter=None, - close_fileobj=False): - self._fileobj = fileobj + def __init__(self, zipfile, mode, zipinfo, decrypter=None): + self._zipfile = zipfile + self._fileobj = zipfile.fp + self._fileobj_pos = self._fileobj.tell() self._decrypter = decrypter - self._close_fileobj = close_fileobj self._compress_type = zipinfo.compress_type self._compress_left = zipinfo.compress_size @@ -896,7 +896,9 @@ class ZipExtFile(io.BufferedIOBase): n = max(n, self.MIN_READ_SIZE) n = min(n, self._compress_left) + self._fileobj.seek(self._fileobj_pos) data = self._fileobj.read(n) + self._fileobj_pos += len(data) self._compress_left -= len(data) if not data: raise EOFError @@ -907,8 +909,10 @@ class ZipExtFile(io.BufferedIOBase): def close(self): try: - if self._close_fileobj: - self._fileobj.close() + if self._zipfile: + self._zipfile._unref() + self._zipfile = None + self._fileobj = None finally: super().close() @@ -948,6 +952,7 @@ class ZipFile: self.mode = key = mode.replace('b', '')[0] self.pwd = None self._comment = b'' + self._refs = 1 # Check if we were passed a file-like object if isinstance(file, str): @@ -1181,87 +1186,75 @@ class ZipFile: raise RuntimeError( "Attempt to read ZIP archive that was already closed") - # Only open a new file for instances where we were not - # given a file object in the constructor - if self._filePassed: - zef_file = self.fp + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name else: - zef_file = io.open(self.filename, 'rb') + # Get info object for name + zinfo = self.getinfo(name) + self.fp.seek(zinfo.header_offset, 0) - try: - # Make sure we have an info object - if isinstance(name, ZipInfo): - # 'name' is already an info object - zinfo = name + # Skip the file header: + fheader = self.fp.read(sizeFileHeader) + if len(fheader) != sizeFileHeader: + raise BadZipFile("Truncated file header") + fheader = struct.unpack(structFileHeader, fheader) + if fheader[_FH_SIGNATURE] != stringFileHeader: + raise BadZipFile("Bad magic number for file header") + + fname = self.fp.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + self.fp.read(fheader[_FH_EXTRA_FIELD_LENGTH]) + + if zinfo.flag_bits & 0x20: + # Zip 2.7: compressed patched data + raise NotImplementedError("compressed patched data (flag bit 5)") + + if zinfo.flag_bits & 0x40: + # strong encryption + raise NotImplementedError("strong encryption (flag bit 6)") + + if zinfo.flag_bits & 0x800: + # UTF-8 filename + fname_str = fname.decode("utf-8") + else: + fname_str = fname.decode("cp437") + + if fname_str != zinfo.orig_filename: + raise BadZipFile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & 0x1 + zd = None + if is_encrypted: + if not pwd: + pwd = self.pwd + if not pwd: + raise RuntimeError("File %s is encrypted, password " + "required for extraction" % name) + + zd = _ZipDecrypter(pwd) + # The first 12 bytes in the cypher stream is an encryption header + # used to strengthen the algorithm. The first 11 bytes are + # completely random, while the 12th contains the MSB of the CRC, + # or the MSB of the file time depending on the header type + # and is used to check the correctness of the password. + header = self.fp.read(12) + h = list(map(zd, header[0:12])) + if zinfo.flag_bits & 0x8: + # compare against the file type from extended local headers + check_byte = (zinfo._raw_time >> 8) & 0xff else: - # Get info object for name - zinfo = self.getinfo(name) - zef_file.seek(zinfo.header_offset, 0) + # compare against the CRC otherwise + check_byte = (zinfo.CRC >> 24) & 0xff + if h[11] != check_byte: + raise RuntimeError("Bad password for file", name) - # Skip the file header: - fheader = zef_file.read(sizeFileHeader) - if len(fheader) != sizeFileHeader: - raise BadZipFile("Truncated file header") - fheader = struct.unpack(structFileHeader, fheader) - if fheader[_FH_SIGNATURE] != stringFileHeader: - raise BadZipFile("Bad magic number for file header") - - fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) - if fheader[_FH_EXTRA_FIELD_LENGTH]: - zef_file.read(fheader[_FH_EXTRA_FIELD_LENGTH]) - - if zinfo.flag_bits & 0x20: - # Zip 2.7: compressed patched data - raise NotImplementedError("compressed patched data (flag bit 5)") - - if zinfo.flag_bits & 0x40: - # strong encryption - raise NotImplementedError("strong encryption (flag bit 6)") - - if zinfo.flag_bits & 0x800: - # UTF-8 filename - fname_str = fname.decode("utf-8") - else: - fname_str = fname.decode("cp437") - - if fname_str != zinfo.orig_filename: - raise BadZipFile( - 'File name in directory %r and header %r differ.' - % (zinfo.orig_filename, fname)) - - # check for encrypted flag & handle password - is_encrypted = zinfo.flag_bits & 0x1 - zd = None - if is_encrypted: - if not pwd: - pwd = self.pwd - if not pwd: - raise RuntimeError("File %s is encrypted, password " - "required for extraction" % name) - - zd = _ZipDecrypter(pwd) - # The first 12 bytes in the cypher stream is an encryption header - # used to strengthen the algorithm. The first 11 bytes are - # completely random, while the 12th contains the MSB of the CRC, - # or the MSB of the file time depending on the header type - # and is used to check the correctness of the password. - header = zef_file.read(12) - h = list(map(zd, header[0:12])) - if zinfo.flag_bits & 0x8: - # compare against the file type from extended local headers - check_byte = (zinfo._raw_time >> 8) & 0xff - else: - # compare against the CRC otherwise - check_byte = (zinfo.CRC >> 24) & 0xff - if h[11] != check_byte: - raise RuntimeError("Bad password for file", name) - - return ZipExtFile(zef_file, mode, zinfo, zd, - close_fileobj=not self._filePassed) - except: - if not self._filePassed: - zef_file.close() - raise + self._refs += 1 + return ZipExtFile(self, mode, zinfo, zd) def extract(self, member, path=None, pwd=None): """Extract a member from the archive to the current working directory, @@ -1511,6 +1504,13 @@ class ZipFile: self.filelist.append(zinfo) self.NameToInfo[zinfo.filename] = zinfo + def _unref(self): + """Called by ZipExtFile instances to deference the file object managed + by this ZipFile.""" + self._refs -= 1 + if self._refs == 0: + self.close() + def __del__(self): """Call the "close()" method in case the user forgot.""" self.close()