gh-85287: Change codecs to raise precise UnicodeEncodeError and UnicodeDecodeError (#113674)

Co-authored-by: Inada Naoki <songofacandy@gmail.com>
This commit is contained in:
John Sloboda 2024-03-17 00:58:42 -04:00 committed by GitHub
parent c514a975ab
commit 649857a157
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 306 additions and 81 deletions

View File

@ -11,7 +11,7 @@ ace_prefix = b"xn--"
sace_prefix = "xn--" sace_prefix = "xn--"
# This assumes query strings, so AllowUnassigned is true # This assumes query strings, so AllowUnassigned is true
def nameprep(label): def nameprep(label): # type: (str) -> str
# Map # Map
newlabel = [] newlabel = []
for c in label: for c in label:
@ -25,7 +25,7 @@ def nameprep(label):
label = unicodedata.normalize("NFKC", label) label = unicodedata.normalize("NFKC", label)
# Prohibit # Prohibit
for c in label: for i, c in enumerate(label):
if stringprep.in_table_c12(c) or \ if stringprep.in_table_c12(c) or \
stringprep.in_table_c22(c) or \ stringprep.in_table_c22(c) or \
stringprep.in_table_c3(c) or \ stringprep.in_table_c3(c) or \
@ -35,7 +35,7 @@ def nameprep(label):
stringprep.in_table_c7(c) or \ stringprep.in_table_c7(c) or \
stringprep.in_table_c8(c) or \ stringprep.in_table_c8(c) or \
stringprep.in_table_c9(c): stringprep.in_table_c9(c):
raise UnicodeError("Invalid character %r" % c) raise UnicodeEncodeError("idna", label, i, i+1, f"Invalid character {c!r}")
# Check bidi # Check bidi
RandAL = [stringprep.in_table_d1(x) for x in label] RandAL = [stringprep.in_table_d1(x) for x in label]
@ -46,29 +46,38 @@ def nameprep(label):
# This is table C.8, which was already checked # This is table C.8, which was already checked
# 2) If a string contains any RandALCat character, the string # 2) If a string contains any RandALCat character, the string
# MUST NOT contain any LCat character. # MUST NOT contain any LCat character.
if any(stringprep.in_table_d2(x) for x in label): for i, x in enumerate(label):
raise UnicodeError("Violation of BIDI requirement 2") if stringprep.in_table_d2(x):
raise UnicodeEncodeError("idna", label, i, i+1,
"Violation of BIDI requirement 2")
# 3) If a string contains any RandALCat character, a # 3) If a string contains any RandALCat character, a
# RandALCat character MUST be the first character of the # RandALCat character MUST be the first character of the
# string, and a RandALCat character MUST be the last # string, and a RandALCat character MUST be the last
# character of the string. # character of the string.
if not RandAL[0] or not RandAL[-1]: if not RandAL[0]:
raise UnicodeError("Violation of BIDI requirement 3") raise UnicodeEncodeError("idna", label, 0, 1,
"Violation of BIDI requirement 3")
if not RandAL[-1]:
raise UnicodeEncodeError("idna", label, len(label)-1, len(label),
"Violation of BIDI requirement 3")
return label return label
def ToASCII(label): def ToASCII(label): # type: (str) -> bytes
try: try:
# Step 1: try ASCII # Step 1: try ASCII
label = label.encode("ascii") label_ascii = label.encode("ascii")
except UnicodeError: except UnicodeEncodeError:
pass pass
else: else:
# Skip to step 3: UseSTD3ASCIIRules is false, so # Skip to step 3: UseSTD3ASCIIRules is false, so
# Skip to step 8. # Skip to step 8.
if 0 < len(label) < 64: if 0 < len(label_ascii) < 64:
return label return label_ascii
raise UnicodeError("label empty or too long") if len(label) == 0:
raise UnicodeEncodeError("idna", label, 0, 1, "label empty")
else:
raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
# Step 2: nameprep # Step 2: nameprep
label = nameprep(label) label = nameprep(label)
@ -76,29 +85,34 @@ def ToASCII(label):
# Step 3: UseSTD3ASCIIRules is false # Step 3: UseSTD3ASCIIRules is false
# Step 4: try ASCII # Step 4: try ASCII
try: try:
label = label.encode("ascii") label_ascii = label.encode("ascii")
except UnicodeError: except UnicodeEncodeError:
pass pass
else: else:
# Skip to step 8. # Skip to step 8.
if 0 < len(label) < 64: if 0 < len(label) < 64:
return label return label_ascii
raise UnicodeError("label empty or too long") if len(label) == 0:
raise UnicodeEncodeError("idna", label, 0, 1, "label empty")
else:
raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
# Step 5: Check ACE prefix # Step 5: Check ACE prefix
if label[:4].lower() == sace_prefix: if label.lower().startswith(sace_prefix):
raise UnicodeError("Label starts with ACE prefix") raise UnicodeEncodeError(
"idna", label, 0, len(sace_prefix), "Label starts with ACE prefix")
# Step 6: Encode with PUNYCODE # Step 6: Encode with PUNYCODE
label = label.encode("punycode") label_ascii = label.encode("punycode")
# Step 7: Prepend ACE prefix # Step 7: Prepend ACE prefix
label = ace_prefix + label label_ascii = ace_prefix + label_ascii
# Step 8: Check size # Step 8: Check size
if 0 < len(label) < 64: # do not check for empty as we prepend ace_prefix.
return label if len(label_ascii) < 64:
raise UnicodeError("label empty or too long") return label_ascii
raise UnicodeEncodeError("idna", label, 0, len(label), "label too long")
def ToUnicode(label): def ToUnicode(label):
if len(label) > 1024: if len(label) > 1024:
@ -110,7 +124,9 @@ def ToUnicode(label):
# per https://www.rfc-editor.org/rfc/rfc3454#section-3.1 while still # per https://www.rfc-editor.org/rfc/rfc3454#section-3.1 while still
# preventing us from wasting time decoding a big thing that'll just # preventing us from wasting time decoding a big thing that'll just
# hit the actual <= 63 length limit in Step 6. # hit the actual <= 63 length limit in Step 6.
raise UnicodeError("label way too long") if isinstance(label, str):
label = label.encode("utf-8", errors="backslashreplace")
raise UnicodeDecodeError("idna", label, 0, len(label), "label way too long")
# Step 1: Check for ASCII # Step 1: Check for ASCII
if isinstance(label, bytes): if isinstance(label, bytes):
pure_ascii = True pure_ascii = True
@ -118,25 +134,32 @@ def ToUnicode(label):
try: try:
label = label.encode("ascii") label = label.encode("ascii")
pure_ascii = True pure_ascii = True
except UnicodeError: except UnicodeEncodeError:
pure_ascii = False pure_ascii = False
if not pure_ascii: if not pure_ascii:
assert isinstance(label, str)
# Step 2: Perform nameprep # Step 2: Perform nameprep
label = nameprep(label) label = nameprep(label)
# It doesn't say this, but apparently, it should be ASCII now # It doesn't say this, but apparently, it should be ASCII now
try: try:
label = label.encode("ascii") label = label.encode("ascii")
except UnicodeError: except UnicodeEncodeError as exc:
raise UnicodeError("Invalid character in IDN label") raise UnicodeEncodeError("idna", label, exc.start, exc.end,
"Invalid character in IDN label")
# Step 3: Check for ACE prefix # Step 3: Check for ACE prefix
if not label[:4].lower() == ace_prefix: assert isinstance(label, bytes)
if not label.lower().startswith(ace_prefix):
return str(label, "ascii") return str(label, "ascii")
# Step 4: Remove ACE prefix # Step 4: Remove ACE prefix
label1 = label[len(ace_prefix):] label1 = label[len(ace_prefix):]
# Step 5: Decode using PUNYCODE # Step 5: Decode using PUNYCODE
result = label1.decode("punycode") try:
result = label1.decode("punycode")
except UnicodeDecodeError as exc:
offset = len(ace_prefix)
raise UnicodeDecodeError("idna", label, offset+exc.start, offset+exc.end, exc.reason)
# Step 6: Apply ToASCII # Step 6: Apply ToASCII
label2 = ToASCII(result) label2 = ToASCII(result)
@ -144,7 +167,8 @@ def ToUnicode(label):
# Step 7: Compare the result of step 6 with the one of step 3 # Step 7: Compare the result of step 6 with the one of step 3
# label2 will already be in lower case. # label2 will already be in lower case.
if str(label, "ascii").lower() != str(label2, "ascii"): if str(label, "ascii").lower() != str(label2, "ascii"):
raise UnicodeError("IDNA does not round-trip", label, label2) raise UnicodeDecodeError("idna", label, 0, len(label),
f"IDNA does not round-trip, '{label!r}' != '{label2!r}'")
# Step 8: return the result of step 5 # Step 8: return the result of step 5
return result return result
@ -156,7 +180,7 @@ class Codec(codecs.Codec):
if errors != 'strict': if errors != 'strict':
# IDNA is quite clear that implementations must be strict # IDNA is quite clear that implementations must be strict
raise UnicodeError("unsupported error handling "+errors) raise UnicodeError(f"Unsupported error handling: {errors}")
if not input: if not input:
return b'', 0 return b'', 0
@ -168,11 +192,16 @@ class Codec(codecs.Codec):
else: else:
# ASCII name: fast path # ASCII name: fast path
labels = result.split(b'.') labels = result.split(b'.')
for label in labels[:-1]: for i, label in enumerate(labels[:-1]):
if not (0 < len(label) < 64): if len(label) == 0:
raise UnicodeError("label empty or too long") offset = sum(len(l) for l in labels[:i]) + i
if len(labels[-1]) >= 64: raise UnicodeEncodeError("idna", input, offset, offset+1,
raise UnicodeError("label too long") "label empty")
for i, label in enumerate(labels):
if len(label) >= 64:
offset = sum(len(l) for l in labels[:i]) + i
raise UnicodeEncodeError("idna", input, offset, offset+len(label),
"label too long")
return result, len(input) return result, len(input)
result = bytearray() result = bytearray()
@ -182,17 +211,27 @@ class Codec(codecs.Codec):
del labels[-1] del labels[-1]
else: else:
trailing_dot = b'' trailing_dot = b''
for label in labels: for i, label in enumerate(labels):
if result: if result:
# Join with U+002E # Join with U+002E
result.extend(b'.') result.extend(b'.')
result.extend(ToASCII(label)) try:
result.extend(ToASCII(label))
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
offset = sum(len(l) for l in labels[:i]) + i
raise UnicodeEncodeError(
"idna",
input,
offset + exc.start,
offset + exc.end,
exc.reason,
)
return bytes(result+trailing_dot), len(input) return bytes(result+trailing_dot), len(input)
def decode(self, input, errors='strict'): def decode(self, input, errors='strict'):
if errors != 'strict': if errors != 'strict':
raise UnicodeError("Unsupported error handling "+errors) raise UnicodeError(f"Unsupported error handling: {errors}")
if not input: if not input:
return "", 0 return "", 0
@ -218,8 +257,15 @@ class Codec(codecs.Codec):
trailing_dot = '' trailing_dot = ''
result = [] result = []
for label in labels: for i, label in enumerate(labels):
result.append(ToUnicode(label)) try:
u_label = ToUnicode(label)
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
offset = sum(len(x) for x in labels[:i]) + len(labels[:i])
raise UnicodeDecodeError(
"idna", input, offset+exc.start, offset+exc.end, exc.reason)
else:
result.append(u_label)
return ".".join(result)+trailing_dot, len(input) return ".".join(result)+trailing_dot, len(input)
@ -227,7 +273,7 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
def _buffer_encode(self, input, errors, final): def _buffer_encode(self, input, errors, final):
if errors != 'strict': if errors != 'strict':
# IDNA is quite clear that implementations must be strict # IDNA is quite clear that implementations must be strict
raise UnicodeError("unsupported error handling "+errors) raise UnicodeError(f"Unsupported error handling: {errors}")
if not input: if not input:
return (b'', 0) return (b'', 0)
@ -251,7 +297,16 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
# Join with U+002E # Join with U+002E
result.extend(b'.') result.extend(b'.')
size += 1 size += 1
result.extend(ToASCII(label)) try:
result.extend(ToASCII(label))
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
raise UnicodeEncodeError(
"idna",
input,
size + exc.start,
size + exc.end,
exc.reason,
)
size += len(label) size += len(label)
result += trailing_dot result += trailing_dot
@ -261,7 +316,7 @@ class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
class IncrementalDecoder(codecs.BufferedIncrementalDecoder): class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
def _buffer_decode(self, input, errors, final): def _buffer_decode(self, input, errors, final):
if errors != 'strict': if errors != 'strict':
raise UnicodeError("Unsupported error handling "+errors) raise UnicodeError("Unsupported error handling: {errors}")
if not input: if not input:
return ("", 0) return ("", 0)
@ -271,7 +326,11 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
labels = dots.split(input) labels = dots.split(input)
else: else:
# Must be ASCII string # Must be ASCII string
input = str(input, "ascii") try:
input = str(input, "ascii")
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
raise UnicodeDecodeError("idna", input,
exc.start, exc.end, exc.reason)
labels = input.split(".") labels = input.split(".")
trailing_dot = '' trailing_dot = ''
@ -288,7 +347,18 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
result = [] result = []
size = 0 size = 0
for label in labels: for label in labels:
result.append(ToUnicode(label)) try:
u_label = ToUnicode(label)
except (UnicodeEncodeError, UnicodeDecodeError) as exc:
raise UnicodeDecodeError(
"idna",
input.encode("ascii", errors="backslashreplace"),
size + exc.start,
size + exc.end,
exc.reason,
)
else:
result.append(u_label)
if size: if size:
size += 1 size += 1
size += len(label) size += len(label)

View File

@ -1,4 +1,4 @@
""" Codec for the Punicode encoding, as specified in RFC 3492 """ Codec for the Punycode encoding, as specified in RFC 3492
Written by Martin v. Löwis. Written by Martin v. Löwis.
""" """
@ -131,10 +131,11 @@ def decode_generalized_number(extended, extpos, bias, errors):
j = 0 j = 0
while 1: while 1:
try: try:
char = ord(extended[extpos]) char = extended[extpos]
except IndexError: except IndexError:
if errors == "strict": if errors == "strict":
raise UnicodeError("incomplete punicode string") raise UnicodeDecodeError("punycode", extended, extpos, extpos+1,
"incomplete punycode string")
return extpos + 1, None return extpos + 1, None
extpos += 1 extpos += 1
if 0x41 <= char <= 0x5A: # A-Z if 0x41 <= char <= 0x5A: # A-Z
@ -142,8 +143,8 @@ def decode_generalized_number(extended, extpos, bias, errors):
elif 0x30 <= char <= 0x39: elif 0x30 <= char <= 0x39:
digit = char - 22 # 0x30-26 digit = char - 22 # 0x30-26
elif errors == "strict": elif errors == "strict":
raise UnicodeError("Invalid extended code point '%s'" raise UnicodeDecodeError("punycode", extended, extpos-1, extpos,
% extended[extpos-1]) f"Invalid extended code point '{extended[extpos-1]}'")
else: else:
return extpos, None return extpos, None
t = T(j, bias) t = T(j, bias)
@ -155,11 +156,14 @@ def decode_generalized_number(extended, extpos, bias, errors):
def insertion_sort(base, extended, errors): def insertion_sort(base, extended, errors):
"""3.2 Insertion unsort coding""" """3.2 Insertion sort coding"""
# This function raises UnicodeDecodeError with position in the extended.
# Caller should add the offset.
char = 0x80 char = 0x80
pos = -1 pos = -1
bias = 72 bias = 72
extpos = 0 extpos = 0
while extpos < len(extended): while extpos < len(extended):
newpos, delta = decode_generalized_number(extended, extpos, newpos, delta = decode_generalized_number(extended, extpos,
bias, errors) bias, errors)
@ -171,7 +175,9 @@ def insertion_sort(base, extended, errors):
char += pos // (len(base) + 1) char += pos // (len(base) + 1)
if char > 0x10FFFF: if char > 0x10FFFF:
if errors == "strict": if errors == "strict":
raise UnicodeError("Invalid character U+%x" % char) raise UnicodeDecodeError(
"punycode", extended, pos-1, pos,
f"Invalid character U+{char:x}")
char = ord('?') char = ord('?')
pos = pos % (len(base) + 1) pos = pos % (len(base) + 1)
base = base[:pos] + chr(char) + base[pos:] base = base[:pos] + chr(char) + base[pos:]
@ -187,11 +193,21 @@ def punycode_decode(text, errors):
pos = text.rfind(b"-") pos = text.rfind(b"-")
if pos == -1: if pos == -1:
base = "" base = ""
extended = str(text, "ascii").upper() extended = text.upper()
else: else:
base = str(text[:pos], "ascii", errors) try:
extended = str(text[pos+1:], "ascii").upper() base = str(text[:pos], "ascii", errors)
return insertion_sort(base, extended, errors) except UnicodeDecodeError as exc:
raise UnicodeDecodeError("ascii", text, exc.start, exc.end,
exc.reason) from None
extended = text[pos+1:].upper()
try:
return insertion_sort(base, extended, errors)
except UnicodeDecodeError as exc:
offset = pos + 1
raise UnicodeDecodeError("punycode", text,
offset+exc.start, offset+exc.end,
exc.reason) from None
### Codec APIs ### Codec APIs
@ -203,7 +219,7 @@ class Codec(codecs.Codec):
def decode(self, input, errors='strict'): def decode(self, input, errors='strict'):
if errors not in ('strict', 'replace', 'ignore'): if errors not in ('strict', 'replace', 'ignore'):
raise UnicodeError("Unsupported error handling "+errors) raise UnicodeError(f"Unsupported error handling: {errors}")
res = punycode_decode(input, errors) res = punycode_decode(input, errors)
return res, len(input) return res, len(input)
@ -214,7 +230,7 @@ class IncrementalEncoder(codecs.IncrementalEncoder):
class IncrementalDecoder(codecs.IncrementalDecoder): class IncrementalDecoder(codecs.IncrementalDecoder):
def decode(self, input, final=False): def decode(self, input, final=False):
if self.errors not in ('strict', 'replace', 'ignore'): if self.errors not in ('strict', 'replace', 'ignore'):
raise UnicodeError("Unsupported error handling "+self.errors) raise UnicodeError(f"Unsupported error handling: {self.errors}")
return punycode_decode(input, self.errors) return punycode_decode(input, self.errors)
class StreamWriter(Codec,codecs.StreamWriter): class StreamWriter(Codec,codecs.StreamWriter):

View File

@ -1,6 +1,6 @@
""" Python 'undefined' Codec """ Python 'undefined' Codec
This codec will always raise a ValueError exception when being This codec will always raise a UnicodeError exception when being
used. It is intended for use by the site.py file to switch off used. It is intended for use by the site.py file to switch off
automatic string to Unicode coercion. automatic string to Unicode coercion.

View File

@ -64,7 +64,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
elif byteorder == 1: elif byteorder == 1:
self.decoder = codecs.utf_16_be_decode self.decoder = codecs.utf_16_be_decode
elif consumed >= 2: elif consumed >= 2:
raise UnicodeError("UTF-16 stream does not start with BOM") raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not start with BOM")
return (output, consumed) return (output, consumed)
return self.decoder(input, self.errors, final) return self.decoder(input, self.errors, final)
@ -138,7 +138,7 @@ class StreamReader(codecs.StreamReader):
elif byteorder == 1: elif byteorder == 1:
self.decode = codecs.utf_16_be_decode self.decode = codecs.utf_16_be_decode
elif consumed>=2: elif consumed>=2:
raise UnicodeError("UTF-16 stream does not start with BOM") raise UnicodeDecodeError("utf-16", input, 0, 2, "Stream does not start with BOM")
return (object, consumed) return (object, consumed)
### encodings module API ### encodings module API

View File

@ -59,7 +59,7 @@ class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
elif byteorder == 1: elif byteorder == 1:
self.decoder = codecs.utf_32_be_decode self.decoder = codecs.utf_32_be_decode
elif consumed >= 4: elif consumed >= 4:
raise UnicodeError("UTF-32 stream does not start with BOM") raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not start with BOM")
return (output, consumed) return (output, consumed)
return self.decoder(input, self.errors, final) return self.decoder(input, self.errors, final)
@ -132,8 +132,8 @@ class StreamReader(codecs.StreamReader):
self.decode = codecs.utf_32_le_decode self.decode = codecs.utf_32_le_decode
elif byteorder == 1: elif byteorder == 1:
self.decode = codecs.utf_32_be_decode self.decode = codecs.utf_32_be_decode
elif consumed>=4: elif consumed >= 4:
raise UnicodeError("UTF-32 stream does not start with BOM") raise UnicodeDecodeError("utf-32", input, 0, 4, "Stream does not start with BOM")
return (object, consumed) return (object, consumed)
### encodings module API ### encodings module API

View File

@ -482,11 +482,11 @@ class UTF32Test(ReadTest, unittest.TestCase):
def test_badbom(self): def test_badbom(self):
s = io.BytesIO(4*b"\xff") s = io.BytesIO(4*b"\xff")
f = codecs.getreader(self.encoding)(s) f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read) self.assertRaises(UnicodeDecodeError, f.read)
s = io.BytesIO(8*b"\xff") s = io.BytesIO(8*b"\xff")
f = codecs.getreader(self.encoding)(s) f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read) self.assertRaises(UnicodeDecodeError, f.read)
def test_partial(self): def test_partial(self):
self.check_partial( self.check_partial(
@ -666,11 +666,11 @@ class UTF16Test(ReadTest, unittest.TestCase):
def test_badbom(self): def test_badbom(self):
s = io.BytesIO(b"\xff\xff") s = io.BytesIO(b"\xff\xff")
f = codecs.getreader(self.encoding)(s) f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read) self.assertRaises(UnicodeDecodeError, f.read)
s = io.BytesIO(b"\xff\xff\xff\xff") s = io.BytesIO(b"\xff\xff\xff\xff")
f = codecs.getreader(self.encoding)(s) f = codecs.getreader(self.encoding)(s)
self.assertRaises(UnicodeError, f.read) self.assertRaises(UnicodeDecodeError, f.read)
def test_partial(self): def test_partial(self):
self.check_partial( self.check_partial(
@ -1356,13 +1356,29 @@ class PunycodeTest(unittest.TestCase):
def test_decode_invalid(self): def test_decode_invalid(self):
testcases = [ testcases = [
(b"xn--w&", "strict", UnicodeError()), (b"xn--w&", "strict", UnicodeDecodeError("punycode", b"", 5, 6, "")),
(b"&egbpdaj6bu4bxfgehfvwxn", "strict", UnicodeDecodeError("punycode", b"", 0, 1, "")),
(b"egbpdaj6bu&4bx&fgehfvwxn", "strict", UnicodeDecodeError("punycode", b"", 10, 11, "")),
(b"egbpdaj6bu4bxfgehfvwxn&", "strict", UnicodeDecodeError("punycode", b"", 22, 23, "")),
(b"\xFFProprostnemluvesky-uyb24dma41a", "strict", UnicodeDecodeError("ascii", b"", 0, 1, "")),
(b"Pro\xFFprostnemluvesky-uyb24dma41a", "strict", UnicodeDecodeError("ascii", b"", 3, 4, "")),
(b"Proprost&nemluvesky-uyb24&dma41a", "strict", UnicodeDecodeError("punycode", b"", 25, 26, "")),
(b"Proprostnemluvesky&-&uyb24dma41a", "strict", UnicodeDecodeError("punycode", b"", 20, 21, "")),
(b"Proprostnemluvesky-&uyb24dma41a", "strict", UnicodeDecodeError("punycode", b"", 19, 20, "")),
(b"Proprostnemluvesky-uyb24d&ma41a", "strict", UnicodeDecodeError("punycode", b"", 25, 26, "")),
(b"Proprostnemluvesky-uyb24dma41a&", "strict", UnicodeDecodeError("punycode", b"", 30, 31, "")),
(b"xn--w&", "ignore", "xn-"), (b"xn--w&", "ignore", "xn-"),
] ]
for puny, errors, expected in testcases: for puny, errors, expected in testcases:
with self.subTest(puny=puny, errors=errors): with self.subTest(puny=puny, errors=errors):
if isinstance(expected, Exception): if isinstance(expected, Exception):
self.assertRaises(UnicodeError, puny.decode, "punycode", errors) with self.assertRaises(UnicodeDecodeError) as cm:
puny.decode("punycode", errors)
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, puny)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
else: else:
self.assertEqual(puny.decode("punycode", errors), expected) self.assertEqual(puny.decode("punycode", errors), expected)
@ -1532,7 +1548,7 @@ class NameprepTest(unittest.TestCase):
orig = str(orig, "utf-8", "surrogatepass") orig = str(orig, "utf-8", "surrogatepass")
if prepped is None: if prepped is None:
# Input contains prohibited characters # Input contains prohibited characters
self.assertRaises(UnicodeError, nameprep, orig) self.assertRaises(UnicodeEncodeError, nameprep, orig)
else: else:
prepped = str(prepped, "utf-8", "surrogatepass") prepped = str(prepped, "utf-8", "surrogatepass")
try: try:
@ -1542,6 +1558,23 @@ class NameprepTest(unittest.TestCase):
class IDNACodecTest(unittest.TestCase): class IDNACodecTest(unittest.TestCase):
invalid_decode_testcases = [
(b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFFpython.org", 0, 1, "")),
(b"pyt\xFFhon.org", UnicodeDecodeError("idna", b"pyt\xFFhon.org", 3, 4, "")),
(b"python\xFF.org", UnicodeDecodeError("idna", b"python\xFF.org", 6, 7, "")),
(b"python.\xFForg", UnicodeDecodeError("idna", b"python.\xFForg", 7, 8, "")),
(b"python.o\xFFrg", UnicodeDecodeError("idna", b"python.o\xFFrg", 8, 9, "")),
(b"python.org\xFF", UnicodeDecodeError("idna", b"python.org\xFF", 10, 11, "")),
(b"xn--pythn-&mua.org", UnicodeDecodeError("idna", b"xn--pythn-&mua.org", 10, 11, "")),
(b"xn--pythn-m&ua.org", UnicodeDecodeError("idna", b"xn--pythn-m&ua.org", 11, 12, "")),
(b"xn--pythn-mua&.org", UnicodeDecodeError("idna", b"xn--pythn-mua&.org", 13, 14, "")),
]
invalid_encode_testcases = [
(f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"foo.{'\xff'*60}", 4, 64, "")),
("あさ.\u034f", UnicodeEncodeError("idna", "あさ.\u034f", 3, 4, "")),
]
def test_builtin_decode(self): def test_builtin_decode(self):
self.assertEqual(str(b"python.org", "idna"), "python.org") self.assertEqual(str(b"python.org", "idna"), "python.org")
self.assertEqual(str(b"python.org.", "idna"), "python.org.") self.assertEqual(str(b"python.org.", "idna"), "python.org.")
@ -1555,16 +1588,38 @@ class IDNACodecTest(unittest.TestCase):
self.assertEqual(str(b"bugs.XN--pythn-mua.org.", "idna"), self.assertEqual(str(b"bugs.XN--pythn-mua.org.", "idna"),
"bugs.pyth\xf6n.org.") "bugs.pyth\xf6n.org.")
def test_builtin_decode_invalid(self):
for case, expected in self.invalid_decode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeDecodeError) as cm:
case.decode("idna")
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start, msg=f'reason: {exc.reason}')
self.assertEqual(exc.end, expected.end)
def test_builtin_encode(self): def test_builtin_encode(self):
self.assertEqual("python.org".encode("idna"), b"python.org") self.assertEqual("python.org".encode("idna"), b"python.org")
self.assertEqual("python.org.".encode("idna"), b"python.org.") self.assertEqual("python.org.".encode("idna"), b"python.org.")
self.assertEqual("pyth\xf6n.org".encode("idna"), b"xn--pythn-mua.org") self.assertEqual("pyth\xf6n.org".encode("idna"), b"xn--pythn-mua.org")
self.assertEqual("pyth\xf6n.org.".encode("idna"), b"xn--pythn-mua.org.") self.assertEqual("pyth\xf6n.org.".encode("idna"), b"xn--pythn-mua.org.")
def test_builtin_encode_invalid(self):
for case, expected in self.invalid_encode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeEncodeError) as cm:
case.encode("idna")
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
def test_builtin_decode_length_limit(self): def test_builtin_decode_length_limit(self):
with self.assertRaisesRegex(UnicodeError, "way too long"): with self.assertRaisesRegex(UnicodeDecodeError, "way too long"):
(b"xn--016c"+b"a"*1100).decode("idna") (b"xn--016c"+b"a"*1100).decode("idna")
with self.assertRaisesRegex(UnicodeError, "too long"): with self.assertRaisesRegex(UnicodeDecodeError, "too long"):
(b"xn--016c"+b"a"*70).decode("idna") (b"xn--016c"+b"a"*70).decode("idna")
def test_stream(self): def test_stream(self):
@ -1602,6 +1657,39 @@ class IDNACodecTest(unittest.TestCase):
self.assertEqual(decoder.decode(b"rg."), "org.") self.assertEqual(decoder.decode(b"rg."), "org.")
self.assertEqual(decoder.decode(b"", True), "") self.assertEqual(decoder.decode(b"", True), "")
def test_incremental_decode_invalid(self):
iterdecode_testcases = [
(b"\xFFpython.org", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")),
(b"pyt\xFFhon.org", UnicodeDecodeError("idna", b"pyt\xFF", 3, 4, "")),
(b"python\xFF.org", UnicodeDecodeError("idna", b"python\xFF", 6, 7, "")),
(b"python.\xFForg", UnicodeDecodeError("idna", b"\xFF", 0, 1, "")),
(b"python.o\xFFrg", UnicodeDecodeError("idna", b"o\xFF", 1, 2, "")),
(b"python.org\xFF", UnicodeDecodeError("idna", b"org\xFF", 3, 4, "")),
(b"xn--pythn-&mua.org", UnicodeDecodeError("idna", b"xn--pythn-&mua.", 10, 11, "")),
(b"xn--pythn-m&ua.org", UnicodeDecodeError("idna", b"xn--pythn-m&ua.", 11, 12, "")),
(b"xn--pythn-mua&.org", UnicodeDecodeError("idna", b"xn--pythn-mua&.", 13, 14, "")),
]
for case, expected in iterdecode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeDecodeError) as cm:
list(codecs.iterdecode((bytes([c]) for c in case), "idna"))
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
decoder = codecs.getincrementaldecoder("idna")()
for case, expected in self.invalid_decode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeDecodeError) as cm:
decoder.decode(case)
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
def test_incremental_encode(self): def test_incremental_encode(self):
self.assertEqual( self.assertEqual(
b"".join(codecs.iterencode("python.org", "idna")), b"".join(codecs.iterencode("python.org", "idna")),
@ -1630,6 +1718,23 @@ class IDNACodecTest(unittest.TestCase):
self.assertEqual(encoder.encode("ample.org."), b"xn--xample-9ta.org.") self.assertEqual(encoder.encode("ample.org."), b"xn--xample-9ta.org.")
self.assertEqual(encoder.encode("", True), b"") self.assertEqual(encoder.encode("", True), b"")
def test_incremental_encode_invalid(self):
iterencode_testcases = [
(f"foo.{'\xff'*60}", UnicodeEncodeError("idna", f"{'\xff'*60}", 0, 60, "")),
("あさ.\u034f", UnicodeEncodeError("idna", "\u034f", 0, 1, "")),
]
for case, expected in iterencode_testcases:
with self.subTest(case=case, expected=expected):
with self.assertRaises(UnicodeEncodeError) as cm:
list(codecs.iterencode(case, "idna"))
exc = cm.exception
self.assertEqual(exc.encoding, expected.encoding)
self.assertEqual(exc.object, expected.object)
self.assertEqual(exc.start, expected.start)
self.assertEqual(exc.end, expected.end)
# codecs.getincrementalencoder.encode() does not throw an error
def test_errors(self): def test_errors(self):
"""Only supports "strict" error handler""" """Only supports "strict" error handler"""
"python.org".encode("idna", "strict") "python.org".encode("idna", "strict")

View File

@ -303,7 +303,7 @@ class Test_IncrementalDecoder(unittest.TestCase):
self.assertRaises(TypeError, decoder.setstate, 123) self.assertRaises(TypeError, decoder.setstate, 123)
self.assertRaises(TypeError, decoder.setstate, ("invalid", 0)) self.assertRaises(TypeError, decoder.setstate, ("invalid", 0))
self.assertRaises(TypeError, decoder.setstate, (b"1234", "invalid")) self.assertRaises(TypeError, decoder.setstate, (b"1234", "invalid"))
self.assertRaises(UnicodeError, decoder.setstate, (b"123456789", 0)) self.assertRaises(UnicodeDecodeError, decoder.setstate, (b"123456789", 0))
class Test_StreamReader(unittest.TestCase): class Test_StreamReader(unittest.TestCase):
def test_bug1728403(self): def test_bug1728403(self):

View File

@ -0,0 +1,2 @@
Changes Unicode codecs to return UnicodeEncodeError or UnicodeDecodeError,
rather than just UnicodeError.

View File

@ -825,8 +825,15 @@ encoder_encode_stateful(MultibyteStatefulEncoderContext *ctx,
if (inpos < datalen) { if (inpos < datalen) {
if (datalen - inpos > MAXENCPENDING) { if (datalen - inpos > MAXENCPENDING) {
/* normal codecs can't reach here */ /* normal codecs can't reach here */
PyErr_SetString(PyExc_UnicodeError, PyObject *excobj = PyObject_CallFunction(PyExc_UnicodeEncodeError,
"pending buffer overflow"); "sOnns",
ctx->codec->encoding,
inbuf,
inpos, datalen,
"pending buffer overflow");
if (excobj == NULL) goto errorexit;
PyErr_SetObject(PyExc_UnicodeEncodeError, excobj);
Py_DECREF(excobj);
goto errorexit; goto errorexit;
} }
ctx->pending = PyUnicode_Substring(inbuf, inpos, datalen); ctx->pending = PyUnicode_Substring(inbuf, inpos, datalen);
@ -857,7 +864,16 @@ decoder_append_pending(MultibyteStatefulDecoderContext *ctx,
npendings = (Py_ssize_t)(buf->inbuf_end - buf->inbuf); npendings = (Py_ssize_t)(buf->inbuf_end - buf->inbuf);
if (npendings + ctx->pendingsize > MAXDECPENDING || if (npendings + ctx->pendingsize > MAXDECPENDING ||
npendings > PY_SSIZE_T_MAX - ctx->pendingsize) { npendings > PY_SSIZE_T_MAX - ctx->pendingsize) {
PyErr_SetString(PyExc_UnicodeError, "pending buffer overflow"); Py_ssize_t bufsize = (Py_ssize_t)(buf->inbuf_end - buf->inbuf_top);
PyObject *excobj = PyUnicodeDecodeError_Create(ctx->codec->encoding,
(const char *)buf->inbuf_top,
bufsize,
0,
bufsize,
"pending buffer overflow");
if (excobj == NULL) return -1;
PyErr_SetObject(PyExc_UnicodeDecodeError, excobj);
Py_DECREF(excobj);
return -1; return -1;
} }
memcpy(ctx->pending + ctx->pendingsize, buf->inbuf, npendings); memcpy(ctx->pending + ctx->pendingsize, buf->inbuf, npendings);
@ -938,7 +954,17 @@ _multibytecodec_MultibyteIncrementalEncoder_getstate_impl(MultibyteIncrementalEn
return NULL; return NULL;
} }
if (pendingsize > MAXENCPENDING*4) { if (pendingsize > MAXENCPENDING*4) {
PyErr_SetString(PyExc_UnicodeError, "pending buffer too large"); PyObject *excobj = PyObject_CallFunction(PyExc_UnicodeEncodeError,
"sOnns",
self->codec->encoding,
self->pending,
0, PyUnicode_GET_LENGTH(self->pending),
"pending buffer too large");
if (excobj == NULL) {
return NULL;
}
PyErr_SetObject(PyExc_UnicodeEncodeError, excobj);
Py_DECREF(excobj);
return NULL; return NULL;
} }
statebytes[0] = (unsigned char)pendingsize; statebytes[0] = (unsigned char)pendingsize;
@ -1267,7 +1293,13 @@ _multibytecodec_MultibyteIncrementalDecoder_setstate_impl(MultibyteIncrementalDe
} }
if (buffersize > MAXDECPENDING) { if (buffersize > MAXDECPENDING) {
PyErr_SetString(PyExc_UnicodeError, "pending buffer too large"); PyObject *excobj = PyUnicodeDecodeError_Create(self->codec->encoding,
PyBytes_AS_STRING(buffer), buffersize,
0, buffersize,
"pending buffer too large");
if (excobj == NULL) return NULL;
PyErr_SetObject(PyExc_UnicodeDecodeError, excobj);
Py_DECREF(excobj);
return NULL; return NULL;
} }