Index: Lib/xmlrpc/client.py =================================================================== --- Lib/xmlrpc/client.py (revision 74444) +++ Lib/xmlrpc/client.py (working copy) @@ -475,21 +475,28 @@ # that's perfectly ok. def __init__(self, encoding=None, allow_none=0): - self.memo = {} - self.data = None + self.__memo = set() self.encoding = encoding self.allow_none = allow_none - dispatch = {} + __dispatch = {} + def _add_memo(self, value): + i = id(value) + if i in self.__memo: + raise TypeError("cannot marshal recursive data") + self.__memo.add(i) + + def _del_memo(self, value): + self.__memo.discard(id(value)) + def dumps(self, values): out = [] write = out.append - dump = self.__dump if isinstance(values, Fault): # fault instance write("\n") - dump({'faultCode': values.faultCode, + self._dump({'faultCode': values.faultCode, 'faultString': values.faultString}, write) write("\n") @@ -503,15 +510,15 @@ write("\n") for v in values: write("\n") - dump(v, write) + self._dump(v, write) write("\n") write("\n") result = "".join(out) return result - def __dump(self, value, write): + def _dump(self, value, write): try: - f = self.dispatch[type(value)] + f = self.__dispatch[type(value)] except KeyError: # check if this object can be marshalled as a structure try: @@ -522,18 +529,18 @@ # because we don't know how to marshal these types # (e.g. a string sub-class) for type_ in type(value).__mro__: - if type_ in self.dispatch.keys(): + if type_ in self.__dispatch.keys(): raise TypeError("cannot marshal %s objects" % type(value)) # XXX(twouters): using "_arbitrary_instance" as key as a quick-fix # for the p3yk merge, this should probably be fixed more neatly. - f = self.dispatch["_arbitrary_instance"] + f = self.__dispatch["_arbitrary_instance"] f(self, value, write) def dump_nil (self, value, write): if not self.allow_none: raise TypeError("cannot marshal None unless allow_none is enabled") write("") - dispatch[type(None)] = dump_nil + __dispatch[type(None)] = dump_nil def dump_int(self, value, write): # in case ints are > 32 bits @@ -542,13 +549,13 @@ write("") write(str(value)) write("\n") - #dispatch[int] = dump_int + #__dispatch[int] = dump_int def dump_bool(self, value, write): write("") write(value and "1" or "0") write("\n") - dispatch[bool] = dump_bool + __dispatch[bool] = dump_bool def dump_long(self, value, write): if value > MAXINT or value < MININT: @@ -556,64 +563,56 @@ write("") write(str(int(value))) write("\n") - dispatch[int] = dump_long + __dispatch[int] = dump_long def dump_double(self, value, write): write("") write(repr(value)) write("\n") - dispatch[float] = dump_double + __dispatch[float] = dump_double def dump_string(self, value, write, escape=escape): write("") write(escape(value)) write("\n") - dispatch[bytes] = dump_string + __dispatch[bytes] = dump_string def dump_unicode(self, value, write, escape=escape): write("") write(escape(value)) write("\n") - dispatch[str] = dump_unicode + __dispatch[str] = dump_unicode def dump_array(self, value, write): - i = id(value) - if i in self.memo: - raise TypeError("cannot marshal recursive sequences") - self.memo[i] = None - dump = self.__dump + self._add_memo(value) write("\n") for v in value: - dump(v, write) + self._dump(v, write) write("\n") - del self.memo[i] - dispatch[tuple] = dump_array - dispatch[list] = dump_array + self._del_memo(value) + __dispatch[tuple] = dump_array + __dispatch[list] = dump_array def dump_struct(self, value, write, escape=escape): - i = id(value) - if i in self.memo: - raise TypeError("cannot marshal recursive dictionaries") - self.memo[i] = None - dump = self.__dump + self._add_memo(value) write("\n") for k, v in value.items(): write("\n") if not isinstance(k, str): raise TypeError("dictionary key must be string") write("%s\n" % escape(k)) - dump(v, write) + self._dump(v, write) write("\n") write("\n") - del self.memo[i] - dispatch[dict] = dump_struct + self._del_memo(value) + __dispatch[dict] = dump_struct if datetime: def dump_datetime(self, value, write): write("") write(_strftime(value)) write("\n") - dispatch[datetime.datetime] = dump_datetime + __dispatch[datetime.datetime] = dump_datetime def dump_instance(self, value, write): # check for special wrappers @@ -624,11 +623,11 @@ else: # store instance attributes as a struct (really?) self.dump_struct(value.__dict__, write) - dispatch[DateTime] = dump_instance - dispatch[Binary] = dump_instance + __dispatch[DateTime] = dump_instance + __dispatch[Binary] = dump_instance # XXX(twouters): using "_arbitrary_instance" as key as a quick-fix # for the p3yk merge, this should probably be fixed more neatly. - dispatch["_arbitrary_instance"] = dump_instance + __dispatch["_arbitrary_instance"] = dump_instance ## # XML-RPC unmarshaller. @@ -654,7 +653,6 @@ self._data = [] self._methodname = None self._encoding = "utf-8" - self.append = self._stack.append self._use_datetime = use_datetime if use_datetime and not datetime: raise ValueError("the datetime module is not available") @@ -714,36 +712,36 @@ dispatch = {} def end_nil (self, data): - self.append(None) + self._stack.append(None) self._value = 0 dispatch["nil"] = end_nil def end_boolean(self, data): if data == "0": - self.append(False) + self._stack.append(False) elif data == "1": - self.append(True) + self._stack.append(True) else: raise TypeError("bad boolean value") self._value = 0 dispatch["boolean"] = end_boolean def end_int(self, data): - self.append(int(data)) + self._stack.append(int(data)) self._value = 0 dispatch["i4"] = end_int dispatch["i8"] = end_int dispatch["int"] = end_int def end_double(self, data): - self.append(float(data)) + self._stack.append(float(data)) self._value = 0 dispatch["double"] = end_double def end_string(self, data): if self._encoding: data = data.decode(self._encoding) - self.append(data) + self._stack.append(data) self._value = 0 dispatch["string"] = end_string dispatch["name"] = end_string # struct keys are always strings @@ -769,7 +767,7 @@ def end_base64(self, data): value = Binary() value.decode(data.encode("ascii")) - self.append(value) + self._stack.append(value) self._value = 0 dispatch["base64"] = end_base64 @@ -778,7 +776,7 @@ value.decode(data) if self._use_datetime: value = _datetime_type(data) - self.append(value) + self._stack.append(value) dispatch["dateTime.iso8601"] = end_dateTime def end_value(self, data): @@ -880,7 +878,7 @@ # # return A (parser, unmarshaller) tuple. -def getparser(use_datetime=0): +def getparser(use_datetime=0, parser=None, unmarshaller=None): """getparser() -> parser, unmarshaller Create an instance of the fastest available parser, and attach it @@ -888,22 +886,42 @@ """ if use_datetime and not datetime: raise ValueError("the datetime module is not available") - if FastParser and FastUnmarshaller: + + if unmarshaller is not None: + target = unmarshaller(use_datetime=use_datetime) + elif FastUnmarshaller is not None: if use_datetime: mkdatetime = _datetime_type else: mkdatetime = _datetime target = FastUnmarshaller(True, False, _binary, mkdatetime, Fault) + else: + target = Unmarshaller(use_datetime=use_datetime) + + if parser is not None: + parser_obj = parser(target) + elif FastParser is not None: parser = FastParser(target) else: - target = Unmarshaller(use_datetime=use_datetime) - if FastParser: - parser = FastParser(target) - else: - parser = ExpatParser(target) + parser = ExpatParser(target) + return parser, target ## +# Create marshaller object +# This function picks the fastest available marshaller. +# +# return A marshaller object + +def getmarshaller(encoding=None, allow_none=0, marshaller=None): + if marshaller is not None: + return marshaller(encoding=encoding, allow_none=allow_none) + elif FastMarshaller is not None: + return FastMarshaller(encoding) + else: + return Marshaller(encoding, allow_none) + +## # Convert a Python tuple or a Fault instance to an XML-RPC packet. # # @def dumps(params, **options) @@ -914,10 +932,11 @@ # If used with a tuple, the tuple must be a singleton (that is, # it must contain exactly one element). # @keyparam encoding The packet encoding. +# @keyparam marshaller Custom marshaller class # @return A string containing marshalled data. def dumps(params, methodname=None, methodresponse=None, encoding=None, - allow_none=0): + allow_none=0, marshaller=None): """data [,options] -> marshalled data Convert an argument tuple or a Fault instance to an XML-RPC @@ -948,11 +967,7 @@ if not encoding: encoding = "utf-8" - if FastMarshaller: - m = FastMarshaller(encoding) - else: - m = Marshaller(encoding, allow_none) - + m = getmarshaller(encoding=encoding, allow_none=allow_none, marshaller=marshaller) data = m.dumps(params) if encoding != "utf-8": @@ -989,11 +1004,14 @@ # represents a fault condition, this function raises a Fault exception. # # @param data An XML-RPC packet, given as an 8-bit string. +# @keyparam use_datetime use datetime module +# @keyparam parser custom parser class +# @keyparam unmarshaller custom unmarshaller class # @return A tuple containing the unpacked data, and the method name # (None if not present). # @see Fault -def loads(data, use_datetime=0): +def loads(data, use_datetime=0, parser=None, unmarshaller=None): """data -> unmarshalled data, method name Convert an XML-RPC packet to unmarshalled data plus a method @@ -1002,7 +1020,7 @@ If the XML-RPC packet represents a fault condition, this function raises a Fault exception. """ - p, u = getparser(use_datetime=use_datetime) + p, u = getparser(use_datetime=use_datetime, parser=parser, unmarshaller=unmarshaller) p.feed(data) p.close() return u.close(), u.getmethodname() @@ -1114,8 +1132,7 @@ # that they can decode such a request encode_threshold = None #None = don't encode - def __init__(self, use_datetime=0): - self._use_datetime = use_datetime + def __init__(self): self._connection = (None, None) self._extra_headers = [] @@ -1148,7 +1165,7 @@ resp = http_conn.getresponse() if resp.status == 200: self.verbose = verbose - return self.parse_response(resp) + return self.response_iterator(resp) except Fault: raise @@ -1168,16 +1185,26 @@ dict(resp.getheaders()) ) + def response_iterator(self, resp): - ## - # Create parser. - # - # @return A 2-tuple containing a parser and a unmarshaller. + if resp.getheader("Content-Encoding", "") == "gzip": + stream = GzipDecodedResponse(resp) + else: + stream = resp - def getparser(self): - # get parser and unmarshaller - return getparser(use_datetime=self._use_datetime) + while 1: + response = stream.read(1024) + if not response: + break + if self.verbose: + print("body:", repr(response)) + yield response + if stream is not resp: + stream.close() + + raise StopIteration() + ## # Get authorization info from host parameter # Host may be a string, or a (host, x509-dict) tuple; if a string, @@ -1289,35 +1316,7 @@ connection.putheader("Content-Length", str(len(request_body))) connection.endheaders(request_body) - ## - # Parse response. - # - # @param file Stream. - # @return Response tuple and target method. - def parse_response(self, response): - # read response data from httpresponse, and parse it - if response.getheader("Content-Encoding", "") == "gzip": - stream = GzipDecodedResponse(response) - else: - stream = response - - p, u = self.getparser() - - while 1: - data = stream.read(1024) - if not data: - break - if self.verbose: - print("body:", repr(data)) - p.feed(data) - - if stream is not response: - stream.close() - p.close() - - return u.close() - ## # Standard transport class for XML-RPC over HTTPS. @@ -1380,7 +1379,8 @@ """ def __init__(self, uri, transport=None, encoding=None, verbose=0, - allow_none=0, use_datetime=0): + allow_none=0, use_datetime=0, marshaller=None, parser=None, + unmarshaller=None): # establish a "logical" server connection # get the url @@ -1394,14 +1394,18 @@ if transport is None: if type == "https": - transport = SafeTransport(use_datetime=use_datetime) + transport = SafeTransport() else: - transport = Transport(use_datetime=use_datetime) + transport = Transport() self.__transport = transport + self.__use_datetime = use_datetime self.__encoding = encoding or 'utf-8' self.__verbose = verbose self.__allow_none = allow_none + self.__marshaller = marshaller + self.__parser = parser + self.__unmarshaller = unmarshaller def __close(self): self.__transport.close() @@ -1410,15 +1414,24 @@ # call a method on the remote server request = dumps(params, methodname, encoding=self.__encoding, - allow_none=self.__allow_none).encode(self.__encoding) + allow_none=self.__allow_none, + marshaller=self.__marshaller).encode(self.__encoding) - response = self.__transport.request( + p, u = getparser(use_datetime=self.__use_datetime, + parser=self.__parser, unmarshaller=self.__unmarshaller) + + response_iterator = self.__transport.request( self.__host, self.__handler, request, verbose=self.__verbose ) + for part in response_iterator: + p.feed(part) + p.close() + response = u.close() + if len(response) == 1: response = response[0] Index: Lib/xmlrpc/server.py =================================================================== --- Lib/xmlrpc/server.py (revision 74444) +++ Lib/xmlrpc/server.py (working copy) @@ -159,11 +159,15 @@ reason to instantiate this class directly. """ - def __init__(self, allow_none=False, encoding=None): + def __init__(self, allow_none=False, encoding=None, + marshaller=None, parser=None, unmarshaller=None): self.funcs = {} self.instance = None self.allow_none = allow_none self.encoding = encoding or 'utf-8' + self.marshaller = marshaller + self.parser = parser + self.unmarshaller = unmarshaller def register_instance(self, instance, allow_dotted_names=False): """Registers an instance to respond to XML-RPC requests. @@ -244,7 +248,7 @@ """ try: - params, method = loads(data) + params, method = loads(data, parser=self.parser, unmarshaller=self.unmarshaller) # generate response if dispatch_method is not None: @@ -254,17 +258,19 @@ # wrap response in a singleton tuple response = (response,) response = dumps(response, methodresponse=1, - allow_none=self.allow_none, encoding=self.encoding) + allow_none=self.allow_none, encoding=self.encoding, + marshaller=self.marshaller) except Fault as fault: response = dumps(fault, allow_none=self.allow_none, - encoding=self.encoding) + encoding=self.encoding, + marshaller=self.marshaller) except: # report exception back to server exc_type, exc_value, exc_tb = sys.exc_info() response = dumps( Fault(1, "%s:%s" % (exc_type, exc_value)), encoding=self.encoding, allow_none=self.allow_none, - ) + marshaller=self.marshaller) return response.encode(self.encoding) @@ -570,10 +576,12 @@ _send_traceback_header = False def __init__(self, addr, requestHandler=SimpleXMLRPCRequestHandler, - logRequests=True, allow_none=False, encoding=None, bind_and_activate=True): + logRequests=True, allow_none=False, encoding=None, bind_and_activate=True, + marshaller=None, parser=None, unmarshaller=None): self.logRequests = logRequests - SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding) + SimpleXMLRPCDispatcher.__init__(self, allow_none, encoding, marshaller=marshaller, + parser=parser, unmarshaller=unmarshaller) socketserver.TCPServer.__init__(self, addr, requestHandler, bind_and_activate) # [Bug #1222790] If possible, set close-on-exec flag; if a Index: Lib/test/test_xmlrpc.py =================================================================== --- Lib/test/test_xmlrpc.py (revision 74412) +++ Lib/test/test_xmlrpc.py (working copy) @@ -573,9 +573,9 @@ class Transport(xmlrpclib.Transport): #custom transport, stores the response length for our perusal fake_gzip = False - def parse_response(self, response): + def response_iterator(self, response): self.response_length=int(response.getheader("content-length", 0)) - return xmlrpclib.Transport.parse_response(self, response) + return xmlrpclib.Transport.response_iterator(self, response) def send_content(self, connection, body): if self.fake_gzip: