diff --git a/Lib/test/test_sax.py b/Lib/test/test_sax.py index a16c821762c..825c16a8d03 100644 --- a/Lib/test/test_sax.py +++ b/Lib/test/test_sax.py @@ -14,6 +14,7 @@ from xml.sax.expatreader import create_parser from xml.sax.handler import feature_namespaces from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl from cStringIO import StringIO +import io import os.path import shutil import test.test_support as support @@ -170,9 +171,9 @@ class SaxutilsTest(unittest.TestCase): start = '\n' -class XmlgenTest(unittest.TestCase): +class XmlgenTest: def test_xmlgen_basic(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() gen.startElement("doc", {}) @@ -182,7 +183,7 @@ class XmlgenTest(unittest.TestCase): self.assertEqual(result.getvalue(), start + "") def test_xmlgen_content(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -194,7 +195,7 @@ class XmlgenTest(unittest.TestCase): self.assertEqual(result.getvalue(), start + "huhei") def test_xmlgen_pi(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -206,7 +207,7 @@ class XmlgenTest(unittest.TestCase): self.assertEqual(result.getvalue(), start + "") def test_xmlgen_content_escape(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -219,7 +220,7 @@ class XmlgenTest(unittest.TestCase): start + "<huhei&") def test_xmlgen_attr_escape(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -238,8 +239,41 @@ class XmlgenTest(unittest.TestCase): "" "")) + def test_xmlgen_encoding(self): + encodings = ('iso-8859-15', 'utf-8', + 'utf-16be', 'utf-16le', + 'utf-32be', 'utf-32le') + for encoding in encodings: + result = self.ioclass() + gen = XMLGenerator(result, encoding=encoding) + + gen.startDocument() + gen.startElement("doc", {"a": u'\u20ac'}) + gen.characters(u"\u20ac") + gen.endElement("doc") + gen.endDocument() + + self.assertEqual(result.getvalue(), ( + u'\n' + u'\u20ac' % encoding + ).encode(encoding, 'xmlcharrefreplace')) + + def test_xmlgen_unencodable(self): + result = self.ioclass() + gen = XMLGenerator(result, encoding='ascii') + + gen.startDocument() + gen.startElement("doc", {"a": u'\u20ac'}) + gen.characters(u"\u20ac") + gen.endElement("doc") + gen.endDocument() + + self.assertEqual(result.getvalue(), + '\n' + '') + def test_xmlgen_ignorable(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -251,7 +285,7 @@ class XmlgenTest(unittest.TestCase): self.assertEqual(result.getvalue(), start + " ") def test_xmlgen_ns(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -269,7 +303,7 @@ class XmlgenTest(unittest.TestCase): ns_uri)) def test_1463026_1(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -280,7 +314,7 @@ class XmlgenTest(unittest.TestCase): self.assertEqual(result.getvalue(), start+'') def test_1463026_2(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -293,7 +327,7 @@ class XmlgenTest(unittest.TestCase): self.assertEqual(result.getvalue(), start+'') def test_1463026_3(self): - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -321,7 +355,7 @@ class XmlgenTest(unittest.TestCase): parser = make_parser() parser.setFeature(feature_namespaces, True) - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) parser.setContentHandler(gen) parser.parse(test_xml) @@ -340,7 +374,7 @@ class XmlgenTest(unittest.TestCase): # # This test demonstrates the bug by direct manipulation of the # XMLGenerator. - result = StringIO() + result = self.ioclass() gen = XMLGenerator(result) gen.startDocument() @@ -360,6 +394,29 @@ class XmlgenTest(unittest.TestCase): 'Hello' '')) + def test_no_close_file(self): + result = self.ioclass() + def func(out): + gen = XMLGenerator(out) + gen.startDocument() + gen.startElement("doc", {}) + func(result) + self.assertFalse(result.closed) + +class StringXmlgenTest(XmlgenTest, unittest.TestCase): + ioclass = StringIO + +class BytesIOXmlgenTest(XmlgenTest, unittest.TestCase): + ioclass = io.BytesIO + +class WriterXmlgenTest(XmlgenTest, unittest.TestCase): + class ioclass(list): + write = list.append + closed = False + + def getvalue(self): + return b''.join(self) + class XMLFilterBaseTest(unittest.TestCase): def test_filter_basic(self): @@ -804,7 +861,9 @@ class XmlReaderTest(XmlTestBase): def test_main(): run_unittest(MakeParserTest, SaxutilsTest, - XmlgenTest, + StringXmlgenTest, + BytesIOXmlgenTest, + WriterXmlgenTest, ExpatReaderTest, ErrorReportingTest, XmlReaderTest) diff --git a/Lib/xml/sax/saxutils.py b/Lib/xml/sax/saxutils.py index 7989713f588..dad74f5389d 100644 --- a/Lib/xml/sax/saxutils.py +++ b/Lib/xml/sax/saxutils.py @@ -4,6 +4,7 @@ convenience of application and driver writers. """ import os, urlparse, urllib, types +import io import sys import handler import xmlreader @@ -13,15 +14,6 @@ try: except AttributeError: _StringTypes = [types.StringType] -# See whether the xmlcharrefreplace error handler is -# supported -try: - from codecs import xmlcharrefreplace_errors - _error_handling = "xmlcharrefreplace" - del xmlcharrefreplace_errors -except ImportError: - _error_handling = "strict" - def __dict_replace(s, d): """Replace substrings of a string using a dictionary.""" for key, value in d.items(): @@ -82,25 +74,46 @@ def quoteattr(data, entities={}): return data +def _gettextwriter(out, encoding): + if out is None: + import sys + out = sys.stdout + + if isinstance(out, io.RawIOBase): + buffer = io.BufferedIOBase(out) + # Keep the original file open when the TextIOWrapper is + # destroyed + buffer.close = lambda: None + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + buffer = io.BufferedIOBase() + buffer.writable = lambda: True + buffer.write = out.write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + buffer.seekable = out.seekable + buffer.tell = out.tell + except AttributeError: + pass + # wrap a binary writer with TextIOWrapper + return io.TextIOWrapper(buffer, encoding=encoding, + errors='xmlcharrefreplace', + newline='\n') + class XMLGenerator(handler.ContentHandler): def __init__(self, out=None, encoding="iso-8859-1"): - if out is None: - import sys - out = sys.stdout handler.ContentHandler.__init__(self) - self._out = out + out = _gettextwriter(out, encoding) + self._write = out.write + self._flush = out.flush self._ns_contexts = [{}] # contains uri -> prefix dicts self._current_context = self._ns_contexts[-1] self._undeclared_ns_maps = [] self._encoding = encoding - def _write(self, text): - if isinstance(text, str): - self._out.write(text) - else: - self._out.write(text.encode(self._encoding, _error_handling)) - def _qname(self, name): """Builds a qualified name from a (ns_url, localname) pair""" if name[0]: @@ -121,9 +134,12 @@ class XMLGenerator(handler.ContentHandler): # ContentHandler methods def startDocument(self): - self._write('\n' % + self._write(u'\n' % self._encoding) + def endDocument(self): + self._flush() + def startPrefixMapping(self, prefix, uri): self._ns_contexts.append(self._current_context.copy()) self._current_context[uri] = prefix @@ -134,39 +150,39 @@ class XMLGenerator(handler.ContentHandler): del self._ns_contexts[-1] def startElement(self, name, attrs): - self._write('<' + name) + self._write(u'<' + name) for (name, value) in attrs.items(): - self._write(' %s=%s' % (name, quoteattr(value))) - self._write('>') + self._write(u' %s=%s' % (name, quoteattr(value))) + self._write(u'>') def endElement(self, name): - self._write('' % name) + self._write(u'' % name) def startElementNS(self, name, qname, attrs): - self._write('<' + self._qname(name)) + self._write(u'<' + self._qname(name)) for prefix, uri in self._undeclared_ns_maps: if prefix: - self._out.write(' xmlns:%s="%s"' % (prefix, uri)) + self._write(u' xmlns:%s="%s"' % (prefix, uri)) else: - self._out.write(' xmlns="%s"' % uri) + self._write(u' xmlns="%s"' % uri) self._undeclared_ns_maps = [] for (name, value) in attrs.items(): - self._write(' %s=%s' % (self._qname(name), quoteattr(value))) - self._write('>') + self._write(u' %s=%s' % (self._qname(name), quoteattr(value))) + self._write(u'>') def endElementNS(self, name, qname): - self._write('' % self._qname(name)) + self._write(u'' % self._qname(name)) def characters(self, content): - self._write(escape(content)) + self._write(escape(unicode(content))) def ignorableWhitespace(self, content): - self._write(content) + self._write(unicode(content)) def processingInstruction(self, target, data): - self._write('' % (target, data)) + self._write(u'' % (target, data)) class XMLFilterBase(xmlreader.XMLReader): diff --git a/Misc/NEWS b/Misc/NEWS index eab49934c3e..39298b52783 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -202,6 +202,8 @@ Core and Builtins Library ------- +- Issue #1470548: XMLGenerator now works with UTF-16 and UTF-32 encodings. + - Issue #6975: os.path.realpath() now correctly resolves multiple nested symlinks on POSIX platforms.