Issue #13641: Decoding functions in the base64 module now accept ASCII-only unicode strings.

Patch by Catalin Iacob.
This commit is contained in:
Antoine Pitrou 2012-02-20 19:30:23 +01:00
parent 0588eac218
commit ea6b4d5f70
4 changed files with 130 additions and 73 deletions

View File

@ -18,9 +18,14 @@ POST request. The encoding algorithm is not the same as the
There are two interfaces provided by this module. The modern interface There are two interfaces provided by this module. The modern interface
supports encoding and decoding ASCII byte string objects using all three supports encoding and decoding ASCII byte string objects using all three
alphabets. The legacy interface provides for encoding and decoding to and from alphabets. Additionally, the decoding functions of the modern interface also
file-like objects as well as byte strings, but only using the Base64 standard accept Unicode strings containing only ASCII characters. The legacy interface
alphabet. provides for encoding and decoding to and from file-like objects as well as
byte strings, but only using the Base64 standard alphabet.
.. versionchanged:: 3.3
ASCII-only Unicode strings are now accepted by the decoding functions of
the modern interface.
The modern interface provides: The modern interface provides:

View File

@ -29,6 +29,16 @@ __all__ = [
bytes_types = (bytes, bytearray) # Types acceptable as binary data bytes_types = (bytes, bytearray) # Types acceptable as binary data
def _bytes_from_decode_data(s):
if isinstance(s, str):
try:
return s.encode('ascii')
except UnicodeEncodeError:
raise ValueError('string argument should contain only ASCII characters')
elif isinstance(s, bytes_types):
return s
else:
raise TypeError("argument should be bytes or ASCII string, not %s" % s.__class__.__name__)
def _translate(s, altchars): def _translate(s, altchars):
if not isinstance(s, bytes_types): if not isinstance(s, bytes_types):
@ -79,12 +89,9 @@ def b64decode(s, altchars=None, validate=False):
discarded prior to the padding check. If validate is True, discarded prior to the padding check. If validate is True,
non-base64-alphabet characters in the input result in a binascii.Error. non-base64-alphabet characters in the input result in a binascii.Error.
""" """
if not isinstance(s, bytes_types): s = _bytes_from_decode_data(s)
raise TypeError("expected bytes, not %s" % s.__class__.__name__)
if altchars is not None: if altchars is not None:
if not isinstance(altchars, bytes_types): altchars = _bytes_from_decode_data(altchars)
raise TypeError("expected bytes, not %s"
% altchars.__class__.__name__)
assert len(altchars) == 2, repr(altchars) assert len(altchars) == 2, repr(altchars)
s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'}) s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'})
if validate and not re.match(b'^[A-Za-z0-9+/]*={0,2}$', s): if validate and not re.match(b'^[A-Za-z0-9+/]*={0,2}$', s):
@ -211,8 +218,7 @@ def b32decode(s, casefold=False, map01=None):
the input is incorrectly padded or if there are non-alphabet the input is incorrectly padded or if there are non-alphabet
characters present in the input. characters present in the input.
""" """
if not isinstance(s, bytes_types): s = _bytes_from_decode_data(s)
raise TypeError("expected bytes, not %s" % s.__class__.__name__)
quanta, leftover = divmod(len(s), 8) quanta, leftover = divmod(len(s), 8)
if leftover: if leftover:
raise binascii.Error('Incorrect padding') raise binascii.Error('Incorrect padding')
@ -220,8 +226,7 @@ def b32decode(s, casefold=False, map01=None):
# False, or the character to map the digit 1 (one) to. It should be # False, or the character to map the digit 1 (one) to. It should be
# either L (el) or I (eye). # either L (el) or I (eye).
if map01 is not None: if map01 is not None:
if not isinstance(map01, bytes_types): map01 = _bytes_from_decode_data(map01)
raise TypeError("expected bytes, not %s" % map01.__class__.__name__)
assert len(map01) == 1, repr(map01) assert len(map01) == 1, repr(map01)
s = _translate(s, {b'0': b'O', b'1': map01}) s = _translate(s, {b'0': b'O', b'1': map01})
if casefold: if casefold:
@ -292,8 +297,7 @@ def b16decode(s, casefold=False):
s were incorrectly padded or if there are non-alphabet characters s were incorrectly padded or if there are non-alphabet characters
present in the string. present in the string.
""" """
if not isinstance(s, bytes_types): s = _bytes_from_decode_data(s)
raise TypeError("expected bytes, not %s" % s.__class__.__name__)
if casefold: if casefold:
s = s.upper() s = s.upper()
if re.search(b'[^0-9A-F]', s): if re.search(b'[^0-9A-F]', s):

View File

@ -102,44 +102,53 @@ class BaseXYTestCase(unittest.TestCase):
def test_b64decode(self): def test_b64decode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org")
eq(base64.b64decode(b'AA=='), b'\x00') tests = {b"d3d3LnB5dGhvbi5vcmc=": b"www.python.org",
eq(base64.b64decode(b"YQ=="), b"a") b'AA==': b'\x00',
eq(base64.b64decode(b"YWI="), b"ab") b"YQ==": b"a",
eq(base64.b64decode(b"YWJj"), b"abc") b"YWI=": b"ab",
eq(base64.b64decode(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" b"YWJj": b"abc",
b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT" b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="), b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
b"abcdefghijklmnopqrstuvwxyz" b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==":
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
b"0123456789!@#0^&*();:<>,. []{}") b"abcdefghijklmnopqrstuvwxyz"
eq(base64.b64decode(b''), b'') b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
b"0123456789!@#0^&*();:<>,. []{}",
b'': b'',
}
for data, res in tests.items():
eq(base64.b64decode(data), res)
eq(base64.b64decode(data.decode('ascii')), res)
# Test with arbitrary alternative characters # Test with arbitrary alternative characters
eq(base64.b64decode(b'01a*b$cd', altchars=b'*$'), b'\xd3V\xbeo\xf7\x1d') tests_altchars = {(b'01a*b$cd', b'*$'): b'\xd3V\xbeo\xf7\x1d',
# Check if passing a str object raises an error }
self.assertRaises(TypeError, base64.b64decode, "") for (data, altchars), res in tests_altchars.items():
self.assertRaises(TypeError, base64.b64decode, b"", altchars="") data_str = data.decode('ascii')
altchars_str = altchars.decode('ascii')
eq(base64.b64decode(data, altchars=altchars), res)
eq(base64.b64decode(data_str, altchars=altchars), res)
eq(base64.b64decode(data, altchars=altchars_str), res)
eq(base64.b64decode(data_str, altchars=altchars_str), res)
# Test standard alphabet # Test standard alphabet
eq(base64.standard_b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org") for data, res in tests.items():
eq(base64.standard_b64decode(b"YQ=="), b"a") eq(base64.standard_b64decode(data), res)
eq(base64.standard_b64decode(b"YWI="), b"ab") eq(base64.standard_b64decode(data.decode('ascii')), res)
eq(base64.standard_b64decode(b"YWJj"), b"abc")
eq(base64.standard_b64decode(b""), b"")
eq(base64.standard_b64decode(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="),
b"abcdefghijklmnopqrstuvwxyz"
b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
b"0123456789!@#0^&*();:<>,. []{}")
# Check if passing a str object raises an error
self.assertRaises(TypeError, base64.standard_b64decode, "")
self.assertRaises(TypeError, base64.standard_b64decode, b"", altchars="")
# Test with 'URL safe' alternative characters # Test with 'URL safe' alternative characters
eq(base64.urlsafe_b64decode(b'01a-b_cd'), b'\xd3V\xbeo\xf7\x1d') tests_urlsafe = {b'01a-b_cd': b'\xd3V\xbeo\xf7\x1d',
self.assertRaises(TypeError, base64.urlsafe_b64decode, "") b'': b'',
}
for data, res in tests_urlsafe.items():
eq(base64.urlsafe_b64decode(data), res)
eq(base64.urlsafe_b64decode(data.decode('ascii')), res)
def test_b64decode_padding_error(self): def test_b64decode_padding_error(self):
self.assertRaises(binascii.Error, base64.b64decode, b'abc') self.assertRaises(binascii.Error, base64.b64decode, b'abc')
self.assertRaises(binascii.Error, base64.b64decode, 'abc')
def test_b64decode_invalid_chars(self): def test_b64decode_invalid_chars(self):
# issue 1466065: Test some invalid characters. # issue 1466065: Test some invalid characters.
@ -154,8 +163,10 @@ class BaseXYTestCase(unittest.TestCase):
(b'YWJj\nYWI=', b'abcab')) (b'YWJj\nYWI=', b'abcab'))
for bstr, res in tests: for bstr, res in tests:
self.assertEqual(base64.b64decode(bstr), res) self.assertEqual(base64.b64decode(bstr), res)
self.assertEqual(base64.b64decode(bstr.decode('ascii')), res)
with self.assertRaises(binascii.Error): with self.assertRaises(binascii.Error):
base64.b64decode(bstr, validate=True) base64.b64decode(bstr, validate=True)
base64.b64decode(bstr.decode('ascii'), validate=True)
def test_b32encode(self): def test_b32encode(self):
eq = self.assertEqual eq = self.assertEqual
@ -170,40 +181,62 @@ class BaseXYTestCase(unittest.TestCase):
def test_b32decode(self): def test_b32decode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b32decode(b''), b'') tests = {b'': b'',
eq(base64.b32decode(b'AA======'), b'\x00') b'AA======': b'\x00',
eq(base64.b32decode(b'ME======'), b'a') b'ME======': b'a',
eq(base64.b32decode(b'MFRA===='), b'ab') b'MFRA====': b'ab',
eq(base64.b32decode(b'MFRGG==='), b'abc') b'MFRGG===': b'abc',
eq(base64.b32decode(b'MFRGGZA='), b'abcd') b'MFRGGZA=': b'abcd',
eq(base64.b32decode(b'MFRGGZDF'), b'abcde') b'MFRGGZDF': b'abcde',
self.assertRaises(TypeError, base64.b32decode, "") }
for data, res in tests.items():
eq(base64.b32decode(data), res)
eq(base64.b32decode(data.decode('ascii')), res)
def test_b32decode_casefold(self): def test_b32decode_casefold(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b32decode(b'', True), b'') tests = {b'': b'',
eq(base64.b32decode(b'ME======', True), b'a') b'ME======': b'a',
eq(base64.b32decode(b'MFRA====', True), b'ab') b'MFRA====': b'ab',
eq(base64.b32decode(b'MFRGG===', True), b'abc') b'MFRGG===': b'abc',
eq(base64.b32decode(b'MFRGGZA=', True), b'abcd') b'MFRGGZA=': b'abcd',
eq(base64.b32decode(b'MFRGGZDF', True), b'abcde') b'MFRGGZDF': b'abcde',
# Lower cases # Lower cases
eq(base64.b32decode(b'me======', True), b'a') b'me======': b'a',
eq(base64.b32decode(b'mfra====', True), b'ab') b'mfra====': b'ab',
eq(base64.b32decode(b'mfrgg===', True), b'abc') b'mfrgg===': b'abc',
eq(base64.b32decode(b'mfrggza=', True), b'abcd') b'mfrggza=': b'abcd',
eq(base64.b32decode(b'mfrggzdf', True), b'abcde') b'mfrggzdf': b'abcde',
# Expected exceptions }
for data, res in tests.items():
eq(base64.b32decode(data, True), res)
eq(base64.b32decode(data.decode('ascii'), True), res)
self.assertRaises(TypeError, base64.b32decode, b'me======') self.assertRaises(TypeError, base64.b32decode, b'me======')
self.assertRaises(TypeError, base64.b32decode, 'me======')
# Mapping zero and one # Mapping zero and one
eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe')
eq(base64.b32decode(b'M1023456', map01=b'L'), b'b\xdd\xad\xf3\xbe') eq(base64.b32decode('MLO23456'), b'b\xdd\xad\xf3\xbe')
eq(base64.b32decode(b'M1023456', map01=b'I'), b'b\x1d\xad\xf3\xbe')
self.assertRaises(TypeError, base64.b32decode, b"", map01="") map_tests = {(b'M1023456', b'L'): b'b\xdd\xad\xf3\xbe',
(b'M1023456', b'I'): b'b\x1d\xad\xf3\xbe',
}
for (data, map01), res in map_tests.items():
data_str = data.decode('ascii')
map01_str = map01.decode('ascii')
eq(base64.b32decode(data, map01=map01), res)
eq(base64.b32decode(data_str, map01=map01), res)
eq(base64.b32decode(data, map01=map01_str), res)
eq(base64.b32decode(data_str, map01=map01_str), res)
def test_b32decode_error(self): def test_b32decode_error(self):
self.assertRaises(binascii.Error, base64.b32decode, b'abc') for data in [b'abc', b'ABCDEF==']:
self.assertRaises(binascii.Error, base64.b32decode, b'ABCDEF==') with self.assertRaises(binascii.Error):
base64.b32decode(data)
base64.b32decode(data.decode('ascii'))
def test_b16encode(self): def test_b16encode(self):
eq = self.assertEqual eq = self.assertEqual
@ -214,12 +247,24 @@ class BaseXYTestCase(unittest.TestCase):
def test_b16decode(self): def test_b16decode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b16decode(b'0102ABCDEF'), b'\x01\x02\xab\xcd\xef') eq(base64.b16decode(b'0102ABCDEF'), b'\x01\x02\xab\xcd\xef')
eq(base64.b16decode('0102ABCDEF'), b'\x01\x02\xab\xcd\xef')
eq(base64.b16decode(b'00'), b'\x00') eq(base64.b16decode(b'00'), b'\x00')
eq(base64.b16decode('00'), b'\x00')
# Lower case is not allowed without a flag # Lower case is not allowed without a flag
self.assertRaises(binascii.Error, base64.b16decode, b'0102abcdef') self.assertRaises(binascii.Error, base64.b16decode, b'0102abcdef')
self.assertRaises(binascii.Error, base64.b16decode, '0102abcdef')
# Case fold # Case fold
eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef') eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef')
self.assertRaises(TypeError, base64.b16decode, "") eq(base64.b16decode('0102abcdef', True), b'\x01\x02\xab\xcd\xef')
def test_decode_nonascii_str(self):
decode_funcs = (base64.b64decode,
base64.standard_b64decode,
base64.urlsafe_b64decode,
base64.b32decode,
base64.b16decode)
for f in decode_funcs:
self.assertRaises(ValueError, f, 'with non-ascii \xcb')
def test_ErrorHeritage(self): def test_ErrorHeritage(self):
self.assertTrue(issubclass(binascii.Error, ValueError)) self.assertTrue(issubclass(binascii.Error, ValueError))

View File

@ -469,6 +469,9 @@ Core and Builtins
Library Library
------- -------
- Issue #13641: Decoding functions in the base64 module now accept ASCII-only
unicode strings. Patch by Catalin Iacob.
- Issue #14043: Speed up importlib's _FileFinder by at least 8x, and add a - Issue #14043: Speed up importlib's _FileFinder by at least 8x, and add a
new importlib.invalidate_caches() function. new importlib.invalidate_caches() function.