Fix quopri to operate consistently on bytes.

This commit is contained in:
Martin v. Löwis 2007-07-28 17:52:25 +00:00
parent f3f0c611dd
commit c582bfca26
2 changed files with 87 additions and 75 deletions

View File

@ -6,10 +6,10 @@
__all__ = ["encode", "decode", "encodestring", "decodestring"] __all__ = ["encode", "decode", "encodestring", "decodestring"]
ESCAPE = '=' ESCAPE = b'='
MAXLINESIZE = 76 MAXLINESIZE = 76
HEX = '0123456789ABCDEF' HEX = b'0123456789ABCDEF'
EMPTYSTRING = '' EMPTYSTRING = b''
try: try:
from binascii import a2b_qp, b2a_qp from binascii import a2b_qp, b2a_qp
@ -19,23 +19,25 @@ except ImportError:
def needsquoting(c, quotetabs, header): def needsquoting(c, quotetabs, header):
"""Decide whether a particular character needs to be quoted. """Decide whether a particular byte ordinal needs to be quoted.
The 'quotetabs' flag indicates whether embedded tabs and spaces should be The 'quotetabs' flag indicates whether embedded tabs and spaces should be
quoted. Note that line-ending tabs and spaces are always encoded, as per quoted. Note that line-ending tabs and spaces are always encoded, as per
RFC 1521. RFC 1521.
""" """
if c in ' \t': assert isinstance(c, bytes)
if c in b' \t':
return quotetabs return quotetabs
# if header, we have to escape _ because _ is used to escape space # if header, we have to escape _ because _ is used to escape space
if c == '_': if c == b'_':
return header return header
return c == ESCAPE or not (' ' <= c <= '~') return c == ESCAPE or not (b' ' <= c <= b'~')
def quote(c): def quote(c):
"""Quote a single character.""" """Quote a single character."""
i = ord(c) assert isinstance(c, bytes) and len(c)==1
return ESCAPE + HEX[i//16] + HEX[i%16] c = ord(c)
return ESCAPE + bytes((HEX[c//16], HEX[c%16]))
@ -56,12 +58,12 @@ def encode(input, output, quotetabs, header = 0):
output.write(odata) output.write(odata)
return return
def write(s, output=output, lineEnd='\n'): def write(s, output=output, lineEnd=b'\n'):
# RFC 1521 requires that the line ending in a space or tab must have # RFC 1521 requires that the line ending in a space or tab must have
# that trailing character encoded. # that trailing character encoded.
if s and s[-1:] in ' \t': if s and s[-1:] in b' \t':
output.write(s[:-1] + quote(s[-1]) + lineEnd) output.write(s[:-1] + quote(s[-1:]) + lineEnd)
elif s == '.': elif s == b'.':
output.write(quote(s) + lineEnd) output.write(quote(s) + lineEnd)
else: else:
output.write(s + lineEnd) output.write(s + lineEnd)
@ -73,16 +75,17 @@ def encode(input, output, quotetabs, header = 0):
break break
outline = [] outline = []
# Strip off any readline induced trailing newline # Strip off any readline induced trailing newline
stripped = '' stripped = b''
if line[-1:] == '\n': if line[-1:] == b'\n':
line = line[:-1] line = line[:-1]
stripped = '\n' stripped = b'\n'
# Calculate the un-length-limited encoded line # Calculate the un-length-limited encoded line
for c in line: for c in line:
c = bytes((c,))
if needsquoting(c, quotetabs, header): if needsquoting(c, quotetabs, header):
c = quote(c) c = quote(c)
if header and c == ' ': if header and c == b' ':
outline.append('_') outline.append(b'_')
else: else:
outline.append(c) outline.append(c)
# First, write out the previous line # First, write out the previous line
@ -94,7 +97,7 @@ def encode(input, output, quotetabs, header = 0):
while len(thisline) > MAXLINESIZE: while len(thisline) > MAXLINESIZE:
# Don't forget to include the soft line break `=' sign in the # Don't forget to include the soft line break `=' sign in the
# length calculation! # length calculation!
write(thisline[:MAXLINESIZE-1], lineEnd='=\n') write(thisline[:MAXLINESIZE-1], lineEnd=b'=\n')
thisline = thisline[MAXLINESIZE-1:] thisline = thisline[MAXLINESIZE-1:]
# Write out the current line # Write out the current line
prevline = thisline prevline = thisline
@ -105,9 +108,9 @@ def encode(input, output, quotetabs, header = 0):
def encodestring(s, quotetabs = 0, header = 0): def encodestring(s, quotetabs = 0, header = 0):
if b2a_qp is not None: if b2a_qp is not None:
return b2a_qp(s, quotetabs = quotetabs, header = header) return b2a_qp(s, quotetabs = quotetabs, header = header)
from io import StringIO from io import BytesIO
infp = StringIO(s) infp = BytesIO(s)
outfp = StringIO() outfp = BytesIO()
encode(infp, outfp, quotetabs, header) encode(infp, outfp, quotetabs, header)
return outfp.getvalue() return outfp.getvalue()
@ -124,44 +127,44 @@ def decode(input, output, header = 0):
output.write(odata) output.write(odata)
return return
new = '' new = b''
while 1: while 1:
line = input.readline() line = input.readline()
if not line: break if not line: break
i, n = 0, len(line) i, n = 0, len(line)
if n > 0 and line[n-1] == '\n': if n > 0 and line[n-1:n] == b'\n':
partial = 0; n = n-1 partial = 0; n = n-1
# Strip trailing whitespace # Strip trailing whitespace
while n > 0 and line[n-1] in " \t\r": while n > 0 and line[n-1:n] in b" \t\r":
n = n-1 n = n-1
else: else:
partial = 1 partial = 1
while i < n: while i < n:
c = line[i] c = line[i:i+1]
if c == '_' and header: if c == b'_' and header:
new = new + ' '; i = i+1 new = new + b' '; i = i+1
elif c != ESCAPE: elif c != ESCAPE:
new = new + c; i = i+1 new = new + c; i = i+1
elif i+1 == n and not partial: elif i+1 == n and not partial:
partial = 1; break partial = 1; break
elif i+1 < n and line[i+1] == ESCAPE: elif i+1 < n and line[i+1] == ESCAPE:
new = new + ESCAPE; i = i+2 new = new + ESCAPE; i = i+2
elif i+2 < n and ishex(line[i+1]) and ishex(line[i+2]): elif i+2 < n and ishex(line[i+1:i+2]) and ishex(line[i+2:i+3]):
new = new + chr(unhex(line[i+1:i+3])); i = i+3 new = new + bytes((unhex(line[i+1:i+3]),)); i = i+3
else: # Bad escape sequence -- leave it in else: # Bad escape sequence -- leave it in
new = new + c; i = i+1 new = new + c; i = i+1
if not partial: if not partial:
output.write(new + '\n') output.write(new + b'\n')
new = '' new = b''
if new: if new:
output.write(new) output.write(new)
def decodestring(s, header = 0): def decodestring(s, header = 0):
if a2b_qp is not None: if a2b_qp is not None:
return a2b_qp(s, header = header) return a2b_qp(s, header = header)
from io import StringIO from io import BytesIO
infp = StringIO(s) infp = BytesIO(s)
outfp = StringIO() outfp = BytesIO()
decode(infp, outfp, header = header) decode(infp, outfp, header = header)
return outfp.getvalue() return outfp.getvalue()
@ -169,21 +172,23 @@ def decodestring(s, header = 0):
# Other helper functions # Other helper functions
def ishex(c): def ishex(c):
"""Return true if the character 'c' is a hexadecimal digit.""" """Return true if the byte ordinal 'c' is a hexadecimal digit in ASCII."""
return '0' <= c <= '9' or 'a' <= c <= 'f' or 'A' <= c <= 'F' assert isinstance(c, bytes)
return b'0' <= c <= b'9' or b'a' <= c <= b'f' or b'A' <= c <= b'F'
def unhex(s): def unhex(s):
"""Get the integer value of a hexadecimal number.""" """Get the integer value of a hexadecimal number."""
bits = 0 bits = 0
for c in s: for c in s:
if '0' <= c <= '9': c = bytes((c,))
if b'0' <= c <= b'9':
i = ord('0') i = ord('0')
elif 'a' <= c <= 'f': elif b'a' <= c <= b'f':
i = ord('a')-10 i = ord('a')-10
elif 'A' <= c <= 'F': elif b'A' <= c <= b'F':
i = ord('A')-10 i = ord(b'A')-10
else: else:
break assert False, "non-hex digit "+repr(c)
bits = bits*16 + (ord(c) - i) bits = bits*16 + (ord(c) - i)
return bits return bits
@ -214,18 +219,18 @@ def main():
sts = 0 sts = 0
for file in args: for file in args:
if file == '-': if file == '-':
fp = sys.stdin fp = sys.stdin.buffer
else: else:
try: try:
fp = open(file) fp = open(file, "rb")
except IOError as msg: except IOError as msg:
sys.stderr.write("%s: can't open (%s)\n" % (file, msg)) sys.stderr.write("%s: can't open (%s)\n" % (file, msg))
sts = 1 sts = 1
continue continue
if deco: if deco:
decode(fp, sys.stdout) decode(fp, sys.stdout.buffer)
else: else:
encode(fp, sys.stdout, tabs) encode(fp, sys.stdout.buffer, tabs)
if fp is not sys.stdin: if fp is not sys.stdin:
fp.close() fp.close()
if sts: if sts:

View File

@ -6,7 +6,7 @@ import quopri
ENCSAMPLE = """\ ENCSAMPLE = b"""\
Here's a bunch of special=20 Here's a bunch of special=20
=A1=A2=A3=A4=A5=A6=A7=A8=A9 =A1=A2=A3=A4=A5=A6=A7=A8=A9
@ -25,8 +25,8 @@ characters... have fun!
""" """
# First line ends with a space # First line ends with a space
DECSAMPLE = "Here's a bunch of special \n" + \ DECSAMPLE = b"Here's a bunch of special \n" + \
"""\ b"""\
\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9 \xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9
\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3 \xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3
@ -67,48 +67,48 @@ class QuopriTestCase(unittest.TestCase):
# used in the "quotetabs=0" tests. # used in the "quotetabs=0" tests.
STRINGS = ( STRINGS = (
# Some normal strings # Some normal strings
('hello', 'hello'), (b'hello', b'hello'),
('''hello (b'''hello
there there
world''', '''hello world''', b'''hello
there there
world'''), world'''),
('''hello (b'''hello
there there
world world
''', '''hello ''', b'''hello
there there
world world
'''), '''),
('\201\202\203', '=81=82=83'), (b'\201\202\203', b'=81=82=83'),
# Add some trailing MUST QUOTE strings # Add some trailing MUST QUOTE strings
('hello ', 'hello=20'), (b'hello ', b'hello=20'),
('hello\t', 'hello=09'), (b'hello\t', b'hello=09'),
# Some long lines. First, a single line of 108 characters # Some long lines. First, a single line of 108 characters
('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\xd8\xd9\xda\xdb\xdc\xdd\xde\xdfxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', (b'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\xd8\xd9\xda\xdb\xdc\xdd\xde\xdfxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx',
'''xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx=D8=D9=DA=DB=DC=DD=DE=DFx= b'''xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx=D8=D9=DA=DB=DC=DD=DE=DFx=
xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'''), xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'''),
# A line of exactly 76 characters, no soft line break should be needed # A line of exactly 76 characters, no soft line break should be needed
('yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy', (b'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy',
'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy'), b'yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy'),
# A line of 77 characters, forcing a soft line break at position 75, # A line of 77 characters, forcing a soft line break at position 75,
# and a second line of exactly 2 characters (because the soft line # and a second line of exactly 2 characters (because the soft line
# break `=' sign counts against the line length limit). # break `=' sign counts against the line length limit).
('zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz', (b'zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz',
'''zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz= b'''zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz=
zz'''), zz'''),
# A line of 151 characters, forcing a soft line break at position 75, # A line of 151 characters, forcing a soft line break at position 75,
# with a second line of exactly 76 characters and no trailing = # with a second line of exactly 76 characters and no trailing =
('zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz', (b'zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz',
'''zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz= b'''zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz=
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz'''), zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz'''),
# A string containing a hard line break, but which the first line is # A string containing a hard line break, but which the first line is
# 151 characters and the second line is exactly 76 characters. This # 151 characters and the second line is exactly 76 characters. This
# should leave us with three lines, the first which has a soft line # should leave us with three lines, the first which has a soft line
# break, and which the second and third do not. # break, and which the second and third do not.
('''yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy (b'''yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''', zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''',
'''yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy= b'''yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy=
yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz'''), zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz'''),
# Now some really complex stuff ;) # Now some really complex stuff ;)
@ -117,14 +117,14 @@ zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''')
# These are used in the "quotetabs=1" tests. # These are used in the "quotetabs=1" tests.
ESTRINGS = ( ESTRINGS = (
('hello world', 'hello=20world'), (b'hello world', b'hello=20world'),
('hello\tworld', 'hello=09world'), (b'hello\tworld', b'hello=09world'),
) )
# These are used in the "header=1" tests. # These are used in the "header=1" tests.
HSTRINGS = ( HSTRINGS = (
('hello world', 'hello_world'), (b'hello world', b'hello_world'),
('hello_world', 'hello=5Fworld'), (b'hello_world', b'hello=5Fworld'),
) )
@withpythonimplementation @withpythonimplementation
@ -161,18 +161,18 @@ zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''')
@withpythonimplementation @withpythonimplementation
def test_embedded_ws(self): def test_embedded_ws(self):
for p, e in self.ESTRINGS: for p, e in self.ESTRINGS:
self.assert_(quopri.encodestring(p, quotetabs=True) == e) self.assertEqual(quopri.encodestring(p, quotetabs=True), e)
self.assertEqual(quopri.decodestring(e), p) self.assertEqual(quopri.decodestring(e), p)
@withpythonimplementation @withpythonimplementation
def test_encode_header(self): def test_encode_header(self):
for p, e in self.HSTRINGS: for p, e in self.HSTRINGS:
self.assert_(quopri.encodestring(p, header=True) == e) self.assertEqual(quopri.encodestring(p, header=True), e)
@withpythonimplementation @withpythonimplementation
def test_decode_header(self): def test_decode_header(self):
for p, e in self.HSTRINGS: for p, e in self.HSTRINGS:
self.assert_(quopri.decodestring(e, header=True) == p) self.assertEqual(quopri.decodestring(e, header=True), p)
def test_scriptencode(self): def test_scriptencode(self):
(p, e) = self.STRINGS[-1] (p, e) = self.STRINGS[-1]
@ -182,13 +182,20 @@ zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz''')
# On Windows, Python will output the result to stdout using # On Windows, Python will output the result to stdout using
# CRLF, as the mode of stdout is text mode. To compare this # CRLF, as the mode of stdout is text mode. To compare this
# with the expected result, we need to do a line-by-line comparison. # with the expected result, we need to do a line-by-line comparison.
self.assertEqual(cout.splitlines(), e.splitlines()) cout = cout.decode('latin-1').splitlines()
e = e.decode('latin-1').splitlines()
assert len(cout)==len(e)
for i in range(len(cout)):
self.assertEqual(cout[i], e[i])
self.assertEqual(cout, e)
def test_scriptdecode(self): def test_scriptdecode(self):
(p, e) = self.STRINGS[-1] (p, e) = self.STRINGS[-1]
process = subprocess.Popen([sys.executable, "-mquopri", "-d"], process = subprocess.Popen([sys.executable, "-mquopri", "-d"],
stdin=subprocess.PIPE, stdout=subprocess.PIPE) stdin=subprocess.PIPE, stdout=subprocess.PIPE)
cout, cerr = process.communicate(e) cout, cerr = process.communicate(e)
cout = cout.decode('latin-1')
p = p.decode('latin-1')
self.assertEqual(cout.splitlines(), p.splitlines()) self.assertEqual(cout.splitlines(), p.splitlines())
def test_main(): def test_main():