diff -r 51b5ee7cfa3b Lib/test/test_sax.py --- a/Lib/test/test_sax.py Sun Jul 15 06:19:44 2012 +0300 +++ b/Lib/test/test_sax.py Sun Jul 15 10:07:45 2012 +0300 @@ -13,7 +13,7 @@ from xml.sax.expatreader import create_parser from xml.sax.handler import feature_namespaces from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl -from io import StringIO +from io import BytesIO, StringIO from test.support import findfile, run_unittest import unittest @@ -158,31 +158,31 @@ # ===== XMLGenerator -start = '\n' +start = b'\n' class XmlgenTest(unittest.TestCase): def test_xmlgen_basic(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() gen.startElement("doc", {}) gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + "") + self.assertEqual(result.getvalue(), start + b"") def test_xmlgen_basic_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() gen.startElement("doc", {}) gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + "") + self.assertEqual(result.getvalue(), start + b"") def test_xmlgen_content(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -191,10 +191,10 @@ gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + "huhei") + self.assertEqual(result.getvalue(), start + b"huhei") def test_xmlgen_content_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() @@ -203,10 +203,10 @@ gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + "huhei") + self.assertEqual(result.getvalue(), start + b"huhei") def test_xmlgen_pi(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -215,10 +215,10 @@ gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + "") + self.assertEqual(result.getvalue(), start + b"") def test_xmlgen_content_escape(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -228,10 +228,10 @@ gen.endDocument() self.assertEqual(result.getvalue(), - start + "<huhei&") + start + b"<huhei&") def test_xmlgen_attr_escape(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -246,12 +246,12 @@ gen.endDocument() self.assertEqual(result.getvalue(), start + - ("" - "" - "")) + (b"" + b"" + b"")) def test_xmlgen_ignorable(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -260,10 +260,10 @@ gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + " ") + self.assertEqual(result.getvalue(), start + b" ") def test_xmlgen_ignorable_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() @@ -272,10 +272,10 @@ gen.endElement("doc") gen.endDocument() - self.assertEqual(result.getvalue(), start + " ") + self.assertEqual(result.getvalue(), start + b" ") def test_xmlgen_ns(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -290,10 +290,10 @@ self.assertEqual(result.getvalue(), start + \ ('' % - ns_uri)) + ns_uri).encode()) def test_xmlgen_ns_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() @@ -308,10 +308,10 @@ self.assertEqual(result.getvalue(), start + \ ('' % - ns_uri)) + ns_uri).encode()) def test_1463026_1(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -319,10 +319,10 @@ gen.endElementNS((None, 'a'), 'a') gen.endDocument() - self.assertEqual(result.getvalue(), start+'') + self.assertEqual(result.getvalue(), start + b'') def test_1463026_1_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() @@ -330,10 +330,10 @@ gen.endElementNS((None, 'a'), 'a') gen.endDocument() - self.assertEqual(result.getvalue(), start+'') + self.assertEqual(result.getvalue(), start + b'') def test_1463026_2(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -343,10 +343,10 @@ gen.endPrefixMapping(None) gen.endDocument() - self.assertEqual(result.getvalue(), start+'') + self.assertEqual(result.getvalue(), start + b'') def test_1463026_2_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() @@ -356,10 +356,10 @@ gen.endPrefixMapping(None) gen.endDocument() - self.assertEqual(result.getvalue(), start+'') + self.assertEqual(result.getvalue(), start + b'') def test_1463026_3(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -370,10 +370,10 @@ gen.endDocument() self.assertEqual(result.getvalue(), - start+'') + start + b'') def test_1463026_3_empty(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result, short_empty_elements=True) gen.startDocument() @@ -384,7 +384,7 @@ gen.endDocument() self.assertEqual(result.getvalue(), - start+'') + start + b'') def test_5027_1(self): # The xml prefix (as in xml:lang below) is reserved and bound by @@ -401,16 +401,16 @@ parser = make_parser() parser.setFeature(feature_namespaces, True) - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) parser.setContentHandler(gen) parser.parse(test_xml) self.assertEqual(result.getvalue(), start + ( - '' - 'Hello' - '')) + b'' + b'Hello' + b'')) def test_5027_2(self): # The xml prefix (as in xml:lang below) is reserved and bound by @@ -420,7 +420,7 @@ # # This test demonstrates the bug by direct manipulation of the # XMLGenerator. - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) gen.startDocument() @@ -436,14 +436,14 @@ self.assertEqual(result.getvalue(), start + ( - '' - 'Hello' - '')) + b'' + b'Hello' + b'')) class XMLFilterBaseTest(unittest.TestCase): def test_filter_basic(self): - result = StringIO() + result = BytesIO() gen = XMLGenerator(result) filter = XMLFilterBase() filter.setContentHandler(gen) @@ -455,7 +455,7 @@ filter.endElement("doc") filter.endDocument() - self.assertEqual(result.getvalue(), start + "content ") + self.assertEqual(result.getvalue(), start + b"content ") # =========================================================================== # @@ -463,7 +463,7 @@ # # =========================================================================== -with open(TEST_XMLFILE_OUT) as f: +with open(TEST_XMLFILE_OUT, 'rb') as f: xml_test_out = f.read() class ExpatReaderTest(XmlTestBase): @@ -472,11 +472,11 @@ def test_expat_file(self): parser = create_parser() - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser.setContentHandler(xmlgen) - with open(TEST_XMLFILE) as f: + with open(TEST_XMLFILE, 'rb') as f: parser.parse(f) self.assertEqual(result.getvalue(), xml_test_out) @@ -517,13 +517,13 @@ def resolveEntity(self, publicId, systemId): inpsrc = InputSource() - inpsrc.setByteStream(StringIO("")) + inpsrc.setByteStream(BytesIO(b"")) return inpsrc def test_expat_entityresolver(self): parser = create_parser() parser.setEntityResolver(self.TestEntityResolver()) - result = StringIO() + result = BytesIO() parser.setContentHandler(XMLGenerator(result)) parser.feed('") + b"") # ===== Attributes support @@ -602,7 +602,7 @@ def test_expat_inpsource_filename(self): parser = create_parser() - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser.setContentHandler(xmlgen) @@ -612,7 +612,7 @@ def test_expat_inpsource_sysid(self): parser = create_parser() - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser.setContentHandler(xmlgen) @@ -622,12 +622,12 @@ def test_expat_inpsource_stream(self): parser = create_parser() - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser.setContentHandler(xmlgen) inpsrc = InputSource() - with open(TEST_XMLFILE) as f: + with open(TEST_XMLFILE, 'rb') as f: inpsrc.setByteStream(f) parser.parse(inpsrc) @@ -636,7 +636,7 @@ # ===== IncrementalParser support def test_expat_incremental(self): - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser = create_parser() parser.setContentHandler(xmlgen) @@ -645,10 +645,10 @@ parser.feed("") parser.close() - self.assertEqual(result.getvalue(), start + "") + self.assertEqual(result.getvalue(), start + b"") def test_expat_incremental_reset(self): - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser = create_parser() parser.setContentHandler(xmlgen) @@ -656,7 +656,7 @@ parser.feed("") parser.feed("text") - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser.setContentHandler(xmlgen) parser.reset() @@ -666,12 +666,12 @@ parser.feed("") parser.close() - self.assertEqual(result.getvalue(), start + "text") + self.assertEqual(result.getvalue(), start + b"text") # ===== Locator support def test_expat_locator_noinfo(self): - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser = create_parser() parser.setContentHandler(xmlgen) @@ -685,7 +685,7 @@ self.assertEqual(parser.getLineNumber(), 1) def test_expat_locator_withinfo(self): - result = StringIO() + result = BytesIO() xmlgen = XMLGenerator(result) parser = create_parser() parser.setContentHandler(xmlgen) @@ -706,7 +706,7 @@ parser = create_parser() parser.setContentHandler(ContentHandler()) # do nothing source = InputSource() - source.setByteStream(StringIO("")) #ill-formed + source.setByteStream(BytesIO(b"")) #ill-formed name = "a file name" source.setSystemId(name) try: diff -r 51b5ee7cfa3b Lib/xml/sax/saxutils.py --- a/Lib/xml/sax/saxutils.py Sun Jul 15 06:19:44 2012 +0300 +++ b/Lib/xml/sax/saxutils.py Sun Jul 15 10:07:45 2012 +0300 @@ -4,18 +4,11 @@ """ import os, urllib.parse, urllib.request +import io +import contextlib from . import handler from . import xmlreader -# See whether the xmlcharrefreplace error handler is -# supported -try: - from codecs import xmlcharrefreplace_errors - _error_handling = "xmlcharrefreplace" - del xmlcharrefreplace_errors -except ImportError: - _error_handling = "strict" - def __dict_replace(s, d): """Replace substrings of a string using a dictionary.""" for key, value in d.items(): @@ -79,23 +72,55 @@ class XMLGenerator(handler.ContentHandler): def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False): - if out is None: - import sys - out = sys.stdout handler.ContentHandler.__init__(self) - self._out = out - self._ns_contexts = [{}] # contains uri -> prefix dicts - self._current_context = self._ns_contexts[-1] - self._undeclared_ns_maps = [] - self._encoding = encoding - self._short_empty_elements = short_empty_elements - self._pending_start_element = False - - def _write(self, text): - if isinstance(text, str): - self._out.write(text) - else: - self._out.write(text.encode(self._encoding, _error_handling)) + stack = contextlib.ExitStack() + self._close = stack.close + try: + if out is None: + import sys + out = sys.stdout + elif isinstance(out, io.TextIOBase): + # use a text writer as is + pass + else: + # wrap a binary writer with TextIOWrapper + if isinstance(out, io.BufferedIOBase): + pass + elif isinstance(out, io.RawIOBase): + out = io.BufferedWriter(out) + # Keep the original file open when the BufferedWriter is + # destroyed + stack.callback(out.detach) + else: + # This is to handle passed objects that aren't in the + # IOBase hierarchy, but just have a write method + writer = out + out = io.BufferedIOBase() + out.writable = lambda: True + out.write = writer.write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + out.seekable = writer.seekable + out.tell = writer.tell + except AttributeError: + pass + out = io.TextIOWrapper(out, encoding=encoding, + errors='xmlcharrefreplace', + newline='\n') + # Keep the original file open when the TextIOWrapper is + # destroyed + stack.callback(out.detach) + self._write = out.write + self._ns_contexts = [{}] # contains uri -> prefix dicts + self._current_context = self._ns_contexts[-1] + self._undeclared_ns_maps = [] + self._encoding = encoding + self._short_empty_elements = short_empty_elements + self._pending_start_element = False + except: + self._close + raise def _qname(self, name): """Builds a qualified name from a (ns_url, localname) pair""" @@ -119,12 +144,18 @@ self._write('>') self._pending_start_element = False + def __del__(self): + self._close() + # ContentHandler methods def startDocument(self): self._write('\n' % self._encoding) + def endDocument(self): + self._close() + def startPrefixMapping(self, prefix, uri): self._ns_contexts.append(self._current_context.copy()) self._current_context[uri] = prefix @@ -157,9 +188,9 @@ for prefix, uri in self._undeclared_ns_maps: if prefix: - self._out.write(' xmlns:%s="%s"' % (prefix, uri)) + self._write(' xmlns:%s="%s"' % (prefix, uri)) else: - self._out.write(' xmlns="%s"' % uri) + self._write(' xmlns="%s"' % uri) self._undeclared_ns_maps = [] for (name, value) in attrs.items():