Issue #16076: make _elementtree.Element pickle-able in a way that is compatible

with the Python version of the class.

Patch by Daniel Shahaf.
This commit is contained in:
Eli Bendersky 2013-01-10 06:06:01 -08:00
commit 065eeb1085
2 changed files with 237 additions and 22 deletions

View File

@ -16,14 +16,20 @@
import html import html
import io import io
import operator
import pickle import pickle
import sys import sys
import unittest import unittest
import weakref import weakref
from itertools import product
from test import support from test import support
from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
# pyET is the pure-Python implementation.
#
# ET is pyET in test_xml_etree and is the C accelerated version in
# test_xml_etree_c.
pyET = None pyET = None
ET = None ET = None
@ -171,6 +177,38 @@ def check_element(element):
for elem in element: for elem in element:
check_element(elem) check_element(elem)
class ElementTestCase:
@classmethod
def setUpClass(cls):
cls.modules = {pyET, ET}
def pickleRoundTrip(self, obj, name, dumper, loader):
save_m = sys.modules[name]
try:
sys.modules[name] = dumper
temp = pickle.dumps(obj)
sys.modules[name] = loader
result = pickle.loads(temp)
except pickle.PicklingError as pe:
# pyET must be second, because pyET may be (equal to) ET.
human = dict([(ET, "cET"), (pyET, "pyET")])
raise support.TestFailed("Failed to round-trip %r from %r to %r"
% (obj,
human.get(dumper, dumper),
human.get(loader, loader))) from pe
finally:
sys.modules[name] = save_m
return result
def assertEqualElements(self, alice, bob):
self.assertIsInstance(alice, (ET.Element, pyET.Element))
self.assertIsInstance(bob, (ET.Element, pyET.Element))
self.assertEqual(len(list(alice)), len(list(bob)))
for x, y in zip(alice, bob):
self.assertEqualElements(x, y)
properties = operator.attrgetter('tag', 'tail', 'text', 'attrib')
self.assertEqual(properties(alice), properties(bob))
# -------------------------------------------------------------------- # --------------------------------------------------------------------
# element tree tests # element tree tests
@ -1715,7 +1753,7 @@ def check_issue10777():
# -------------------------------------------------------------------- # --------------------------------------------------------------------
class BasicElementTest(unittest.TestCase): class BasicElementTest(ElementTestCase, unittest.TestCase):
def test_augmentation_type_errors(self): def test_augmentation_type_errors(self):
e = ET.Element('joe') e = ET.Element('joe')
self.assertRaises(TypeError, e.append, 'b') self.assertRaises(TypeError, e.append, 'b')
@ -1775,19 +1813,22 @@ class BasicElementTest(unittest.TestCase):
self.assertEqual(e1.get('w', default=7), 7) self.assertEqual(e1.get('w', default=7), 7)
def test_pickle(self): def test_pickle(self):
# For now this test only works for the Python version of ET, # issue #16076: the C implementation wasn't pickleable.
# so set sys.modules accordingly because pickle uses __import__ for dumper, loader in product(self.modules, repeat=2):
# to load the __module__ of the class. e = dumper.Element('foo', bar=42)
if pyET: e.text = "text goes here"
sys.modules['xml.etree.ElementTree'] = pyET e.tail = "opposite of head"
else: dumper.SubElement(e, 'child').append(dumper.Element('grandchild'))
raise unittest.SkipTest('only for the Python version') e.append(dumper.Element('child'))
e1 = ET.Element('foo', bar=42) e.findall('.//grandchild')[0].set('attr', 'other value')
s = pickle.dumps(e1)
e2 = pickle.loads(s)
self.assertEqual(e2.tag, 'foo')
self.assertEqual(e2.attrib['bar'], 42)
e2 = self.pickleRoundTrip(e, 'xml.etree.ElementTree',
dumper, loader)
self.assertEqual(e2.tag, 'foo')
self.assertEqual(e2.attrib['bar'], 42)
self.assertEqual(len(e2), 2)
self.assertEqualElements(e, e2)
class ElementTreeTest(unittest.TestCase): class ElementTreeTest(unittest.TestCase):
def test_istype(self): def test_istype(self):
@ -2433,7 +2474,7 @@ class KeywordArgsTest(unittest.TestCase):
class NoAcceleratorTest(unittest.TestCase): class NoAcceleratorTest(unittest.TestCase):
def setUp(self): def setUp(self):
if not pyET: if not pyET:
raise SkipTest('only for the Python version') raise unittest.SkipTest('only for the Python version')
# Test that the C accelerator was not imported for pyET # Test that the C accelerator was not imported for pyET
def test_correct_import_pyET(self): def test_correct_import_pyET(self):
@ -2486,10 +2527,10 @@ class CleanContext(object):
def test_main(module=None): def test_main(module=None):
# When invoked without a module, runs the Python ET tests by loading pyET. # When invoked without a module, runs the Python ET tests by loading pyET.
# Otherwise, uses the given module as the ET. # Otherwise, uses the given module as the ET.
global pyET
pyET = import_fresh_module('xml.etree.ElementTree',
blocked=['_elementtree'])
if module is None: if module is None:
global pyET
pyET = import_fresh_module('xml.etree.ElementTree',
blocked=['_elementtree'])
module = pyET module = pyET
global ET global ET
@ -2509,7 +2550,7 @@ def test_main(module=None):
# These tests will only run for the pure-Python version that doesn't import # These tests will only run for the pure-Python version that doesn't import
# _elementtree. We can't use skipUnless here, because pyET is filled in only # _elementtree. We can't use skipUnless here, because pyET is filled in only
# after the module is loaded. # after the module is loaded.
if pyET: if pyET is not ET:
test_classes.extend([ test_classes.extend([
NoAcceleratorTest, NoAcceleratorTest,
]) ])
@ -2518,7 +2559,7 @@ def test_main(module=None):
support.run_unittest(*test_classes) support.run_unittest(*test_classes)
# XXX the C module should give the same warnings as the Python module # XXX the C module should give the same warnings as the Python module
with CleanContext(quiet=(module is not pyET)): with CleanContext(quiet=(pyET is not ET)):
support.run_doctest(sys.modules[__name__], verbosity=True) support.run_doctest(sys.modules[__name__], verbosity=True)
finally: finally:
# don't interfere with subsequent tests # don't interfere with subsequent tests

