diff -r e0f997a7aaa5 Lib/test/test_xml_etree.py --- a/Lib/test/test_xml_etree.py Sun May 20 18:34:11 2012 +0200 +++ b/Lib/test/test_xml_etree.py Mon May 21 01:26:35 2012 +0300 @@ -923,6 +923,10 @@ b'abc' >>> serialize(elem, encoding="iso-8859-1") b"\nabc" + >>> serialize(elem, encoding="utf-16").decode('utf-16') + "\nabc" + >>> serialize(elem, encoding="utf-32").decode('utf-32') + "\nabc" >>> elem.text = "<&\"\'>" >>> serialize(elem) @@ -933,6 +937,10 @@ b'<&"\'>' >>> serialize(elem, encoding="iso-8859-1") b'\n<&"\'>' + >>> serialize(elem, encoding="utf-16").decode('utf-16') + '\n<&"\'>' + >>> serialize(elem, encoding="utf-32").decode('utf-32') + '\n<&"\'>' >>> elem.attrib["key"] = "<&\"\'>" >>> elem.text = None @@ -944,6 +952,10 @@ b'' >>> serialize(elem, encoding="iso-8859-1") b'\n' + >>> serialize(elem, encoding="utf-16").decode('utf-16') + '\n' + >>> serialize(elem, encoding="utf-32").decode('utf-32') + '\n' >>> elem.text = '\xe5\xf6\xf6<>' >>> elem.attrib.clear() @@ -955,6 +967,10 @@ b'åöö<>' >>> serialize(elem, encoding="iso-8859-1") b"\n\xe5\xf6\xf6<>" + >>> serialize(elem, encoding="utf-16").decode('utf-16') + "\nåöö<>" + >>> serialize(elem, encoding="utf-32").decode('utf-32') + "\nåöö<>" >>> elem.attrib["key"] = '\xe5\xf6\xf6<>' >>> elem.text = None @@ -966,6 +982,10 @@ b'' >>> serialize(elem, encoding="iso-8859-1") b'\n' + >>> serialize(elem, encoding="utf-16").decode('utf-16') + '\n' + >>> serialize(elem, encoding="utf-32").decode('utf-32') + '\n' """ def methods(): @@ -2081,6 +2101,27 @@ self.assertEqual(tree.getroot().tag, 'site') +class UserIOTest(unittest.TestCase): + class dummy: + pass + + def test_read_from_user_reader(self): + raw = io.BytesIO(b'''''') + reader = self.dummy() + reader.read = raw.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_user_writer(self): + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + tree = ET.parse(io.BytesIO(b'''''')) + tree.write(writer) + self.assertEqual(raw.getvalue(), b'''''') + + class ParseErrorTest(unittest.TestCase): def test_subclass(self): self.assertIsInstance(ET.ParseError(), SyntaxError) @@ -2155,6 +2196,7 @@ ElementSlicingTest, BasicElementTest, StringIOTest, + UserIOTest, ParseErrorTest, ElementTreeTest, TreeBuilderTest] diff -r e0f997a7aaa5 Lib/xml/etree/ElementTree.py --- a/Lib/xml/etree/ElementTree.py Sun May 20 18:34:11 2012 +0200 +++ b/Lib/xml/etree/ElementTree.py Mon May 21 01:26:35 2012 +0300 @@ -100,6 +100,7 @@ import sys import re import warnings +import io class _SimpleElementPath: # emulate pre-1.2 find/findtext/findall behaviour @@ -835,20 +836,29 @@ encoding = encoding.lower() if hasattr(file_or_filename, "write"): file = file_or_filename + if encoding != "unicode": + if not isinstance(file, io.BufferedIOBase): + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = file_or_filename.write + try: + # Required to write BOM + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + close_file = False else: if encoding != "unicode": - file = open(file_or_filename, "wb") + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") else: file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write + close_file = True + write = file.write if method == "xml" and (xml_declaration or (xml_declaration is None and encoding not in ("utf-8", "us-ascii", "unicode"))): @@ -864,8 +874,10 @@ qnames, namespaces = _namespaces(self._root, default_namespace) serialize = _serialize[method] serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: + if close_file: file.close() + elif file_or_filename is not file: + file.detach() def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -1159,10 +1171,9 @@ # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() + file = io.BufferedIOBase() + file.writable = lambda: True file.write = data.append ElementTree(element).write(file, encoding, method=method) if encoding in (str, "unicode"): @@ -1186,10 +1197,9 @@ # @since 1.3 def tostringlist(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() + file = io.BufferedIOBase() + file.writable = lambda: True file.write = data.append ElementTree(element).write(file, encoding, method=method) # FIXME: merge small fragments into larger parts