Make test_base64 pass.

Change binascii.Error to derive from ValueError
and raise binascii.Error everywhere where values are bad
(why on earth did the old code use TypeError?!?).
This commit is contained in:
Guido van Rossum 2007-05-22 21:56:47 +00:00
parent 0e225aa09b
commit 4581ae5fa2
5 changed files with 268 additions and 232 deletions

View File

@ -4,6 +4,7 @@
# Modified 04-Oct-1995 by Jack Jansen to use binascii module # Modified 04-Oct-1995 by Jack Jansen to use binascii module
# Modified 30-Dec-2003 by Barry Warsaw to add full RFC 3548 support # Modified 30-Dec-2003 by Barry Warsaw to add full RFC 3548 support
# Modified 22-May-2007 by Guido van Rossum to use bytes everywhere
import re import re
import struct import struct
@ -25,122 +26,131 @@ __all__ = [
'urlsafe_b64encode', 'urlsafe_b64decode', 'urlsafe_b64encode', 'urlsafe_b64decode',
] ]
_translation = [chr(_x) for _x in range(256)]
EMPTYSTRING = ''
def _translate(s, altchars): def _translate(s, altchars):
translation = _translation[:] assert isinstance(s, bytes), type(s)
translation = bytes(range(256))
for k, v in altchars.items(): for k, v in altchars.items():
translation[ord(k)] = v translation[ord(k)] = v[0]
return s.translate(''.join(translation)) return s.translate(translation)
# Base64 encoding/decoding uses binascii # Base64 encoding/decoding uses binascii
def b64encode(s, altchars=None): def b64encode(s, altchars=None):
"""Encode a string using Base64. """Encode a byte string using Base64.
s is the string to encode. Optional altchars must be a string of at least s is the byte string to encode. Optional altchars must be a byte
length 2 (additional characters are ignored) which specifies an string of length 2 which specifies an alternative alphabet for the
alternative alphabet for the '+' and '/' characters. This allows an '+' and '/' characters. This allows an application to
application to e.g. generate url or filesystem safe Base64 strings. e.g. generate url or filesystem safe Base64 strings.
The encoded string is returned. The encoded byte string is returned.
""" """
if not isinstance(s, bytes):
s = bytes(s)
# Strip off the trailing newline # Strip off the trailing newline
encoded = binascii.b2a_base64(s)[:-1] encoded = binascii.b2a_base64(s)[:-1]
if altchars is not None: if altchars is not None:
return _translate(encoded, {'+': altchars[0], '/': altchars[1]}) if not isinstance(altchars, bytes):
altchars = bytes(altchars)
assert len(altchars) == 2, repr(altchars)
return _translate(encoded, {'+': altchars[0:1], '/': altchars[1:2]})
return encoded return encoded
def b64decode(s, altchars=None): def b64decode(s, altchars=None):
"""Decode a Base64 encoded string. """Decode a Base64 encoded byte string.
s is the string to decode. Optional altchars must be a string of at least s is the byte string to decode. Optional altchars must be a
length 2 (additional characters are ignored) which specifies the string of length 2 which specifies the alternative alphabet used
alternative alphabet used instead of the '+' and '/' characters. instead of the '+' and '/' characters.
The decoded string is returned. A TypeError is raised if s were The decoded byte string is returned. binascii.Error is raised if
incorrectly padded or if there are non-alphabet characters present in the s were incorrectly padded or if there are non-alphabet characters
string. present in the string.
""" """
if not isinstance(s, bytes):
s = bytes(s)
if altchars is not None: if altchars is not None:
s = _translate(s, {altchars[0]: '+', altchars[1]: '/'}) if not isinstance(altchars, bytes):
try: altchars = bytes(altchars)
return binascii.a2b_base64(s) assert len(altchars) == 2, repr(altchars)
except binascii.Error as msg: s = _translate(s, {chr(altchars[0]): b'+', chr(altchars[1]): b'/'})
# Transform this exception for consistency return binascii.a2b_base64(s)
raise TypeError(msg)
def standard_b64encode(s): def standard_b64encode(s):
"""Encode a string using the standard Base64 alphabet. """Encode a byte string using the standard Base64 alphabet.
s is the string to encode. The encoded string is returned. s is the byte string to encode. The encoded byte string is returned.
""" """
return b64encode(s) return b64encode(s)
def standard_b64decode(s): def standard_b64decode(s):
"""Decode a string encoded with the standard Base64 alphabet. """Decode a byte string encoded with the standard Base64 alphabet.
s is the string to decode. The decoded string is returned. A TypeError s is the byte string to decode. The decoded byte string is
is raised if the string is incorrectly padded or if there are non-alphabet returned. binascii.Error is raised if the input is incorrectly
characters present in the string. padded or if there are non-alphabet characters present in the
input.
""" """
return b64decode(s) return b64decode(s)
def urlsafe_b64encode(s): def urlsafe_b64encode(s):
"""Encode a string using a url-safe Base64 alphabet. """Encode a byte string using a url-safe Base64 alphabet.
s is the string to encode. The encoded string is returned. The alphabet s is the byte string to encode. The encoded byte string is
uses '-' instead of '+' and '_' instead of '/'. returned. The alphabet uses '-' instead of '+' and '_' instead of
'/'.
""" """
return b64encode(s, '-_') return b64encode(s, b'-_')
def urlsafe_b64decode(s): def urlsafe_b64decode(s):
"""Decode a string encoded with the standard Base64 alphabet. """Decode a byte string encoded with the standard Base64 alphabet.
s is the string to decode. The decoded string is returned. A TypeError s is the byte string to decode. The decoded byte string is
is raised if the string is incorrectly padded or if there are non-alphabet returned. binascii.Error is raised if the input is incorrectly
characters present in the string. padded or if there are non-alphabet characters present in the
input.
The alphabet uses '-' instead of '+' and '_' instead of '/'. The alphabet uses '-' instead of '+' and '_' instead of '/'.
""" """
return b64decode(s, '-_') return b64decode(s, b'-_')
# Base32 encoding/decoding must be done in Python # Base32 encoding/decoding must be done in Python
_b32alphabet = { _b32alphabet = {
0: 'A', 9: 'J', 18: 'S', 27: '3', 0: b'A', 9: b'J', 18: b'S', 27: b'3',
1: 'B', 10: 'K', 19: 'T', 28: '4', 1: b'B', 10: b'K', 19: b'T', 28: b'4',
2: 'C', 11: 'L', 20: 'U', 29: '5', 2: b'C', 11: b'L', 20: b'U', 29: b'5',
3: 'D', 12: 'M', 21: 'V', 30: '6', 3: b'D', 12: b'M', 21: b'V', 30: b'6',
4: 'E', 13: 'N', 22: 'W', 31: '7', 4: b'E', 13: b'N', 22: b'W', 31: b'7',
5: 'F', 14: 'O', 23: 'X', 5: b'F', 14: b'O', 23: b'X',
6: 'G', 15: 'P', 24: 'Y', 6: b'G', 15: b'P', 24: b'Y',
7: 'H', 16: 'Q', 25: 'Z', 7: b'H', 16: b'Q', 25: b'Z',
8: 'I', 17: 'R', 26: '2', 8: b'I', 17: b'R', 26: b'2',
} }
_b32tab = [v for k, v in sorted(_b32alphabet.items())] _b32tab = [v[0] for k, v in sorted(_b32alphabet.items())]
_b32rev = dict([(v, int(k)) for k, v in _b32alphabet.items()]) _b32rev = dict([(v[0], k) for k, v in _b32alphabet.items()])
def b32encode(s): def b32encode(s):
"""Encode a string using Base32. """Encode a byte string using Base32.
s is the string to encode. The encoded string is returned. s is the byte string to encode. The encoded byte string is returned.
""" """
parts = [] if not isinstance(s, bytes):
s = bytes(s)
quanta, leftover = divmod(len(s), 5) quanta, leftover = divmod(len(s), 5)
# Pad the last quantum with zero bits if necessary # Pad the last quantum with zero bits if necessary
if leftover: if leftover:
s += ('\0' * (5 - leftover)) s = s + bytes(5 - leftover) # Don't use += !
quanta += 1 quanta += 1
encoded = bytes()
for i in range(quanta): for i in range(quanta):
# c1 and c2 are 16 bits wide, c3 is 8 bits wide. The intent of this # c1 and c2 are 16 bits wide, c3 is 8 bits wide. The intent of this
# code is to process the 40 bits in units of 5 bits. So we take the 1 # code is to process the 40 bits in units of 5 bits. So we take the 1
@ -150,57 +160,61 @@ def b32encode(s):
c1, c2, c3 = struct.unpack('!HHB', s[i*5:(i+1)*5]) c1, c2, c3 = struct.unpack('!HHB', s[i*5:(i+1)*5])
c2 += (c1 & 1) << 16 # 17 bits wide c2 += (c1 & 1) << 16 # 17 bits wide
c3 += (c2 & 3) << 8 # 10 bits wide c3 += (c2 & 3) << 8 # 10 bits wide
parts.extend([_b32tab[c1 >> 11], # bits 1 - 5 encoded += bytes([_b32tab[c1 >> 11], # bits 1 - 5
_b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10 _b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
_b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15 _b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
_b32tab[c2 >> 12], # bits 16 - 20 (1 - 5) _b32tab[c2 >> 12], # bits 16 - 20 (1 - 5)
_b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10) _b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
_b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15) _b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
_b32tab[c3 >> 5], # bits 31 - 35 (1 - 5) _b32tab[c3 >> 5], # bits 31 - 35 (1 - 5)
_b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5) _b32tab[c3 & 0x1f], # bits 36 - 40 (1 - 5)
]) ])
encoded = EMPTYSTRING.join(parts)
# Adjust for any leftover partial quanta # Adjust for any leftover partial quanta
if leftover == 1: if leftover == 1:
return encoded[:-6] + '======' return encoded[:-6] + b'======'
elif leftover == 2: elif leftover == 2:
return encoded[:-4] + '====' return encoded[:-4] + b'===='
elif leftover == 3: elif leftover == 3:
return encoded[:-3] + '===' return encoded[:-3] + b'==='
elif leftover == 4: elif leftover == 4:
return encoded[:-1] + '=' return encoded[:-1] + b'='
return encoded return encoded
def b32decode(s, casefold=False, map01=None): def b32decode(s, casefold=False, map01=None):
"""Decode a Base32 encoded string. """Decode a Base32 encoded byte string.
s is the string to decode. Optional casefold is a flag specifying whether s is the byte string to decode. Optional casefold is a flag
a lowercase alphabet is acceptable as input. For security purposes, the specifying whether a lowercase alphabet is acceptable as input.
default is False. For security purposes, the default is False.
RFC 3548 allows for optional mapping of the digit 0 (zero) to the letter O RFC 3548 allows for optional mapping of the digit 0 (zero) to the
(oh), and for optional mapping of the digit 1 (one) to either the letter I letter O (oh), and for optional mapping of the digit 1 (one) to
(eye) or letter L (el). The optional argument map01 when not None, either the letter I (eye) or letter L (el). The optional argument
specifies which letter the digit 1 should be mapped to (when map01 is not map01 when not None, specifies which letter the digit 1 should be
None, the digit 0 is always mapped to the letter O). For security mapped to (when map01 is not None, the digit 0 is always mapped to
purposes the default is None, so that 0 and 1 are not allowed in the the letter O). For security purposes the default is None, so that
input. 0 and 1 are not allowed in the input.
The decoded string is returned. A TypeError is raised if s were The decoded byte string is returned. binascii.Error is raised if
incorrectly padded or if there are non-alphabet characters present in the the input is incorrectly padded or if there are non-alphabet
string. characters present in the input.
""" """
if not isinstance(s, bytes):
s = bytes(s)
quanta, leftover = divmod(len(s), 8) quanta, leftover = divmod(len(s), 8)
if leftover: if leftover:
raise TypeError('Incorrect padding') raise binascii.Error('Incorrect padding')
# Handle section 2.4 zero and one mapping. The flag map01 will be either # Handle section 2.4 zero and one mapping. The flag map01 will be either
# False, or the character to map the digit 1 (one) to. It should be # False, or the character to map the digit 1 (one) to. It should be
# either L (el) or I (eye). # either L (el) or I (eye).
if map01: if map01:
s = _translate(s, {'0': 'O', '1': map01}) if not isinstance(map01, bytes):
map01 = bytes(map01)
assert len(map01) == 1, repr(map01)
s = _translate(s, {'0': b'O', '1': map01})
if casefold: if casefold:
s = s.upper() s = bytes(str(s, "ascii").upper(), "ascii")
# Strip off pad characters from the right. We need to count the pad # Strip off pad characters from the right. We need to count the pad
# characters because this will tell us how many null bytes to remove from # characters because this will tell us how many null bytes to remove from
# the end of the decoded string. # the end of the decoded string.
@ -225,9 +239,9 @@ def b32decode(s, casefold=False, map01=None):
acc = 0 acc = 0
shift = 35 shift = 35
# Process the last, partial quanta # Process the last, partial quanta
last = binascii.unhexlify('%010x' % acc) last = binascii.unhexlify(bytes('%010x' % acc))
if padchars == 0: if padchars == 0:
last = '' # No characters last = b'' # No characters
elif padchars == 1: elif padchars == 1:
last = last[:-1] last = last[:-1]
elif padchars == 3: elif padchars == 3:
@ -237,9 +251,9 @@ def b32decode(s, casefold=False, map01=None):
elif padchars == 6: elif padchars == 6:
last = last[:-4] last = last[:-4]
else: else:
raise TypeError('Incorrect padding') raise binascii.Error('Incorrect padding')
parts.append(last) parts.append(last)
return EMPTYSTRING.join(parts) return b''.join(parts)
@ -247,35 +261,37 @@ def b32decode(s, casefold=False, map01=None):
# lowercase. The RFC also recommends against accepting input case # lowercase. The RFC also recommends against accepting input case
# insensitively. # insensitively.
def b16encode(s): def b16encode(s):
"""Encode a string using Base16. """Encode a byte string using Base16.
s is the string to encode. The encoded string is returned. s is the byte string to encode. The encoded byte string is returned.
""" """
return binascii.hexlify(s).upper() return bytes(str(binascii.hexlify(s), "ascii").upper(), "ascii")
def b16decode(s, casefold=False): def b16decode(s, casefold=False):
"""Decode a Base16 encoded string. """Decode a Base16 encoded byte string.
s is the string to decode. Optional casefold is a flag specifying whether s is the byte string to decode. Optional casefold is a flag
a lowercase alphabet is acceptable as input. For security purposes, the specifying whether a lowercase alphabet is acceptable as input.
default is False. For security purposes, the default is False.
The decoded string is returned. A TypeError is raised if s were The decoded byte string is returned. binascii.Error is raised if
incorrectly padded or if there are non-alphabet characters present in the s were incorrectly padded or if there are non-alphabet characters
string. present in the string.
""" """
if not isinstance(s, bytes):
s = bytes(s)
if casefold: if casefold:
s = s.upper() s = bytes(str(s, "ascii").upper(), "ascii")
if re.search('[^0-9A-F]', s): if re.search('[^0-9A-F]', s):
raise TypeError('Non-base16 digit found') raise binascii.Error('Non-base16 digit found')
return binascii.unhexlify(s) return binascii.unhexlify(s)
# Legacy interface. This code could be cleaned up since I don't believe # Legacy interface. This code could be cleaned up since I don't believe
# binascii has any line length limitations. It just doesn't seem worth it # binascii has any line length limitations. It just doesn't seem worth it
# though. # though. The files should be opened in binary mode.
MAXLINESIZE = 76 # Excluding the CRLF MAXLINESIZE = 76 # Excluding the CRLF
MAXBINSIZE = (MAXLINESIZE//4)*3 MAXBINSIZE = (MAXLINESIZE//4)*3
@ -307,22 +323,26 @@ def decode(input, output):
def encodestring(s): def encodestring(s):
"""Encode a string into multiple lines of base-64 data.""" """Encode a string into multiple lines of base-64 data."""
if not isinstance(s, bytes):
s = bytes(s)
pieces = [] pieces = []
for i in range(0, len(s), MAXBINSIZE): for i in range(0, len(s), MAXBINSIZE):
chunk = s[i : i + MAXBINSIZE] chunk = s[i : i + MAXBINSIZE]
pieces.append(binascii.b2a_base64(chunk)) pieces.append(binascii.b2a_base64(chunk))
return "".join(pieces) return b"".join(pieces)
def decodestring(s): def decodestring(s):
"""Decode a string.""" """Decode a string."""
if not isinstance(s, bytes):
s = bytes(s)
return binascii.a2b_base64(s) return binascii.a2b_base64(s)
# Useable as a script... # Usable as a script...
def test(): def main():
"""Small test program""" """Small main program"""
import sys, getopt import sys, getopt
try: try:
opts, args = getopt.getopt(sys.argv[1:], 'deut') opts, args = getopt.getopt(sys.argv[1:], 'deut')
@ -339,19 +359,22 @@ def test():
if o == '-e': func = encode if o == '-e': func = encode
if o == '-d': func = decode if o == '-d': func = decode
if o == '-u': func = decode if o == '-u': func = decode
if o == '-t': test1(); return if o == '-t': test(); return
if args and args[0] != '-': if args and args[0] != '-':
func(open(args[0], 'rb'), sys.stdout) func(open(args[0], 'rb'), sys.stdout)
else: else:
func(sys.stdin, sys.stdout) func(sys.stdin, sys.stdout)
def test1(): def test():
s0 = "Aladdin:open sesame" s0 = b"Aladdin:open sesame"
print(repr(s0))
s1 = encodestring(s0) s1 = encodestring(s0)
print(repr(s1))
s2 = decodestring(s1) s2 = decodestring(s1)
print(s0, repr(s1), s2) print(repr(s2))
assert s0 == s2
if __name__ == '__main__': if __name__ == '__main__':
test() main()

View File

@ -1,37 +1,38 @@
import unittest import unittest
from test import test_support from test import test_support
import base64 import base64
import binascii
class LegacyBase64TestCase(unittest.TestCase): class LegacyBase64TestCase(unittest.TestCase):
def test_encodestring(self): def test_encodestring(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.encodestring("www.python.org"), "d3d3LnB5dGhvbi5vcmc=\n") eq(base64.encodestring(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=\n")
eq(base64.encodestring("a"), "YQ==\n") eq(base64.encodestring(b"a"), b"YQ==\n")
eq(base64.encodestring("ab"), "YWI=\n") eq(base64.encodestring(b"ab"), b"YWI=\n")
eq(base64.encodestring("abc"), "YWJj\n") eq(base64.encodestring(b"abc"), b"YWJj\n")
eq(base64.encodestring(""), "") eq(base64.encodestring(b""), b"")
eq(base64.encodestring("abcdefghijklmnopqrstuvwxyz" eq(base64.encodestring(b"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789!@#0^&*();:<>,. []{}"), b"0123456789!@#0^&*();:<>,. []{}"),
"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n") b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n")
def test_decodestring(self): def test_decodestring(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.decodestring("d3d3LnB5dGhvbi5vcmc=\n"), "www.python.org") eq(base64.decodestring(b"d3d3LnB5dGhvbi5vcmc=\n"), b"www.python.org")
eq(base64.decodestring("YQ==\n"), "a") eq(base64.decodestring(b"YQ==\n"), b"a")
eq(base64.decodestring("YWI=\n"), "ab") eq(base64.decodestring(b"YWI=\n"), b"ab")
eq(base64.decodestring("YWJj\n"), "abc") eq(base64.decodestring(b"YWJj\n"), b"abc")
eq(base64.decodestring("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" eq(base64.decodestring(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n"), b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==\n"),
"abcdefghijklmnopqrstuvwxyz" b"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789!@#0^&*();:<>,. []{}") b"0123456789!@#0^&*();:<>,. []{}")
eq(base64.decodestring(''), '') eq(base64.decodestring(b''), b'')
def test_encode(self): def test_encode(self):
eq = self.assertEqual eq = self.assertEqual
@ -59,127 +60,130 @@ class BaseXYTestCase(unittest.TestCase):
def test_b64encode(self): def test_b64encode(self):
eq = self.assertEqual eq = self.assertEqual
# Test default alphabet # Test default alphabet
eq(base64.b64encode("www.python.org"), "d3d3LnB5dGhvbi5vcmc=") eq(base64.b64encode(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=")
eq(base64.b64encode('\x00'), 'AA==') eq(base64.b64encode(b'\x00'), b'AA==')
eq(base64.b64encode("a"), "YQ==") eq(base64.b64encode(b"a"), b"YQ==")
eq(base64.b64encode("ab"), "YWI=") eq(base64.b64encode(b"ab"), b"YWI=")
eq(base64.b64encode("abc"), "YWJj") eq(base64.b64encode(b"abc"), b"YWJj")
eq(base64.b64encode(""), "") eq(base64.b64encode(b""), b"")
eq(base64.b64encode("abcdefghijklmnopqrstuvwxyz" eq(base64.b64encode(b"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789!@#0^&*();:<>,. []{}"), b"0123456789!@#0^&*();:<>,. []{}"),
"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==") b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==")
# Test with arbitrary alternative characters # Test with arbitrary alternative characters
eq(base64.b64encode('\xd3V\xbeo\xf7\x1d', altchars='*$'), '01a*b$cd') eq(base64.b64encode(b'\xd3V\xbeo\xf7\x1d', altchars='*$'), b'01a*b$cd')
# Test standard alphabet # Test standard alphabet
eq(base64.standard_b64encode("www.python.org"), "d3d3LnB5dGhvbi5vcmc=") eq(base64.standard_b64encode(b"www.python.org"), b"d3d3LnB5dGhvbi5vcmc=")
eq(base64.standard_b64encode("a"), "YQ==") eq(base64.standard_b64encode(b"a"), b"YQ==")
eq(base64.standard_b64encode("ab"), "YWI=") eq(base64.standard_b64encode(b"ab"), b"YWI=")
eq(base64.standard_b64encode("abc"), "YWJj") eq(base64.standard_b64encode(b"abc"), b"YWJj")
eq(base64.standard_b64encode(""), "") eq(base64.standard_b64encode(b""), b"")
eq(base64.standard_b64encode("abcdefghijklmnopqrstuvwxyz" eq(base64.standard_b64encode(b"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789!@#0^&*();:<>,. []{}"), b"0123456789!@#0^&*();:<>,. []{}"),
"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==") b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ==")
# Test with 'URL safe' alternative characters # Test with 'URL safe' alternative characters
eq(base64.urlsafe_b64encode('\xd3V\xbeo\xf7\x1d'), '01a-b_cd') eq(base64.urlsafe_b64encode(b'\xd3V\xbeo\xf7\x1d'), b'01a-b_cd')
def test_b64decode(self): def test_b64decode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b64decode("d3d3LnB5dGhvbi5vcmc="), "www.python.org") eq(base64.b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org")
eq(base64.b64decode('AA=='), '\x00') eq(base64.b64decode(b'AA=='), b'\x00')
eq(base64.b64decode("YQ=="), "a") eq(base64.b64decode(b"YQ=="), b"a")
eq(base64.b64decode("YWI="), "ab") eq(base64.b64decode(b"YWI="), b"ab")
eq(base64.b64decode("YWJj"), "abc") eq(base64.b64decode(b"YWJj"), b"abc")
eq(base64.b64decode("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" eq(base64.b64decode(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0\nNT"
"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="), b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="),
"abcdefghijklmnopqrstuvwxyz" b"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789!@#0^&*();:<>,. []{}") b"0123456789!@#0^&*();:<>,. []{}")
eq(base64.b64decode(''), '') eq(base64.b64decode(b''), b'')
# Test with arbitrary alternative characters # Test with arbitrary alternative characters
eq(base64.b64decode('01a*b$cd', altchars='*$'), '\xd3V\xbeo\xf7\x1d') eq(base64.b64decode(b'01a*b$cd', altchars='*$'), b'\xd3V\xbeo\xf7\x1d')
# Test standard alphabet # Test standard alphabet
eq(base64.standard_b64decode("d3d3LnB5dGhvbi5vcmc="), "www.python.org") eq(base64.standard_b64decode(b"d3d3LnB5dGhvbi5vcmc="), b"www.python.org")
eq(base64.standard_b64decode("YQ=="), "a") eq(base64.standard_b64decode(b"YQ=="), b"a")
eq(base64.standard_b64decode("YWI="), "ab") eq(base64.standard_b64decode(b"YWI="), b"ab")
eq(base64.standard_b64decode("YWJj"), "abc") eq(base64.standard_b64decode(b"YWJj"), b"abc")
eq(base64.standard_b64decode(""), "") eq(base64.standard_b64decode(b""), b"")
eq(base64.standard_b64decode("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE" eq(base64.standard_b64decode(b"YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXpBQkNE"
"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT" b"RUZHSElKS0xNTk9QUVJTVFVWV1hZWjAxMjM0NT"
"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="), b"Y3ODkhQCMwXiYqKCk7Ojw+LC4gW117fQ=="),
"abcdefghijklmnopqrstuvwxyz" b"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"0123456789!@#0^&*();:<>,. []{}") b"0123456789!@#0^&*();:<>,. []{}")
# Test with 'URL safe' alternative characters # Test with 'URL safe' alternative characters
eq(base64.urlsafe_b64decode('01a-b_cd'), '\xd3V\xbeo\xf7\x1d') eq(base64.urlsafe_b64decode(b'01a-b_cd'), b'\xd3V\xbeo\xf7\x1d')
def test_b64decode_error(self): def test_b64decode_error(self):
self.assertRaises(TypeError, base64.b64decode, 'abc') self.assertRaises(binascii.Error, base64.b64decode, b'abc')
def test_b32encode(self): def test_b32encode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b32encode(''), '') eq(base64.b32encode(b''), b'')
eq(base64.b32encode('\x00'), 'AA======') eq(base64.b32encode(b'\x00'), b'AA======')
eq(base64.b32encode('a'), 'ME======') eq(base64.b32encode(b'a'), b'ME======')
eq(base64.b32encode('ab'), 'MFRA====') eq(base64.b32encode(b'ab'), b'MFRA====')
eq(base64.b32encode('abc'), 'MFRGG===') eq(base64.b32encode(b'abc'), b'MFRGG===')
eq(base64.b32encode('abcd'), 'MFRGGZA=') eq(base64.b32encode(b'abcd'), b'MFRGGZA=')
eq(base64.b32encode('abcde'), 'MFRGGZDF') eq(base64.b32encode(b'abcde'), b'MFRGGZDF')
def test_b32decode(self): def test_b32decode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b32decode(''), '') eq(base64.b32decode(b''), b'')
eq(base64.b32decode('AA======'), '\x00') eq(base64.b32decode(b'AA======'), b'\x00')
eq(base64.b32decode('ME======'), 'a') eq(base64.b32decode(b'ME======'), b'a')
eq(base64.b32decode('MFRA===='), 'ab') eq(base64.b32decode(b'MFRA===='), b'ab')
eq(base64.b32decode('MFRGG==='), 'abc') eq(base64.b32decode(b'MFRGG==='), b'abc')
eq(base64.b32decode('MFRGGZA='), 'abcd') eq(base64.b32decode(b'MFRGGZA='), b'abcd')
eq(base64.b32decode('MFRGGZDF'), 'abcde') eq(base64.b32decode(b'MFRGGZDF'), b'abcde')
def test_b32decode_casefold(self): def test_b32decode_casefold(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b32decode('', True), '') eq(base64.b32decode(b'', True), b'')
eq(base64.b32decode('ME======', True), 'a') eq(base64.b32decode(b'ME======', True), b'a')
eq(base64.b32decode('MFRA====', True), 'ab') eq(base64.b32decode(b'MFRA====', True), b'ab')
eq(base64.b32decode('MFRGG===', True), 'abc') eq(base64.b32decode(b'MFRGG===', True), b'abc')
eq(base64.b32decode('MFRGGZA=', True), 'abcd') eq(base64.b32decode(b'MFRGGZA=', True), b'abcd')
eq(base64.b32decode('MFRGGZDF', True), 'abcde') eq(base64.b32decode(b'MFRGGZDF', True), b'abcde')
# Lower cases # Lower cases
eq(base64.b32decode('me======', True), 'a') eq(base64.b32decode(b'me======', True), b'a')
eq(base64.b32decode('mfra====', True), 'ab') eq(base64.b32decode(b'mfra====', True), b'ab')
eq(base64.b32decode('mfrgg===', True), 'abc') eq(base64.b32decode(b'mfrgg===', True), b'abc')
eq(base64.b32decode('mfrggza=', True), 'abcd') eq(base64.b32decode(b'mfrggza=', True), b'abcd')
eq(base64.b32decode('mfrggzdf', True), 'abcde') eq(base64.b32decode(b'mfrggzdf', True), b'abcde')
# Expected exceptions # Expected exceptions
self.assertRaises(TypeError, base64.b32decode, 'me======') self.assertRaises(TypeError, base64.b32decode, b'me======')
# Mapping zero and one # Mapping zero and one
eq(base64.b32decode('MLO23456'), 'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'MLO23456'), b'b\xdd\xad\xf3\xbe')
eq(base64.b32decode('M1023456', map01='L'), 'b\xdd\xad\xf3\xbe') eq(base64.b32decode(b'M1023456', map01=b'L'), b'b\xdd\xad\xf3\xbe')
eq(base64.b32decode('M1023456', map01='I'), 'b\x1d\xad\xf3\xbe') eq(base64.b32decode(b'M1023456', map01=b'I'), b'b\x1d\xad\xf3\xbe')
def test_b32decode_error(self): def test_b32decode_error(self):
self.assertRaises(TypeError, base64.b32decode, 'abc') self.assertRaises(binascii.Error, base64.b32decode, b'abc')
self.assertRaises(TypeError, base64.b32decode, 'ABCDEF==') self.assertRaises(binascii.Error, base64.b32decode, b'ABCDEF==')
def test_b16encode(self): def test_b16encode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b16encode('\x01\x02\xab\xcd\xef'), '0102ABCDEF') eq(base64.b16encode(b'\x01\x02\xab\xcd\xef'), b'0102ABCDEF')
eq(base64.b16encode('\x00'), '00') eq(base64.b16encode(b'\x00'), b'00')
def test_b16decode(self): def test_b16decode(self):
eq = self.assertEqual eq = self.assertEqual
eq(base64.b16decode('0102ABCDEF'), '\x01\x02\xab\xcd\xef') eq(base64.b16decode(b'0102ABCDEF'), b'\x01\x02\xab\xcd\xef')
eq(base64.b16decode('00'), '\x00') eq(base64.b16decode(b'00'), b'\x00')
# Lower case is not allowed without a flag # Lower case is not allowed without a flag
self.assertRaises(TypeError, base64.b16decode, '0102abcdef') self.assertRaises(binascii.Error, base64.b16decode, b'0102abcdef')
# Case fold # Case fold
eq(base64.b16decode('0102abcdef', True), '\x01\x02\xab\xcd\xef') eq(base64.b16decode(b'0102abcdef', True), b'\x01\x02\xab\xcd\xef')
def test_ErrorHeritage(self):
self.assert_(issubclass(binascii.Error, ValueError))

View File

@ -118,8 +118,8 @@ class BinASCIITest(unittest.TestCase):
t = binascii.b2a_hex(s) t = binascii.b2a_hex(s)
u = binascii.a2b_hex(t) u = binascii.a2b_hex(t)
self.assertEqual(s, u) self.assertEqual(s, u)
self.assertRaises(TypeError, binascii.a2b_hex, t[:-1]) self.assertRaises(binascii.Error, binascii.a2b_hex, t[:-1])
self.assertRaises(TypeError, binascii.a2b_hex, t[:-1] + b'q') self.assertRaises(binascii.Error, binascii.a2b_hex, t[:-1] + b'q')
# Verify the treatment of Unicode strings # Verify the treatment of Unicode strings
if test_support.have_unicode: if test_support.have_unicode:

View File

@ -1000,7 +1000,7 @@ binascii_unhexlify(PyObject *self, PyObject *args)
* raise an exception. * raise an exception.
*/ */
if (arglen % 2) { if (arglen % 2) {
PyErr_SetString(PyExc_TypeError, "Odd-length string"); PyErr_SetString(Error, "Odd-length string");
return NULL; return NULL;
} }
@ -1013,7 +1013,7 @@ binascii_unhexlify(PyObject *self, PyObject *args)
int top = to_int(Py_CHARMASK(argbuf[i])); int top = to_int(Py_CHARMASK(argbuf[i]));
int bot = to_int(Py_CHARMASK(argbuf[i+1])); int bot = to_int(Py_CHARMASK(argbuf[i+1]));
if (top == -1 || bot == -1) { if (top == -1 || bot == -1) {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(Error,
"Non-hexadecimal digit found"); "Non-hexadecimal digit found");
goto finally; goto finally;
} }
@ -1371,7 +1371,7 @@ initbinascii(void)
PyDict_SetItemString(d, "__doc__", x); PyDict_SetItemString(d, "__doc__", x);
Py_XDECREF(x); Py_XDECREF(x);
Error = PyErr_NewException("binascii.Error", NULL, NULL); Error = PyErr_NewException("binascii.Error", PyExc_ValueError, NULL);
PyDict_SetItemString(d, "Error", Error); PyDict_SetItemString(d, "Error", Error);
Incomplete = PyErr_NewException("binascii.Incomplete", NULL, NULL); Incomplete = PyErr_NewException("binascii.Incomplete", NULL, NULL);
PyDict_SetItemString(d, "Incomplete", Incomplete); PyDict_SetItemString(d, "Incomplete", Incomplete);

View File

@ -3519,11 +3519,20 @@ long_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
return PyLong_FromLong(0L); return PyLong_FromLong(0L);
if (base == -909) if (base == -909)
return PyNumber_Long(x); return PyNumber_Long(x);
else if (PyString_Check(x)) { else if (PyString_Check(x) || PyBytes_Check(x)) {
/* Since PyLong_FromString doesn't have a length parameter, /* Since PyLong_FromString doesn't have a length parameter,
* check here for possible NULs in the string. */ * check here for possible NULs in the string. */
char *string = PyString_AS_STRING(x); char *string;
if (strlen(string) != PyString_Size(x)) { int size;
if (PyBytes_Check(x)) {
string = PyBytes_AS_STRING(x);
size = PyBytes_GET_SIZE(x);
}
else {
string = PyString_AS_STRING(x);
size = PyString_GET_SIZE(x);
}
if (strlen(string) != size) {
/* create a repr() of the input string, /* create a repr() of the input string,
* just like PyLong_FromString does. */ * just like PyLong_FromString does. */
PyObject *srepr; PyObject *srepr;
@ -3536,7 +3545,7 @@ long_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
Py_DECREF(srepr); Py_DECREF(srepr);
return NULL; return NULL;
} }
return PyLong_FromString(PyString_AS_STRING(x), NULL, base); return PyLong_FromString(string, NULL, base);
} }
else if (PyUnicode_Check(x)) else if (PyUnicode_Check(x))
return PyLong_FromUnicode(PyUnicode_AS_UNICODE(x), return PyLong_FromUnicode(PyUnicode_AS_UNICODE(x),