Issue #6218: Make io.BytesIO and io.StringIO picklable.

This commit is contained in:
Alexandre Vassalotti 2009-07-22 03:24:36 +00:00
parent d2bb18b281
commit cf76e1ac92
5 changed files with 410 additions and 16 deletions

View File

@ -765,6 +765,11 @@ class BytesIO(BufferedIOBase):
self._buffer = buf
self._pos = 0
def __getstate__(self):
if self.closed:
raise ValueError("__getstate__ on closed file")
return self.__dict__.copy()
def getvalue(self):
"""Return the bytes value (contents) of the buffer
"""

View File

@ -9,6 +9,7 @@ from test import support
import io
import _pyio as pyio
import sys
import pickle
class MemorySeekTestMixin:
@ -346,6 +347,42 @@ class MemoryTestMixin:
memio = self.ioclass()
memio.foo = 1
def test_pickling(self):
buf = self.buftype("1234567890")
memio = self.ioclass(buf)
memio.foo = 42
memio.seek(2)
class PickleTestMemIO(self.ioclass):
def __init__(me, initvalue, foo):
self.ioclass.__init__(me, initvalue)
me.foo = foo
# __getnewargs__ is undefined on purpose. This checks that PEP 307
# is used to provide pickling support.
# Pickle expects the class to be on the module level. Here we use a
# little hack to allow the PickleTestMemIO class to derive from
# self.ioclass without having to define all combinations explictly on
# the module-level.
import __main__
PickleTestMemIO.__module__ = '__main__'
__main__.PickleTestMemIO = PickleTestMemIO
submemio = PickleTestMemIO(buf, 80)
submemio.seek(2)
# We only support pickle protocol 2 and onward since we use extended
# __reduce__ API of PEP 307 to provide pickling support.
for proto in range(2, pickle.HIGHEST_PROTOCOL):
for obj in (memio, submemio):
obj2 = pickle.loads(pickle.dumps(obj, protocol=proto))
self.assertEqual(obj.getvalue(), obj2.getvalue())
self.assertEqual(obj.__class__, obj2.__class__)
self.assertEqual(obj.foo, obj2.foo)
self.assertEqual(obj.tell(), obj2.tell())
obj.close()
self.assertRaises(ValueError, pickle.dumps, obj, proto)
del __main__.PickleTestMemIO
class PyBytesIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
@ -425,13 +462,26 @@ class PyBytesIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
self.assertEqual(memio.getvalue(), buf)
class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
buftype = str
ioclass = pyio.StringIO
UnsupportedOperation = pyio.UnsupportedOperation
EOF = ""
class TextIOTestMixin:
# TextIO-specific behaviour.
def test_relative_seek(self):
memio = self.ioclass()
self.assertRaises(IOError, memio.seek, -1, 1)
self.assertRaises(IOError, memio.seek, 3, 1)
self.assertRaises(IOError, memio.seek, -3, 1)
self.assertRaises(IOError, memio.seek, -1, 2)
self.assertRaises(IOError, memio.seek, 1, 1)
self.assertRaises(IOError, memio.seek, 1, 2)
def test_textio_properties(self):
memio = self.ioclass()
# These are just dummy values but we nevertheless check them for fear
# of unexpected breakage.
self.assertTrue(memio.encoding is None)
self.assertEqual(memio.errors, "strict")
self.assertEqual(memio.line_buffering, False)
def test_newlines_property(self):
memio = self.ioclass(newline=None)
@ -513,7 +563,6 @@ class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
def test_newline_cr(self):
# newline="\r"
memio = self.ioclass("a\nb\r\nc\rd", newline="\r")
memio.seek(0)
self.assertEqual(memio.read(), "a\rb\r\rc\rd")
memio.seek(0)
self.assertEqual(list(memio), ["a\r", "b\r", "\r", "c\r", "d"])
@ -521,7 +570,6 @@ class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
def test_newline_crlf(self):
# newline="\r\n"
memio = self.ioclass("a\nb\r\nc\rd", newline="\r\n")
memio.seek(0)
self.assertEqual(memio.read(), "a\r\nb\r\r\nc\rd")
memio.seek(0)
self.assertEqual(list(memio), ["a\r\n", "b\r\r\n", "c\rd"])
@ -539,10 +587,59 @@ class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin, unittest.TestCase):
self.ioclass(newline=newline)
class PyStringIOTest(MemoryTestMixin, MemorySeekTestMixin,
TextIOTestMixin, unittest.TestCase):
buftype = str
ioclass = pyio.StringIO
UnsupportedOperation = pyio.UnsupportedOperation
EOF = ""
class PyStringIOPickleTest(TextIOTestMixin, unittest.TestCase):
"""Test if pickle restores properly the internal state of StringIO.
"""
buftype = str
UnsupportedOperation = pyio.UnsupportedOperation
EOF = ""
class ioclass(pyio.StringIO):
def __new__(cls, *args, **kwargs):
return pickle.loads(pickle.dumps(pyio.StringIO(*args, **kwargs)))
def __init__(self, *args, **kwargs):
pass
class CBytesIOTest(PyBytesIOTest):
ioclass = io.BytesIO
UnsupportedOperation = io.UnsupportedOperation
def test_getstate(self):
memio = self.ioclass()
state = memio.__getstate__()
self.assertEqual(len(state), 3)
bytearray(state[0]) # Check if state[0] supports the buffer interface.
self.assert_(isinstance(state[1], int))
self.assert_(isinstance(state[2], dict) or state[2] is None)
memio.close()
self.assertRaises(ValueError, memio.__getstate__)
def test_setstate(self):
# This checks whether __setstate__ does proper input validation.
memio = self.ioclass()
memio.__setstate__((b"no error", 0, None))
memio.__setstate__((bytearray(b"no error"), 0, None))
memio.__setstate__((b"no error", 0, {'spam': 3}))
self.assertRaises(ValueError, memio.__setstate__, (b"", -1, None))
self.assertRaises(TypeError, memio.__setstate__, ("unicode", 0, None))
self.assertRaises(TypeError, memio.__setstate__, (b"", 0.0, None))
self.assertRaises(TypeError, memio.__setstate__, (b"", 0, 0))
self.assertRaises(TypeError, memio.__setstate__, (b"len-test", 0))
self.assertRaises(TypeError, memio.__setstate__)
self.assertRaises(TypeError, memio.__setstate__, 0)
memio.close()
self.assertRaises(ValueError, memio.__setstate__, (b"closed", 0, None))
class CStringIOTest(PyStringIOTest):
ioclass = io.StringIO
UnsupportedOperation = io.UnsupportedOperation
@ -561,9 +658,48 @@ class CStringIOTest(PyStringIOTest):
self.assertEqual(memio.tell(), len(buf) * 2)
self.assertEqual(memio.getvalue(), buf + buf)
def test_getstate(self):
memio = self.ioclass()
state = memio.__getstate__()
self.assertEqual(len(state), 4)
self.assert_(isinstance(state[0], str))
self.assert_(isinstance(state[1], str))
self.assert_(isinstance(state[2], int))
self.assert_(isinstance(state[3], dict) or state[3] is None)
memio.close()
self.assertRaises(ValueError, memio.__getstate__)
def test_setstate(self):
# This checks whether __setstate__ does proper input validation.
memio = self.ioclass()
memio.__setstate__(("no error", "\n", 0, None))
memio.__setstate__(("no error", "", 0, {'spam': 3}))
self.assertRaises(ValueError, memio.__setstate__, ("", "f", 0, None))
self.assertRaises(ValueError, memio.__setstate__, ("", "", -1, None))
self.assertRaises(TypeError, memio.__setstate__, (b"", "", 0, None))
self.assertRaises(TypeError, memio.__setstate__, ("", b"", 0, None))
self.assertRaises(TypeError, memio.__setstate__, ("", "", 0.0, None))
self.assertRaises(TypeError, memio.__setstate__, ("", "", 0, 0))
self.assertRaises(TypeError, memio.__setstate__, ("len-test", 0))
self.assertRaises(TypeError, memio.__setstate__)
self.assertRaises(TypeError, memio.__setstate__, 0)
memio.close()
self.assertRaises(ValueError, memio.__setstate__, ("closed", "", 0, None))
class CStringIOPickleTest(PyStringIOPickleTest):
UnsupportedOperation = io.UnsupportedOperation
class ioclass(io.StringIO):
def __new__(cls, *args, **kwargs):
return pickle.loads(pickle.dumps(io.StringIO(*args, **kwargs)))
def __init__(self, *args, **kwargs):
pass
def test_main():
tests = [PyBytesIOTest, PyStringIOTest, CBytesIOTest, CStringIOTest]
tests = [PyBytesIOTest, PyStringIOTest, CBytesIOTest, CStringIOTest,
PyStringIOPickleTest, CStringIOPickleTest]
support.run_unittest(*tests)
if __name__ == '__main__':

View File

@ -85,6 +85,8 @@ Library
- Issue #4005: Fixed a crash of pydoc when there was a zip file present in
sys.path.
- Issue #6218: io.StringIO and io.BytesIO instances are now picklable.
Extension Modules
-----------------

View File

@ -606,6 +606,120 @@ bytesio_close(bytesio *self)
Py_RETURN_NONE;
}
/* Pickling support.
Note that only pickle protocol 2 and onward are supported since we use
extended __reduce__ API of PEP 307 to make BytesIO instances picklable.
Providing support for protocol < 2 would require the __reduce_ex__ method
which is notably long-winded when defined properly.
For BytesIO, the implementation would similar to one coded for
object.__reduce_ex__, but slightly less general. To be more specific, we
could call bytesio_getstate directly and avoid checking for the presence of
a fallback __reduce__ method. However, we would still need a __newobj__
function to use the efficient instance representation of PEP 307.
*/
static PyObject *
bytesio_getstate(bytesio *self)
{
PyObject *initvalue = bytesio_getvalue(self);
PyObject *dict;
PyObject *state;
if (initvalue == NULL)
return NULL;
if (self->dict == NULL) {
Py_INCREF(Py_None);
dict = Py_None;
}
else {
dict = PyDict_Copy(self->dict);
if (dict == NULL)
return NULL;
}
state = Py_BuildValue("(OnN)", initvalue, self->pos, dict);
Py_DECREF(initvalue);
return state;
}
static PyObject *
bytesio_setstate(bytesio *self, PyObject *state)
{
PyObject *result;
PyObject *position_obj;
PyObject *dict;
Py_ssize_t pos;
assert(state != NULL);
/* We allow the state tuple to be longer than 3, because we may need
someday to extend the object's state without breaking
backward-compatibility. */
if (!PyTuple_Check(state) || Py_SIZE(state) < 3) {
PyErr_Format(PyExc_TypeError,
"%.200s.__setstate__ argument should be 3-tuple, got %.200s",
Py_TYPE(self)->tp_name, Py_TYPE(state)->tp_name);
return NULL;
}
/* Reset the object to its default state. This is only needed to handle
the case of repeated calls to __setstate__. */
self->string_size = 0;
self->pos = 0;
/* Set the value of the internal buffer. If state[0] does not support the
buffer protocol, bytesio_write will raise the appropriate TypeError. */
result = bytesio_write(self, PyTuple_GET_ITEM(state, 0));
if (result == NULL)
return NULL;
Py_DECREF(result);
/* Set carefully the position value. Alternatively, we could use the seek
method instead of modifying self->pos directly to better protect the
object internal state against errneous (or malicious) inputs. */
position_obj = PyTuple_GET_ITEM(state, 1);
if (!PyLong_Check(position_obj)) {
PyErr_Format(PyExc_TypeError,
"second item of state must be an integer, not %.200s",
Py_TYPE(position_obj)->tp_name);
return NULL;
}
pos = PyLong_AsSsize_t(position_obj);
if (pos == -1 && PyErr_Occurred())
return NULL;
if (pos < 0) {
PyErr_SetString(PyExc_ValueError,
"position value cannot be negative");
return NULL;
}
self->pos = pos;
/* Set the dictionary of the instance variables. */
dict = PyTuple_GET_ITEM(state, 2);
if (dict != Py_None) {
if (!PyDict_Check(dict)) {
PyErr_Format(PyExc_TypeError,
"third item of state should be a dict, got a %.200s",
Py_TYPE(dict)->tp_name);
return NULL;
}
if (self->dict) {
/* Alternatively, we could replace the internal dictionary
completely. However, it seems more practical to just update it. */
if (PyDict_Update(self->dict, dict) < 0)
return NULL;
}
else {
Py_INCREF(dict);
self->dict = dict;
}
}
Py_RETURN_NONE;
}
static void
bytesio_dealloc(bytesio *self)
{
@ -630,9 +744,9 @@ bytesio_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
if (self == NULL)
return NULL;
self->string_size = 0;
self->pos = 0;
self->buf_size = 0;
/* tp_alloc initializes all the fields to zero. So we don't have to
initialize them here. */
self->buf = (char *)PyMem_Malloc(0);
if (self->buf == NULL) {
Py_DECREF(self);
@ -705,6 +819,8 @@ static struct PyMethodDef bytesio_methods[] = {
{"getvalue", (PyCFunction)bytesio_getvalue, METH_VARARGS, getval_doc},
{"seek", (PyCFunction)bytesio_seek, METH_VARARGS, seek_doc},
{"truncate", (PyCFunction)bytesio_truncate, METH_VARARGS, truncate_doc},
{"__getstate__", (PyCFunction)bytesio_getstate, METH_NOARGS, NULL},
{"__setstate__", (PyCFunction)bytesio_setstate, METH_O, NULL},
{NULL, NULL} /* sentinel */
};

View File

@ -533,9 +533,9 @@ stringio_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
if (self == NULL)
return NULL;
self->string_size = 0;
self->pos = 0;
self->buf_size = 0;
/* tp_alloc initializes all the fields to zero. So we don't have to
initialize them here. */
self->buf = (Py_UNICODE *)PyMem_Malloc(0);
if (self->buf == NULL) {
Py_DECREF(self);
@ -597,6 +597,9 @@ stringio_init(stringio *self, PyObject *args, PyObject *kwds)
Py_CLEAR(self->writenl);
Py_CLEAR(self->decoder);
assert((newline != NULL && newline_obj != Py_None) ||
(newline == NULL && newline_obj == Py_None));
assert((newline != NULL && newline_obj != Py_None) ||
(newline == NULL && newline_obj == Py_None));
@ -672,6 +675,135 @@ stringio_writable(stringio *self, PyObject *args)
Py_RETURN_TRUE;
}
/* Pickling support.
The implementation of __getstate__ is similar to the one for BytesIO,
except that we also save the newline parameter. For __setstate__ and unlike
BytesIO, we call __init__ to restore the object's state. Doing so allows us
to avoid decoding the complex newline state while keeping the object
representation compact.
See comment in bytesio.c regarding why only pickle protocols and onward are
supported.
*/
static PyObject *
stringio_getstate(stringio *self)
{
PyObject *initvalue = stringio_getvalue(self);
PyObject *dict;
PyObject *state;
if (initvalue == NULL)
return NULL;
if (self->dict == NULL) {
Py_INCREF(Py_None);
dict = Py_None;
}
else {
dict = PyDict_Copy(self->dict);
if (dict == NULL)
return NULL;
}
state = Py_BuildValue("(OOnN)", initvalue,
self->readnl ? self->readnl : Py_None,
self->pos, dict);
Py_DECREF(initvalue);
return state;
}
static PyObject *
stringio_setstate(stringio *self, PyObject *state)
{
PyObject *initarg;
PyObject *position_obj;
PyObject *dict;
Py_ssize_t pos;
assert(state != NULL);
CHECK_CLOSED(self);
/* We allow the state tuple to be longer than 4, because we may need
someday to extend the object's state without breaking
backward-compatibility. */
if (!PyTuple_Check(state) || Py_SIZE(state) < 4) {
PyErr_Format(PyExc_TypeError,
"%.200s.__setstate__ argument should be 4-tuple, got %.200s",
Py_TYPE(self)->tp_name, Py_TYPE(state)->tp_name);
return NULL;
}
/* Initialize the object's state. */
initarg = PyTuple_GetSlice(state, 0, 2);
if (initarg == NULL)
return NULL;
if (stringio_init(self, initarg, NULL) < 0) {
Py_DECREF(initarg);
return NULL;
}
Py_DECREF(initarg);
/* Restore the buffer state. Even if __init__ did initialize the buffer,
we have to initialize it again since __init__ may translates the
newlines in the inital_value string. We clearly do not want that
because the string value in the state tuple has already been translated
once by __init__. So we do not take any chance and replace object's
buffer completely. */
{
Py_UNICODE *buf = PyUnicode_AS_UNICODE(PyTuple_GET_ITEM(state, 0));
Py_ssize_t bufsize = PyUnicode_GET_SIZE(PyTuple_GET_ITEM(state, 0));
if (resize_buffer(self, bufsize) < 0)
return NULL;
memcpy(self->buf, buf, bufsize * sizeof(Py_UNICODE));
self->string_size = bufsize;
}
/* Set carefully the position value. Alternatively, we could use the seek
method instead of modifying self->pos directly to better protect the
object internal state against errneous (or malicious) inputs. */
position_obj = PyTuple_GET_ITEM(state, 2);
if (!PyLong_Check(position_obj)) {
PyErr_Format(PyExc_TypeError,
"third item of state must be an integer, got %.200s",
Py_TYPE(position_obj)->tp_name);
return NULL;
}
pos = PyLong_AsSsize_t(position_obj);
if (pos == -1 && PyErr_Occurred())
return NULL;
if (pos < 0) {
PyErr_SetString(PyExc_ValueError,
"position value cannot be negative");
return NULL;
}
self->pos = pos;
/* Set the dictionary of the instance variables. */
dict = PyTuple_GET_ITEM(state, 3);
if (dict != Py_None) {
if (!PyDict_Check(dict)) {
PyErr_Format(PyExc_TypeError,
"fourth item of state should be a dict, got a %.200s",
Py_TYPE(dict)->tp_name);
return NULL;
}
if (self->dict) {
/* Alternatively, we could replace the internal dictionary
completely. However, it seems more practical to just update it. */
if (PyDict_Update(self->dict, dict) < 0)
return NULL;
}
else {
Py_INCREF(dict);
self->dict = dict;
}
}
Py_RETURN_NONE;
}
static PyObject *
stringio_closed(stringio *self, void *context)
{
@ -710,6 +842,9 @@ static struct PyMethodDef stringio_methods[] = {
{"seekable", (PyCFunction)stringio_seekable, METH_NOARGS},
{"readable", (PyCFunction)stringio_readable, METH_NOARGS},
{"writable", (PyCFunction)stringio_writable, METH_NOARGS},
{"__getstate__", (PyCFunction)stringio_getstate, METH_NOARGS},
{"__setstate__", (PyCFunction)stringio_setstate, METH_O},
{NULL, NULL} /* sentinel */
};