diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 7b130ca16a7..e2d67cd949c 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -11,6 +11,7 @@ import sys import unittest import warnings from test import support, string_tests +import _string # Error handling (bad decoder return) def search_function(encoding): @@ -1516,6 +1517,57 @@ class UnicodeTest(string_tests.CommonTest, self.assertEqual(wchar, nonbmp + '\0') +class StringModuleTest(unittest.TestCase): + def test_formatter_parser(self): + def parse(format): + return list(_string.formatter_parser(format)) + + formatter = parse("prefix {2!s}xxx{0:^+10.3f}{obj.attr!s} {z[0]!s:10}") + self.assertEqual(formatter, [ + ('prefix ', '2', '', 's'), + ('xxx', '0', '^+10.3f', None), + ('', 'obj.attr', '', 's'), + (' ', 'z[0]', '10', 's'), + ]) + + formatter = parse("prefix {} suffix") + self.assertEqual(formatter, [ + ('prefix ', '', '', None), + (' suffix', None, None, None), + ]) + + formatter = parse("str") + self.assertEqual(formatter, [ + ('str', None, None, None), + ]) + + formatter = parse("") + self.assertEqual(formatter, []) + + formatter = parse("{0}") + self.assertEqual(formatter, [ + ('', '0', '', None), + ]) + + self.assertRaises(TypeError, _string.formatter_parser, 1) + + def test_formatter_field_name_split(self): + def split(name): + items = list(_string.formatter_field_name_split(name)) + items[1] = list(items[1]) + return items + self.assertEqual(split("obj"), ["obj", []]) + self.assertEqual(split("obj.arg"), ["obj", [(True, 'arg')]]) + self.assertEqual(split("obj[key]"), ["obj", [(False, 'key')]]) + self.assertEqual(split("obj.arg[key1][key2]"), [ + "obj", + [(True, 'arg'), + (False, 'key1'), + (False, 'key2'), + ]]) + self.assertRaises(TypeError, _string.formatter_field_name_split, 1) + + def test_main(): support.run_unittest(__name__) diff --git a/Objects/stringlib/string_format.h b/Objects/stringlib/string_format.h index 40535457f38..c1a6d1d5317 100644 --- a/Objects/stringlib/string_format.h +++ b/Objects/stringlib/string_format.h @@ -1192,6 +1192,11 @@ formatter_parser(PyObject *ignored, STRINGLIB_OBJECT *self) { formatteriterobject *it; + if (!PyUnicode_Check(self)) { + PyErr_Format(PyExc_TypeError, "expected str, got %s", Py_TYPE(self)->tp_name); + return NULL; + } + it = PyObject_New(formatteriterobject, &PyFormatterIter_Type); if (it == NULL) return NULL; @@ -1332,6 +1337,11 @@ formatter_field_name_split(PyObject *ignored, STRINGLIB_OBJECT *self) PyObject *first_obj = NULL; PyObject *result = NULL; + if (!PyUnicode_Check(self)) { + PyErr_Format(PyExc_TypeError, "expected str, got %s", Py_TYPE(self)->tp_name); + return NULL; + } + it = PyObject_New(fieldnameiterobject, &PyFieldNameIter_Type); if (it == NULL) return NULL;