diff -r 9552f8af321e Lib/string.py --- a/Lib/string.py Wed Sep 28 07:53:32 2016 +0300 +++ b/Lib/string.py Thu Sep 29 12:01:54 2016 +0300 @@ -74,6 +74,12 @@ class _TemplateMetaclass(type): cls.pattern = _re.compile(pattern, cls.flags | _re.VERBOSE) +def _safe_getitem(mapping, key, default): + try: + return mapping[key] + except KeyError: + return default + class Template(metaclass=_TemplateMetaclass): """A string class for supporting $-substitutions.""" @@ -83,9 +89,62 @@ class Template(metaclass=_TemplateMetacl def __init__(self, template): self.template = template + self._substitute = self._compile_substitute(template) + self._safe_substitute = self._compile_safe_substitute(template) # Search for $$, $identifier, ${identifier}, and any bare $'s + def _compile_substitute(self, template): + parts = [] + prev = 0 + for mo in self.pattern.finditer(template): + literal = template[prev: mo.start()] + if literal: + parts.append(repr(literal)) + prev = mo.end() + # Check the most common path first. + named = mo.group('named') or mo.group('braced') + if named is not None: + sub = "mapping[%r]" % named + if '\\' in sub or "'''" in sub: + return None + parts.append("f'''{%s!s}'''" % sub) + elif mo.group('escaped') is not None: + parts.append(repr(self.delimiter)) + else: + return None + literal = template[prev:] + if literal: + parts.append(repr(literal)) + return eval('lambda mapping: ' + ''.join(parts)) + + def _compile_safe_substitute(self, template): + parts = [] + prev = 0 + for mo in self.pattern.finditer(template): + literal = template[prev: mo.start()] + if literal: + parts.append(repr(literal)) + prev = mo.end() + # Check the most common path first. + named = mo.group('named') or mo.group('braced') + if named is not None: + sub = "_safe_getitem(mapping, %r, %r)" % (named, mo.group()) + if '\\' in sub or "'''" in sub: + return None + parts.append("f'''{%s!s}'''" % sub) + elif mo.group('escaped') is not None: + parts.append(repr(self.delimiter)) + elif mo.group('invalid') is not None: + parts.append(repr(mo.group())) + else: + return None + literal = template[prev:] + if literal: + parts.append(repr(literal)) + return eval('lambda mapping, _safe_getitem=_safe_getitem: ' + + ''.join(parts)) + def _invalid(self, mo): i = mo.start('invalid') lines = self.template[:i].splitlines(keepends=True) @@ -111,6 +170,10 @@ class Template(metaclass=_TemplateMetacl mapping = _ChainMap(kws, args[0]) else: mapping = args[0] + + if self._substitute is not None: + return self._substitute(mapping) + # Helper function for .sub() def convert(mo): # Check the most common path first. @@ -138,6 +201,10 @@ class Template(metaclass=_TemplateMetacl mapping = _ChainMap(kws, args[0]) else: mapping = args[0] + + if self._safe_substitute is not None: + return self._safe_substitute(mapping) + # Helper function for .sub() def convert(mo): named = mo.group('named') or mo.group('braced') diff -r 9552f8af321e Lib/test/test_string.py --- a/Lib/test/test_string.py Wed Sep 28 07:53:32 2016 +0300 +++ b/Lib/test/test_string.py Thu Sep 29 12:01:54 2016 +0300 @@ -282,6 +282,9 @@ class TestTemplate(unittest.TestCase): m.bag.what = 'ham' s = PathPattern('$bag.foo.who likes to eat a bag of $bag.what') self.assertEqual(s.substitute(m), 'tim likes to eat a bag of ham') + self.assertEqual(s.safe_substitute(m), 'tim likes to eat a bag of ham') + del m.bag.foo.who + self.assertEqual(s.safe_substitute(m), '$bag.foo.who likes to eat a bag of ham') def test_pattern_override(self): class MyPattern(Template): @@ -298,6 +301,9 @@ class TestTemplate(unittest.TestCase): m.bag.what = 'ham' s = MyPattern('@bag.foo.who likes to eat a bag of @bag.what') self.assertEqual(s.substitute(m), 'tim likes to eat a bag of ham') + self.assertEqual(s.safe_substitute(m), 'tim likes to eat a bag of ham') + del m.bag.foo.who + self.assertEqual(s.safe_substitute(m), '@bag.foo.who likes to eat a bag of ham') class BadPattern(Template): pattern = r""" @@ -345,6 +351,31 @@ class TestTemplate(unittest.TestCase): val = t.safe_substitute({'location': 'Cleveland'}) self.assertEqual(val, 'PyCon in Cleveland') + def test_special_characters_in_name(self): + class MyTemplate(Template): + pattern = r""" + @[[](?P[^]]*)[]] | + @(?P[a-z]+) | + (?P@@) | + (?P@) + """ + m = { + '\\': 'backslash', + "\t": 'tab', + '"': 'quotation mark', + "'": 'apostrophe', + '"""': 'triple quotation mark', + "'''": 'triple apostrophe', + "%": 'percent sign', + "$": 'dollar sign', + "{": 'left brace', + "}": 'right brace', + } + for k in m: + s = MyTemplate('<@[%s]>' % k) + self.assertEqual(s.substitute(m), '<%s>' % m[k]) + self.assertEqual(s.safe_substitute(m), '<%s>' % m[k]) + def test_invalid_with_no_lines(self): # The error formatting for invalid templates # has a special case for no data that the default