diff -r d03dbc324b60 Lib/test/test_xml_etree.py --- a/Lib/test/test_xml_etree.py Sat Jul 07 22:15:22 2012 +1000 +++ b/Lib/test/test_xml_etree.py Sat Jul 07 17:23:00 2012 +0300 @@ -888,65 +888,6 @@ """ ET.XML("" % encoding) -def encoding(): - r""" - Test encoding issues. - - >>> elem = ET.Element("tag") - >>> elem.text = "abc" - >>> serialize(elem) - 'abc' - >>> serialize(elem, encoding="utf-8") - b'abc' - >>> serialize(elem, encoding="us-ascii") - b'abc' - >>> serialize(elem, encoding="iso-8859-1") - b"\nabc" - - >>> elem.text = "<&\"\'>" - >>> serialize(elem) - '<&"\'>' - >>> serialize(elem, encoding="utf-8") - b'<&"\'>' - >>> serialize(elem, encoding="us-ascii") # cdata characters - b'<&"\'>' - >>> serialize(elem, encoding="iso-8859-1") - b'\n<&"\'>' - - >>> elem.attrib["key"] = "<&\"\'>" - >>> elem.text = None - >>> serialize(elem) - '' - >>> serialize(elem, encoding="utf-8") - b'' - >>> serialize(elem, encoding="us-ascii") - b'' - >>> serialize(elem, encoding="iso-8859-1") - b'\n' - - >>> elem.text = '\xe5\xf6\xf6<>' - >>> elem.attrib.clear() - >>> serialize(elem) - '\xe5\xf6\xf6<>' - >>> serialize(elem, encoding="utf-8") - b'\xc3\xa5\xc3\xb6\xc3\xb6<>' - >>> serialize(elem, encoding="us-ascii") - b'åöö<>' - >>> serialize(elem, encoding="iso-8859-1") - b"\n\xe5\xf6\xf6<>" - - >>> elem.attrib["key"] = '\xe5\xf6\xf6<>' - >>> elem.text = None - >>> serialize(elem) - '' - >>> serialize(elem, encoding="utf-8") - b'' - >>> serialize(elem, encoding="us-ascii") - b'' - >>> serialize(elem, encoding="iso-8859-1") - b'\n' - """ - def methods(): r""" Test serialization methods. @@ -2166,16 +2107,129 @@ self.assertEqual(self._subelem_tags(e), ['a1']) -class StringIOTest(unittest.TestCase): +class IOTest(unittest.TestCase): + def test_encoding(self): + # Test encoding issues. + elem = ET.Element("tag") + elem.text = "abc" + self.assertEqual(serialize(elem), 'abc') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'abc') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'abc') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "abc" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = "<&\"\'>" + self.assertEqual(serialize(elem), '<&"\'>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<&"\'>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<&"\'>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "<&\"'>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = "<&\"\'>" + self.assertEqual(serialize(elem), '') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '\xe5\xf6\xf6<>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'\xc3\xa5\xc3\xb6\xc3\xb6<>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'åöö<>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "åöö<>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "" % enc).encode(enc)) + def test_read_from_stringio(self): tree = ET.ElementTree() + stream = io.StringIO('''''') + tree.parse(stream) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_stringio(self): stream = io.StringIO() - stream.write('''''') - stream.seek(0) - tree.parse(stream) + tree = ET.ElementTree(ET.XML('''''')) + tree.write(stream, encoding='unicode') + self.assertEqual(stream.getvalue(), '''''') + def test_read_from_bytesio(self): + tree = ET.ElementTree() + raw = io.BytesIO(b'''''') + tree.parse(raw) self.assertEqual(tree.getroot().tag, 'site') + def test_write_to_bytesio(self): + raw = io.BytesIO() + tree = ET.ElementTree(ET.XML('''''')) + tree.write(raw) + self.assertEqual(raw.getvalue(), b'''''') + + class dummy: + pass + + def test_read_from_user_text_reader(self): + stream = io.StringIO('''''') + reader = self.dummy() + reader.read = stream.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_user_text_writer(self): + stream = io.StringIO() + writer = self.dummy() + writer.write = stream.write + tree = ET.ElementTree(ET.XML('''''')) + tree.write(writer, encoding='unicode') + self.assertEqual(stream.getvalue(), '''''') + + def test_read_from_user_binary_reader(self): + raw = io.BytesIO(b'''''') + reader = self.dummy() + reader.read = raw.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + tree = ET.ElementTree() + + def test_write_to_user_binary_writer(self): + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + tree = ET.ElementTree(ET.XML('''''')) + tree.write(writer) + self.assertEqual(raw.getvalue(), b'''''') + class ParseErrorTest(unittest.TestCase): def test_subclass(self): @@ -2299,7 +2353,7 @@ test_classes = [ ElementSlicingTest, BasicElementTest, - StringIOTest, + IOTest, ParseErrorTest, XincludeTest, ElementTreeTest, diff -r d03dbc324b60 Lib/xml/etree/ElementTree.py --- a/Lib/xml/etree/ElementTree.py Sat Jul 07 22:15:22 2012 +1000 +++ b/Lib/xml/etree/ElementTree.py Sat Jul 07 17:23:00 2012 +0300 @@ -100,6 +100,7 @@ import sys import re import warnings +import io from . import ElementPath @@ -814,20 +815,32 @@ encoding = encoding.lower() if hasattr(file_or_filename, "write"): file = file_or_filename + if encoding != "unicode": + if not isinstance(file, io.BufferedIOBase): + if isinstance(file, io.RawIOBase): + file = io.BufferedWriter(file) + else: + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = file_or_filename.write + try: + # Required to write BOM + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + close_file = False else: if encoding != "unicode": - file = open(file_or_filename, "wb") + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") else: file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write + close_file = True + write = file.write if method == "xml" and (xml_declaration or (xml_declaration is None and encoding not in ("utf-8", "us-ascii", "unicode"))): @@ -843,8 +856,11 @@ qnames, namespaces = _namespaces(self._root, default_namespace) serialize = _serialize[method] serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: + if close_file: file.close() + elif file_or_filename is not file: + file.flush() + file.detach() def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -1134,10 +1150,9 @@ # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() + file = io.BufferedIOBase() + file.writable = lambda: True file.write = data.append ElementTree(element).write(file, encoding, method=method) if encoding in (str, "unicode"): @@ -1161,10 +1176,9 @@ # @since 1.3 def tostringlist(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() + file = io.BufferedIOBase() + file.writable = lambda: True file.write = data.append ElementTree(element).write(file, encoding, method=method) # FIXME: merge small fragments into larger parts