diff --git a/Lib/base64.py b/Lib/base64.py --- a/Lib/base64.py +++ b/Lib/base64.py @@ -28,6 +28,7 @@ bytes_types = (bytes, bytearray) # Types acceptable as binary data +ascii_types = bytes_types + (str,) def _translate(s, altchars): @@ -79,10 +80,10 @@ discarded prior to the padding check. If validate is True, non-base64-alphabet characters in the input result in a binascii.Error. """ - if not isinstance(s, bytes_types): + if not isinstance(s, ascii_types): raise TypeError("expected bytes, not %s" % s.__class__.__name__) if altchars is not None: - if not isinstance(altchars, bytes_types): + if not isinstance(altchars, ascii_types): raise TypeError("expected bytes, not %s" % altchars.__class__.__name__) assert len(altchars) == 2, repr(altchars) @@ -211,7 +212,7 @@ the input is incorrectly padded or if there are non-alphabet characters present in the input. """ - if not isinstance(s, bytes_types): + if not isinstance(s, ascii_types): raise TypeError("expected bytes, not %s" % s.__class__.__name__) quanta, leftover = divmod(len(s), 8) if leftover: @@ -220,7 +221,7 @@ # False, or the character to map the digit 1 (one) to. It should be # either L (el) or I (eye). if map01 is not None: - if not isinstance(map01, bytes_types): + if not isinstance(map01, ascii_types): raise TypeError("expected bytes, not %s" % map01.__class__.__name__) assert len(map01) == 1, repr(map01) s = _translate(s, {b'0': b'O', b'1': map01}) @@ -292,7 +293,7 @@ s were incorrectly padded or if there are non-alphabet characters present in the string. """ - if not isinstance(s, bytes_types): + if not isinstance(s, ascii_types): raise TypeError("expected bytes, not %s" % s.__class__.__name__) if casefold: s = s.upper() diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -116,9 +116,6 @@ eq(base64.b64decode(b''), b'') # Test with arbitrary alternative characters eq(base64.b64decode(b'01a*b$cd', altchars=b'*$'), b'\xd3V\xbeo\xf7\x1d') - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.b64decode, "") - self.assertRaises(TypeError, base64.b64decode, b"", altchars="") # Test standard alphabet eq(base64.standard_b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org") eq(base64.standard_b64decode(b"YQ=="), b"a") @@ -131,9 +128,6 @@ b"abcdefghijklmnopqrstuvwxyz" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"0123456789!@#0^&*();:<>,. []{}") - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.standard_b64decode, "") - self.assertRaises(TypeError, base64.standard_b64decode, b"", altchars="") # Test with 'URL safe' alternative characters eq(base64.urlsafe_b64decode(b'01a-b_cd'), b'\xd3V\xbeo\xf7\x1d') self.assertRaises(TypeError, base64.urlsafe_b64decode, "") @@ -157,6 +151,15 @@ with self.assertRaises(binascii.Error): base64.b64decode(bstr, validate=True) + def test_b64decode_ascii_chars(self): + # issue 13641: Decoding functions in the base64 + # module could accept unicode strings. + tests = (('MDA=', b'00'), + ('AA==', b'\x00'), + ('/w==', b'\xff')) + for str, result in tests: + self.assertEqual(base64.b64decode(str), result) + def test_b32encode(self): eq = self.assertEqual eq(base64.b32encode(b''), b'') @@ -199,7 +202,6 @@ eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'M1023456', map01=b'L'), b'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'M1023456', map01=b'I'), b'b\x1d\xad\xf3\xbe') - self.assertRaises(TypeError, base64.b32decode, b"", map01="") def test_b32decode_error(self): self.assertRaises(binascii.Error, base64.b32decode, b'abc')