diff -r 677a9326b4d4 Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py Mon Jul 09 18:16:11 2012 -0700
+++ b/Lib/test/test_xml_etree.py Fri Jul 13 23:23:04 2012 +0300
@@ -21,7 +21,7 @@
import weakref
from test import support
-from test.support import findfile, import_fresh_module, gc_collect
+from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect
pyET = None
ET = None
@@ -888,65 +888,6 @@
"""
ET.XML("" % encoding)
-def encoding():
- r"""
- Test encoding issues.
-
- >>> elem = ET.Element("tag")
- >>> elem.text = "abc"
- >>> serialize(elem)
- 'abc'
- >>> serialize(elem, encoding="utf-8")
- b'abc'
- >>> serialize(elem, encoding="us-ascii")
- b'abc'
- >>> serialize(elem, encoding="iso-8859-1")
- b"\nabc"
-
- >>> elem.text = "<&\"\'>"
- >>> serialize(elem)
- '<&"\'>'
- >>> serialize(elem, encoding="utf-8")
- b'<&"\'>'
- >>> serialize(elem, encoding="us-ascii") # cdata characters
- b'<&"\'>'
- >>> serialize(elem, encoding="iso-8859-1")
- b'\n<&"\'>'
-
- >>> elem.attrib["key"] = "<&\"\'>"
- >>> elem.text = None
- >>> serialize(elem)
- ''
- >>> serialize(elem, encoding="utf-8")
- b''
- >>> serialize(elem, encoding="us-ascii")
- b''
- >>> serialize(elem, encoding="iso-8859-1")
- b'\n'
-
- >>> elem.text = '\xe5\xf6\xf6<>'
- >>> elem.attrib.clear()
- >>> serialize(elem)
- '\xe5\xf6\xf6<>'
- >>> serialize(elem, encoding="utf-8")
- b'\xc3\xa5\xc3\xb6\xc3\xb6<>'
- >>> serialize(elem, encoding="us-ascii")
- b'åöö<>'
- >>> serialize(elem, encoding="iso-8859-1")
- b"\n\xe5\xf6\xf6<>"
-
- >>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
- >>> elem.text = None
- >>> serialize(elem)
- ''
- >>> serialize(elem, encoding="utf-8")
- b''
- >>> serialize(elem, encoding="us-ascii")
- b''
- >>> serialize(elem, encoding="iso-8859-1")
- b'\n'
- """
-
def methods():
r"""
Test serialization methods.
@@ -2166,16 +2107,185 @@
self.assertEqual(self._subelem_tags(e), ['a1'])
-class StringIOTest(unittest.TestCase):
+class IOTest(unittest.TestCase):
+ def tearDown(self):
+ unlink(TESTFN)
+
+ def test_encoding(self):
+ # Test encoding issues.
+ elem = ET.Element("tag")
+ elem.text = "abc"
+ self.assertEqual(serialize(elem), 'abc')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'abc')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'abc')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("\n"
+ "abc" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.text = "<&\"\'>"
+ self.assertEqual(serialize(elem), '<&"\'>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'<&"\'>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'<&"\'>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("\n"
+ "<&\"'>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.attrib["key"] = "<&\"\'>"
+ self.assertEqual(serialize(elem), '')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("\n"
+ "" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.text = '\xe5\xf6\xf6<>'
+ self.assertEqual(serialize(elem), '\xe5\xf6\xf6<>')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'\xc3\xa5\xc3\xb6\xc3\xb6<>')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'åöö<>')
+ for enc in ("iso-8859-1", "utf-16", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("\n"
+ "åöö<>" % enc).encode(enc))
+
+ elem = ET.Element("tag")
+ elem.attrib["key"] = '\xe5\xf6\xf6<>'
+ self.assertEqual(serialize(elem), '')
+ self.assertEqual(serialize(elem, encoding="utf-8"),
+ b'')
+ self.assertEqual(serialize(elem, encoding="us-ascii"),
+ b'')
+ for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"):
+ self.assertEqual(serialize(elem, encoding=enc),
+ ("\n"
+ "" % enc).encode(enc))
+
+ def test_write_to_filename(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ tree.write(TESTFN)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''''')
+
+ def test_write_to_text_file(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ with open(TESTFN, 'w', encoding='utf-8') as f:
+ tree.write(f, encoding='unicode')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''''')
+
+ def test_write_to_binary_file(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ with open(TESTFN, 'wb') as f:
+ tree.write(f)
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(), b'''''')
+
+ def test_write_to_binary_file_with_bom(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ # test BOM writing to buffered file
+ with open(TESTFN, 'wb') as f:
+ tree.write(f, encoding='utf-16')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ '''\n'''
+ ''''''.encode("utf-16"))
+ # test BOM writing to non-buffered file
+ with open(TESTFN, 'wb', buffering=0) as f:
+ tree.write(f, encoding='utf-16')
+ self.assertFalse(f.closed)
+ with open(TESTFN, 'rb') as f:
+ self.assertEqual(f.read(),
+ '''\n'''
+ ''''''.encode("utf-16"))
+
def test_read_from_stringio(self):
tree = ET.ElementTree()
+ stream = io.StringIO('''''')
+ tree.parse(stream)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_stringio(self):
+ tree = ET.ElementTree(ET.XML(''''''))
stream = io.StringIO()
- stream.write('''''')
- stream.seek(0)
- tree.parse(stream)
+ tree.write(stream, encoding='unicode')
+ self.assertEqual(stream.getvalue(), '''''')
+ def test_read_from_bytesio(self):
+ tree = ET.ElementTree()
+ raw = io.BytesIO(b'''''')
+ tree.parse(raw)
self.assertEqual(tree.getroot().tag, 'site')
+ def test_write_to_bytesio(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ raw = io.BytesIO()
+ tree.write(raw)
+ self.assertEqual(raw.getvalue(), b'''''')
+
+ class dummy:
+ pass
+
+ def test_read_from_user_text_reader(self):
+ stream = io.StringIO('''''')
+ reader = self.dummy()
+ reader.read = stream.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_user_text_writer(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ stream = io.StringIO()
+ writer = self.dummy()
+ writer.write = stream.write
+ tree.write(writer, encoding='unicode')
+ self.assertEqual(stream.getvalue(), '''''')
+
+ def test_read_from_user_binary_reader(self):
+ raw = io.BytesIO(b'''''')
+ reader = self.dummy()
+ reader.read = raw.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
+ self.assertEqual(tree.getroot().tag, 'site')
+ tree = ET.ElementTree()
+
+ def test_write_to_user_binary_writer(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ tree.write(writer)
+ self.assertEqual(raw.getvalue(), b'''''')
+
+ def test_write_to_user_binary_writer_with_bom(self):
+ tree = ET.ElementTree(ET.XML(''''''))
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ writer.seekable = lambda: True
+ writer.tell = raw.tell
+ tree.write(writer, encoding="utf-16")
+ self.assertEqual(raw.getvalue(),
+ '''\n'''
+ ''''''.encode("utf-16"))
+
class ParseErrorTest(unittest.TestCase):
def test_subclass(self):
@@ -2299,7 +2409,7 @@
test_classes = [
ElementSlicingTest,
BasicElementTest,
- StringIOTest,
+ IOTest,
ParseErrorTest,
XincludeTest,
ElementTreeTest,
diff -r 677a9326b4d4 Lib/xml/etree/ElementTree.py
--- a/Lib/xml/etree/ElementTree.py Mon Jul 09 18:16:11 2012 -0700
+++ b/Lib/xml/etree/ElementTree.py Fri Jul 13 23:23:04 2012 +0300
@@ -100,6 +100,8 @@
import sys
import re
import warnings
+import io
+import contextlib
from . import ElementPath
@@ -812,39 +814,22 @@
encoding = "unicode"
else:
encoding = encoding.lower()
- if hasattr(file_or_filename, "write"):
- file = file_or_filename
- else:
- if encoding != "unicode":
- file = open(file_or_filename, "wb")
+ with _get_writer(file_or_filename, encoding) as write:
+ if method == "xml" and (xml_declaration or
+ (xml_declaration is None and
+ encoding not in ("utf-8", "us-ascii", "unicode"))):
+ declared_encoding = encoding
+ if encoding == "unicode":
+ # Retrieve the default encoding for the xml declaration
+ import locale
+ declared_encoding = locale.getpreferredencoding()
+ write("\n" % declared_encoding)
+ if method == "text":
+ _serialize_text(write, self._root)
else:
- file = open(file_or_filename, "w")
- if encoding != "unicode":
- def write(text):
- try:
- return file.write(text.encode(encoding,
- "xmlcharrefreplace"))
- except (TypeError, AttributeError):
- _raise_serialization_error(text)
- else:
- write = file.write
- if method == "xml" and (xml_declaration or
- (xml_declaration is None and
- encoding not in ("utf-8", "us-ascii", "unicode"))):
- declared_encoding = encoding
- if encoding == "unicode":
- # Retrieve the default encoding for the xml declaration
- import locale
- declared_encoding = locale.getpreferredencoding()
- write("\n" % declared_encoding)
- if method == "text":
- _serialize_text(write, self._root)
- else:
- qnames, namespaces = _namespaces(self._root, default_namespace)
- serialize = _serialize[method]
- serialize(write, self._root, qnames, namespaces)
- if file_or_filename is not file:
- file.close()
+ qnames, namespaces = _namespaces(self._root, default_namespace)
+ serialize = _serialize[method]
+ serialize(write, self._root, qnames, namespaces)
def write_c14n(self, file):
# lxml.etree compatibility. use output method instead
@@ -853,6 +838,54 @@
# --------------------------------------------------------------------
# serialization support
+@contextlib.contextmanager
+def _get_writer(file_or_filename, encoding):
+ # returns text write method and release all resourses after using
+ try:
+ write = file_or_filename.write
+ except AttributeError:
+ # file_or_filename is a file name
+ if encoding == "unicode":
+ file = open(file_or_filename, "w")
+ else:
+ file = open(file_or_filename, "w", encoding=encoding,
+ errors="xmlcharrefreplace")
+ with file:
+ yield file.write
+ else:
+ # file_or_filename is a file-like object
+ # encoding determines if it is a text or binary writer
+ if encoding == "unicode":
+ # use a text writer as is
+ yield write
+ else:
+ # wrap a binary writer with TextIOWrapper
+ with contextlib.ExitStack() as stack:
+ if isinstance(file_or_filename, io.BufferedIOBase):
+ file = file_or_filename
+ elif isinstance(file_or_filename, io.RawIOBase):
+ file = io.BufferedWriter(file_or_filename)
+ # keep the original file open when the BufferedWriter is destroyed
+ stack.callback(file.detach)
+ else:
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
+ file.write = write
+ try:
+ # TextIOWrapper uses this methods to determine
+ # if BOM (for UTF-16, etc) should be added
+ file.seekable = file_or_filename.seekable
+ file.tell = file_or_filename.tell
+ except AttributeError:
+ pass
+ file = io.TextIOWrapper(file,
+ encoding=encoding,
+ errors="xmlcharrefreplace",
+ newline="\n")
+ # keep the original file open when the TextIOWrapper is destroyed
+ stack.callback(file.detach)
+ yield file.write
+
def _namespaces(elem, default_namespace=None):
# identify namespaces used in this tree
@@ -1134,10 +1167,9 @@
# @defreturn string
def tostring(element, encoding=None, method=None):
- class dummy:
- pass
data = []
- file = dummy()
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
file.write = data.append
ElementTree(element).write(file, encoding, method=method)
if encoding in (str, "unicode"):
@@ -1161,10 +1193,9 @@
# @since 1.3
def tostringlist(element, encoding=None, method=None):
- class dummy:
- pass
data = []
- file = dummy()
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
file.write = data.append
ElementTree(element).write(file, encoding, method=method)
# FIXME: merge small fragments into larger parts