From ccab67ba7901a3012ad66f0ffafac4ea925a1ff0 Mon Sep 17 00:00:00 2001 From: Nikita Sobolev Date: Wed, 12 Oct 2022 19:27:53 +0300 Subject: [PATCH] gh-97982: Factorize PyUnicode_Count() and unicode_count() code (#98025) Add unicode_count_impl() to factorize PyUnicode_Count() and unicode_count() code. --- Lib/test/test_unicode.py | 10 +++++ Objects/unicodeobject.c | 86 ++++++++++++---------------------------- 2 files changed, 36 insertions(+), 60 deletions(-) diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 5b816574f2c..15244cb949e 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -241,6 +241,10 @@ class UnicodeTest(string_tests.CommonTest, self.checkequal(0, 'a' * 10, 'count', 'a\u0102') self.checkequal(0, 'a' * 10, 'count', 'a\U00100304') self.checkequal(0, '\u0102' * 10, 'count', '\u0102\U00100304') + # test subclass + class MyStr(str): + pass + self.checkequal(3, MyStr('aaa'), 'count', 'a') def test_find(self): string_tests.CommonTest.test_find(self) @@ -3002,6 +3006,12 @@ class CAPITest(unittest.TestCase): self.assertEqual(unicode_count(uni, ch, 0, len(uni)), 1) self.assertEqual(unicode_count(st, ch, 0, len(st)), 0) + # subclasses should still work + class MyStr(str): + pass + + self.assertEqual(unicode_count(MyStr('aab'), 'a', 0, 3), 2) + # Test PyUnicode_FindChar() @support.cpython_only @unittest.skipIf(_testcapi is None, 'need _testcapi module') diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 51e660afba0..5b737f1b596 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -8964,21 +8964,20 @@ _PyUnicode_InsertThousandsGrouping( return count; } - -Py_ssize_t -PyUnicode_Count(PyObject *str, - PyObject *substr, - Py_ssize_t start, - Py_ssize_t end) +static Py_ssize_t +unicode_count_impl(PyObject *str, + PyObject *substr, + Py_ssize_t start, + Py_ssize_t end) { + assert(PyUnicode_Check(str)); + assert(PyUnicode_Check(substr)); + Py_ssize_t result; int kind1, kind2; const void *buf1 = NULL, *buf2 = NULL; Py_ssize_t len1, len2; - if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0) - return -1; - kind1 = PyUnicode_KIND(str); kind2 = PyUnicode_KIND(substr); if (kind1 < kind2) @@ -8998,6 +8997,7 @@ PyUnicode_Count(PyObject *str, goto onError; } + // We don't reuse `anylib_count` here because of the explicit casts. switch (kind1) { case PyUnicode_1BYTE_KIND: result = ucs1lib_count( @@ -9033,6 +9033,18 @@ PyUnicode_Count(PyObject *str, return -1; } +Py_ssize_t +PyUnicode_Count(PyObject *str, + PyObject *substr, + Py_ssize_t start, + Py_ssize_t end) +{ + if (ensure_unicode(str) < 0 || ensure_unicode(substr) < 0) + return -1; + + return unicode_count_impl(str, substr, start, end); +} + Py_ssize_t PyUnicode_Find(PyObject *str, PyObject *substr, @@ -10848,62 +10860,16 @@ unicode_count(PyObject *self, PyObject *args) PyObject *substring = NULL; /* initialize to fix a compiler warning */ Py_ssize_t start = 0; Py_ssize_t end = PY_SSIZE_T_MAX; - PyObject *result; - int kind1, kind2; - const void *buf1, *buf2; - Py_ssize_t len1, len2, iresult; + Py_ssize_t result; if (!parse_args_finds_unicode("count", args, &substring, &start, &end)) return NULL; - kind1 = PyUnicode_KIND(self); - kind2 = PyUnicode_KIND(substring); - if (kind1 < kind2) - return PyLong_FromLong(0); + result = unicode_count_impl(self, substring, start, end); + if (result == -1) + return NULL; - len1 = PyUnicode_GET_LENGTH(self); - len2 = PyUnicode_GET_LENGTH(substring); - ADJUST_INDICES(start, end, len1); - if (end - start < len2) - return PyLong_FromLong(0); - - buf1 = PyUnicode_DATA(self); - buf2 = PyUnicode_DATA(substring); - if (kind2 != kind1) { - buf2 = unicode_askind(kind2, buf2, len2, kind1); - if (!buf2) - return NULL; - } - switch (kind1) { - case PyUnicode_1BYTE_KIND: - iresult = ucs1lib_count( - ((const Py_UCS1*)buf1) + start, end - start, - buf2, len2, PY_SSIZE_T_MAX - ); - break; - case PyUnicode_2BYTE_KIND: - iresult = ucs2lib_count( - ((const Py_UCS2*)buf1) + start, end - start, - buf2, len2, PY_SSIZE_T_MAX - ); - break; - case PyUnicode_4BYTE_KIND: - iresult = ucs4lib_count( - ((const Py_UCS4*)buf1) + start, end - start, - buf2, len2, PY_SSIZE_T_MAX - ); - break; - default: - Py_UNREACHABLE(); - } - - result = PyLong_FromSsize_t(iresult); - - assert((kind2 == kind1) == (buf2 == PyUnicode_DATA(substring))); - if (kind2 != kind1) - PyMem_Free((void *)buf2); - - return result; + return PyLong_FromSsize_t(result); } /*[clinic input]