diff --git a/Lib/tarfile.py b/Lib/tarfile.py index ba3e95f281..03b7799445 100755 --- a/Lib/tarfile.py +++ b/Lib/tarfile.py @@ -143,6 +143,20 @@ PAX_NUMBER_FIELDS = { "size": int } +# SafeTarFile-related string constants. +WARN_ABSOLUTE_NAME = "absolute name" +WARN_RELATIVE_NAME = "relative name" +WARN_DUPLICATE_NAME = "duplicate name" +WARN_ABSOLUTE_LINKNAME = "absolute linkname" +WARN_RELATIVE_LINKNAME = "relative linkname" +WARN_SETUID_SET = "setuid set" +WARN_SETGID_SET = "setgid set" +WARN_CHARACTER_DEVICE = "character device" +WARN_BLOCK_DEVICE = "block device" + +LIMIT_MAX_FILES = "file limit exceeded" +LIMIT_MAX_SIZE = "space limit exceeded" + #--------------------------------------------------------- # initialization #--------------------------------------------------------- @@ -296,6 +310,19 @@ class InvalidHeaderError(HeaderError): class SubsequentHeaderError(HeaderError): """Exception for missing and invalid extended headers.""" pass +class SecurityError(TarError): + """Exception for potentially dangerous contents.""" + def __init__(self, tarinfo, warning): + self.tarinfo = tarinfo + self.warning = warning + def __str__(self): + return "%s: %s" % (self.tarinfo, self.warning) +class LimitError(SecurityError): + """Exception for an exceeded limit.""" + def __init__(self, warning): + super().__init__(None, warning) + def __str__(self): + return self.warning #--------------------------- # internal stream interface @@ -2418,6 +2445,158 @@ class TarFile(object): self.fileobj.close() self.closed = True +class SafeTarFile(TarFile): + """A subclass of TarFile that safeguards against malicious data. + """ + + def __init__(self, *args, ignore_warnings=None, + max_files=100000, max_total=1024**3, **kwargs): + super().__init__(*args, **kwargs) + + if ignore_warnings: + self.ignore_warnings = set(ignore_warnings) + else: + self.ignore_warnings = set() + + self.max_files = max_files + self.max_total = max_total + self.symlink_effective_name_map = {} #todo shanxs figure dictionary naming convention + + def __iter__(self): + """Safe iterator over the TarFile, that raises a SecurityError + exception on the first warning. + """ + for tarinfo, warnings in self.analyze(): + if warnings: + raise SecurityError(tarinfo, warnings.pop()) + yield tarinfo + + def analyze(self): + """Generate a list of (TarInfo, warnings) tuples. + """ + self.names = set() + self.total = 0 + + for tarinfo in super().__iter__(): + warnings = set(self._check_member(tarinfo)) + yield tarinfo, warnings - self.ignore_warnings + + def filter(self): + """Generate a list of good TarInfo objects. + """ + for tarinfo, warnings in self.analyze(): + if warnings: + continue + yield tarinfo + + def is_safe(self): + """Return True if the archive should be safe to extract. + """ + try: + for tarinfo, warnings in self.analyze(): + if warnings: + return False + else: + return True + + except LimitError: + return False + + def _check_member(self, tarinfo): + """Check a single TarInfo object for problems. Override this in a + subclass if you want to add more checks. + """ + if self.max_files and len(self.members) == self.max_files: + raise LimitError(LIMIT_MAX_FILES) + + self.total = tarinfo.size + if self.max_total and self.total > self.max_total: + raise LimitError(LIMIT_MAX_SIZE) + + effective_name = self._get_effective_name(tarinfo.name) + if effective_name in self.symlink_effective_name_map: + del self.symlink_effective_name_map[effective_name] + + yield from self._check_all(tarinfo, effective_name) + + if tarinfo.issym(): + effective_linkname = self._get_effective_name(tarinfo.linkname) + cwd = os.path.dirname(effective_name) + relative_effective_linkname = os.path.relpath(effective_linkname, cwd) + self.symlink_effective_name_map[effective_name] = relative_effective_linkname + yield from self._check_symlink(effective_name, relative_effective_linkname) + elif tarinfo.islnk(): + yield from self._check_link(tarinfo) + elif tarinfo.ischr() or tarinfo.isblk(): + yield from self._check_device(tarinfo) + + def _get_effective_name(self, given_name): + namelist = given_name.split(os.path.sep) + if len(namelist) > 1: + effective_name = "" + + for i in range(len(namelist)): + name = namelist[i] + + if name == "": + effective_name += os.path.sep + else: + effective_name += name + + effective_name = os.path.normpath(effective_name) + if effective_name in self.symlink_effective_name_map: + effective_name = self.symlink_effective_name_map[effective_name] + + if i < len(namelist) - 1 and effective_name[len(effective_name)-1] != "/": + effective_name += "/" + + return effective_name + else: + return given_name + + def _check_all(self, tarinfo, effective_name): + if os.path.isabs(effective_name): + yield WARN_ABSOLUTE_NAME + + name = os.path.normpath(effective_name) + if name.startswith(".."): + yield WARN_RELATIVE_NAME + + if effective_name in self.names: + yield WARN_DUPLICATE_NAME + else: + self.names.add(effective_name) + + if tarinfo.isreg() and tarinfo.mode & stat.S_ISUID: + yield WARN_SETUID_SET + + if tarinfo.isreg() and tarinfo.mode & stat.S_ISGID: + yield WARN_SETGID_SET + + def _check_symlink(self, effective_name, effective_linkname): + if os.path.isabs(effective_linkname): + yield WARN_ABSOLUTE_LINKNAME + + linkname = os.path.join(os.path.dirname(effective_name), effective_linkname) + linkname = os.path.normpath(linkname) + + if linkname.startswith(".."): + yield WARN_RELATIVE_LINKNAME + + def _check_link(self, tarinfo): + if os.path.isabs(tarinfo.linkname): + yield WARN_ABSOLUTE_LINKNAME + + linkname = os.path.normpath(tarinfo.linkname) + if linkname.startswith(".."): + yield WARN_RELATIVE_LINKNAME + + def _check_device(self, tarinfo): + if tarinfo.ischr(): + yield WARN_CHARACTER_DEVICE + elif tarinfo.isblk(): + yield WARN_BLOCK_DEVICE + #-------------------- # exported functions #-------------------- @@ -2433,6 +2612,7 @@ def is_tarfile(name): return False open = TarFile.open +safe_open = SafeTarFile.open def main(): @@ -2458,7 +2638,7 @@ def main(): if args.test is not None: src = args.test if is_tarfile(src): - with open(src, 'r') as tar: + with SafeTarFile.open(src, 'r') as tar: tar.getmembers() print(tar.getmembers(), file=sys.stderr) if args.verbose: @@ -2469,7 +2649,7 @@ def main(): elif args.list is not None: src = args.list if is_tarfile(src): - with TarFile.open(src, 'r:*') as tf: + with SafeTarFile.open(src, 'r:*') as tf: tf.list(verbose=args.verbose) else: parser.exit(1, '{!r} is not a tar archive.\n'.format(src)) @@ -2484,7 +2664,7 @@ def main(): parser.exit(1, parser.format_help()) if is_tarfile(src): - with TarFile.open(src, 'r:*') as tf: + with SafeTarFile.open(src, 'r:*') as tf: tf.extractall(path=curdir) if args.verbose: if curdir == '.': @@ -2515,7 +2695,7 @@ def main(): tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w' tar_files = args.create - with TarFile.open(tar_name, tar_mode) as tf: + with SafeTarFile.open(tar_name, tar_mode) as tf: for file_name in tar_files: tf.add(file_name)