diff -r 99818330b4c0 Lib/test/test_xml_etree.py --- a/Lib/test/test_xml_etree.py Mon Sep 05 15:40:10 2016 -0700 +++ b/Lib/test/test_xml_etree.py Mon Sep 05 17:29:12 2016 -0700 @@ -5,6 +5,7 @@ # For this purpose, the module-level "ET" symbol is temporarily # monkey-patched when running the "test_xml_etree_c" test suite. +import codecs import copy import html import io @@ -17,6 +18,7 @@ import warnings import weakref from itertools import product +from unittest import mock from test import support from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr @@ -104,7 +106,7 @@ class ModuleTest(unittest.TestCase): def serialize(elem, to_string=True, encoding='unicode', **options): - if encoding != 'unicode': + if not encoding or encoding.lower() != 'unicode': file = io.BytesIO() else: file = io.StringIO() @@ -157,8 +159,8 @@ class ElementTestCase: class ElementTreeTest(unittest.TestCase): - def serialize_check(self, elem, expected): - self.assertEqual(serialize(elem), expected) + def serialize_check(self, elem, expected, **options): + self.assertEqual(serialize(elem, **options), expected) def test_interface(self): # Test element tree interface. @@ -490,7 +492,6 @@ class ElementTreeTest(unittest.TestCase) self.assertEqual(b"".join(ET.tostringlist(element)), b'text') self.assertEqual(ET.tostring(element, "ascii"), - b"\n" b"text") _, ids = ET.XMLID("text") self.assertEqual(len(ids), 0) @@ -498,6 +499,59 @@ class ElementTreeTest(unittest.TestCase) self.assertEqual(len(ids), 1) self.assertEqual(ids["body"].tag, 'body') + def test_write_xml_declaration(self): + elem = ET.XML("") + + for encoding in ('unicode', 'UniCODE'): + for locale_encoding in ('utf8', 'latin9'): + for xml_declaration in (None, False, True): + with mock.patch('locale.getpreferredencoding', + return_value=locale_encoding): + if xml_declaration is None: + need_declaration = False + else: + need_declaration = xml_declaration + + if need_declaration: + expected = ("\n" + "" + % locale_encoding) + else: + expected = '' + + self.serialize_check(elem, + expected, + encoding=encoding, + xml_declaration=xml_declaration) + + for encoding in ("ASCII", "us-ascii", "UTF-8", "utf8", + "latin9", "GBK", + None): + for xml_declaration in (None, False, True): + enc_name = encoding if encoding else "us-ascii" + + if xml_declaration is None: + norm_enc = codecs.lookup(enc_name).name + need_declaration = (norm_enc not in {'utf-8', 'ascii'}) + else: + need_declaration = xml_declaration + + if need_declaration: + expected = ("\n" + "" % enc_name) + expected = expected.encode(enc_name) + else: + expected = b'' + + self.serialize_check(elem, + expected, + encoding=encoding, + xml_declaration=xml_declaration) + + # Invalid codec name + with self.assertRaises(LookupError): + serialize(elem, encoding="xxxxx") + def test_iterparse(self): # Test iterparse interface. @@ -1636,12 +1690,10 @@ class BugsTest(unittest.TestCase): e = ET.XML(b"" b't\xc3\xa3g') self.assertEqual(ET.tostring(e, 'ascii'), - b"\n" b'tãg') e = ET.XML(b"" b't\xe3g') self.assertEqual(ET.tostring(e, 'ascii'), - b"\n" b'tãg') def test_issue3151(self): @@ -2127,7 +2179,7 @@ class ElementIterTest(unittest.TestCase) sourcefile = serialize(doc, to_string=False) self.assertEqual(next(ET.iterparse(sourcefile))[0], 'end') - # With an explitit parser too (issue #9708) + # With an explicit parser too (issue #9708) sourcefile = serialize(doc, to_string=False) parser = ET.XMLParser(target=ET.TreeBuilder()) self.assertEqual(next(ET.iterparse(sourcefile, parser=parser))[0], diff -r 99818330b4c0 Lib/xml/etree/ElementTree.py --- a/Lib/xml/etree/ElementTree.py Mon Sep 05 15:40:10 2016 -0700 +++ b/Lib/xml/etree/ElementTree.py Mon Sep 05 17:29:12 2016 -0700 @@ -91,12 +91,13 @@ VERSION = "1.3.0" -import sys -import re -import warnings -import io +import codecs import collections import contextlib +import io +import re +import sys +import warnings from . import ElementPath @@ -755,18 +756,29 @@ class ElementTree: encoding = "utf-8" else: encoding = "us-ascii" + enc_lower = encoding.lower() with _get_writer(file_or_filename, enc_lower) as write: - if method == "xml" and (xml_declaration or - (xml_declaration is None and - enc_lower not in ("utf-8", "us-ascii", "unicode"))): - declared_encoding = encoding - if enc_lower == "unicode": - # Retrieve the default encoding for the xml declaration - import locale - declared_encoding = locale.getpreferredencoding() - write("\n" % ( - declared_encoding,)) + if method == "xml": + # By default (xml_declaration=None), the XML declaration is + # skipped for UTF-8 and ASCII, since UTF-8 is the default XML + # encoding and ASCII is a subset of UTF-8 + if xml_declaration is None: + if enc_lower != "unicode": + norm_enc = codecs.lookup(enc_lower).name + xml_declaration = (norm_enc not in {"utf-8", "ascii"}) + else: + xml_declaration = False + + if xml_declaration: + if enc_lower == "unicode": + # Retrieve the default encoding for the xml declaration + import locale + encoding = locale.getpreferredencoding() + + write("\n" + % (encoding,)) + if method == "text": _serialize_text(write, self._root) else: