diff -r e0f997a7aaa5 Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py Sun May 20 18:34:11 2012 +0200
+++ b/Lib/test/test_xml_etree.py Mon May 21 01:26:35 2012 +0300
@@ -923,6 +923,10 @@
b'abc'
>>> serialize(elem, encoding="iso-8859-1")
b"\nabc"
+ >>> serialize(elem, encoding="utf-16").decode('utf-16')
+ "\nabc"
+ >>> serialize(elem, encoding="utf-32").decode('utf-32')
+ "\nabc"
>>> elem.text = "<&\"\'>"
>>> serialize(elem)
@@ -933,6 +937,10 @@
b'<&"\'>'
>>> serialize(elem, encoding="iso-8859-1")
b'\n<&"\'>'
+ >>> serialize(elem, encoding="utf-16").decode('utf-16')
+ '\n<&"\'>'
+ >>> serialize(elem, encoding="utf-32").decode('utf-32')
+ '\n<&"\'>'
>>> elem.attrib["key"] = "<&\"\'>"
>>> elem.text = None
@@ -944,6 +952,10 @@
b''
>>> serialize(elem, encoding="iso-8859-1")
b'\n'
+ >>> serialize(elem, encoding="utf-16").decode('utf-16')
+ '\n'
+ >>> serialize(elem, encoding="utf-32").decode('utf-32')
+ '\n'
>>> elem.text = '\xe5\xf6\xf6<>'
>>> elem.attrib.clear()
@@ -955,6 +967,10 @@
b'åöö<>'
>>> serialize(elem, encoding="iso-8859-1")
b"\n\xe5\xf6\xf6<>"
+ >>> serialize(elem, encoding="utf-16").decode('utf-16')
+ "\nåöö<>"
+ >>> serialize(elem, encoding="utf-32").decode('utf-32')
+ "\nåöö<>"
>>> elem.attrib["key"] = '\xe5\xf6\xf6<>'
>>> elem.text = None
@@ -966,6 +982,10 @@
b''
>>> serialize(elem, encoding="iso-8859-1")
b'\n'
+ >>> serialize(elem, encoding="utf-16").decode('utf-16')
+ '\n'
+ >>> serialize(elem, encoding="utf-32").decode('utf-32')
+ '\n'
"""
def methods():
@@ -2081,6 +2101,27 @@
self.assertEqual(tree.getroot().tag, 'site')
+class UserIOTest(unittest.TestCase):
+ class dummy:
+ pass
+
+ def test_read_from_user_reader(self):
+ raw = io.BytesIO(b'''''')
+ reader = self.dummy()
+ reader.read = raw.read
+ tree = ET.ElementTree()
+ tree.parse(reader)
+ self.assertEqual(tree.getroot().tag, 'site')
+
+ def test_write_to_user_writer(self):
+ raw = io.BytesIO()
+ writer = self.dummy()
+ writer.write = raw.write
+ tree = ET.parse(io.BytesIO(b''''''))
+ tree.write(writer)
+ self.assertEqual(raw.getvalue(), b'''''')
+
+
class ParseErrorTest(unittest.TestCase):
def test_subclass(self):
self.assertIsInstance(ET.ParseError(), SyntaxError)
@@ -2155,6 +2196,7 @@
ElementSlicingTest,
BasicElementTest,
StringIOTest,
+ UserIOTest,
ParseErrorTest,
ElementTreeTest,
TreeBuilderTest]
diff -r e0f997a7aaa5 Lib/xml/etree/ElementTree.py
--- a/Lib/xml/etree/ElementTree.py Sun May 20 18:34:11 2012 +0200
+++ b/Lib/xml/etree/ElementTree.py Mon May 21 01:26:35 2012 +0300
@@ -100,6 +100,7 @@
import sys
import re
import warnings
+import io
class _SimpleElementPath:
# emulate pre-1.2 find/findtext/findall behaviour
@@ -835,20 +836,29 @@
encoding = encoding.lower()
if hasattr(file_or_filename, "write"):
file = file_or_filename
+ if encoding != "unicode":
+ if not isinstance(file, io.BufferedIOBase):
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
+ file.write = file_or_filename.write
+ try:
+ # Required to write BOM
+ file.seekable = file_or_filename.seekable
+ file.tell = file_or_filename.tell
+ except AttributeError:
+ pass
+ file = io.TextIOWrapper(file, encoding=encoding,
+ errors="xmlcharrefreplace",
+ newline="\n")
+ close_file = False
else:
if encoding != "unicode":
- file = open(file_or_filename, "wb")
+ file = open(file_or_filename, "w", encoding=encoding,
+ errors="xmlcharrefreplace")
else:
file = open(file_or_filename, "w")
- if encoding != "unicode":
- def write(text):
- try:
- return file.write(text.encode(encoding,
- "xmlcharrefreplace"))
- except (TypeError, AttributeError):
- _raise_serialization_error(text)
- else:
- write = file.write
+ close_file = True
+ write = file.write
if method == "xml" and (xml_declaration or
(xml_declaration is None and
encoding not in ("utf-8", "us-ascii", "unicode"))):
@@ -864,8 +874,10 @@
qnames, namespaces = _namespaces(self._root, default_namespace)
serialize = _serialize[method]
serialize(write, self._root, qnames, namespaces)
- if file_or_filename is not file:
+ if close_file:
file.close()
+ elif file_or_filename is not file:
+ file.detach()
def write_c14n(self, file):
# lxml.etree compatibility. use output method instead
@@ -1159,10 +1171,9 @@
# @defreturn string
def tostring(element, encoding=None, method=None):
- class dummy:
- pass
data = []
- file = dummy()
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
file.write = data.append
ElementTree(element).write(file, encoding, method=method)
if encoding in (str, "unicode"):
@@ -1186,10 +1197,9 @@
# @since 1.3
def tostringlist(element, encoding=None, method=None):
- class dummy:
- pass
data = []
- file = dummy()
+ file = io.BufferedIOBase()
+ file.writable = lambda: True
file.write = data.append
ElementTree(element).write(file, encoding, method=method)
# FIXME: merge small fragments into larger parts