diff -r 99818330b4c0 Lib/test/test_xml_etree.py
--- a/Lib/test/test_xml_etree.py Mon Sep 05 15:40:10 2016 -0700
+++ b/Lib/test/test_xml_etree.py Mon Sep 05 17:29:12 2016 -0700
@@ -5,6 +5,7 @@
# For this purpose, the module-level "ET" symbol is temporarily
# monkey-patched when running the "test_xml_etree_c" test suite.
+import codecs
import copy
import html
import io
@@ -17,6 +18,7 @@ import warnings
import weakref
from itertools import product
+from unittest import mock
from test import support
from test.support import TESTFN, findfile, import_fresh_module, gc_collect, swap_attr
@@ -104,7 +106,7 @@ class ModuleTest(unittest.TestCase):
def serialize(elem, to_string=True, encoding='unicode', **options):
- if encoding != 'unicode':
+ if not encoding or encoding.lower() != 'unicode':
file = io.BytesIO()
else:
file = io.StringIO()
@@ -157,8 +159,8 @@ class ElementTestCase:
class ElementTreeTest(unittest.TestCase):
- def serialize_check(self, elem, expected):
- self.assertEqual(serialize(elem), expected)
+ def serialize_check(self, elem, expected, **options):
+ self.assertEqual(serialize(elem, **options), expected)
def test_interface(self):
# Test element tree interface.
@@ -490,7 +492,6 @@ class ElementTreeTest(unittest.TestCase)
self.assertEqual(b"".join(ET.tostringlist(element)),
b'
text')
self.assertEqual(ET.tostring(element, "ascii"),
- b"\n"
b"text")
_, ids = ET.XMLID("text")
self.assertEqual(len(ids), 0)
@@ -498,6 +499,59 @@ class ElementTreeTest(unittest.TestCase)
self.assertEqual(len(ids), 1)
self.assertEqual(ids["body"].tag, 'body')
+ def test_write_xml_declaration(self):
+ elem = ET.XML("")
+
+ for encoding in ('unicode', 'UniCODE'):
+ for locale_encoding in ('utf8', 'latin9'):
+ for xml_declaration in (None, False, True):
+ with mock.patch('locale.getpreferredencoding',
+ return_value=locale_encoding):
+ if xml_declaration is None:
+ need_declaration = False
+ else:
+ need_declaration = xml_declaration
+
+ if need_declaration:
+ expected = ("\n"
+ ""
+ % locale_encoding)
+ else:
+ expected = ''
+
+ self.serialize_check(elem,
+ expected,
+ encoding=encoding,
+ xml_declaration=xml_declaration)
+
+ for encoding in ("ASCII", "us-ascii", "UTF-8", "utf8",
+ "latin9", "GBK",
+ None):
+ for xml_declaration in (None, False, True):
+ enc_name = encoding if encoding else "us-ascii"
+
+ if xml_declaration is None:
+ norm_enc = codecs.lookup(enc_name).name
+ need_declaration = (norm_enc not in {'utf-8', 'ascii'})
+ else:
+ need_declaration = xml_declaration
+
+ if need_declaration:
+ expected = ("\n"
+ "" % enc_name)
+ expected = expected.encode(enc_name)
+ else:
+ expected = b''
+
+ self.serialize_check(elem,
+ expected,
+ encoding=encoding,
+ xml_declaration=xml_declaration)
+
+ # Invalid codec name
+ with self.assertRaises(LookupError):
+ serialize(elem, encoding="xxxxx")
+
def test_iterparse(self):
# Test iterparse interface.
@@ -1636,12 +1690,10 @@ class BugsTest(unittest.TestCase):
e = ET.XML(b""
b't\xc3\xa3g')
self.assertEqual(ET.tostring(e, 'ascii'),
- b"\n"
b'tãg')
e = ET.XML(b""
b't\xe3g')
self.assertEqual(ET.tostring(e, 'ascii'),
- b"\n"
b'tãg')
def test_issue3151(self):
@@ -2127,7 +2179,7 @@ class ElementIterTest(unittest.TestCase)
sourcefile = serialize(doc, to_string=False)
self.assertEqual(next(ET.iterparse(sourcefile))[0], 'end')
- # With an explitit parser too (issue #9708)
+ # With an explicit parser too (issue #9708)
sourcefile = serialize(doc, to_string=False)
parser = ET.XMLParser(target=ET.TreeBuilder())
self.assertEqual(next(ET.iterparse(sourcefile, parser=parser))[0],
diff -r 99818330b4c0 Lib/xml/etree/ElementTree.py
--- a/Lib/xml/etree/ElementTree.py Mon Sep 05 15:40:10 2016 -0700
+++ b/Lib/xml/etree/ElementTree.py Mon Sep 05 17:29:12 2016 -0700
@@ -91,12 +91,13 @@
VERSION = "1.3.0"
-import sys
-import re
-import warnings
-import io
+import codecs
import collections
import contextlib
+import io
+import re
+import sys
+import warnings
from . import ElementPath
@@ -755,18 +756,29 @@ class ElementTree:
encoding = "utf-8"
else:
encoding = "us-ascii"
+
enc_lower = encoding.lower()
with _get_writer(file_or_filename, enc_lower) as write:
- if method == "xml" and (xml_declaration or
- (xml_declaration is None and
- enc_lower not in ("utf-8", "us-ascii", "unicode"))):
- declared_encoding = encoding
- if enc_lower == "unicode":
- # Retrieve the default encoding for the xml declaration
- import locale
- declared_encoding = locale.getpreferredencoding()
- write("\n" % (
- declared_encoding,))
+ if method == "xml":
+ # By default (xml_declaration=None), the XML declaration is
+ # skipped for UTF-8 and ASCII, since UTF-8 is the default XML
+ # encoding and ASCII is a subset of UTF-8
+ if xml_declaration is None:
+ if enc_lower != "unicode":
+ norm_enc = codecs.lookup(enc_lower).name
+ xml_declaration = (norm_enc not in {"utf-8", "ascii"})
+ else:
+ xml_declaration = False
+
+ if xml_declaration:
+ if enc_lower == "unicode":
+ # Retrieve the default encoding for the xml declaration
+ import locale
+ encoding = locale.getpreferredencoding()
+
+ write("\n"
+ % (encoding,))
+
if method == "text":
_serialize_text(write, self._root)
else: