diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 61416ba7c69..8b169050879 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -1795,6 +1795,28 @@ class BasicElementTest(ElementTestCase, unittest.TestCase): self.assertRaises(TypeError, e.append, 'b') self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo']) self.assertRaises(TypeError, e.insert, 0, 'foo') + e[:] = [ET.Element('bar')] + with self.assertRaises(TypeError): + e[0] = 'foo' + with self.assertRaises(TypeError): + e[:] = [ET.Element('bar'), 'foo'] + + if hasattr(e, '__setstate__'): + state = { + 'tag': 'tag', + '_children': [None], # non-Element + 'attrib': 'attr', + 'tail': 'tail', + 'text': 'text', + } + self.assertRaises(TypeError, e.__setstate__, state) + + if hasattr(e, '__deepcopy__'): + class E(ET.Element): + def __deepcopy__(self, memo): + return None # non-Element + e[:] = [E('bar')] + self.assertRaises(TypeError, copy.deepcopy, e) def test_cyclic_gc(self): class Dummy: @@ -1981,26 +2003,6 @@ class BadElementTest(ElementTestCase, unittest.TestCase): elem = b.close() self.assertEqual(elem[0].tail, 'ABCDEFGHIJKL') - def test_element_iter(self): - # Issue #27863 - state = { - 'tag': 'tag', - '_children': [None], # non-Element - 'attrib': 'attr', - 'tail': 'tail', - 'text': 'text', - } - - e = ET.Element('tag') - try: - e.__setstate__(state) - except AttributeError: - e.__dict__ = state - - it = e.iter() - self.assertIs(next(it), e) - self.assertRaises(AttributeError, next, it) - def test_subscr(self): # Issue #27863 class X: diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 371b37147e3..85586d0b1c4 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -217,11 +217,11 @@ class Element: return self._children[index] def __setitem__(self, index, element): - # if isinstance(index, slice): - # for elt in element: - # assert iselement(elt) - # else: - # assert iselement(element) + if isinstance(index, slice): + for elt in element: + self._assert_is_element(elt) + else: + self._assert_is_element(element) self._children[index] = element def __delitem__(self, index): diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c index 919591467c0..f88315d7711 100644 --- a/Modules/_elementtree.c +++ b/Modules/_elementtree.c @@ -480,11 +480,24 @@ element_resize(ElementObject* self, Py_ssize_t extra) return -1; } +LOCAL(void) +raise_type_error(PyObject *element) +{ + PyErr_Format(PyExc_TypeError, + "expected an Element, not \"%.200s\"", + Py_TYPE(element)->tp_name); +} + LOCAL(int) element_add_subelement(ElementObject* self, PyObject* element) { /* add a child element to a parent */ + if (!Element_Check(element)) { + raise_type_error(element); + return -1; + } + if (element_resize(self, 1) < 0) return -1; @@ -803,7 +816,11 @@ _elementtree_Element___deepcopy___impl(ElementObject *self, PyObject *memo) for (i = 0; i < self->extra->length; i++) { PyObject* child = deepcopy(self->extra->children[i], memo); - if (!child) { + if (!child || !Element_Check(child)) { + if (child) { + raise_type_error(child); + Py_DECREF(child); + } element->extra->length = i; goto error; } @@ -1024,8 +1041,15 @@ element_setstate_from_attributes(ElementObject *self, /* Copy children */ for (i = 0; i < nchildren; i++) { - self->extra->children[i] = PyList_GET_ITEM(children, i); - Py_INCREF(self->extra->children[i]); + PyObject *child = PyList_GET_ITEM(children, i); + if (!Element_Check(child)) { + raise_type_error(child); + self->extra->length = i; + dealloc_extra(oldextra); + return NULL; + } + Py_INCREF(child); + self->extra->children[i] = child; } assert(!self->extra->length); @@ -1167,16 +1191,6 @@ _elementtree_Element_extend(ElementObject *self, PyObject *elements) for (i = 0; i < PySequence_Fast_GET_SIZE(seq); i++) { PyObject* element = PySequence_Fast_GET_ITEM(seq, i); Py_INCREF(element); - if (!Element_Check(element)) { - PyErr_Format( - PyExc_TypeError, - "expected an Element, not \"%.200s\"", - Py_TYPE(element)->tp_name); - Py_DECREF(seq); - Py_DECREF(element); - return NULL; - } - if (element_add_subelement(self, element) < 0) { Py_DECREF(seq); Py_DECREF(element); @@ -1219,8 +1233,7 @@ _elementtree_Element_find_impl(ElementObject *self, PyObject *path, for (i = 0; i < self->extra->length; i++) { PyObject* item = self->extra->children[i]; int rc; - if (!Element_Check(item)) - continue; + assert(Element_Check(item)); Py_INCREF(item); rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ); if (rc > 0) @@ -1266,8 +1279,7 @@ _elementtree_Element_findtext_impl(ElementObject *self, PyObject *path, for (i = 0; i < self->extra->length; i++) { PyObject *item = self->extra->children[i]; int rc; - if (!Element_Check(item)) - continue; + assert(Element_Check(item)); Py_INCREF(item); rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ); if (rc > 0) { @@ -1323,8 +1335,7 @@ _elementtree_Element_findall_impl(ElementObject *self, PyObject *path, for (i = 0; i < self->extra->length; i++) { PyObject* item = self->extra->children[i]; int rc; - if (!Element_Check(item)) - continue; + assert(Element_Check(item)); Py_INCREF(item); rc = PyObject_RichCompareBool(((ElementObject*)item)->tag, path, Py_EQ); if (rc != 0 && (rc < 0 || PyList_Append(out, item) < 0)) { @@ -1736,6 +1747,10 @@ element_setitem(PyObject* self_, Py_ssize_t index, PyObject* item) old = self->extra->children[index]; if (item) { + if (!Element_Check(item)) { + raise_type_error(item); + return -1; + } Py_INCREF(item); self->extra->children[index] = item; } else { @@ -1930,6 +1945,15 @@ element_ass_subscr(PyObject* self_, PyObject* item, PyObject* value) } } + for (i = 0; i < newlen; i++) { + PyObject *element = PySequence_Fast_GET_ITEM(seq, i); + if (!Element_Check(element)) { + raise_type_error(element); + Py_DECREF(seq); + return -1; + } + } + if (slicelen > 0) { /* to avoid recursive calls to this method (via decref), move old items to the recycle bin here, and get rid of them when @@ -2207,12 +2231,7 @@ elementiter_next(ElementIterObject *it) continue; } - if (!Element_Check(extra->children[child_index])) { - PyErr_Format(PyExc_AttributeError, - "'%.100s' object has no attribute 'iter'", - Py_TYPE(extra->children[child_index])->tp_name); - return NULL; - } + assert(Element_Check(extra->children[child_index])); elem = (ElementObject *)extra->children[child_index]; item->child_index++; Py_INCREF(elem);