Index: Lib/xmlrpc/client.py
===================================================================
--- Lib/xmlrpc/client.py (revision 74444)
+++ Lib/xmlrpc/client.py (working copy)
@@ -475,21 +475,28 @@
# that's perfectly ok.
def __init__(self, encoding=None, allow_none=0):
- self.memo = {}
- self.data = None
+ self.__memo = set()
self.encoding = encoding
self.allow_none = allow_none
- dispatch = {}
+ __dispatch = {}
+ def _add_memo(self, value):
+ i = id(value)
+ if i in self.__memo:
+ raise TypeError("cannot marshal recursive data")
+ self.__memo.add(i)
+
+ def _del_memo(self, value):
+ self.__memo.discard(id(value))
+
def dumps(self, values):
out = []
write = out.append
- dump = self.__dump
if isinstance(values, Fault):
# fault instance
write("\n")
- dump({'faultCode': values.faultCode,
+ self._dump({'faultCode': values.faultCode,
'faultString': values.faultString},
write)
write("\n")
@@ -503,15 +510,15 @@
write("\n")
for v in values:
write("\n")
- dump(v, write)
+ self._dump(v, write)
write("\n")
write("\n")
result = "".join(out)
return result
- def __dump(self, value, write):
+ def _dump(self, value, write):
try:
- f = self.dispatch[type(value)]
+ f = self.__dispatch[type(value)]
except KeyError:
# check if this object can be marshalled as a structure
try:
@@ -522,18 +529,18 @@
# because we don't know how to marshal these types
# (e.g. a string sub-class)
for type_ in type(value).__mro__:
- if type_ in self.dispatch.keys():
+ if type_ in self.__dispatch.keys():
raise TypeError("cannot marshal %s objects" % type(value))
# XXX(twouters): using "_arbitrary_instance" as key as a quick-fix
# for the p3yk merge, this should probably be fixed more neatly.
- f = self.dispatch["_arbitrary_instance"]
+ f = self.__dispatch["_arbitrary_instance"]
f(self, value, write)
def dump_nil (self, value, write):
if not self.allow_none:
raise TypeError("cannot marshal None unless allow_none is enabled")
write("")
- dispatch[type(None)] = dump_nil
+ __dispatch[type(None)] = dump_nil
def dump_int(self, value, write):
# in case ints are > 32 bits
@@ -542,13 +549,13 @@
write("")
write(str(value))
write("\n")
- #dispatch[int] = dump_int
+ #__dispatch[int] = dump_int
def dump_bool(self, value, write):
write("")
write(value and "1" or "0")
write("\n")
- dispatch[bool] = dump_bool
+ __dispatch[bool] = dump_bool
def dump_long(self, value, write):
if value > MAXINT or value < MININT:
@@ -556,64 +563,56 @@
write("")
write(str(int(value)))
write("\n")
- dispatch[int] = dump_long
+ __dispatch[int] = dump_long
def dump_double(self, value, write):
write("")
write(repr(value))
write("\n")
- dispatch[float] = dump_double
+ __dispatch[float] = dump_double
def dump_string(self, value, write, escape=escape):
write("")
write(escape(value))
write("\n")
- dispatch[bytes] = dump_string
+ __dispatch[bytes] = dump_string
def dump_unicode(self, value, write, escape=escape):
write("")
write(escape(value))
write("\n")
- dispatch[str] = dump_unicode
+ __dispatch[str] = dump_unicode
def dump_array(self, value, write):
- i = id(value)
- if i in self.memo:
- raise TypeError("cannot marshal recursive sequences")
- self.memo[i] = None
- dump = self.__dump
+ self._add_memo(value)
write("\n")
for v in value:
- dump(v, write)
+ self._dump(v, write)
write("\n")
- del self.memo[i]
- dispatch[tuple] = dump_array
- dispatch[list] = dump_array
+ self._del_memo(value)
+ __dispatch[tuple] = dump_array
+ __dispatch[list] = dump_array
def dump_struct(self, value, write, escape=escape):
- i = id(value)
- if i in self.memo:
- raise TypeError("cannot marshal recursive dictionaries")
- self.memo[i] = None
- dump = self.__dump
+ self._add_memo(value)
write("\n")
for k, v in value.items():
write("\n")
if not isinstance(k, str):
raise TypeError("dictionary key must be string")
write("%s\n" % escape(k))
- dump(v, write)
+ self._dump(v, write)
write("\n")
write("\n")
- del self.memo[i]
- dispatch[dict] = dump_struct
+ self._del_memo(value)
+ __dispatch[dict] = dump_struct
if datetime:
def dump_datetime(self, value, write):
write("")
write(_strftime(value))
write("\n")
- dispatch[datetime.datetime] = dump_datetime
+ __dispatch[datetime.datetime] = dump_datetime
def dump_instance(self, value, write):
# check for special wrappers
@@ -624,11 +623,11 @@
else:
# store instance attributes as a struct (really?)
self.dump_struct(value.__dict__, write)
- dispatch[DateTime] = dump_instance
- dispatch[Binary] = dump_instance
+ __dispatch[DateTime] = dump_instance
+ __dispatch[Binary] = dump_instance
# XXX(twouters): using "_arbitrary_instance" as key as a quick-fix
# for the p3yk merge, this should probably be fixed more neatly.
- dispatch["_arbitrary_instance"] = dump_instance
+ __dispatch["_arbitrary_instance"] = dump_instance
##
# XML-RPC unmarshaller.
@@ -654,7 +653,6 @@
self._data = []
self._methodname = None
self._encoding = "utf-8"
- self.append = self._stack.append
self._use_datetime = use_datetime
if use_datetime and not datetime:
raise ValueError("the datetime module is not available")
@@ -714,36 +712,36 @@
dispatch = {}
def end_nil (self, data):
- self.append(None)
+ self._stack.append(None)
self._value = 0
dispatch["nil"] = end_nil
def end_boolean(self, data):
if data == "0":
- self.append(False)
+ self._stack.append(False)
elif data == "1":
- self.append(True)
+ self._stack.append(True)
else:
raise TypeError("bad boolean value")
self._value = 0
dispatch["boolean"] = end_boolean
def end_int(self, data):
- self.append(int(data))
+ self._stack.append(int(data))
self._value = 0
dispatch["i4"] = end_int
dispatch["i8"] = end_int
dispatch["int"] = end_int
def end_double(self, data):
- self.append(float(data))
+ self._stack.append(float(data))
self._value = 0
dispatch["double"] = end_double
def end_string(self, data):
if self._encoding:
data = data.decode(self._encoding)
- self.append(data)
+ self._stack.append(data)
self._value = 0
dispatch["string"] = end_string
dispatch["name"] = end_string # struct keys are always strings
@@ -769,7 +767,7 @@
def end_base64(self, data):
value = Binary()
value.decode(data.encode("ascii"))
- self.append(value)
+ self._stack.append(value)
self._value = 0
dispatch["base64"] = end_base64
@@ -778,7 +776,7 @@
value.decode(data)
if self._use_datetime:
value = _datetime_type(data)
- self.append(value)
+ self._stack.append(value)
dispatch["dateTime.iso8601"] = end_dateTime
def end_value(self, data):
@@ -880,7 +878,7 @@
#
# return A (parser, unmarshaller) tuple.
-def getparser(use_datetime=0):
+def getparser(use_datetime=0, parser=None, unmarshaller=None):
"""getparser() -> parser, unmarshaller
Create an instance of the fastest available parser, and attach it
@@ -888,22 +886,42 @@
"""
if use_datetime and not datetime:
raise ValueError("the datetime module is not available")
- if FastParser and FastUnmarshaller:
+
+ if unmarshaller is not None:
+ target = unmarshaller(use_datetime=use_datetime)
+ elif FastUnmarshaller is not None:
if use_datetime:
mkdatetime = _datetime_type
else:
mkdatetime = _datetime
target = FastUnmarshaller(True, False, _binary, mkdatetime, Fault)
+ else:
+ target = Unmarshaller(use_datetime=use_datetime)
+
+ if parser is not None:
+ parser_obj = parser(target)
+ elif FastParser is not None:
parser = FastParser(target)
else:
- target = Unmarshaller(use_datetime=use_datetime)
- if FastParser:
- parser = FastParser(target)
- else:
- parser = ExpatParser(target)
+ parser = ExpatParser(target)
+
return parser, target
##
+# Create marshaller object
+# This function picks the fastest available marshaller.
+#
+# return A marshaller object
+
+def getmarshaller(encoding=None, allow_none=0, marshaller=None):
+ if marshaller is not None:
+ return marshaller(encoding=encoding, allow_none=allow_none)
+ elif FastMarshaller is not None:
+ return FastMarshaller(encoding)
+ else:
+ return Marshaller(encoding, allow_none)
+
+##
# Convert a Python tuple or a Fault instance to an XML-RPC packet.
#
# @def dumps(params, **options)
@@ -914,10 +932,11 @@
# If used with a tuple, the tuple must be a singleton (that is,
# it must contain exactly one element).
# @keyparam encoding The packet encoding.
+# @keyparam marshaller Custom marshaller class
# @return A string containing marshalled data.
def dumps(params, methodname=None, methodresponse=None, encoding=None,
- allow_none=0):
+ allow_none=0, marshaller=None):
"""data [,options] -> marshalled data
Convert an argument tuple or a Fault instance to an XML-RPC
@@ -948,11 +967,7 @@
if not encoding:
encoding = "utf-8"
- if FastMarshaller:
- m = FastMarshaller(encoding)
- else:
- m = Marshaller(encoding, allow_none)
-
+ m = getmarshaller(encoding=encoding, allow_none=allow_none, marshaller=marshaller)
data = m.dumps(params)
if encoding != "utf-8":
@@ -989,11 +1004,14 @@
# represents a fault condition, this function raises a Fault exception.
#
# @param data An XML-RPC packet, given as an 8-bit string.
+# @keyparam use_datetime use datetime module
+# @keyparam parser custom parser class
+# @keyparam unmarshaller custom unmarshaller class
# @return A tuple containing the unpacked data, and the method name
# (None if not present).
# @see Fault
-def loads(data, use_datetime=0):
+def loads(data, use_datetime=0, parser=None, unmarshaller=None):
"""data -> unmarshalled data, method name
Convert an XML-RPC packet to unmarshalled data plus a method
@@ -1002,7 +1020,7 @@
If the XML-RPC packet represents a fault condition, this function
raises a Fault exception.
"""
- p, u = getparser(use_datetime=use_datetime)
+ p, u = getparser(use_datetime=use_datetime, parser=parser, unmarshaller=unmarshaller)
p.feed(data)
p.close()
return u.close(), u.getmethodname()
@@ -1114,8 +1132,7 @@
# that they can decode such a request
encode_threshold = None #None = don't encode
- def __init__(self, use_datetime=0):
- self._use_datetime = use_datetime
+ def __init__(self):
self._connection = (None, None)
self._extra_headers = []
@@ -1148,7 +1165,7 @@
resp = http_conn.getresponse()
if resp.status == 200:
self.verbose = verbose
- return self.parse_response(resp)
+ return self.response_iterator(resp)
except Fault:
raise
@@ -1168,16 +1185,26 @@
dict(resp.getheaders())
)
+ def response_iterator(self, resp):
- ##
- # Create parser.
- #
- # @return A 2-tuple containing a parser and a unmarshaller.
+ if resp.getheader("Content-Encoding", "") == "gzip":
+ stream = GzipDecodedResponse(resp)
+ else:
+ stream = resp
- def getparser(self):
- # get parser and unmarshaller
- return getparser(use_datetime=self._use_datetime)
+ while 1:
+ response = stream.read(1024)
+ if not response:
+ break
+ if self.verbose:
+ print("body:", repr(response))
+ yield response
+ if stream is not resp:
+ stream.close()
+
+ raise StopIteration()
+
##
# Get authorization info from host parameter
# Host may be a string, or a (host, x509-dict) tuple; if a string,
@@ -1289,35 +1316,7 @@
connection.putheader("Content-Length", str(len(request_body)))
connection.endheaders(request_body)
- ##
- # Parse response.
- #
- # @param file Stream.
- # @return Response tuple and target method.
- def parse_response(self, response):
- # read response data from httpresponse, and parse it
- if response.getheader("Content-Encoding", "") == "gzip":
- stream = GzipDecodedResponse(response)
- else:
- stream = response
-
- p, u = self.getparser()
-
- while 1:
- data = stream.read(1024)
- if not data:
- break
- if self.verbose:
- print("body:", repr(data))
- p.feed(data)
-
- if stream is not response:
- stream.close()
- p.close()
-
- return u.close()
-
##
# Standard transport class for XML-RPC over HTTPS.
@@ -1380,7 +1379,8 @@
"""
def __init__(self, uri, transport=None, encoding=None, verbose=0,
- allow_none=0, use_datetime=0):
+ allow_none=0, use_datetime=0, marshaller=None, parser=None,
+ unmarshaller=None):
# establish a "logical" server connection
# get the url
@@ -1394,14 +1394,18 @@
if transport is None:
if type == "https":
- transport = SafeTransport(use_datetime=use_datetime)
+ transport = SafeTransport()
else:
- transport = Transport(use_datetime=use_datetime)
+ transport = Transport()
self.__transport = transport
+ self.__use_datetime = use_datetime
self.__encoding = encoding or 'utf-8'
self.__verbose = verbose
self.__allow_none = allow_none
+ self.__marshaller = marshaller
+ self.__parser = parser
+ self.__unmarshaller = unmarshaller
def __close(self):
self.__transport.close()
@@ -1410,15 +1414,24 @@
# call a method on the remote server
request = dumps(params, methodname, encoding=self.__encoding,
- allow_none=self.__allow_none).encode(self.__encoding)
+ allow_none=self.__allow_none,
+ marshaller=self.__marshaller).encode(self.__encoding)
- response = self.__transport.request(
+ p, u = getparser(use_datetime=self.__use_datetime,
+ parser=self.__parser, unmarshaller=self.__unmarshaller)
+
+ response_iterator = self.__transport.request(
self.__host,
self.__handler,
request,
verbose=self.__verbose
)
+ for part in response_iterator:
+ p.feed(part)
+ p.close()
+ response = u.close()
+
if len(response) == 1:
response = response[0]
Index: Lib/xmlrpc/server.py
===================================================================
--- Lib/xmlrpc/server.py (revision 74444)
+++ Lib/xmlrpc/server.py (working copy)
@@ -159,11 +159,15 @@
reason to instantiate this class directly.
"""
- def __init__(self, allow_none=False, encoding=None):
+ def __init__(self, allow_none=False, encoding=None,
+ marshaller=None, parser=None, unmarshaller=None):
self.funcs = {}
self.instance = None
self.allow_none = allow_none
self.encoding = encoding or 'utf-8'
+ self.marshaller = marshaller
+ self.parser = parser
+ self.unmarshaller = unmarshaller
def register_instance(self, instance, allow_dotted_names=False):
"""Registers an instance to respond to XML-RPC requests.
@@ -244,7 +248,7 @@
"""
try:
- params, method = loads(data)
+ params, method = loads(data, parser=self.parser, unmarshaller=self.unmarshaller)
# generate response
if dispatch_method is not None:
@@ -254,17 +258,19 @@
# wrap response in a singleton tuple
response = (response,)
response = dumps(response, methodresponse=1,
- allow_none=self.allow_none, encoding=self.encoding)
+ allow_none=self.allow_none, encoding=self.encoding,
+ marshaller=self.marshaller)
except Fault as fault:
response = dumps(fault, allow_none=self.allow_none,
- encoding=self.encoding)
+ encoding=self.encoding,
+ marshaller=self.marshaller)
except:
# report exception back to server
exc_type, exc_value, exc_tb = sys.exc_info()
response = dumps(
Fault(1, "%s:%s" % (exc_type, exc_value)),
encoding=self.encoding, allow_none=self.allow_none,
- )
+ marshaller=self.marshaller)
return response.encode(self.encoding)
@@ -570,10 +576,12 @@
_send_traceback_header = False
def __init__(self, addr, requestHandler=SimpleXMLRPCRequestHandler,
- logRequests=True, allow_none=False, encoding=None, bind_and_activate=True):
+ logRequests=True, allow_none=False, encoding=None, bind_and_activate=True,
+ marshaller=None, parser=None, unmarshaller=None):
self.logRequests = logRequests
- SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding)
+ SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, marshaller=marshaller,
+ parser=parser, unmarshaller=unmarshaller)
socketserver.TCPServer.__init__(self, addr, requestHandler, bind_and_activate)
# [Bug #1222790] If possible, set close-on-exec flag; if a
Index: Lib/test/test_xmlrpc.py
===================================================================
--- Lib/test/test_xmlrpc.py (revision 74412)
+++ Lib/test/test_xmlrpc.py (working copy)
@@ -573,9 +573,9 @@
class Transport(xmlrpclib.Transport):
#custom transport, stores the response length for our perusal
fake_gzip = False
- def parse_response(self, response):
+ def response_iterator(self, response):
self.response_length=int(response.getheader("content-length", 0))
- return xmlrpclib.Transport.parse_response(self, response)
+ return xmlrpclib.Transport.response_iterator(self, response)
def send_content(self, connection, body):
if self.fake_gzip: