gh-85287: Change codecs to raise precise UnicodeEncodeError and UnicodeDecodeError (#113674)

Co-authored-by: Inada Naoki <songofacandy@gmail.com>
This commit is contained in:
John Sloboda 2024-03-17 00:58:42 -04:00 committed by GitHub
parent c514a975ab
commit 649857a157
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 306 additions and 81 deletions

View File

@ -11,7 +11,7 @@ ace_prefix = b"xn--"
sace_prefix = "xn--"
# This assumes query strings, so AllowUnassigned is true
def nameprep(label):
def nameprep(label): # type: (str) -> str
# Map
newlabel = []
for c in label:
@ -25,7 +25,7 @@ def nameprep(label):
label = unicodedata.normalize("NFKC", label)
# Prohibit
for c in label:
for i, c in enumerate(label):
if stringprep.in_table_c12(c) or \
stringprep.in_table_c22(c) or \
stringprep.in_table_c3(c) or \
@ -35,7 +35,7 @@ def nameprep(label):
stringprep.in_table_c7(c) or \
stringprep.in_table_c8(c) or \
stringprep.in_table_c9(c):
raise UnicodeError("Invalid character %r" % c)
raise UnicodeEncodeError("idna", label, i, i+1, f"Invalid character {c!r}")
# Check bidi
RandAL = [stringprep.in_table_d1(x) for x in label]
@ -46,29 +46,38 @@ def nameprep(label):
# This is table C.8, which was already checked
# 2) If a string contains any RandALCat character, the string
# MUST NOT contain any LCat character.
if any(stringprep.in_table_d2(x) for x in label):
raise UnicodeError("Violation of BIDI requirement 2")
for i, x in enumerate(label):
if stringprep.in_table_d2(x):
raise UnicodeEncodeError("idna", label, i, i+1,
"Violation of BIDI requirement 2")
# 3) If a string contains any RandALCat character, a
# RandALCat character MUST be the first character of the
# string, and a RandALCat character MUST be the last
# character of the string.
if not RandAL[0] or not RandAL[-1]:
raise UnicodeError("Violation of BIDI requirement 3")
if not RandAL[0]:
raise UnicodeEncodeError("idna", label, 0, 1,
"Violation of BIDI requirement 3")
if not RandAL[-1]:
raise UnicodeEncodeError("idna", label, len(label)-1, len(label),
"Violation of BIDI requirement 3")
return label
def ToASCII(label):
def ToASCII(label): # type: (str) -> bytes
try:
# Step 1: try ASCII
label = label.encode("ascii")
except UnicodeError:
label_ascii = label.encode("ascii")
except UnicodeEncodeError:
pass
else:
# Skip to step 3: UseSTD3ASCIIRules is false, so
# Skip to step 8.
if 0 < len(label) < 64:
return label
raise UnicodeError("label empty or too long")
if 0 < len(label_ascii) < 64:
return label_ascii
if len(label) == 0:
raise UnicodeEncodeError("idna", label, 0, 1, "label empty")
else:
raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
# Step 2: nameprep
label = nameprep(label)
@ -76,29 +85,34 @@ def ToASCII(label):
# Step 3: UseSTD3ASCIIRules is false
# Step 4: try ASCII
try:
label = label.encode("ascii")
except UnicodeError:
label_ascii = label.encode("ascii")
except UnicodeEncodeError:
pass
else:
# Skip to step 8.
if 0 < len(label) < 64:
return label
raise UnicodeError("label empty or too long")
return label_ascii
if len(label) == 0:
raise UnicodeEncodeError("idna", label, 0, 1, "label empty")
else:
raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
# Step 5: Check ACE prefix
if label[:4].lower() == sace_prefix:
raise UnicodeError("Label starts with ACE prefix")
if label.lower().startswith(sace_prefix):
raise UnicodeEncodeError(
"idna", label, 0, len(sace_prefix), "Label starts with ACE prefix")
# Step 6: Encode with PUNYCODE
label = label.encode("punycode")
label_ascii = label.encode("punycode")
# Step 7: Prepend ACE prefix
label = ace_prefix + label
label_ascii = ace_prefix + label_ascii
# Step 8: Check size
if 0 < len(label) < 64:
return label
raise UnicodeError("label empty or too long")
# do not check for empty as we prepend ace_prefix.
if len(label_ascii) < 64:
return label_ascii
raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
def ToUnicode(label):
if len(label) > 1024:
@ -110,7 +124,9 @@ def ToUnicode(label):
# per https://www.rfc-editor.org/rfc/rfc3454#section-3.1 while still
# preventing us from wasting time decoding a big thing that'll just
# hit the actual <= 63 length limit in Step 6.
raise UnicodeError("label way too long")
if isinstance(label, str):
label = label.encode("utf-8", errors="backslashreplace")
raise UnicodeDecodeError("idna", label, 0, len(label), "label way too long")
# Step 1: Check for ASCII
if isinstance(label, bytes):
pure_ascii = True
@ -118,25 +134,32 @@ def ToUnicode(label):
try:
label = label.encode("ascii")
pure_ascii = True
except UnicodeError:
except UnicodeEncodeError:
pure_ascii = False
if not pure_ascii:
assert isinstance(label, str)
# Step 2: Perform nameprep
label = nameprep(label)
# It doesn't say this, but apparently, it should be ASCII now
try:
label = label.encode("ascii")
except UnicodeError:
raise UnicodeError("Invalid character in IDN label")
except UnicodeEncodeError as exc:
raise UnicodeEncodeError("idna", label, exc.start, exc.end,
"Invalid character in IDN label")
# Step 3: Check for ACE prefix
if not label[:4].lower() == ace_prefix:
assert isinstance(label, bytes)
if not label.lower().startswith(ace_prefix):
return str(label, "ascii")
# Step 4: Remove ACE prefix
label1 = label[len(ace_prefix):]
# Step 5: Decode using PUNYCODE
result = label1.decode("punycode")
try:
result = label1.decode("punycode")
except UnicodeDecodeError as exc:
offset = len(ace_prefix)
raise UnicodeDecodeError("idna", label, offset+exc.start, offset+exc.end, exc.reason)
# Step 6: Apply ToASCII
label2 = ToASCII(result)
@ -144,7 +167,8 @@ def ToUnicode(label):
# Step 7: Compare the result of step 6 with the one of step 3
# label2 will already be in lower case.
if str(label, "ascii").lower() != str(label2, "ascii"):
raise UnicodeError("IDNA does not round-trip", label, label2)
raise UnicodeDecodeError("idna", label, 0, len(label),
f"IDNA does not round-trip, '{label!r}' != '{label2!r}'")
# Step 8: return the result of step 5
return result
@ -156,7 +180,7 @@ class Codec(codecs.Codec):
if errors != 'strict':
# IDNA is quite clear that implementations must be strict
raise UnicodeError("unsupported error handling "+errors)
raise UnicodeError(f"Unsupported error handling: {errors}")
if not input:
return b'', 0
@ -168,11 +192,16 @@ class Codec(codecs.Codec):
else:
# ASCII name: fast path
labels = result.split(b'.')
for label in labels[:-1]:
if not (0 < len(label) < 64):
raise UnicodeError("label empty or too long")
if len(labels[-1]) >= 64:
raise UnicodeError("label too long")
for i, label in enumerate(labels[:-1]):
if len(label) == 0:
offset = sum(len(l) for l in labels[:i]) + i
raise UnicodeEncodeError("idna", input, offset, offset+1,
"label empty")
for i, label in enumerate(labels):
if len(label) >= 64:
offset = sum(len(l) for l in labels[:i]) + i
raise UnicodeEncodeError("idna", input, offset, offset+len(label),
"label too long")
return result, len(input)
result = bytearray()
@ -182,17 +211,27 @@ class Codec(codecs.Codec):
del labels[-1]
else:
trailing_dot = b''
for label in labels:
for i, label in enumerate(labels):
if result:
# Join with U+002E
result.extend(b'.')
result.extend(ToASCII(label))
try:
result.extend(ToASCII(label))
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
offset = sum(len(l) for l in labels[:i]) + i
raise UnicodeEncodeError(
"idna",
input,
offset + exc.start,
offset + exc.end,
exc.reason,
)
return bytes(result+trailing_dot), len(input)
def decode(self, input, errors='strict'):
if errors != 'strict':
raise UnicodeError("Unsupported error handling "+errors)
raise UnicodeError(f"Unsupported error handling: {errors}")
if not input:
return "", 0
@ -218,8 +257,15 @@ class Codec(codecs.Codec):
trailing_dot = ''
result = []
for label in labels:
result.append(ToUnicode(label))
for i, label in enumerate(labels):
try:
u_label = ToUnicode(label)
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
offset = sum(len(x) for x in labels[:i]) + len(labels[:i])
raise UnicodeDecodeError(
"idna", input, offset+exc.start, offset+exc.end, exc.reason)
else:
result.append(u_label)
return ".".join(result)+trailing_dot, len(input)
@ -227,7 +273,7 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
def _buffer_encode(self, input, errors, final):
if errors != 'strict':
# IDNA is quite clear that implementations must be strict
raise UnicodeError("unsupported error handling "+errors)
raise UnicodeError(f"Unsupported error handling: {errors}")
if not input:
return (b'', 0)
@ -251,7 +297,16 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
# Join with U+002E
result.extend(b'.')
size += 1
result.extend(ToASCII(label))
try:
result.extend(ToASCII(label))
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
raise UnicodeEncodeError(
"idna",
input,
size + exc.start,
size + exc.end,
exc.reason,
)
size += len(label)
result += trailing_dot
@ -261,7 +316,7 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
def _buffer_decode(self, input, errors, final):
if errors != 'strict':
raise UnicodeError("Unsupported error handling "+errors)
raise UnicodeError("Unsupported error handling: {errors}")
if not input:
return ("", 0)
@ -271,7 +326,11 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
labels = dots.split(input)
else:
# Must be ASCII string
input = str(input, "ascii")
try:
input = str(input, "ascii")
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
raise UnicodeDecodeError("idna", input,
exc.start, exc.end, exc.reason)
labels = input.split(".")
trailing_dot = ''
@ -288,7 +347,18 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
result = []
size = 0
for label in labels:
result.append(ToUnicode(label))
try:
u_label = ToUnicode(label)
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
raise UnicodeDecodeError(
"idna",
input.encode("ascii", errors="backslashreplace"),
size + exc.start,
size + exc.end,
exc.reason,
)
else:
result.append(u_label)
if size:
size += 1
size += len(label)

View File

@ -1,4 +1,4 @@
""" Codec for the Punicode encoding, as specified in RFC 3492
""" Codec for the Punycode encoding, as specified in RFC 3492
Written by Martin v. Löwis.
"""
@ -131,10 +131,11 @@ def decode_generalized_number(extended, extpos, bias, errors):
j = 0
while 1:
try:
char = ord(extended[extpos])
char = extended[extpos]
except IndexError:
if errors == "strict":
raise UnicodeError("incomplete punicode string")
raise UnicodeDecodeError("punycode", extended, extpos, extpos+1,
"incomplete punycode string")
return extpos + 1, None
extpos += 1
if 0x41 <= char <= 0x5A: # A-Z
@ -142,8 +143,8 @@ def decode_generalized_number(extended, extpos, bias, errors):
elif 0x30 <= char <= 0x39:
digit = char - 22 # 0x30-26
elif errors == "strict":
raise UnicodeError("Invalid extended code point '%s'"
% extended[extpos-1])
raise UnicodeDecodeError("punycode", extended, extpos-1, extpos,
f"Invalid extended code point '{extended[extpos-1]}'")
else:
return extpos, None
t = T(j, bias)
@ -155,11 +156,14 @@ def decode_generalized_number(extended, extpos, bias, errors):
def insertion_sort(base, extended, errors):
"""3.2 Insertion unsort coding"""
"""3.2 Insertion sort coding"""
# This function raises UnicodeDecodeError with position in the extended.
# Caller should add the offset.
char = 0x80
pos = -1
bias = 72
extpos = 0
while extpos < len(extended):
newpos, delta = decode_generalized_number(extended, extpos,
bias, errors)
@ -171,7 +175,9 @@ def insertion_sort(base, extended, errors):
char += pos // (len(base) + 1)
if char > 0x10FFFF:
if errors == "strict":
raise UnicodeError("Invalid character U+%x" % char)
raise UnicodeDecodeError(
"punycode", extended, pos-1, pos,
f"Invalid character U+{char:x}")
char = ord('?')
pos = pos % (len(base) + 1)
base = base[:pos] + chr(char) + base[pos:]
@ -187,11 +193,21 @@ def punycode_decode(text, errors):
pos = text.rfind(b"-")
if pos == -1:
base = ""
extended = str(text, "ascii").upper()
extended = text.upper()
else:
base = str(text[:pos], "ascii", errors)
extended = str(text[pos+1:], "ascii").upper()
return insertion_sort(base, extended, errors)
try:
base = str(text[:pos], "ascii", errors)
except UnicodeDecodeError as exc:
raise UnicodeDecodeError("ascii", text, exc.start, exc.end,
exc.reason) from None
extended = text[pos+1:].upper()
try:
return insertion_sort(base, extended, errors)
except UnicodeDecodeError as exc:
offset = pos + 1
raise UnicodeDecodeError("punycode", text,
offset+exc.start, offset+exc.end,
exc.reason) from None
### Codec APIs
@ -203,7 +219,7 @@ class Codec(codecs.Codec):
def decode(self, input, errors='strict'):
if errors not in ('strict', 'replace', 'ignore'):
raise UnicodeError("Unsupported error handling "+errors)
raise UnicodeError(f"Unsupported error handling: {errors}")
res = punycode_decode(input, errors)
return res, len(input)
@ -214,7 +230,7 @@ class IncrementalEncoder(codecs.IncrementalEncoder):
class IncrementalDecoder(codecs.IncrementalDecoder):
def decode(self, input, final=False):
if self.errors not in ('strict', 'replace', 'ignore'):
raise UnicodeError("Unsupported error handling "+self.errors)
raise UnicodeError(f"Unsupported error handling: {self.errors}")
return punycode_decode(input, self.errors)
class StreamWriter(Codec,codecs.StreamWriter):

View File

@ -1,6 +1,6 @@
""" Python 'undefined' Codec
This codec will always raise a ValueError exception when being
This codec will always raise a UnicodeError exception when being
used. It is intended for use by the site.py file to switch off
automatic string to Unicode coercion.

View File

@ -64,7 +64,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
elif byteorder == 1:
self.decoder = codecs.utf_16_be_decode
elif consumed >= 2:
raise UnicodeError("UTF-16 stream does not start with BOM")
raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not start with BOM")
return (output, consumed)
return self.decoder(input, self.errors, final)
@ -138,7 +138,7 @@ class StreamReader(codecs.StreamReader):
elif byteorder == 1:
self.decode = codecs.utf_16_be_decode
elif consumed>=2:
raise UnicodeError("UTF-16 stream does not start with BOM")
raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not start with BOM")
return (object, consumed)
### encodings module API

View File

@ -59,7 +59,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
elif byteorder == 1:
self.decoder = codecs.utf_32_be_decode
elif consumed >= 4:
raise UnicodeError("UTF-32 stream does not start with BOM")
raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not start with BOM")
return (output, consumed)
return self.decoder(input, self.errors, final)
@ -132,8 +132,8 @@ class StreamReader(codecs.StreamReader):
self.decode = codecs.utf_32_le_decode
elif byteorder == 1:
self.decode = codecs.utf_32_be_decode
elif consumed>=4:
raise UnicodeError("UTF-32 stream does not start with BOM")
elif consumed >= 4:
raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not start with BOM")
return (object, consumed)
### encodings module API

View File

@ -482,11 +482,11 @@ class UTF32Test(ReadTest, unittest.TestCase):
def test_badbom(self):
s = io.BytesIO(4*b"\xff")
f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read)
self.assertRaises(UnicodeDecodeError, f.read)
s = io.BytesIO(8*b"\xff")
f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read)
self.assertRaises(UnicodeDecodeError, f.read)
def test_partial(self):
self.check_partial(
@ -666,11 +666,11 @@ class UTF16Test(ReadTest, unittest.TestCase):
def test_badbom(self):
s = io.BytesIO(b"\xff\xff")
f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read)
self.assertRaises(UnicodeDecodeError, f.read)
s = io.BytesIO(b"\xff\xff\xff\xff")
f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read)
self.assertRaises(UnicodeDecodeError, f.read)
def test_partial(self):
self.check_partial(
@ -1356,13 +1356,29 @@ class PunycodeTest(unittest.TestCase):
def test_decode_invalid(self):
testcases = [
(b"xn--w&", "strict", UnicodeError()),
(b"xn--w&", "strict", UnicodeDecodeError("punycode", b"", 5, 6, "")),
(b"&egbpdaj6bu4bxfgehfvwxn", "strict", UnicodeDecodeError("punycode", b"", 0, 1, "")),
(b"egbpdaj6bu&4bx&fgehfvwxn", "strict", UnicodeDecodeError("punycode", b"", 10, 11, "")),
(b"egbpdaj6bu4bxfgehfvwxn&", "strict", UnicodeDecodeError("punycode", b"", 22, 23, "")),
(b"\xFFProprostnemluvesky-uyb24dma41a", "strict", UnicodeDecodeError("ascii", b"", 0, 1, "")),
(b"Pro\xFFprostnemluvesky-uyb24dma41a", "strict", UnicodeDecodeError("ascii", b"", 3, 4, "")),
(b"Proprost&nemluvesky-uyb24&dma41a", "strict", UnicodeDecodeError("punycode", b"", 25, 26, "")),
(b"Proprostnemluvesky&-&uyb24dma41a", "strict", UnicodeDecodeError("punycode", b"", 20, 21, "")),
(b"Proprostnemluvesky-&uyb24dma41a", "strict", UnicodeDecodeError("punycode", b"", 19, 20, "")),
(b"Proprostnemluvesky-uyb24d&ma41a", "strict", UnicodeDecodeError("punycode", b"", 25, 26, "")),
(b"Proprostnemluvesky-uyb24dma41a&", "strict", UnicodeDecodeError("punycode", b"", 30, 31, "")),
(b"xn--w&", "ignore", "xn-"),
]
for puny, errors, expected in testcases:
with self.subTest(puny=puny, errors=errors):
if isinstance(expected, Exception):
self.assertRaises(UnicodeError, puny.decode, "punycode", errors)
with self.assertRaises(UnicodeDecodeError) as cm:
puny.decode("punycode", errors)
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, puny)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
else:
self.assertEqual(puny.decode("punycode", errors), expected)
@ -1532,7 +1548,7 @@ class NameprepTest(unittest.TestCase):
orig = str(orig, "utf-8", "surrogatepass")
if prepped is None:
# Input contains prohibited characters
self.assertRaises(UnicodeError, nameprep, orig)
self.assertRaises(UnicodeEncodeError, nameprep, orig)
else:
prepped = str(prepped, "utf-8", "surrogatepass")
try:
@ -1542,6 +1558,23 @@ class NameprepTest(unittest.TestCase):
class IDNACodecTest(unittest.TestCase):
invalid_decode_testcases = [
(b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFFpython.org", 0, 1, "")),
(b"pyt\xFFhon.org", UnicodeDecodeError("idna", b"pyt\xFFhon.org", 3, 4, "")),
(b"python\xFF.org", UnicodeDecodeError("idna", b"python\xFF.org", 6, 7, "")),
(b"python.\xFForg", UnicodeDecodeError("idna", b"python.\xFForg", 7, 8, "")),
(b"python.o\xFFrg", UnicodeDecodeError("idna", b"python.o\xFFrg", 8, 9, "")),
(b"python.org\xFF", UnicodeDecodeError("idna", b"python.org\xFF", 10, 11, "")),
(b"xn--pythn-&mua.org", UnicodeDecodeError("idna", b"xn--pythn-&mua.org", 10, 11, "")),
(b"xn--pythn-m&ua.org", UnicodeDecodeError("idna", b"xn--pythn-m&ua.org", 11, 12, "")),
(b"xn--pythn-mua&.org", UnicodeDecodeError("idna", b"xn--pythn-mua&.org", 13, 14, "")),
]
invalid_encode_testcases = [
(f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"foo.{'\xff'*60}", 4, 64, "")),
("あさ.\u034f", UnicodeEncodeError("idna", "あさ.\u034f", 3, 4, "")),
]
def test_builtin_decode(self):
self.assertEqual(str(b"python.org", "idna"), "python.org")
self.assertEqual(str(b"python.org.", "idna"), "python.org.")
@ -1555,16 +1588,38 @@ class IDNACodecTest(unittest.TestCase):
self.assertEqual(str(b"bugs.XN--pythn-mua.org.", "idna"),
"bugs.pyth\xf6n.org.")
def test_builtin_decode_invalid(self):
for case, expected in self.invalid_decode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeDecodeError) as cm:
case.decode("idna")
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start, msg=f'reason: {exc.reason}')
self.assertEqual(exc.end, expected.end)
def test_builtin_encode(self):
self.assertEqual("python.org".encode("idna"), b"python.org")
self.assertEqual("python.org.".encode("idna"), b"python.org.")
self.assertEqual("pyth\xf6n.org".encode("idna"), b"xn--pythn-mua.org")
self.assertEqual("pyth\xf6n.org.".encode("idna"), b"xn--pythn-mua.org.")
def test_builtin_encode_invalid(self):
for case, expected in self.invalid_encode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeEncodeError) as cm:
case.encode("idna")
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
def test_builtin_decode_length_limit(self):
with self.assertRaisesRegex(UnicodeError, "way too long"):
with self.assertRaisesRegex(UnicodeDecodeError, "way too long"):
(b"xn--016c"+b"a"*1100).decode("idna")
with self.assertRaisesRegex(UnicodeError, "too long"):
with self.assertRaisesRegex(UnicodeDecodeError, "too long"):
(b"xn--016c"+b"a"*70).decode("idna")
def test_stream(self):
@ -1602,6 +1657,39 @@ class IDNACodecTest(unittest.TestCase):
self.assertEqual(decoder.decode(b"rg."), "org.")
self.assertEqual(decoder.decode(b"", True), "")
def test_incremental_decode_invalid(self):
iterdecode_testcases = [
(b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")),
(b"pyt\xFFhon.org", UnicodeDecodeError("idna", b"pyt\xFF", 3, 4, "")),
(b"python\xFF.org", UnicodeDecodeError("idna", b"python\xFF", 6, 7, "")),
(b"python.\xFForg", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")),
(b"python.o\xFFrg", UnicodeDecodeError("idna", b"o\xFF", 1, 2, "")),
(b"python.org\xFF", UnicodeDecodeError("idna", b"org\xFF", 3, 4, "")),
(b"xn--pythn-&mua.org", UnicodeDecodeError("idna", b"xn--pythn-&mua.", 10, 11, "")),
(b"xn--pythn-m&ua.org", UnicodeDecodeError("idna", b"xn--pythn-m&ua.", 11, 12, "")),
(b"xn--pythn-mua&.org", UnicodeDecodeError("idna", b"xn--pythn-mua&.", 13, 14, "")),
]
for case, expected in iterdecode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeDecodeError) as cm:
list(codecs.iterdecode((bytes([c]) for c in case), "idna"))
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
decoder = codecs.getincrementaldecoder("idna")()
for case, expected in self.invalid_decode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeDecodeError) as cm:
decoder.decode(case)
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
def test_incremental_encode(self):
self.assertEqual(
b"".join(codecs.iterencode("python.org", "idna")),
@ -1630,6 +1718,23 @@ class IDNACodecTest(unittest.TestCase):
self.assertEqual(encoder.encode("ample.org."), b"xn--xample-9ta.org.")
self.assertEqual(encoder.encode("", True), b"")
def test_incremental_encode_invalid(self):
iterencode_testcases = [
(f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"{'\xff'*60}", 0, 60, "")),
("あさ.\u034f", UnicodeEncodeError("idna", "\u034f", 0, 1, "")),
]
for case, expected in iterencode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeEncodeError) as cm:
list(codecs.iterencode(case, "idna"))
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
# codecs.getincrementalencoder.encode() does not throw an error
def test_errors(self):
"""Only supports "strict" error handler"""
"python.org".encode("idna", "strict")

View File

@ -303,7 +303,7 @@ class Test_IncrementalDecoder(unittest.TestCase):
self.assertRaises(TypeError, decoder.setstate, 123)
self.assertRaises(TypeError, decoder.setstate, ("invalid", 0))
self.assertRaises(TypeError, decoder.setstate, (b"1234", "invalid"))
self.assertRaises(UnicodeError, decoder.setstate, (b"123456789", 0))
self.assertRaises(UnicodeDecodeError, decoder.setstate, (b"123456789", 0))
class Test_StreamReader(unittest.TestCase):
def test_bug1728403(self):

View File

@ -0,0 +1,2 @@
Changes Unicode codecs to return UnicodeEncodeError or UnicodeDecodeError,
rather than just UnicodeError.

View File

@ -825,8 +825,15 @@ encoder_encode_stateful(MultibyteStatefulEncoderContext *ctx,
if (inpos < datalen) {
if (datalen - inpos > MAXENCPENDING) {
/* normal codecs can't reach here */
PyErr_SetString(PyExc_UnicodeError,
"pending buffer overflow");
PyObject *excobj = PyObject_CallFunction(PyExc_UnicodeEncodeError,
"sOnns",
ctx->codec->encoding,
inbuf,
inpos, datalen,
"pending buffer overflow");
if (excobj == NULL) goto errorexit;
PyErr_SetObject(PyExc_UnicodeEncodeError, excobj);
Py_DECREF(excobj);
goto errorexit;
}
ctx->pending = PyUnicode_Substring(inbuf, inpos, datalen);
@ -857,7 +864,16 @@ decoder_append_pending(MultibyteStatefulDecoderContext *ctx,
npendings = (Py_ssize_t)(buf->inbuf_end - buf->inbuf);
if (npendings + ctx->pendingsize > MAXDECPENDING ||
npendings > PY_SSIZE_T_MAX - ctx->pendingsize) {
PyErr_SetString(PyExc_UnicodeError, "pending buffer overflow");
Py_ssize_t bufsize = (Py_ssize_t)(buf->inbuf_end - buf->inbuf_top);
PyObject *excobj = PyUnicodeDecodeError_Create(ctx->codec->encoding,
(const char *)buf->inbuf_top,
bufsize,
0,
bufsize,
"pending buffer overflow");
if (excobj == NULL) return -1;
PyErr_SetObject(PyExc_UnicodeDecodeError, excobj);
Py_DECREF(excobj);
return -1;
}
memcpy(ctx->pending + ctx->pendingsize, buf->inbuf, npendings);
@ -938,7 +954,17 @@ _multibytecodec_MultibyteIncrementalEncoder_getstate_impl(MultibyteIncrementalEn
return NULL;
}
if (pendingsize > MAXENCPENDING*4) {
PyErr_SetString(PyExc_UnicodeError, "pending buffer too large");
PyObject *excobj = PyObject_CallFunction(PyExc_UnicodeEncodeError,
"sOnns",
self->codec->encoding,
self->pending,
0, PyUnicode_GET_LENGTH(self->pending),
"pending buffer too large");
if (excobj == NULL) {
return NULL;
}
PyErr_SetObject(PyExc_UnicodeEncodeError, excobj);
Py_DECREF(excobj);
return NULL;
}
statebytes[0] = (unsigned char)pendingsize;
@ -1267,7 +1293,13 @@ _multibytecodec_MultibyteIncrementalDecoder_setstate_impl(MultibyteIncrementalDe
}
if (buffersize > MAXDECPENDING) {
PyErr_SetString(PyExc_UnicodeError, "pending buffer too large");
PyObject *excobj = PyUnicodeDecodeError_Create(self->codec->encoding,
PyBytes_AS_STRING(buffer), buffersize,
0, buffersize,
"pending buffer too large");
if (excobj == NULL) return NULL;
PyErr_SetObject(PyExc_UnicodeDecodeError, excobj);
Py_DECREF(excobj);
return NULL;
}