diff --git a/Lib/test/test_codeccallbacks.py b/Lib/test/test_codeccallbacks.py index 77834c4e486..f76ec656e5a 100644 --- a/Lib/test/test_codeccallbacks.py +++ b/Lib/test/test_codeccallbacks.py @@ -806,6 +806,39 @@ class CodecCallbackTest(unittest.TestCase): text = 'abcghi'*n text.translate(charmap) + def test_mutatingdecodehandler(self): + baddata = [ + ("ascii", b"\xff"), + ("utf-7", b"++"), + ("utf-8", b"\xff"), + ("utf-16", b"\xff"), + ("unicode-escape", b"\\u123g"), + ("raw-unicode-escape", b"\\u123g"), + ("unicode-internal", b"\xff"), + ] + + def replacing(exc): + if isinstance(exc, UnicodeDecodeError): + exc.object = 42 + return ("\u4242", 0) + else: + raise TypeError("don't know how to handle %r" % exc) + codecs.register_error("test.replacing", replacing) + for (encoding, data) in baddata: + self.assertRaises(TypeError, data.decode, encoding, "test.replacing") + + def mutating(exc): + if isinstance(exc, UnicodeDecodeError): + exc.object[:] = b"" + return ("\u4242", 0) + else: + raise TypeError("don't know how to handle %r" % exc) + codecs.register_error("test.mutating", mutating) + # If the decoder doesn't pick up the modified input the following + # will lead to an endless loop + for (encoding, data) in baddata: + self.assertRaises(TypeError, data.decode, encoding, "test.replacing") + def test_main(): test.test_support.run_unittest(CodecCallbackTest) diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 6944eabd42e..d1b5747f579 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -1269,7 +1269,7 @@ int PyUnicode_SetDefaultEncoding(const char *encoding) static int unicode_decode_call_errorhandler(const char *errors, PyObject **errorHandler, const char *encoding, const char *reason, - const char *input, Py_ssize_t insize, Py_ssize_t *startinpos, Py_ssize_t *endinpos, PyObject **exceptionObject, const char **inptr, + const char **input, const char **inend, Py_ssize_t *startinpos, Py_ssize_t *endinpos, PyObject **exceptionObject, const char **inptr, PyObject **output, Py_ssize_t *outpos, Py_UNICODE **outptr) { static char *argparse = "O!n;decoding error handler must return (unicode, int) tuple"; @@ -1277,9 +1277,11 @@ int unicode_decode_call_errorhandler(const char *errors, PyObject **errorHandler PyObject *restuple = NULL; PyObject *repunicode = NULL; Py_ssize_t outsize = PyUnicode_GET_SIZE(*output); + Py_ssize_t insize; Py_ssize_t requiredsize; Py_ssize_t newpos; Py_UNICODE *repptr; + PyObject *inputobj = NULL; Py_ssize_t repsize; int res = -1; @@ -1291,7 +1293,7 @@ int unicode_decode_call_errorhandler(const char *errors, PyObject **errorHandler if (*exceptionObject == NULL) { *exceptionObject = PyUnicodeDecodeError_Create( - encoding, input, insize, *startinpos, *endinpos, reason); + encoding, *input, *inend-*input, *startinpos, *endinpos, reason); if (*exceptionObject == NULL) goto onError; } @@ -1313,6 +1315,19 @@ int unicode_decode_call_errorhandler(const char *errors, PyObject **errorHandler } if (!PyArg_ParseTuple(restuple, argparse, &PyUnicode_Type, &repunicode, &newpos)) goto onError; + + /* Copy back the bytes variables, which might have been modified by the + callback */ + inputobj = PyUnicodeDecodeError_GetObject(*exceptionObject); + if (!inputobj) + goto onError; + if (!PyBytes_Check(inputobj)) { + PyErr_Format(PyExc_TypeError, "exception attribute object must be bytes"); + } + *input = PyBytes_AS_STRING(inputobj); + insize = PyBytes_GET_SIZE(inputobj); + *inend = *input + insize; + if (newpos<0) newpos = insize+newpos; if (newpos<0 || newpos>insize) { @@ -1335,10 +1350,11 @@ int unicode_decode_call_errorhandler(const char *errors, PyObject **errorHandler *outptr = PyUnicode_AS_UNICODE(*output) + *outpos; } *endinpos = newpos; - *inptr = input + newpos; + *inptr = *input + newpos; Py_UNICODE_COPY(*outptr, repptr, repsize); *outptr += repsize; *outpos += repsize; + /* we made it! */ res = 0; @@ -1503,7 +1519,7 @@ PyObject *PyUnicode_DecodeUTF7(const char *s, else if (SPECIAL(ch,0,0)) { errmsg = "unexpected special character"; s++; - goto utf7Error; + goto utf7Error; } else { *p++ = ch; @@ -1516,7 +1532,7 @@ PyObject *PyUnicode_DecodeUTF7(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "utf7", errmsg, - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &e, &startinpos, &endinpos, &exc, &s, (PyObject **)&unicode, &outpos, &p)) goto onError; } @@ -1527,7 +1543,7 @@ PyObject *PyUnicode_DecodeUTF7(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "utf7", "unterminated shift sequence", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &e, &startinpos, &endinpos, &exc, &s, (PyObject **)&unicode, &outpos, &p)) goto onError; if (s < e) @@ -1848,7 +1864,7 @@ PyObject *PyUnicode_DecodeUTF8Stateful(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "utf8", errmsg, - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &e, &startinpos, &endinpos, &exc, &s, (PyObject **)&unicode, &outpos, &p)) goto onError; } @@ -2132,7 +2148,7 @@ PyUnicode_DecodeUTF16Stateful(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "utf16", errmsg, - starts, size, &startinpos, &endinpos, &exc, (const char **)&q, + &starts, (const char **)&e, &startinpos, &endinpos, &exc, (const char **)&q, (PyObject **)&unicode, &outpos, &p)) goto onError; } @@ -2342,7 +2358,7 @@ PyObject *PyUnicode_DecodeUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "unicodeescape", "end of string in escape sequence", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; goto nextByte; @@ -2354,7 +2370,7 @@ PyObject *PyUnicode_DecodeUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "unicodeescape", message, - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; goto nextByte; @@ -2393,7 +2409,7 @@ PyObject *PyUnicode_DecodeUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "unicodeescape", "illegal Unicode character", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; } @@ -2435,7 +2451,7 @@ PyObject *PyUnicode_DecodeUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "unicodeescape", message, - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; break; @@ -2449,7 +2465,7 @@ PyObject *PyUnicode_DecodeUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "unicodeescape", message, - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; } @@ -2728,7 +2744,7 @@ PyObject *PyUnicode_DecodeRawUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "rawunicodeescape", "truncated \\uXXXX", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; goto nextByte; @@ -2746,7 +2762,7 @@ PyObject *PyUnicode_DecodeRawUnicodeEscape(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "rawunicodeescape", "\\Uxxxxxxxx out of range", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; } @@ -2897,7 +2913,7 @@ PyObject *_PyUnicode_DecodeUnicodeInternal(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "unicode_internal", reason, - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &end, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) { goto onError; } @@ -3277,7 +3293,7 @@ PyObject *PyUnicode_DecodeASCII(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "ascii", "ordinal not in range(128)", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &e, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) goto onError; } @@ -3578,7 +3594,7 @@ PyObject *PyUnicode_DecodeCharmap(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "charmap", "character maps to ", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &e, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) { goto onError; } @@ -3628,7 +3644,7 @@ PyObject *PyUnicode_DecodeCharmap(const char *s, if (unicode_decode_call_errorhandler( errors, &errorHandler, "charmap", "character maps to ", - starts, size, &startinpos, &endinpos, &exc, &s, + &starts, &e, &startinpos, &endinpos, &exc, &s, (PyObject **)&v, &outpos, &p)) { Py_DECREF(x); goto onError;