diff --git a/Doc/library/xml.etree.elementtree.rst b/Doc/library/xml.etree.elementtree.rst index 6fe81c92963..dc13c49002a 100644 --- a/Doc/library/xml.etree.elementtree.rst +++ b/Doc/library/xml.etree.elementtree.rst @@ -281,14 +281,15 @@ Element Objects .. method:: append(subelement) - Adds the element *subelement* to the end of this elements internal list - of subelements. + Adds the element *subelement* to the end of this element's internal list + of subelements. Raises :exc:`TypeError` if *subelement* is not an + :class:`Element`. .. method:: extend(subelements) Appends *subelements* from a sequence object with zero or more elements. - Raises :exc:`AssertionError` if a subelement is not a valid object. + Raises :exc:`TypeError` if a subelement is not an :class:`Element`. .. versionadded:: 3.2 @@ -325,9 +326,10 @@ Element Objects Use method :meth:`Element.iter` instead. - .. method:: insert(index, element) + .. method:: insert(index, subelement) - Inserts a subelement at the given position in this element. + Inserts *subelement* at the given position in this element. Raises + :exc:`TypeError` if *subelement* is not an :class:`Element`. .. method:: iter(tag=None) diff --git a/Lib/test/test_xml_etree.py b/Lib/test/test_xml_etree.py index 50e5196d6fd..8a1ea0f688e 100644 --- a/Lib/test/test_xml_etree.py +++ b/Lib/test/test_xml_etree.py @@ -1839,8 +1839,15 @@ def check_issue10777(): # -------------------------------------------------------------------- -class ElementTreeTest(unittest.TestCase): +class BasicElementTest(unittest.TestCase): + def test_augmentation_type_errors(self): + e = ET.Element('joe') + self.assertRaises(TypeError, e.append, 'b') + self.assertRaises(TypeError, e.extend, [ET.Element('bar'), 'foo']) + self.assertRaises(TypeError, e.insert, 0, 'foo') + +class ElementTreeTest(unittest.TestCase): def test_istype(self): self.assertIsInstance(ET.ParseError, type) self.assertIsInstance(ET.QName, type) @@ -1879,7 +1886,6 @@ class ElementTreeTest(unittest.TestCase): class TreeBuilderTest(unittest.TestCase): - sample1 = ('' @@ -1931,7 +1937,6 @@ class TreeBuilderTest(unittest.TestCase): class NoAcceleratorTest(unittest.TestCase): - # Test that the C accelerator was not imported for pyET def test_correct_import_pyET(self): self.assertEqual(pyET.Element.__module__, 'xml.etree.ElementTree') @@ -2096,6 +2101,7 @@ def test_main(module=pyET): test_classes = [ ElementSlicingTest, + BasicElementTest, StringIOTest, ParseErrorTest, ElementTreeTest, diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index 10ee896c289..5f974f65b08 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -298,7 +298,7 @@ class Element: # @param element The element to add. def append(self, element): - # assert iselement(element) + self._assert_is_element(element) self._children.append(element) ## @@ -308,8 +308,8 @@ class Element: # @since 1.3 def extend(self, elements): - # for element in elements: - # assert iselement(element) + for element in elements: + self._assert_is_element(element) self._children.extend(elements) ## @@ -318,9 +318,13 @@ class Element: # @param index Where to insert the new subelement. def insert(self, index, element): - # assert iselement(element) + self._assert_is_element(element) self._children.insert(index, element) + def _assert_is_element(self, e): + if not isinstance(e, Element): + raise TypeError('expected an Element, not %s' % type(e).__name__) + ## # Removes a matching subelement. Unlike the find methods, # this method compares elements based on identity, not on tag diff --git a/Modules/_elementtree.c b/Modules/_elementtree.c index a50a3e7a67b..e8309df2997 100644 --- a/Modules/_elementtree.c +++ b/Modules/_elementtree.c @@ -803,6 +803,15 @@ element_extend(ElementObject* self, PyObject* args) seqlen = PySequence_Size(seq); for (i = 0; i < seqlen; i++) { PyObject* element = PySequence_Fast_GET_ITEM(seq, i); + if (!PyObject_IsInstance(element, (PyObject *)&Element_Type)) { + Py_DECREF(seq); + PyErr_Format( + PyExc_TypeError, + "expected an Element, not \"%.200s\"", + Py_TYPE(element)->tp_name); + return NULL; + } + if (element_add_subelement(self, element) < 0) { Py_DECREF(seq); return NULL;