diff -r 677a9326b4d4 Lib/test/test_xml_etree.py --- a/Lib/test/test_xml_etree.py Mon Jul 09 18:16:11 2012 -0700 +++ b/Lib/test/test_xml_etree.py Fri Jul 13 23:23:04 2012 +0300 @@ -21,7 +21,7 @@ import weakref from test import support -from test.support import findfile, import_fresh_module, gc_collect +from test.support import TESTFN, findfile, unlink, import_fresh_module, gc_collect pyET = None ET = None @@ -888,65 +888,6 @@ """ ET.XML("" % encoding) -def encoding(): - r""" - Test encoding issues. - - >>> elem = ET.Element("tag") - >>> elem.text = "abc" - >>> serialize(elem) - 'abc' - >>> serialize(elem, encoding="utf-8") - b'abc' - >>> serialize(elem, encoding="us-ascii") - b'abc' - >>> serialize(elem, encoding="iso-8859-1") - b"\nabc" - - >>> elem.text = "<&\"\'>" - >>> serialize(elem) - '<&"\'>' - >>> serialize(elem, encoding="utf-8") - b'<&"\'>' - >>> serialize(elem, encoding="us-ascii") # cdata characters - b'<&"\'>' - >>> serialize(elem, encoding="iso-8859-1") - b'\n<&"\'>' - - >>> elem.attrib["key"] = "<&\"\'>" - >>> elem.text = None - >>> serialize(elem) - '' - >>> serialize(elem, encoding="utf-8") - b'' - >>> serialize(elem, encoding="us-ascii") - b'' - >>> serialize(elem, encoding="iso-8859-1") - b'\n' - - >>> elem.text = '\xe5\xf6\xf6<>' - >>> elem.attrib.clear() - >>> serialize(elem) - '\xe5\xf6\xf6<>' - >>> serialize(elem, encoding="utf-8") - b'\xc3\xa5\xc3\xb6\xc3\xb6<>' - >>> serialize(elem, encoding="us-ascii") - b'åöö<>' - >>> serialize(elem, encoding="iso-8859-1") - b"\n\xe5\xf6\xf6<>" - - >>> elem.attrib["key"] = '\xe5\xf6\xf6<>' - >>> elem.text = None - >>> serialize(elem) - '' - >>> serialize(elem, encoding="utf-8") - b'' - >>> serialize(elem, encoding="us-ascii") - b'' - >>> serialize(elem, encoding="iso-8859-1") - b'\n' - """ - def methods(): r""" Test serialization methods. @@ -2166,16 +2107,185 @@ self.assertEqual(self._subelem_tags(e), ['a1']) -class StringIOTest(unittest.TestCase): +class IOTest(unittest.TestCase): + def tearDown(self): + unlink(TESTFN) + + def test_encoding(self): + # Test encoding issues. + elem = ET.Element("tag") + elem.text = "abc" + self.assertEqual(serialize(elem), 'abc') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'abc') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'abc') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "abc" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = "<&\"\'>" + self.assertEqual(serialize(elem), '<&"\'>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'<&"\'>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'<&"\'>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "<&\"'>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = "<&\"\'>" + self.assertEqual(serialize(elem), '') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.text = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '\xe5\xf6\xf6<>') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'\xc3\xa5\xc3\xb6\xc3\xb6<>') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'åöö<>') + for enc in ("iso-8859-1", "utf-16", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "åöö<>" % enc).encode(enc)) + + elem = ET.Element("tag") + elem.attrib["key"] = '\xe5\xf6\xf6<>' + self.assertEqual(serialize(elem), '') + self.assertEqual(serialize(elem, encoding="utf-8"), + b'') + self.assertEqual(serialize(elem, encoding="us-ascii"), + b'') + for enc in ("iso-8859-1", "utf-16", "utf-16le", "utf-16be", "utf-32"): + self.assertEqual(serialize(elem, encoding=enc), + ("\n" + "" % enc).encode(enc)) + + def test_write_to_filename(self): + tree = ET.ElementTree(ET.XML('''''')) + tree.write(TESTFN) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''''') + + def test_write_to_text_file(self): + tree = ET.ElementTree(ET.XML('''''')) + with open(TESTFN, 'w', encoding='utf-8') as f: + tree.write(f, encoding='unicode') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''''') + + def test_write_to_binary_file(self): + tree = ET.ElementTree(ET.XML('''''')) + with open(TESTFN, 'wb') as f: + tree.write(f) + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), b'''''') + + def test_write_to_binary_file_with_bom(self): + tree = ET.ElementTree(ET.XML('''''')) + # test BOM writing to buffered file + with open(TESTFN, 'wb') as f: + tree.write(f, encoding='utf-16') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + '''\n''' + ''''''.encode("utf-16")) + # test BOM writing to non-buffered file + with open(TESTFN, 'wb', buffering=0) as f: + tree.write(f, encoding='utf-16') + self.assertFalse(f.closed) + with open(TESTFN, 'rb') as f: + self.assertEqual(f.read(), + '''\n''' + ''''''.encode("utf-16")) + def test_read_from_stringio(self): tree = ET.ElementTree() + stream = io.StringIO('''''') + tree.parse(stream) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_stringio(self): + tree = ET.ElementTree(ET.XML('''''')) stream = io.StringIO() - stream.write('''''') - stream.seek(0) - tree.parse(stream) + tree.write(stream, encoding='unicode') + self.assertEqual(stream.getvalue(), '''''') + def test_read_from_bytesio(self): + tree = ET.ElementTree() + raw = io.BytesIO(b'''''') + tree.parse(raw) self.assertEqual(tree.getroot().tag, 'site') + def test_write_to_bytesio(self): + tree = ET.ElementTree(ET.XML('''''')) + raw = io.BytesIO() + tree.write(raw) + self.assertEqual(raw.getvalue(), b'''''') + + class dummy: + pass + + def test_read_from_user_text_reader(self): + stream = io.StringIO('''''') + reader = self.dummy() + reader.read = stream.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + + def test_write_to_user_text_writer(self): + tree = ET.ElementTree(ET.XML('''''')) + stream = io.StringIO() + writer = self.dummy() + writer.write = stream.write + tree.write(writer, encoding='unicode') + self.assertEqual(stream.getvalue(), '''''') + + def test_read_from_user_binary_reader(self): + raw = io.BytesIO(b'''''') + reader = self.dummy() + reader.read = raw.read + tree = ET.ElementTree() + tree.parse(reader) + self.assertEqual(tree.getroot().tag, 'site') + tree = ET.ElementTree() + + def test_write_to_user_binary_writer(self): + tree = ET.ElementTree(ET.XML('''''')) + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + tree.write(writer) + self.assertEqual(raw.getvalue(), b'''''') + + def test_write_to_user_binary_writer_with_bom(self): + tree = ET.ElementTree(ET.XML('''''')) + raw = io.BytesIO() + writer = self.dummy() + writer.write = raw.write + writer.seekable = lambda: True + writer.tell = raw.tell + tree.write(writer, encoding="utf-16") + self.assertEqual(raw.getvalue(), + '''\n''' + ''''''.encode("utf-16")) + class ParseErrorTest(unittest.TestCase): def test_subclass(self): @@ -2299,7 +2409,7 @@ test_classes = [ ElementSlicingTest, BasicElementTest, - StringIOTest, + IOTest, ParseErrorTest, XincludeTest, ElementTreeTest, diff -r 677a9326b4d4 Lib/xml/etree/ElementTree.py --- a/Lib/xml/etree/ElementTree.py Mon Jul 09 18:16:11 2012 -0700 +++ b/Lib/xml/etree/ElementTree.py Fri Jul 13 23:23:04 2012 +0300 @@ -100,6 +100,8 @@ import sys import re import warnings +import io +import contextlib from . import ElementPath @@ -812,39 +814,22 @@ encoding = "unicode" else: encoding = encoding.lower() - if hasattr(file_or_filename, "write"): - file = file_or_filename - else: - if encoding != "unicode": - file = open(file_or_filename, "wb") + with _get_writer(file_or_filename, encoding) as write: + if method == "xml" and (xml_declaration or + (xml_declaration is None and + encoding not in ("utf-8", "us-ascii", "unicode"))): + declared_encoding = encoding + if encoding == "unicode": + # Retrieve the default encoding for the xml declaration + import locale + declared_encoding = locale.getpreferredencoding() + write("\n" % declared_encoding) + if method == "text": + _serialize_text(write, self._root) else: - file = open(file_or_filename, "w") - if encoding != "unicode": - def write(text): - try: - return file.write(text.encode(encoding, - "xmlcharrefreplace")) - except (TypeError, AttributeError): - _raise_serialization_error(text) - else: - write = file.write - if method == "xml" and (xml_declaration or - (xml_declaration is None and - encoding not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if encoding == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() - write("\n" % declared_encoding) - if method == "text": - _serialize_text(write, self._root) - else: - qnames, namespaces = _namespaces(self._root, default_namespace) - serialize = _serialize[method] - serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: - file.close() + qnames, namespaces = _namespaces(self._root, default_namespace) + serialize = _serialize[method] + serialize(write, self._root, qnames, namespaces) def write_c14n(self, file): # lxml.etree compatibility. use output method instead @@ -853,6 +838,54 @@ # -------------------------------------------------------------------- # serialization support +@contextlib.contextmanager +def _get_writer(file_or_filename, encoding): + # returns text write method and release all resourses after using + try: + write = file_or_filename.write + except AttributeError: + # file_or_filename is a file name + if encoding == "unicode": + file = open(file_or_filename, "w") + else: + file = open(file_or_filename, "w", encoding=encoding, + errors="xmlcharrefreplace") + with file: + yield file.write + else: + # file_or_filename is a file-like object + # encoding determines if it is a text or binary writer + if encoding == "unicode": + # use a text writer as is + yield write + else: + # wrap a binary writer with TextIOWrapper + with contextlib.ExitStack() as stack: + if isinstance(file_or_filename, io.BufferedIOBase): + file = file_or_filename + elif isinstance(file_or_filename, io.RawIOBase): + file = io.BufferedWriter(file_or_filename) + # keep the original file open when the BufferedWriter is destroyed + stack.callback(file.detach) + else: + file = io.BufferedIOBase() + file.writable = lambda: True + file.write = write + try: + # TextIOWrapper uses this methods to determine + # if BOM (for UTF-16, etc) should be added + file.seekable = file_or_filename.seekable + file.tell = file_or_filename.tell + except AttributeError: + pass + file = io.TextIOWrapper(file, + encoding=encoding, + errors="xmlcharrefreplace", + newline="\n") + # keep the original file open when the TextIOWrapper is destroyed + stack.callback(file.detach) + yield file.write + def _namespaces(elem, default_namespace=None): # identify namespaces used in this tree @@ -1134,10 +1167,9 @@ # @defreturn string def tostring(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() + file = io.BufferedIOBase() + file.writable = lambda: True file.write = data.append ElementTree(element).write(file, encoding, method=method) if encoding in (str, "unicode"): @@ -1161,10 +1193,9 @@ # @since 1.3 def tostringlist(element, encoding=None, method=None): - class dummy: - pass data = [] - file = dummy() + file = io.BufferedIOBase() + file.writable = lambda: True file.write = data.append ElementTree(element).write(file, encoding, method=method) # FIXME: merge small fragments into larger parts