diff --git a/Lib/base64.py b/Lib/base64.py --- a/Lib/base64.py +++ b/Lib/base64.py @@ -28,6 +28,7 @@ bytes_types = (bytes, bytearray) # Types acceptable as binary data +ascii_types = bytes_types + (str,) def _translate(s, altchars): @@ -79,10 +80,10 @@ discarded prior to the padding check. If validate is True, non-base64-alphabet characters in the input result in a binascii.Error. """ - if not isinstance(s, bytes_types): + if not isinstance(s, ascii_types): raise TypeError("expected bytes, not %s" % s.__class__.__name__) if altchars is not None: - if not isinstance(altchars, bytes_types): + if not isinstance(altchars, ascii_types): raise TypeError("expected bytes, not %s" % altchars.__class__.__name__) assert len(altchars) == 2, repr(altchars) @@ -145,8 +146,21 @@ 8: b'I', 17: b'R', 26: b'2', } +_b32ascii_alphabet = { + 0: 'A', 9: 'J', 18: 'S', 27: '3', + 1: 'B', 10: 'K', 19: 'T', 28: '4', + 2: 'C', 11: 'L', 20: 'U', 29: '5', + 3: 'D', 12: 'M', 21: 'V', 30: '6', + 4: 'E', 13: 'N', 22: 'W', 31: '7', + 5: 'F', 14: 'O', 23: 'X', + 6: 'G', 15: 'P', 24: 'Y', + 7: 'H', 16: 'Q', 25: 'Z', + 8: 'I', 17: 'R', 26: '2', + } + _b32tab = [v[0] for k, v in sorted(_b32alphabet.items())] _b32rev = dict([(v[0], k) for k, v in _b32alphabet.items()]) +_b32rev_ascii = dict([(v[0], k) for k, v in _b32ascii_alphabet.items()]) def b32encode(s): @@ -211,7 +225,7 @@ the input is incorrectly padded or if there are non-alphabet characters present in the input. """ - if not isinstance(s, bytes_types): + if not isinstance(s, ascii_types): raise TypeError("expected bytes, not %s" % s.__class__.__name__) quanta, leftover = divmod(len(s), 8) if leftover: @@ -220,7 +234,7 @@ # False, or the character to map the digit 1 (one) to. It should be # either L (el) or I (eye). if map01 is not None: - if not isinstance(map01, bytes_types): + if not isinstance(map01, ascii_types): raise TypeError("expected bytes, not %s" % map01.__class__.__name__) assert len(map01) == 1, repr(map01) s = _translate(s, {b'0': b'O', b'1': map01}) @@ -230,7 +244,11 @@ # characters because this will tell us how many null bytes to remove from # the end of the decoded string. padchars = 0 - mo = re.search(b'(?P[=]*)$', s) + if isinstance(s, bytes_types): + pattern = b'(?P[=]*)$' + else: + pattern = '(?P[=]*)$' + mo = re.search(pattern, s) if mo: padchars = len(mo.group('pad')) if padchars > 0: @@ -240,10 +258,13 @@ acc = 0 shift = 35 for c in s: - val = _b32rev.get(c) + _b32rev_ = _b32rev + if isinstance(c, ascii_types): + _b32rev_ = _b32rev_ascii + val = _b32rev_.get(c) if val is None: raise TypeError('Non-base32 digit found') - acc += _b32rev[c] << shift + acc += _b32rev_[c] << shift shift -= 5 if shift < 0: parts.append(binascii.unhexlify(bytes('%010x' % acc, "ascii"))) @@ -292,7 +313,7 @@ s were incorrectly padded or if there are non-alphabet characters present in the string. """ - if not isinstance(s, bytes_types): + if not isinstance(s, ascii_types): raise TypeError("expected bytes, not %s" % s.__class__.__name__) if casefold: s = s.upper() diff --git a/Lib/test/test_base64.py b/Lib/test/test_base64.py --- a/Lib/test/test_base64.py +++ b/Lib/test/test_base64.py @@ -116,9 +116,6 @@ eq(base64.b64decode(b''), b'') # Test with arbitrary alternative characters eq(base64.b64decode(b'01a*b$cd', altchars=b'*$'), b'\xd3V\xbeo\xf7\x1d') - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.b64decode, "") - self.assertRaises(TypeError, base64.b64decode, b"", altchars="") # Test standard alphabet eq(base64.standard_b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org") eq(base64.standard_b64decode(b"YQ=="), b"a") @@ -131,9 +128,6 @@ b"abcdefghijklmnopqrstuvwxyz" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"0123456789!@#0^&*();:<>,. []{}") - # Check if passing a str object raises an error - self.assertRaises(TypeError, base64.standard_b64decode, "") - self.assertRaises(TypeError, base64.standard_b64decode, b"", altchars="") # Test with 'URL safe' alternative characters eq(base64.urlsafe_b64decode(b'01a-b_cd'), b'\xd3V\xbeo\xf7\x1d') self.assertRaises(TypeError, base64.urlsafe_b64decode, "") @@ -157,6 +151,15 @@ with self.assertRaises(binascii.Error): base64.b64decode(bstr, validate=True) + def test_b64decode_ascii_chars(self): + # issue 13641: Decoding functions in the base64 + # module could accept unicode strings. + tests = (('MDA=', b'00'), + ('AA==', b'\x00'), + ('/w==', b'\xff')) + for str, result in tests: + self.assertEqual(base64.b64decode(str), result) + def test_b32encode(self): eq = self.assertEqual eq(base64.b32encode(b''), b'') @@ -177,7 +180,18 @@ eq(base64.b32decode(b'MFRGG==='), b'abc') eq(base64.b32decode(b'MFRGGZA='), b'abcd') eq(base64.b32decode(b'MFRGGZDF'), b'abcde') - self.assertRaises(TypeError, base64.b32decode, "") + + def test_b32decode_ascii_chars(self): + # issue 13641: Decoding functions in the base64 + # module could accept unicode strings. + tests = (('AA======', b'\x00'), + ('ME======', b'a'), + ('MFRA====', b'ab'), + ('MFRGG===', b'abc'), + ('MFRGGZA=', b'abcd'), + ('MFRGGZDF', b'abcde')) + for str, result in tests: + self.assertEqual(base64.b32decode(str), result) def test_b32decode_casefold(self): eq = self.assertEqual @@ -199,7 +213,6 @@ eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'M1023456', map01=b'L'), b'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'M1023456', map01=b'I'), b'b\x1d\xad\xf3\xbe') - self.assertRaises(TypeError, base64.b32decode, b"", map01="") def test_b32decode_error(self): self.assertRaises(binascii.Error, base64.b32decode, b'abc')