diff --git a/Objects/unicodeobject.c b/Objects/unicodeobject.c index 35aa79f2bc8..9ab366d3be1 100644 --- a/Objects/unicodeobject.c +++ b/Objects/unicodeobject.c @@ -1637,6 +1637,51 @@ unicode_putchar(PyObject **p_unicode, Py_ssize_t *pos, return 0; } +/* Copy a ASCII or latin1 char* string into a Python Unicode string. + Return the length of the input string. + + WARNING: Don't copy the terminating null character and don't check the + maximum character (may write a latin1 character in an ASCII string). */ +static Py_ssize_t +unicode_write_cstr(PyObject *unicode, Py_ssize_t index, const char *str) +{ + enum PyUnicode_Kind kind = PyUnicode_KIND(unicode); + void *data = PyUnicode_DATA(unicode); + + switch (kind) { + case PyUnicode_1BYTE_KIND: { + Py_ssize_t len = strlen(str); + assert(index + len <= PyUnicode_GET_LENGTH(unicode)); + memcpy(data + index, str, len); + return len; + } + case PyUnicode_2BYTE_KIND: { + Py_UCS2 *start = (Py_UCS2 *)data + index; + Py_UCS2 *ucs2 = start; + assert(index <= PyUnicode_GET_LENGTH(unicode)); + + for (; *str; ++ucs2, ++str) + *ucs2 = (Py_UCS2)*str; + + assert((ucs2 - start) <= PyUnicode_GET_LENGTH(unicode)); + return ucs2 - start; + } + default: { + Py_UCS4 *start = (Py_UCS4 *)data + index; + Py_UCS4 *ucs4 = start; + assert(kind == PyUnicode_4BYTE_KIND); + assert(index <= PyUnicode_GET_LENGTH(unicode)); + + for (; *str; ++ucs4, ++str) + *ucs4 = (Py_UCS4)*str; + + assert((ucs4 - start) <= PyUnicode_GET_LENGTH(unicode)); + return ucs4 - start; + } + } +} + + static PyObject* get_latin1_char(unsigned char ch) { @@ -2590,19 +2635,23 @@ PyUnicode_FromFormatV(const char *format, va_list vargs) case 'u': case 'x': case 'p': + { + Py_ssize_t written; /* unused, since we already have the result */ if (*f == 'p') (void) va_arg(vargs, void *); else (void) va_arg(vargs, int); /* extract the result from numberresults and append. */ - for (; *numberresult; ++i, ++numberresult) - PyUnicode_WRITE(kind, data, i, *numberresult); + written = unicode_write_cstr(string, i, numberresult); /* skip over the separating '\0' */ + i += written; + numberresult += written; assert(*numberresult == '\0'); numberresult++; assert(numberresult <= numberresults + numbersize); break; + } case 's': { /* unused, since we already have the result */ @@ -2669,8 +2718,7 @@ PyUnicode_FromFormatV(const char *format, va_list vargs) PyUnicode_WRITE(kind, data, i++, '%'); break; default: - for (; *p; ++p, ++i) - PyUnicode_WRITE(kind, data, i, *p); + i += unicode_write_cstr(string, i, p); assert(i == PyUnicode_GET_LENGTH(string)); goto end; }