View File

@ -814,6 +814,176 @@ element_sizeof(PyObject* _self, PyObject* args)
return PyLong_FromSsize_t(result); return PyLong_FromSsize_t(result);
} }
/* dict keys for getstate/setstate. */
#define PICKLED_TAG "tag"
#define PICKLED_CHILDREN "_children"
#define PICKLED_ATTRIB "attrib"
#define PICKLED_TAIL "tail"
#define PICKLED_TEXT "text"
/* __getstate__ returns a fabricated instance dict as in the pure-Python
* Element implementation, for interoperability/interchangeability. This
* makes the pure-Python implementation details an API, but (a) there aren't
* any unnecessary structures there; and (b) it buys compatibility with 3.2
* pickles. See issue #16076.
*/
static PyObject *
element_getstate(ElementObject *self)
{
int i, noattrib;
PyObject *instancedict = NULL, *children;
/* Build a list of children. */
children = PyList_New(self->extra ? self->extra->length : 0);
if (!children)
return NULL;
for (i = 0; i < PyList_GET_SIZE(children); i++) {
PyObject *child = self->extra->children[i];
Py_INCREF(child);
PyList_SET_ITEM(children, i, child);
}
/* Construct the state object. */
noattrib = (self->extra == NULL || self->extra->attrib == Py_None);
if (noattrib)
instancedict = Py_BuildValue("{sOsOs{}sOsO}",
PICKLED_TAG, self->tag,
PICKLED_CHILDREN, children,
PICKLED_ATTRIB,
PICKLED_TEXT, self->text,
PICKLED_TAIL, self->tail);
else
instancedict = Py_BuildValue("{sOsOsOsOsO}",
PICKLED_TAG, self->tag,
PICKLED_CHILDREN, children,
PICKLED_ATTRIB, self->extra->attrib,
PICKLED_TEXT, self->text,
PICKLED_TAIL, self->tail);
if (instancedict)
return instancedict;
else {
for (i = 0; i < PyList_GET_SIZE(children); i++)
Py_DECREF(PyList_GET_ITEM(children, i));
Py_DECREF(children);
return NULL;
}
}
static PyObject *
element_setstate_from_attributes(ElementObject *self,
PyObject *tag,
PyObject *attrib,
PyObject *text,
PyObject *tail,
PyObject *children)
{
Py_ssize_t i, nchildren;
if (!tag) {
PyErr_SetString(PyExc_TypeError, "tag may not be NULL");
return NULL;
}
if (!text) {
Py_INCREF(Py_None);
text = Py_None;
}
if (!tail) {
Py_INCREF(Py_None);
tail = Py_None;
}
Py_CLEAR(self->tag);
self->tag = tag;
Py_INCREF(self->tag);
Py_CLEAR(self->text);
self->text = text;
Py_INCREF(self->text);
Py_CLEAR(self->tail);
self->tail = tail;
Py_INCREF(self->tail);
/* Handle ATTRIB and CHILDREN. */
if (!children && !attrib)
Py_RETURN_NONE;
/* Compute 'nchildren'. */
if (children) {
if (!PyList_Check(children)) {
PyErr_SetString(PyExc_TypeError, "'_children' is not a list");
return NULL;
}
nchildren = PyList_Size(children);
}
else {
nchildren = 0;
}
/* Allocate 'extra'. */
if (element_resize(self, nchildren)) {
return NULL;
}
assert(self->extra && self->extra->allocated >= nchildren);
/* Copy children */
for (i = 0; i < nchildren; i++) {
self->extra->children[i] = PyList_GET_ITEM(children, i);
Py_INCREF(self->extra->children[i]);
}
self->extra->length = nchildren;
self->extra->allocated = nchildren;
/* Stash attrib. */
if (attrib) {
Py_CLEAR(self->extra->attrib);
self->extra->attrib = attrib;
Py_INCREF(attrib);
}
Py_RETURN_NONE;
}
/* __setstate__ for Element instance from the Python implementation.
* 'state' should be the instance dict.
*/
static PyObject *
element_setstate_from_Python(ElementObject *self, PyObject *state)
{
static char *kwlist[] = {PICKLED_TAG, PICKLED_ATTRIB, PICKLED_TEXT,
PICKLED_TAIL, PICKLED_CHILDREN, 0};
PyObject *args;
PyObject *tag, *attrib, *text, *tail, *children;
int error;
/* More instance dict members than we know to handle? */
tag = attrib = text = tail = children = NULL;
args = PyTuple_New(0);
error = ! PyArg_ParseTupleAndKeywords(args, state, "|$OOOOO", kwlist, &tag,
&attrib, &text, &tail, &children);
Py_DECREF(args);
if (error)
return NULL;
else
return element_setstate_from_attributes(self, tag, attrib, text,
tail, children);
}
static PyObject *
element_setstate(ElementObject *self, PyObject *state)
{
if (!PyDict_CheckExact(state)) {
PyErr_Format(PyExc_TypeError,
"Don't know how to unpickle \"%.200R\" as an Element",
state);
return NULL;
}
else
return element_setstate_from_Python(self, state);
}
LOCAL(int) LOCAL(int)
checkpath(PyObject* tag) checkpath(PyObject* tag)
{ {
@ -1587,6 +1757,8 @@ static PyMethodDef element_methods[] = {
{"__copy__", (PyCFunction) element_copy, METH_VARARGS}, {"__copy__", (PyCFunction) element_copy, METH_VARARGS},
{"__deepcopy__", (PyCFunction) element_deepcopy, METH_VARARGS}, {"__deepcopy__", (PyCFunction) element_deepcopy, METH_VARARGS},
{"__sizeof__", element_sizeof, METH_NOARGS}, {"__sizeof__", element_sizeof, METH_NOARGS},
{"__getstate__", (PyCFunction)element_getstate, METH_NOARGS},
{"__setstate__", (PyCFunction)element_setstate, METH_O},
{NULL, NULL} {NULL, NULL}
}; };
@ -1691,7 +1863,7 @@ static PyMappingMethods element_as_mapping = {
static PyTypeObject Element_Type = { static PyTypeObject Element_Type = {
PyVarObject_HEAD_INIT(NULL, 0) PyVarObject_HEAD_INIT(NULL, 0)
"Element", sizeof(ElementObject), 0, "xml.etree.ElementTree.Element", sizeof(ElementObject), 0,
/* methods */ /* methods */
(destructor)element_dealloc, /* tp_dealloc */ (destructor)element_dealloc, /* tp_dealloc */
0, /* tp_print */ 0, /* tp_print */
@ -1913,6 +2085,8 @@ elementiter_next(ElementIterObject *it)
static PyTypeObject ElementIter_Type = { static PyTypeObject ElementIter_Type = {
PyVarObject_HEAD_INIT(NULL, 0) PyVarObject_HEAD_INIT(NULL, 0)
/* Using the module's name since the pure-Python implementation does not
have such a type. */
"_elementtree._element_iterator", /* tp_name */ "_elementtree._element_iterator", /* tp_name */
sizeof(ElementIterObject), /* tp_basicsize */ sizeof(ElementIterObject), /* tp_basicsize */
0, /* tp_itemsize */ 0, /* tp_itemsize */
@ -2458,7 +2632,7 @@ static PyMethodDef treebuilder_methods[] = {
static PyTypeObject TreeBuilder_Type = { static PyTypeObject TreeBuilder_Type = {
PyVarObject_HEAD_INIT(NULL, 0) PyVarObject_HEAD_INIT(NULL, 0)
"TreeBuilder", sizeof(TreeBuilderObject), 0, "xml.etree.ElementTree.TreeBuilder", sizeof(TreeBuilderObject), 0,
/* methods */ /* methods */
(destructor)treebuilder_dealloc, /* tp_dealloc */ (destructor)treebuilder_dealloc, /* tp_dealloc */
0, /* tp_print */ 0, /* tp_print */
@ -3420,7 +3594,7 @@ xmlparser_getattro(XMLParserObject* self, PyObject* nameobj)
static PyTypeObject XMLParser_Type = { static PyTypeObject XMLParser_Type = {
PyVarObject_HEAD_INIT(NULL, 0) PyVarObject_HEAD_INIT(NULL, 0)
"XMLParser", sizeof(XMLParserObject), 0, "xml.etree.ElementTree.XMLParser", sizeof(XMLParserObject), 0,
/* methods */ /* methods */
(destructor)xmlparser_dealloc, /* tp_dealloc */ (destructor)xmlparser_dealloc, /* tp_dealloc */
0, /* tp_print */ 0, /* tp_print */