Issue #12170: The count(), find(), rfind(), index() and rindex() methods

of bytes and bytearray objects now accept an integer between 0 and 255
as their first argument.  Patch by Petri Lehtinen.
This commit is contained in:
Antoine Pitrou 2011-10-20 23:54:17 +02:00
parent 407cfd1a26
commit ac65d96777
7 changed files with 262 additions and 52 deletions

View File

@ -1805,6 +1805,12 @@ the objects to strings, they have a :func:`decode` method.
Wherever one of these methods needs to interpret the bytes as characters Wherever one of these methods needs to interpret the bytes as characters
(e.g. the :func:`is...` methods), the ASCII character set is assumed. (e.g. the :func:`is...` methods), the ASCII character set is assumed.
.. versionadded:: 3.3
The functions :func:`count`, :func:`find`, :func:`index`,
:func:`rfind` and :func:`rindex` have additional semantics compared to
the corresponding string functions: They also accept an integer in
range 0 to 255 (a byte) as their first argument.
.. note:: .. note::
The methods on bytes and bytearray objects don't accept strings as their The methods on bytes and bytearray objects don't accept strings as their

View File

@ -28,6 +28,11 @@ class BaseTest(unittest.TestCase):
# Change in subclasses to change the behaviour of fixtesttype() # Change in subclasses to change the behaviour of fixtesttype()
type2test = None type2test = None
# Whether the "contained items" of the container are integers in
# range(0, 256) (i.e. bytes, bytearray) or strings of length 1
# (str)
contains_bytes = False
# All tests pass their arguments to the testing methods # All tests pass their arguments to the testing methods
# as str objects. fixtesttype() can be used to propagate # as str objects. fixtesttype() can be used to propagate
# these arguments to the appropriate type # these arguments to the appropriate type
@ -117,7 +122,11 @@ class BaseTest(unittest.TestCase):
self.checkequal(0, '', 'count', 'xx', sys.maxsize, 0) self.checkequal(0, '', 'count', 'xx', sys.maxsize, 0)
self.checkraises(TypeError, 'hello', 'count') self.checkraises(TypeError, 'hello', 'count')
self.checkraises(TypeError, 'hello', 'count', 42)
if self.contains_bytes:
self.checkequal(0, 'hello', 'count', 42)
else:
self.checkraises(TypeError, 'hello', 'count', 42)
# For a variety of combinations, # For a variety of combinations,
# verify that str.count() matches an equivalent function # verify that str.count() matches an equivalent function
@ -163,7 +172,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a', None, 6) self.checkequal( 2, 'rrarrrrrrrrra', 'find', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'find') self.checkraises(TypeError, 'hello', 'find')
self.checkraises(TypeError, 'hello', 'find', 42)
if self.contains_bytes:
self.checkequal(-1, 'hello', 'find', 42)
else:
self.checkraises(TypeError, 'hello', 'find', 42)
self.checkequal(0, '', 'find', '') self.checkequal(0, '', 'find', '')
self.checkequal(-1, '', 'find', '', 1, 1) self.checkequal(-1, '', 'find', '', 1, 1)
@ -217,7 +230,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'rfind', 'a', None, 6) self.checkequal( 2, 'rrarrrrrrrrra', 'rfind', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'rfind') self.checkraises(TypeError, 'hello', 'rfind')
self.checkraises(TypeError, 'hello', 'rfind', 42)
if self.contains_bytes:
self.checkequal(-1, 'hello', 'rfind', 42)
else:
self.checkraises(TypeError, 'hello', 'rfind', 42)
# For a variety of combinations, # For a variety of combinations,
# verify that str.rfind() matches __contains__ # verify that str.rfind() matches __contains__
@ -264,7 +281,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a', None, 6) self.checkequal( 2, 'rrarrrrrrrrra', 'index', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'index') self.checkraises(TypeError, 'hello', 'index')
self.checkraises(TypeError, 'hello', 'index', 42)
if self.contains_bytes:
self.checkraises(ValueError, 'hello', 'index', 42)
else:
self.checkraises(TypeError, 'hello', 'index', 42)
def test_rindex(self): def test_rindex(self):
self.checkequal(12, 'abcdefghiabc', 'rindex', '') self.checkequal(12, 'abcdefghiabc', 'rindex', '')
@ -286,7 +307,11 @@ class BaseTest(unittest.TestCase):
self.checkequal( 2, 'rrarrrrrrrrra', 'rindex', 'a', None, 6) self.checkequal( 2, 'rrarrrrrrrrra', 'rindex', 'a', None, 6)
self.checkraises(TypeError, 'hello', 'rindex') self.checkraises(TypeError, 'hello', 'rindex')
self.checkraises(TypeError, 'hello', 'rindex', 42)
if self.contains_bytes:
self.checkraises(ValueError, 'hello', 'rindex', 42)
else:
self.checkraises(TypeError, 'hello', 'rindex', 42)
def test_lower(self): def test_lower(self):
self.checkequal('hello', 'HeLLo', 'lower') self.checkequal('hello', 'HeLLo', 'lower')

View File

@ -293,10 +293,27 @@ class BaseBytesTest(unittest.TestCase):
def test_count(self): def test_count(self):
b = self.type2test(b'mississippi') b = self.type2test(b'mississippi')
i = 105
p = 112
w = 119
self.assertEqual(b.count(b'i'), 4) self.assertEqual(b.count(b'i'), 4)
self.assertEqual(b.count(b'ss'), 2) self.assertEqual(b.count(b'ss'), 2)
self.assertEqual(b.count(b'w'), 0) self.assertEqual(b.count(b'w'), 0)
self.assertEqual(b.count(i), 4)
self.assertEqual(b.count(w), 0)
self.assertEqual(b.count(b'i', 6), 2)
self.assertEqual(b.count(b'p', 6), 2)
self.assertEqual(b.count(b'i', 1, 3), 1)
self.assertEqual(b.count(b'p', 7, 9), 1)
self.assertEqual(b.count(i, 6), 2)
self.assertEqual(b.count(p, 6), 2)
self.assertEqual(b.count(i, 1, 3), 1)
self.assertEqual(b.count(p, 7, 9), 1)
def test_startswith(self): def test_startswith(self):
b = self.type2test(b'hello') b = self.type2test(b'hello')
self.assertFalse(self.type2test().startswith(b"anything")) self.assertFalse(self.type2test().startswith(b"anything"))
@ -327,35 +344,81 @@ class BaseBytesTest(unittest.TestCase):
def test_find(self): def test_find(self):
b = self.type2test(b'mississippi') b = self.type2test(b'mississippi')
i = 105
w = 119
self.assertEqual(b.find(b'ss'), 2) self.assertEqual(b.find(b'ss'), 2)
self.assertEqual(b.find(b'ss', 3), 5)
self.assertEqual(b.find(b'ss', 1, 7), 2)
self.assertEqual(b.find(b'ss', 1, 3), -1)
self.assertEqual(b.find(b'w'), -1) self.assertEqual(b.find(b'w'), -1)
self.assertEqual(b.find(b'mississippian'), -1) self.assertEqual(b.find(b'mississippian'), -1)
self.assertEqual(b.find(i), 1)
self.assertEqual(b.find(w), -1)
self.assertEqual(b.find(b'ss', 3), 5)
self.assertEqual(b.find(b'ss', 1, 7), 2)
self.assertEqual(b.find(b'ss', 1, 3), -1)
self.assertEqual(b.find(i, 6), 7)
self.assertEqual(b.find(i, 1, 3), 1)
self.assertEqual(b.find(w, 1, 3), -1)
def test_rfind(self): def test_rfind(self):
b = self.type2test(b'mississippi') b = self.type2test(b'mississippi')
i = 105
w = 119
self.assertEqual(b.rfind(b'ss'), 5) self.assertEqual(b.rfind(b'ss'), 5)
self.assertEqual(b.rfind(b'ss', 3), 5)
self.assertEqual(b.rfind(b'ss', 0, 6), 2)
self.assertEqual(b.rfind(b'w'), -1) self.assertEqual(b.rfind(b'w'), -1)
self.assertEqual(b.rfind(b'mississippian'), -1) self.assertEqual(b.rfind(b'mississippian'), -1)
self.assertEqual(b.rfind(i), 10)
self.assertEqual(b.rfind(w), -1)
self.assertEqual(b.rfind(b'ss', 3), 5)
self.assertEqual(b.rfind(b'ss', 0, 6), 2)
self.assertEqual(b.rfind(i, 1, 3), 1)
self.assertEqual(b.rfind(i, 3, 9), 7)
self.assertEqual(b.rfind(w, 1, 3), -1)
def test_index(self): def test_index(self):
b = self.type2test(b'world') b = self.type2test(b'mississippi')
self.assertEqual(b.index(b'w'), 0) i = 105
self.assertEqual(b.index(b'orl'), 1) w = 119
self.assertRaises(ValueError, b.index, b'worm')
self.assertRaises(ValueError, b.index, b'ldo') self.assertEqual(b.index(b'ss'), 2)
self.assertRaises(ValueError, b.index, b'w')
self.assertRaises(ValueError, b.index, b'mississippian')
self.assertEqual(b.index(i), 1)
self.assertRaises(ValueError, b.index, w)
self.assertEqual(b.index(b'ss', 3), 5)
self.assertEqual(b.index(b'ss', 1, 7), 2)
self.assertRaises(ValueError, b.index, b'ss', 1, 3)
self.assertEqual(b.index(i, 6), 7)
self.assertEqual(b.index(i, 1, 3), 1)
self.assertRaises(ValueError, b.index, w, 1, 3)
def test_rindex(self): def test_rindex(self):
# XXX could be more rigorous b = self.type2test(b'mississippi')
b = self.type2test(b'world') i = 105
self.assertEqual(b.rindex(b'w'), 0) w = 119
self.assertEqual(b.rindex(b'orl'), 1)
self.assertRaises(ValueError, b.rindex, b'worm') self.assertEqual(b.rindex(b'ss'), 5)
self.assertRaises(ValueError, b.rindex, b'ldo') self.assertRaises(ValueError, b.rindex, b'w')
self.assertRaises(ValueError, b.rindex, b'mississippian')
self.assertEqual(b.rindex(i), 10)
self.assertRaises(ValueError, b.rindex, w)
self.assertEqual(b.rindex(b'ss', 3), 5)
self.assertEqual(b.rindex(b'ss', 0, 6), 2)
self.assertEqual(b.rindex(i, 1, 3), 1)
self.assertEqual(b.rindex(i, 3, 9), 7)
self.assertRaises(ValueError, b.rindex, w, 1, 3)
def test_replace(self): def test_replace(self):
b = self.type2test(b'mississippi') b = self.type2test(b'mississippi')
@ -552,6 +615,14 @@ class BaseBytesTest(unittest.TestCase):
self.assertEqual(True, b.startswith(h, None, -2)) self.assertEqual(True, b.startswith(h, None, -2))
self.assertEqual(False, b.startswith(x, None, None)) self.assertEqual(False, b.startswith(x, None, None))
def test_integer_arguments_out_of_byte_range(self):
b = self.type2test(b'hello')
for method in (b.count, b.find, b.index, b.rfind, b.rindex):
self.assertRaises(ValueError, method, -1)
self.assertRaises(ValueError, method, 256)
self.assertRaises(ValueError, method, 9999)
def test_find_etc_raise_correct_error_messages(self): def test_find_etc_raise_correct_error_messages(self):
# issue 11828 # issue 11828
b = self.type2test(b'hello') b = self.type2test(b'hello')
@ -1161,9 +1232,11 @@ class FixedStringTest(test.string_tests.BaseTest):
class ByteArrayAsStringTest(FixedStringTest): class ByteArrayAsStringTest(FixedStringTest):
type2test = bytearray type2test = bytearray
contains_bytes = True
class BytesAsStringTest(FixedStringTest): class BytesAsStringTest(FixedStringTest):
type2test = bytes type2test = bytes
contains_bytes = True
class SubclassTest(unittest.TestCase): class SubclassTest(unittest.TestCase):

View File

@ -10,6 +10,10 @@ What's New in Python 3.3 Alpha 1?
Core and Builtins Core and Builtins
----------------- -----------------
- Issue #12170: The count(), find(), rfind(), index() and rindex() methods
of bytes and bytearray objects now accept an integer between 0 and 255
as their first argument. Patch by Petri Lehtinen.
- Issue #12604: VTRACE macro expanded to no-op in _sre.c to avoid compiler - Issue #12604: VTRACE macro expanded to no-op in _sre.c to avoid compiler
warnings. Patch by Josh Triplett and Petri Lehtinen. warnings. Patch by Josh Triplett and Petri Lehtinen.

View File

@ -1071,24 +1071,41 @@ Py_LOCAL_INLINE(Py_ssize_t)
bytearray_find_internal(PyByteArrayObject *self, PyObject *args, int dir) bytearray_find_internal(PyByteArrayObject *self, PyObject *args, int dir)
{ {
PyObject *subobj; PyObject *subobj;
char byte;
Py_buffer subbuf; Py_buffer subbuf;
const char *sub;
Py_ssize_t sub_len;
Py_ssize_t start=0, end=PY_SSIZE_T_MAX; Py_ssize_t start=0, end=PY_SSIZE_T_MAX;
Py_ssize_t res; Py_ssize_t res;
if (!stringlib_parse_args_finds("find/rfind/index/rindex", if (!stringlib_parse_args_finds_byte("find/rfind/index/rindex",
args, &subobj, &start, &end)) args, &subobj, &byte, &start, &end))
return -2;
if (_getbuffer(subobj, &subbuf) < 0)
return -2; return -2;
if (subobj) {
if (_getbuffer(subobj, &subbuf) < 0)
return -2;
sub = subbuf.buf;
sub_len = subbuf.len;
}
else {
sub = &byte;
sub_len = 1;
}
if (dir > 0) if (dir > 0)
res = stringlib_find_slice( res = stringlib_find_slice(
PyByteArray_AS_STRING(self), PyByteArray_GET_SIZE(self), PyByteArray_AS_STRING(self), PyByteArray_GET_SIZE(self),
subbuf.buf, subbuf.len, start, end); sub, sub_len, start, end);
else else
res = stringlib_rfind_slice( res = stringlib_rfind_slice(
PyByteArray_AS_STRING(self), PyByteArray_GET_SIZE(self), PyByteArray_AS_STRING(self), PyByteArray_GET_SIZE(self),
subbuf.buf, subbuf.len, start, end); sub, sub_len, start, end);
PyBuffer_Release(&subbuf);
if (subobj)
PyBuffer_Release(&subbuf);
return res; return res;
} }
@ -1121,23 +1138,39 @@ static PyObject *
bytearray_count(PyByteArrayObject *self, PyObject *args) bytearray_count(PyByteArrayObject *self, PyObject *args)
{ {
PyObject *sub_obj; PyObject *sub_obj;
const char *str = PyByteArray_AS_STRING(self); const char *str = PyByteArray_AS_STRING(self), *sub;
Py_ssize_t sub_len;
char byte;
Py_ssize_t start = 0, end = PY_SSIZE_T_MAX; Py_ssize_t start = 0, end = PY_SSIZE_T_MAX;
Py_buffer vsub; Py_buffer vsub;
PyObject *count_obj; PyObject *count_obj;
if (!stringlib_parse_args_finds("count", args, &sub_obj, &start, &end)) if (!stringlib_parse_args_finds_byte("count", args, &sub_obj, &byte,
&start, &end))
return NULL; return NULL;
if (_getbuffer(sub_obj, &vsub) < 0) if (sub_obj) {
return NULL; if (_getbuffer(sub_obj, &vsub) < 0)
return NULL;
sub = vsub.buf;
sub_len = vsub.len;
}
else {
sub = &byte;
sub_len = 1;
}
ADJUST_INDICES(start, end, PyByteArray_GET_SIZE(self)); ADJUST_INDICES(start, end, PyByteArray_GET_SIZE(self));
count_obj = PyLong_FromSsize_t( count_obj = PyLong_FromSsize_t(
stringlib_count(str + start, end - start, vsub.buf, vsub.len, PY_SSIZE_T_MAX) stringlib_count(str + start, end - start, sub, sub_len, PY_SSIZE_T_MAX)
); );
PyBuffer_Release(&vsub);
if (sub_obj)
PyBuffer_Release(&vsub);
return count_obj; return count_obj;
} }

View File

@ -1230,31 +1230,42 @@ Py_LOCAL_INLINE(Py_ssize_t)
bytes_find_internal(PyBytesObject *self, PyObject *args, int dir) bytes_find_internal(PyBytesObject *self, PyObject *args, int dir)
{ {
PyObject *subobj; PyObject *subobj;
char byte;
Py_buffer subbuf;
const char *sub; const char *sub;
Py_ssize_t sub_len; Py_ssize_t sub_len;
Py_ssize_t start=0, end=PY_SSIZE_T_MAX; Py_ssize_t start=0, end=PY_SSIZE_T_MAX;
Py_ssize_t res;
if (!stringlib_parse_args_finds("find/rfind/index/rindex", if (!stringlib_parse_args_finds_byte("find/rfind/index/rindex",
args, &subobj, &start, &end)) args, &subobj, &byte, &start, &end))
return -2; return -2;
if (PyBytes_Check(subobj)) { if (subobj) {
sub = PyBytes_AS_STRING(subobj); if (_getbuffer(subobj, &subbuf) < 0)
sub_len = PyBytes_GET_SIZE(subobj); return -2;
sub = subbuf.buf;
sub_len = subbuf.len;
}
else {
sub = &byte;
sub_len = 1;
} }
else if (PyObject_AsCharBuffer(subobj, &sub, &sub_len))
/* XXX - the "expected a character buffer object" is pretty
confusing for a non-expert. remap to something else ? */
return -2;
if (dir > 0) if (dir > 0)
return stringlib_find_slice( res = stringlib_find_slice(
PyBytes_AS_STRING(self), PyBytes_GET_SIZE(self), PyBytes_AS_STRING(self), PyBytes_GET_SIZE(self),
sub, sub_len, start, end); sub, sub_len, start, end);
else else
return stringlib_rfind_slice( res = stringlib_rfind_slice(
PyBytes_AS_STRING(self), PyBytes_GET_SIZE(self), PyBytes_AS_STRING(self), PyBytes_GET_SIZE(self),
sub, sub_len, start, end); sub, sub_len, start, end);
if (subobj)
PyBuffer_Release(&subbuf);
return res;
} }
@ -1480,23 +1491,38 @@ bytes_count(PyBytesObject *self, PyObject *args)
PyObject *sub_obj; PyObject *sub_obj;
const char *str = PyBytes_AS_STRING(self), *sub; const char *str = PyBytes_AS_STRING(self), *sub;
Py_ssize_t sub_len; Py_ssize_t sub_len;
char byte;
Py_ssize_t start = 0, end = PY_SSIZE_T_MAX; Py_ssize_t start = 0, end = PY_SSIZE_T_MAX;
if (!stringlib_parse_args_finds("count", args, &sub_obj, &start, &end)) Py_buffer vsub;
PyObject *count_obj;
if (!stringlib_parse_args_finds_byte("count", args, &sub_obj, &byte,
&start, &end))
return NULL; return NULL;
if (PyBytes_Check(sub_obj)) { if (sub_obj) {
sub = PyBytes_AS_STRING(sub_obj); if (_getbuffer(sub_obj, &vsub) < 0)
sub_len = PyBytes_GET_SIZE(sub_obj); return NULL;
sub = vsub.buf;
sub_len = vsub.len;
}
else {
sub = &byte;
sub_len = 1;
} }
else if (PyObject_AsCharBuffer(sub_obj, &sub, &sub_len))
return NULL;
ADJUST_INDICES(start, end, PyBytes_GET_SIZE(self)); ADJUST_INDICES(start, end, PyBytes_GET_SIZE(self));
return PyLong_FromSsize_t( count_obj = PyLong_FromSsize_t(
stringlib_count(str + start, end - start, sub, sub_len, PY_SSIZE_T_MAX) stringlib_count(str + start, end - start, sub, sub_len, PY_SSIZE_T_MAX)
); );
if (sub_obj)
PyBuffer_Release(&vsub);
return count_obj;
} }

View File

@ -167,4 +167,47 @@ STRINGLIB(parse_args_finds_unicode)(const char * function_name, PyObject *args,
return 0; return 0;
} }
#else /* !STRINGLIB_IS_UNICODE */
/*
Wraps stringlib_parse_args_finds() and additionally checks whether the
first argument is an integer in range(0, 256).
If this is the case, writes the integer value to the byte parameter
and sets subobj to NULL. Otherwise, sets the first argument to subobj
and doesn't touch byte. The other parameters are similar to those of
stringlib_parse_args_finds().
*/
Py_LOCAL_INLINE(int)
STRINGLIB(parse_args_finds_byte)(const char *function_name, PyObject *args,
PyObject **subobj, char *byte,
Py_ssize_t *start, Py_ssize_t *end)
{
PyObject *tmp_subobj;
Py_ssize_t ival;
if(!STRINGLIB(parse_args_finds)(function_name, args, &tmp_subobj,
start, end))
return 0;
ival = PyNumber_AsSsize_t(tmp_subobj, PyExc_ValueError);
if (ival == -1 && PyErr_Occurred()) {
PyErr_Clear();
*subobj = tmp_subobj;
}
else {
/* The first argument was an integer */
if(ival < 0 || ival > 255) {
PyErr_SetString(PyExc_ValueError, "byte must be in range(0, 256)");
return 0;
}
*subobj = NULL;
*byte = (char)ival;
}
return 1;
}
#endif /* STRINGLIB_IS_UNICODE */ #endif /* STRINGLIB_IS_UNICODE */