diff -r 51b5ee7cfa3b Lib/test/test_sax.py
--- a/Lib/test/test_sax.py Sun Jul 15 06:19:44 2012 +0300
+++ b/Lib/test/test_sax.py Sun Jul 15 10:07:45 2012 +0300
@@ -13,7 +13,7 @@
from xml.sax.expatreader import create_parser
from xml.sax.handler import feature_namespaces
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
-from io import StringIO
+from io import BytesIO, StringIO
from test.support import findfile, run_unittest
import unittest
@@ -158,31 +158,31 @@
# ===== XMLGenerator
-start = '\n'
+start = b'\n'
class XmlgenTest(unittest.TestCase):
def test_xmlgen_basic(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
gen.startElement("doc", {})
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + "")
+ self.assertEqual(result.getvalue(), start + b"")
def test_xmlgen_basic_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
gen.startElement("doc", {})
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + "")
+ self.assertEqual(result.getvalue(), start + b"")
def test_xmlgen_content(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -191,10 +191,10 @@
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + "huhei")
+ self.assertEqual(result.getvalue(), start + b"huhei")
def test_xmlgen_content_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
@@ -203,10 +203,10 @@
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + "huhei")
+ self.assertEqual(result.getvalue(), start + b"huhei")
def test_xmlgen_pi(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -215,10 +215,10 @@
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + "")
+ self.assertEqual(result.getvalue(), start + b"")
def test_xmlgen_content_escape(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -228,10 +228,10 @@
gen.endDocument()
self.assertEqual(result.getvalue(),
- start + "<huhei&")
+ start + b"<huhei&")
def test_xmlgen_attr_escape(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -246,12 +246,12 @@
gen.endDocument()
self.assertEqual(result.getvalue(), start +
- (""
- ""
- ""))
+ (b""
+ b""
+ b""))
def test_xmlgen_ignorable(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -260,10 +260,10 @@
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + " ")
+ self.assertEqual(result.getvalue(), start + b" ")
def test_xmlgen_ignorable_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
@@ -272,10 +272,10 @@
gen.endElement("doc")
gen.endDocument()
- self.assertEqual(result.getvalue(), start + " ")
+ self.assertEqual(result.getvalue(), start + b" ")
def test_xmlgen_ns(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -290,10 +290,10 @@
self.assertEqual(result.getvalue(), start + \
('' %
- ns_uri))
+ ns_uri).encode())
def test_xmlgen_ns_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
@@ -308,10 +308,10 @@
self.assertEqual(result.getvalue(), start + \
('' %
- ns_uri))
+ ns_uri).encode())
def test_1463026_1(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -319,10 +319,10 @@
gen.endElementNS((None, 'a'), 'a')
gen.endDocument()
- self.assertEqual(result.getvalue(), start+'')
+ self.assertEqual(result.getvalue(), start + b'')
def test_1463026_1_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
@@ -330,10 +330,10 @@
gen.endElementNS((None, 'a'), 'a')
gen.endDocument()
- self.assertEqual(result.getvalue(), start+'')
+ self.assertEqual(result.getvalue(), start + b'')
def test_1463026_2(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -343,10 +343,10 @@
gen.endPrefixMapping(None)
gen.endDocument()
- self.assertEqual(result.getvalue(), start+'')
+ self.assertEqual(result.getvalue(), start + b'')
def test_1463026_2_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
@@ -356,10 +356,10 @@
gen.endPrefixMapping(None)
gen.endDocument()
- self.assertEqual(result.getvalue(), start+'')
+ self.assertEqual(result.getvalue(), start + b'')
def test_1463026_3(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -370,10 +370,10 @@
gen.endDocument()
self.assertEqual(result.getvalue(),
- start+'')
+ start + b'')
def test_1463026_3_empty(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument()
@@ -384,7 +384,7 @@
gen.endDocument()
self.assertEqual(result.getvalue(),
- start+'')
+ start + b'')
def test_5027_1(self):
# The xml prefix (as in xml:lang below) is reserved and bound by
@@ -401,16 +401,16 @@
parser = make_parser()
parser.setFeature(feature_namespaces, True)
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
parser.setContentHandler(gen)
parser.parse(test_xml)
self.assertEqual(result.getvalue(),
start + (
- ''
- 'Hello'
- ''))
+ b''
+ b'Hello'
+ b''))
def test_5027_2(self):
# The xml prefix (as in xml:lang below) is reserved and bound by
@@ -420,7 +420,7 @@
#
# This test demonstrates the bug by direct manipulation of the
# XMLGenerator.
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
gen.startDocument()
@@ -436,14 +436,14 @@
self.assertEqual(result.getvalue(),
start + (
- ''
- 'Hello'
- ''))
+ b''
+ b'Hello'
+ b''))
class XMLFilterBaseTest(unittest.TestCase):
def test_filter_basic(self):
- result = StringIO()
+ result = BytesIO()
gen = XMLGenerator(result)
filter = XMLFilterBase()
filter.setContentHandler(gen)
@@ -455,7 +455,7 @@
filter.endElement("doc")
filter.endDocument()
- self.assertEqual(result.getvalue(), start + "content ")
+ self.assertEqual(result.getvalue(), start + b"content ")
# ===========================================================================
#
@@ -463,7 +463,7 @@
#
# ===========================================================================
-with open(TEST_XMLFILE_OUT) as f:
+with open(TEST_XMLFILE_OUT, 'rb') as f:
xml_test_out = f.read()
class ExpatReaderTest(XmlTestBase):
@@ -472,11 +472,11 @@
def test_expat_file(self):
parser = create_parser()
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen)
- with open(TEST_XMLFILE) as f:
+ with open(TEST_XMLFILE, 'rb') as f:
parser.parse(f)
self.assertEqual(result.getvalue(), xml_test_out)
@@ -517,13 +517,13 @@
def resolveEntity(self, publicId, systemId):
inpsrc = InputSource()
- inpsrc.setByteStream(StringIO(""))
+ inpsrc.setByteStream(BytesIO(b""))
return inpsrc
def test_expat_entityresolver(self):
parser = create_parser()
parser.setEntityResolver(self.TestEntityResolver())
- result = StringIO()
+ result = BytesIO()
parser.setContentHandler(XMLGenerator(result))
parser.feed('")
+ b"")
# ===== Attributes support
@@ -602,7 +602,7 @@
def test_expat_inpsource_filename(self):
parser = create_parser()
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen)
@@ -612,7 +612,7 @@
def test_expat_inpsource_sysid(self):
parser = create_parser()
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen)
@@ -622,12 +622,12 @@
def test_expat_inpsource_stream(self):
parser = create_parser()
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen)
inpsrc = InputSource()
- with open(TEST_XMLFILE) as f:
+ with open(TEST_XMLFILE, 'rb') as f:
inpsrc.setByteStream(f)
parser.parse(inpsrc)
@@ -636,7 +636,7 @@
# ===== IncrementalParser support
def test_expat_incremental(self):
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser = create_parser()
parser.setContentHandler(xmlgen)
@@ -645,10 +645,10 @@
parser.feed("")
parser.close()
- self.assertEqual(result.getvalue(), start + "")
+ self.assertEqual(result.getvalue(), start + b"")
def test_expat_incremental_reset(self):
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser = create_parser()
parser.setContentHandler(xmlgen)
@@ -656,7 +656,7 @@
parser.feed("")
parser.feed("text")
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen)
parser.reset()
@@ -666,12 +666,12 @@
parser.feed("")
parser.close()
- self.assertEqual(result.getvalue(), start + "text")
+ self.assertEqual(result.getvalue(), start + b"text")
# ===== Locator support
def test_expat_locator_noinfo(self):
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser = create_parser()
parser.setContentHandler(xmlgen)
@@ -685,7 +685,7 @@
self.assertEqual(parser.getLineNumber(), 1)
def test_expat_locator_withinfo(self):
- result = StringIO()
+ result = BytesIO()
xmlgen = XMLGenerator(result)
parser = create_parser()
parser.setContentHandler(xmlgen)
@@ -706,7 +706,7 @@
parser = create_parser()
parser.setContentHandler(ContentHandler()) # do nothing
source = InputSource()
- source.setByteStream(StringIO("")) #ill-formed
+ source.setByteStream(BytesIO(b"")) #ill-formed
name = "a file name"
source.setSystemId(name)
try:
diff -r 51b5ee7cfa3b Lib/xml/sax/saxutils.py
--- a/Lib/xml/sax/saxutils.py Sun Jul 15 06:19:44 2012 +0300
+++ b/Lib/xml/sax/saxutils.py Sun Jul 15 10:07:45 2012 +0300
@@ -4,18 +4,11 @@
"""
import os, urllib.parse, urllib.request
+import io
+import contextlib
from . import handler
from . import xmlreader
-# See whether the xmlcharrefreplace error handler is
-# supported
-try:
- from codecs import xmlcharrefreplace_errors
- _error_handling = "xmlcharrefreplace"
- del xmlcharrefreplace_errors
-except ImportError:
- _error_handling = "strict"
-
def __dict_replace(s, d):
"""Replace substrings of a string using a dictionary."""
for key, value in d.items():
@@ -79,23 +72,55 @@
class XMLGenerator(handler.ContentHandler):
def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
- if out is None:
- import sys
- out = sys.stdout
handler.ContentHandler.__init__(self)
- self._out = out
- self._ns_contexts = [{}] # contains uri -> prefix dicts
- self._current_context = self._ns_contexts[-1]
- self._undeclared_ns_maps = []
- self._encoding = encoding
- self._short_empty_elements = short_empty_elements
- self._pending_start_element = False
-
- def _write(self, text):
- if isinstance(text, str):
- self._out.write(text)
- else:
- self._out.write(text.encode(self._encoding, _error_handling))
+ stack = contextlib.ExitStack()
+ self._close = stack.close
+ try:
+ if out is None:
+ import sys
+ out = sys.stdout
+ elif isinstance(out, io.TextIOBase):
+ # use a text writer as is
+ pass
+ else:
+ # wrap a binary writer with TextIOWrapper
+ if isinstance(out, io.BufferedIOBase):
+ pass
+ elif isinstance(out, io.RawIOBase):
+ out = io.BufferedWriter(out)
+ # Keep the original file open when the BufferedWriter is
+ # destroyed
+ stack.callback(out.detach)
+ else:
+ # This is to handle passed objects that aren't in the
+ # IOBase hierarchy, but just have a write method
+ writer = out
+ out = io.BufferedIOBase()
+ out.writable = lambda: True
+ out.write = writer.write
+ try:
+ # TextIOWrapper uses this methods to determine
+ # if BOM (for UTF-16, etc) should be added
+ out.seekable = writer.seekable
+ out.tell = writer.tell
+ except AttributeError:
+ pass
+ out = io.TextIOWrapper(out, encoding=encoding,
+ errors='xmlcharrefreplace',
+ newline='\n')
+ # Keep the original file open when the TextIOWrapper is
+ # destroyed
+ stack.callback(out.detach)
+ self._write = out.write
+ self._ns_contexts = [{}] # contains uri -> prefix dicts
+ self._current_context = self._ns_contexts[-1]
+ self._undeclared_ns_maps = []
+ self._encoding = encoding
+ self._short_empty_elements = short_empty_elements
+ self._pending_start_element = False
+ except:
+ self._close
+ raise
def _qname(self, name):
"""Builds a qualified name from a (ns_url, localname) pair"""
@@ -119,12 +144,18 @@
self._write('>')
self._pending_start_element = False
+ def __del__(self):
+ self._close()
+
# ContentHandler methods
def startDocument(self):
self._write('\n' %
self._encoding)
+ def endDocument(self):
+ self._close()
+
def startPrefixMapping(self, prefix, uri):
self._ns_contexts.append(self._current_context.copy())
self._current_context[uri] = prefix
@@ -157,9 +188,9 @@
for prefix, uri in self._undeclared_ns_maps:
if prefix:
- self._out.write(' xmlns:%s="%s"' % (prefix, uri))
+ self._write(' xmlns:%s="%s"' % (prefix, uri))
else:
- self._out.write(' xmlns="%s"' % uri)
+ self._write(' xmlns="%s"' % uri)
self._undeclared_ns_maps = []
for (name, value) in attrs.items():