diff --git a/Lib/xml/etree/ElementTree.py b/Lib/xml/etree/ElementTree.py index c26a764..b404e20 100644 --- a/Lib/xml/etree/ElementTree.py +++ b/Lib/xml/etree/ElementTree.py @@ -96,6 +96,7 @@ VERSION = "1.3.0" # structure, and convert it from and to XML. ## +import io import sys import re import warnings @@ -811,43 +812,64 @@ class ElementTree: elif method not in _serialize: # FIXME: raise an ImportError for c14n if ElementC14N is missing? raise ValueError("unknown method %r" % method) - if hasattr(file_or_filename, "write"): - file = file_or_filename + + if encoding: + _encoding = encoding else: + import locale + _encoding = locale.getpreferredencoding() + + close_required = False + bytes_write = False + if isinstance(file_or_filename, io.BufferedIOBase): + file = file_or_filename + bytes_write = True + elif isinstance(file_or_filename, io.TextIOBase): + bytes_write = False + try: + fd = file_or_filename.fileno() + except io.UnsupportedOperation: + fd = -1 + if fd >= 0: + if file_or_filename.encoding: + _encoding = file_or_filename.encoding + file = open(fd, "w", + encoding=_encoding, errors="xmlcharrefreplace", + closefd=False) + else: + file = file_or_filename + elif hasattr(file_or_filename, "write"): + file = file_or_filename if encoding: - file = open(file_or_filename, "wb") + bytes_write = True else: - file = open(file_or_filename, "w") - if encoding: + bytes_write = False + else: + file = open(file_or_filename, "w", + encoding=_encoding, errors="xmlcharrefreplace") + bytes_write = False + close_required = True + if bytes_write: def write(text): try: - return file.write(text.encode(encoding, + return file.write(text.encode(_encoding, "xmlcharrefreplace")) except (TypeError, AttributeError): _raise_serialization_error(text) else: write = file.write - if not encoding: - if method == "c14n": - encoding = "utf-8" - else: - encoding = None - elif xml_declaration or (xml_declaration is None and - encoding not in ("utf-8", "us-ascii")): + + if xml_declaration or _encoding.upper() not in ("UTF-8", "US-ASCII"): if method == "xml": - encoding_ = encoding - if not encoding: - # Retrieve the default encoding for the xml declaration - import locale - encoding_ = locale.getpreferredencoding() - write("\n" % encoding_) + write("\n" % _encoding) + if method == "text": _serialize_text(write, self._root) else: qnames, namespaces = _namespaces(self._root, default_namespace) serialize = _serialize[method] serialize(write, self._root, qnames, namespaces) - if file_or_filename is not file: + if close_required: file.close() def write_c14n(self, file):