# HG changeset patch # Parent 2096158376e5ed0a2e4837103422fc5e157f572b diff --git a/Lib/test/test_xmlrpc.py b/Lib/test/test_xmlrpc.py --- a/Lib/test/test_xmlrpc.py +++ b/Lib/test/test_xmlrpc.py @@ -24,6 +24,8 @@ 'ashortlong': 2, 'anotherlist': ['.zyx.41'], 'abase64': xmlrpclib.Binary(b"my dog has fleas"), + 'b64bytes': b"my dog has fleas", + 'b64bytearray': bytearray(b"my dog has fleas"), 'boolean': False, 'unicode': '\u4000\u6000\u8000', 'ukey\u4000': 'regular value', @@ -47,24 +49,33 @@ # since with use_datetime set to 1 the unmarshaller would create # datetime objects for the 'datetime[123]' keys as well dt = datetime.datetime(2005, 2, 10, 11, 41, 23) + self.assertEqual(dt, xmlrpclib.DateTime('20050210T11:41:23')) + s = xmlrpclib.dumps((dt,)) (newdt,), m = xmlrpclib.loads(s, use_datetime=1) self.assertEqual(newdt, dt) - self.assertEqual(m, None) + self.assertIs(type(newdt), datetime.datetime) + self.assertIsNone(m) (newdt,), m = xmlrpclib.loads(s, use_datetime=0) - self.assertEqual(newdt, xmlrpclib.DateTime('20050210T11:41:23')) + self.assertEqual(newdt, dt) + self.assertIs(type(newdt), xmlrpclib.DateTime) + self.assertIsNone(m) def test_datetime_before_1900(self): # same as before but with a date before 1900 dt = datetime.datetime(1, 2, 10, 11, 41, 23) + self.assertEqual(dt, xmlrpclib.DateTime('00010210T11:41:23')) s = xmlrpclib.dumps((dt,)) (newdt,), m = xmlrpclib.loads(s, use_datetime=1) self.assertEqual(newdt, dt) - self.assertEqual(m, None) + self.assertIs(type(newdt), datetime.datetime) + self.assertIsNone(m) (newdt,), m = xmlrpclib.loads(s, use_datetime=0) - self.assertEqual(newdt, xmlrpclib.DateTime('00010210T11:41:23')) + self.assertEqual(newdt, dt) + self.assertIs(type(newdt), xmlrpclib.DateTime) + self.assertIsNone(m) def test_bug_1164912 (self): d = xmlrpclib.DateTime() @@ -133,6 +144,22 @@ xmlrpclib.loads(strg)[0][0]) self.assertRaises(TypeError, xmlrpclib.dumps, (arg1,)) + def test_dump_bytes(self): + sample = b"my dog has fleas" + self.assertEqual(sample, xmlrpclib.Binary(sample)) + for type_ in bytes, bytearray, xmlrpclib.Binary: + value = type_(sample) + s = xmlrpclib.dumps((value,)) + (newvalue,), m = xmlrpclib.loads(s, use_bytes=1) + self.assertEqual(newvalue, sample) + self.assertIs(type(newvalue), bytes) + self.assertIsNone(m) + + (newvalue,), m = xmlrpclib.loads(s, use_bytes=0) + self.assertEqual(newvalue, sample) + self.assertIs(type(newvalue), xmlrpclib.Binary) + self.assertIsNone(m) + def test_get_host_info(self): # see bug #3613, this raised a TypeError transp = xmlrpc.client.Transport() @@ -140,9 +167,6 @@ ('host.tld', [('Authorization', 'Basic dXNlcg==')], {})) - def test_dump_bytes(self): - self.assertRaises(TypeError, xmlrpclib.dumps, (b"my dog has fleas",)) - def test_ssl_presence(self): try: import ssl diff --git a/Lib/xmlrpc/client.py b/Lib/xmlrpc/client.py --- a/Lib/xmlrpc/client.py +++ b/Lib/xmlrpc/client.py @@ -371,7 +371,7 @@ if data is None: data = b"" else: - if not isinstance(data, bytes): + if not isinstance(data, (bytes, bytearray)): raise TypeError("expected bytes, not %s" % data.__class__.__name__) data = bytes(data) # Make a copy of the bytes! @@ -544,6 +544,14 @@ write("\n") dispatch[str] = dump_unicode + def dump_bytes(self, value, write): + write("\n") + encoded = base64.encodebytes(value) + write(encoded.decode('ascii')) + write("\n") + dispatch[bytes] = dump_bytes + dispatch[bytearray] = dump_bytes + def dump_array(self, value, write): i = id(value) if i in self.memo: @@ -614,7 +622,7 @@ # and again, if you don't understand what's going on in here, # that's perfectly ok. - def __init__(self, use_datetime=False): + def __init__(self, use_datetime=False, use_bytes=False): self._type = None self._stack = [] self._marks = [] @@ -623,6 +631,7 @@ self._encoding = "utf-8" self.append = self._stack.append self._use_datetime = use_datetime + self._use_bytes = use_bytes def close(self): # return response tuple and target method @@ -734,6 +743,8 @@ def end_base64(self, data): value = Binary() value.decode(data.encode("ascii")) + if self._use_bytes: + value = value.data self.append(value) self._value = 0 dispatch["base64"] = end_base64 @@ -845,7 +856,7 @@ # # return A (parser, unmarshaller) tuple. -def getparser(use_datetime=False): +def getparser(use_datetime=False, use_bytes=False): """getparser() -> parser, unmarshaller Create an instance of the fastest available parser, and attach it @@ -856,10 +867,14 @@ mkdatetime = _datetime_type else: mkdatetime = _datetime - target = FastUnmarshaller(True, False, _binary, mkdatetime, Fault) + if use_bytes: + mkbytes = base64.decodebytes + else: + mkbytes = _binary + target = FastUnmarshaller(True, False, mkbytes, mkdatetime, Fault) parser = FastParser(target) else: - target = Unmarshaller(use_datetime=use_datetime) + target = Unmarshaller(use_datetime=use_datetime, use_bytes=use_bytes) if FastParser: parser = FastParser(target) else: @@ -956,7 +971,7 @@ # (None if not present). # @see Fault -def loads(data, use_datetime=False): +def loads(data, use_datetime=False, use_bytes=False): """data -> unmarshalled data, method name Convert an XML-RPC packet to unmarshalled data plus a method @@ -965,7 +980,7 @@ If the XML-RPC packet represents a fault condition, this function raises a Fault exception. """ - p, u = getparser(use_datetime=use_datetime) + p, u = getparser(use_datetime=use_datetime, use_bytes=use_bytes) p.feed(data) p.close() return u.close(), u.getmethodname() @@ -1077,8 +1092,9 @@ # that they can decode such a request encode_threshold = None #None = don't encode - def __init__(self, use_datetime=False): + def __init__(self, use_datetime=False, use_bytes=False): self._use_datetime = use_datetime + self._use_bytes = use_bytes self._connection = (None, None) self._extra_headers = [] @@ -1139,7 +1155,8 @@ def getparser(self): # get parser and unmarshaller - return getparser(use_datetime=self._use_datetime) + return getparser(use_datetime=self._use_datetime, + use_bytes=self._use_bytes) ## # Get authorization info from host parameter @@ -1346,7 +1363,7 @@ """ def __init__(self, uri, transport=None, encoding=None, verbose=False, - allow_none=False, use_datetime=False): + allow_none=False, use_datetime=False, use_bytes=False): # establish a "logical" server connection # get the url @@ -1360,9 +1377,10 @@ if transport is None: if type == "https": - transport = SafeTransport(use_datetime=use_datetime) + handler = SafeTransport else: - transport = Transport(use_datetime=use_datetime) + handler = Transport + transport = handler(use_datetime=use_datetime, use_bytes=use_bytes) self.__transport = transport self.__encoding = encoding or 'utf-8'