From 13e57219d3143e4bae976a90846d6902e0514006 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Thu, 27 Apr 2006 22:54:26 +0000 Subject: [PATCH] Implement bytes += bytes, bytes *= int, int in bytes, bytes in bytes. --- Lib/test/test_bytes.py | 51 +++++++++++++++++++++- Objects/bytesobject.c | 99 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 143 insertions(+), 7 deletions(-) diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index cf5cd5ad4d5..94524d4a6f1 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -296,8 +296,57 @@ class BytesTest(unittest.TestCase): self.assertRaises(TypeError, lambda: b * 3.14) self.assertRaises(TypeError, lambda: 3.14 * b) self.assertRaises(MemoryError, lambda: b * sys.maxint) + + def test_repeat_1char(self): self.assertEqual(bytes('x')*100, bytes('x'*100)) + def test_iconcat(self): + b = bytes("abc") + b1 = b + b += bytes("def") + self.assertEqual(b, bytes("abcdef")) + self.assertEqual(b, b1) + self.failUnless(b is b1) + + def test_irepeat(self): + b = bytes("abc") + b1 = b + b *= 3 + self.assertEqual(b, bytes("abcabcabc")) + self.assertEqual(b, b1) + self.failUnless(b is b1) + + def test_irepeat_1char(self): + b = bytes("x") + b1 = b + b *= 100 + self.assertEqual(b, bytes("x"*100)) + self.assertEqual(b, b1) + self.failUnless(b is b1) + + def test_contains(self): + b = bytes("abc") + self.failUnless(ord('a') in b) + self.failUnless(long(ord('a')) in b) + self.failIf(200 in b) + self.failIf(200L in b) + self.assertRaises(ValueError, lambda: 300 in b) + self.assertRaises(ValueError, lambda: -1 in b) + self.assertRaises(TypeError, lambda: None in b) + self.assertRaises(TypeError, lambda: float(ord('a')) in b) + self.assertRaises(TypeError, lambda: "a" in b) + self.failUnless(bytes("") in b) + self.failUnless(bytes("a") in b) + self.failUnless(bytes("b") in b) + self.failUnless(bytes("c") in b) + self.failUnless(bytes("ab") in b) + self.failUnless(bytes("bc") in b) + self.failUnless(bytes("abc") in b) + self.failIf(bytes("ac") in b) + self.failIf(bytes("d") in b) + self.failIf(bytes("dab") in b) + self.failIf(bytes("abd") in b) + # Optimizations: # __iter__? (optimization) # __reversed__? (optimization) @@ -311,7 +360,6 @@ class BytesTest(unittest.TestCase): # pop # NOT sort! # With int arg: - # __contains__ # index # count # append @@ -321,7 +369,6 @@ class BytesTest(unittest.TestCase): # startswith # endswidth # find, rfind - # __contains__ (bytes arg) # index, rindex (bytes arg) # join # replace diff --git a/Objects/bytesobject.c b/Objects/bytesobject.c index 36b44245008..c4f9eeca995 100644 --- a/Objects/bytesobject.c +++ b/Objects/bytesobject.c @@ -115,6 +115,31 @@ bytes_concat(PyBytesObject *self, PyObject *other) return (PyObject *)result; } +static PyObject * +bytes_iconcat(PyBytesObject *self, PyObject *other) +{ + Py_ssize_t mysize; + Py_ssize_t osize; + Py_ssize_t size; + + if (!PyBytes_Check(other)) { + PyErr_Format(PyExc_TypeError, + "can't concat bytes to %.100s", other->ob_type->tp_name); + return NULL; + } + + mysize = self->ob_size; + osize = ((PyBytesObject *)other)->ob_size; + size = mysize + osize; + if (size < 0) + return PyErr_NoMemory(); + if (PyBytes_Resize((PyObject *)self, size) < 0) + return NULL; + memcpy(self->ob_bytes + mysize, ((PyBytesObject *)other)->ob_bytes, osize); + Py_INCREF(self); + return (PyObject *)self; +} + static PyObject * bytes_repeat(PyBytesObject *self, Py_ssize_t count) { @@ -133,7 +158,7 @@ bytes_repeat(PyBytesObject *self, Py_ssize_t count) if (mysize == 1) memset(result->ob_bytes, self->ob_bytes[0], size); else { - int i; + Py_ssize_t i; for (i = 0; i < count; i++) memcpy(result->ob_bytes + i*mysize, self->ob_bytes, mysize); } @@ -141,6 +166,72 @@ bytes_repeat(PyBytesObject *self, Py_ssize_t count) return (PyObject *)result; } +static PyObject * +bytes_irepeat(PyBytesObject *self, Py_ssize_t count) +{ + Py_ssize_t mysize; + Py_ssize_t size; + + if (count < 0) + count = 0; + mysize = self->ob_size; + size = mysize * count; + if (count != 0 && size / count != mysize) + return PyErr_NoMemory(); + if (PyBytes_Resize((PyObject *)self, size) < 0) + return NULL; + + if (mysize == 1) + memset(self->ob_bytes, self->ob_bytes[0], size); + else { + Py_ssize_t i; + for (i = 1; i < count; i++) + memcpy(self->ob_bytes + i*mysize, self->ob_bytes, mysize); + } + + Py_INCREF(self); + return (PyObject *)self; +} + +static int +bytes_substring(PyBytesObject *self, PyBytesObject *other) +{ + Py_ssize_t i; + + if (other->ob_size == 1) { + return memchr(self->ob_bytes, other->ob_bytes[0], + self->ob_size) != NULL; + } + if (other->ob_size == 0) + return 1; /* Edge case */ + for (i = 0; i + other->ob_size <= self->ob_size; i++) { + /* XXX Yeah, yeah, lots of optimizations possible... */ + if (memcmp(self->ob_bytes + i, other->ob_bytes, other->ob_size) == 0) + return 1; + } + return 0; +} + +static int +bytes_contains(PyBytesObject *self, PyObject *value) +{ + Py_ssize_t ival; + + if (PyBytes_Check(value)) + return bytes_substring(self, (PyBytesObject *)value); + + ival = PyNumber_Index(value); + if (ival == -1 && PyErr_Occurred()) + return -1; + + if (ival < 0 || ival >= 256) { + PyErr_SetString(PyExc_ValueError, "byte must be in range(0, 256)"); + return -1; + } + + return memchr(self->ob_bytes, ival, self->ob_size) != NULL; +} + static PyObject * bytes_getitem(PyBytesObject *self, Py_ssize_t i) { @@ -590,11 +681,9 @@ static PySequenceMethods bytes_as_sequence = { (ssizessizeargfunc)bytes_getslice, /*sq_slice*/ (ssizeobjargproc)bytes_setitem, /*sq_ass_item*/ (ssizessizeobjargproc)bytes_setslice, /* sq_ass_slice */ -#if 0 (objobjproc)bytes_contains, /* sq_contains */ - (binaryfunc)bytes_inplace_concat, /* sq_inplace_concat */ - (ssizeargfunc)bytes_inplace_repeat, /* sq_inplace_repeat */ -#endif + (binaryfunc)bytes_iconcat, /* sq_inplace_concat */ + (ssizeargfunc)bytes_irepeat, /* sq_inplace_repeat */ }; static PyMappingMethods bytes_as_mapping = {