diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 7b130ca..0bc1525 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -11,6 +11,7 @@ import sys import unittest import warnings from test import support, string_tests +import _string # Error handling (bad decoder return) def search_function(encoding): @@ -1516,6 +1517,33 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual(wchar, nonbmp + '\0') +class StringModuleTest(unittest.TestCase): + def test_formatter_parser(self): + def parse(format): + return list(_string.formatter_parser(format)) + formatter = parse("prefix {2!s}xxx{0:^+10.3f}{obj.attr!s}") + self.assertEqual(formatter, [ + ('prefix ', '2', '', 's'), + ('xxx', '0', '^+10.3f', None), + ('', 'obj.attr', '', 's'), + ]) + self.assertRaises(TypeError, _string.formatter_parser, 1) + + def test_formatter_field_name_split(self): + def split(name): + items = list(_string.formatter_field_name_split(name)) + items[1] = list(items[1]) + return items + self.assertEqual(split("obj.arg"), ["obj", [(True, 'arg')]]) + self.assertEqual(split("obj.arg[key1][key2]"), [ + "obj", + [(True, 'arg'), + (False, 'key1'), + (False, 'key2'), + ]]) + self.assertRaises(TypeError, _string.formatter_field_name_split, 1) + + def test_main(): support.run_unittest(__name__) diff --git a/Objects/stringlib/string_format.h b/Objects/stringlib/string_format.h index 4053545..2c09e26 100644 --- a/Objects/stringlib/string_format.h +++ b/Objects/stringlib/string_format.h @@ -1192,6 +1192,11 @@ formatter_parser(PyObject *ignored, STRINGLIB_OBJECT *self) { formatteriterobject *it; + if (!PyUnicode_Check(self)) { + PyErr_Format(PyExc_TypeError, "expect str, not %s", Py_TYPE(self)->tp_name); + return NULL; + } + it = PyObject_New(formatteriterobject, &PyFormatterIter_Type); if (it == NULL) return NULL; @@ -1332,6 +1337,11 @@ formatter_field_name_split(PyObject *ignored, STRINGLIB_OBJECT *self) PyObject *first_obj = NULL; PyObject *result = NULL; + if (!PyUnicode_Check(self)) { + PyErr_Format(PyExc_TypeError, "expect str, not %s", Py_TYPE(self)->tp_name); + return NULL; + } + it = PyObject_New(fieldnameiterobject, &PyFieldNameIter_Type); if (it == NULL) return NULL;