Rewrite PyBytes_FromFormatV() using _PyBytesWriter API

* Add much more unit tests on PyBytes_FromFormatV()
* Remove the first loop to compute the length of the output string
* Use _PyBytesWriter to handle the bytes buffer, use overallocation
* Cleanup the code to make simpler and easier to review
This commit is contained in:
Victor Stinner 2015-10-14 00:21:35 +02:00
parent b6d84832bf
commit 03dab786b2
2 changed files with 243 additions and 181 deletions

View File

@ -783,25 +783,93 @@ class BytesTest(BaseBytesTest, unittest.TestCase):
# Test PyBytes_FromFormat()
def test_from_format(self):
test.support.import_module('ctypes')
from ctypes import pythonapi, py_object, c_int, c_char_p
_testcapi = test.support.import_module('_testcapi')
from ctypes import pythonapi, py_object
from ctypes import (
c_int, c_uint,
c_long, c_ulong,
c_size_t, c_ssize_t,
c_char_p)
PyBytes_FromFormat = pythonapi.PyBytes_FromFormat
PyBytes_FromFormat.restype = py_object
# basic tests
self.assertEqual(PyBytes_FromFormat(b'format'),
b'format')
self.assertEqual(PyBytes_FromFormat(b'Hello %s !', b'world'),
b'Hello world !')
# test formatters
self.assertEqual(PyBytes_FromFormat(b'c=%c', c_int(0)),
b'c=\0')
self.assertEqual(PyBytes_FromFormat(b'c=%c', c_int(ord('@'))),
b'c=@')
self.assertEqual(PyBytes_FromFormat(b'c=%c', c_int(255)),
b'c=\xff')
self.assertEqual(PyBytes_FromFormat(b'd=%d ld=%ld zd=%zd',
c_int(1), c_long(2),
c_size_t(3)),
b'd=1 ld=2 zd=3')
self.assertEqual(PyBytes_FromFormat(b'd=%d ld=%ld zd=%zd',
c_int(-1), c_long(-2),
c_size_t(-3)),
b'd=-1 ld=-2 zd=-3')
self.assertEqual(PyBytes_FromFormat(b'u=%u lu=%lu zu=%zu',
c_uint(123), c_ulong(456),
c_size_t(789)),
b'u=123 lu=456 zu=789')
self.assertEqual(PyBytes_FromFormat(b'i=%i', c_int(123)),
b'i=123')
self.assertEqual(PyBytes_FromFormat(b'i=%i', c_int(-123)),
b'i=-123')
self.assertEqual(PyBytes_FromFormat(b'x=%x', c_int(0xabc)),
b'x=abc')
self.assertEqual(PyBytes_FromFormat(b'ptr=%p',
c_char_p(0xabcdef)),
b'ptr=0xabcdef')
self.assertEqual(PyBytes_FromFormat(b's=%s', c_char_p(b'cstr')),
b's=cstr')
# test minimum and maximum integer values
size_max = c_size_t(-1).value
for formatstr, ctypes_type, value, py_formatter in (
(b'%d', c_int, _testcapi.INT_MIN, str),
(b'%d', c_int, _testcapi.INT_MAX, str),
(b'%ld', c_long, _testcapi.LONG_MIN, str),
(b'%ld', c_long, _testcapi.LONG_MAX, str),
(b'%lu', c_ulong, _testcapi.ULONG_MAX, str),
(b'%zd', c_ssize_t, _testcapi.PY_SSIZE_T_MIN, str),
(b'%zd', c_ssize_t, _testcapi.PY_SSIZE_T_MAX, str),
(b'%zu', c_size_t, size_max, str),
(b'%p', c_char_p, size_max, lambda value: '%#x' % value),
):
self.assertEqual(PyBytes_FromFormat(formatstr, ctypes_type(value)),
py_formatter(value).encode('ascii')),
# width and precision (width is currently ignored)
self.assertEqual(PyBytes_FromFormat(b'%5s', b'a'),
b'a')
self.assertEqual(PyBytes_FromFormat(b'%.3s', b'abcdef'),
b'abc')
# '%%' formatter
self.assertEqual(PyBytes_FromFormat(b'%%'),
b'%')
self.assertEqual(PyBytes_FromFormat(b'[%%]'),
b'[%]')
self.assertEqual(PyBytes_FromFormat(b'%%%c', c_int(ord('_'))),
b'%_')
self.assertEqual(PyBytes_FromFormat(b'%%s'),
b'%s')
# Invalid formats and partial formatting
self.assertEqual(PyBytes_FromFormat(b'%'), b'%')
self.assertEqual(PyBytes_FromFormat(b'%%'), b'%')
self.assertEqual(PyBytes_FromFormat(b'%%s'), b'%s')
self.assertEqual(PyBytes_FromFormat(b'[%%]'), b'[%]')
self.assertEqual(PyBytes_FromFormat(b'%%%c', c_int(ord('_'))), b'%_')
self.assertEqual(PyBytes_FromFormat(b'x=%i y=%', c_int(2), c_int(3)),
b'x=2 y=%')
self.assertEqual(PyBytes_FromFormat(b'c:%c', c_int(255)),
b'c:\xff')
self.assertEqual(PyBytes_FromFormat(b's:%s', c_char_p(b'cstr')),
b's:cstr')
# Issue #19969
# Issue #19969: %c must raise OverflowError for values
# not in the range [0; 255]
self.assertRaises(OverflowError,
PyBytes_FromFormat, b'%c', c_int(-1))
self.assertRaises(OverflowError,

View File

@ -174,190 +174,184 @@ PyBytes_FromString(const char *str)
PyObject *
PyBytes_FromFormatV(const char *format, va_list vargs)
{
va_list count;
Py_ssize_t n = 0;
const char* f;
char *s;
PyObject* string;
const char *f;
const char *p;
Py_ssize_t prec;
int longflag;
int size_tflag;
/* Longest 64-bit formatted numbers:
- "18446744073709551615\0" (21 bytes)
- "-9223372036854775808\0" (21 bytes)
Decimal takes the most space (it isn't enough for octal.)
Py_VA_COPY(count, vargs);
/* step 1: figure out how large a buffer we need */
for (f = format; *f; f++) {
if (*f == '%') {
const char* p = f;
while (*++f && *f != '%' && !Py_ISALPHA(*f))
;
Longest 64-bit pointer representation:
"0xffffffffffffffff\0" (19 bytes). */
char buffer[21];
_PyBytesWriter writer;
/* skip the 'l' or 'z' in {%ld, %zd, %lu, %zu} since
* they don't affect the amount of space we reserve.
*/
if ((*f == 'l' || *f == 'z') &&
(f[1] == 'd' || f[1] == 'u'))
++f;
_PyBytesWriter_Init(&writer);
switch (*f) {
case 'c':
{
int c = va_arg(count, int);
if (c < 0 || c > 255) {
PyErr_SetString(PyExc_OverflowError,
"PyBytes_FromFormatV(): %c format "
"expects an integer in range [0; 255]");
return NULL;
}
n++;
break;
}
case '%':
n++;
break;
case 'd': case 'u': case 'i': case 'x':
(void) va_arg(count, int);
/* 20 bytes is enough to hold a 64-bit
integer. Decimal takes the most space.
This isn't enough for octal. */
n += 20;
break;
case 's':
s = va_arg(count, char*);
n += strlen(s);
break;
case 'p':
(void) va_arg(count, int);
/* maximum 64-bit pointer representation:
* 0xffffffffffffffff
* so 19 characters is enough.
* XXX I count 18 -- what's the extra for?
*/
n += 19;
break;
default:
/* if we stumble upon an unknown
formatting code, copy the rest of
the format string to the output
string. (we cannot just skip the
code, since there's no way to know
what's in the argument list) */
n += strlen(p);
goto expand;
}
} else
n++;
}
expand:
/* step 2: fill the buffer */
/* Since we've analyzed how much space we need for the worst case,
use sprintf directly instead of the slower PyOS_snprintf. */
string = PyBytes_FromStringAndSize(NULL, n);
if (!string)
s = _PyBytesWriter_Alloc(&writer, strlen(format));
if (s == NULL)
return NULL;
writer.overallocate = 1;
s = PyBytes_AsString(string);
#define WRITE_BYTES(str) \
do { \
s = _PyBytesWriter_WriteBytes(&writer, s, (str), strlen(str)); \
if (s == NULL) \
goto error; \
} while (0)
for (f = format; *f; f++) {
if (*f == '%') {
const char* p = f++;
if (*f != '%') {
*s++ = *f;
continue;
}
p = f++;
/* ignore the width (ex: 10 in "%10s") */
while (Py_ISDIGIT(*f))
f++;
/* parse the precision (ex: 10 in "%.10s") */
prec = 0;
if (*f == '.') {
f++;
for (; Py_ISDIGIT(*f); f++) {
prec = (prec * 10) + (*f - '0');
}
}
while (*f && *f != '%' && !Py_ISALPHA(*f))
f++;
/* handle the long flag ('l'), but only for %ld and %lu.
others can be added when necessary. */
longflag = 0;
if (*f == 'l' && (f[1] == 'd' || f[1] == 'u')) {
longflag = 1;
++f;
}
/* handle the size_t flag ('z'). */
size_tflag = 0;
if (*f == 'z' && (f[1] == 'd' || f[1] == 'u')) {
size_tflag = 1;
++f;
}
/* substract bytes preallocated for the format string
(ex: 2 for "%s") */
writer.min_size -= (f - p + 1);
switch (*f) {
case 'c':
{
int c = va_arg(vargs, int);
if (c < 0 || c > 255) {
PyErr_SetString(PyExc_OverflowError,
"PyBytes_FromFormatV(): %c format "
"expects an integer in range [0; 255]");
goto error;
}
writer.min_size++;
*s++ = (unsigned char)c;
break;
}
case 'd':
if (longflag)
sprintf(buffer, "%ld", va_arg(vargs, long));
else if (size_tflag)
sprintf(buffer, "%" PY_FORMAT_SIZE_T "d",
va_arg(vargs, Py_ssize_t));
else
sprintf(buffer, "%d", va_arg(vargs, int));
assert(strlen(buffer) < sizeof(buffer));
WRITE_BYTES(buffer);
break;
case 'u':
if (longflag)
sprintf(buffer, "%lu",
va_arg(vargs, unsigned long));
else if (size_tflag)
sprintf(buffer, "%" PY_FORMAT_SIZE_T "u",
va_arg(vargs, size_t));
else
sprintf(buffer, "%u",
va_arg(vargs, unsigned int));
assert(strlen(buffer) < sizeof(buffer));
WRITE_BYTES(buffer);
break;
case 'i':
sprintf(buffer, "%i", va_arg(vargs, int));
assert(strlen(buffer) < sizeof(buffer));
WRITE_BYTES(buffer);
break;
case 'x':
sprintf(buffer, "%x", va_arg(vargs, int));
assert(strlen(buffer) < sizeof(buffer));
WRITE_BYTES(buffer);
break;
case 's':
{
Py_ssize_t i;
int longflag = 0;
int size_tflag = 0;
/* parse the width.precision part (we're only
interested in the precision value, if any) */
n = 0;
while (Py_ISDIGIT(*f))
n = (n*10) + *f++ - '0';
if (*f == '.') {
f++;
n = 0;
while (Py_ISDIGIT(*f))
n = (n*10) + *f++ - '0';
p = va_arg(vargs, char*);
i = strlen(p);
if (prec > 0 && i > prec)
i = prec;
s = _PyBytesWriter_WriteBytes(&writer, s, p, i);
if (s == NULL)
goto error;
break;
}
case 'p':
sprintf(buffer, "%p", va_arg(vargs, void*));
assert(strlen(buffer) < sizeof(buffer));
/* %p is ill-defined: ensure leading 0x. */
if (buffer[1] == 'X')
buffer[1] = 'x';
else if (buffer[1] != 'x') {
memmove(buffer+2, buffer, strlen(buffer)+1);
buffer[0] = '0';
buffer[1] = 'x';
}
while (*f && *f != '%' && !Py_ISALPHA(*f))
f++;
/* handle the long flag, but only for %ld and %lu.
others can be added when necessary. */
if (*f == 'l' && (f[1] == 'd' || f[1] == 'u')) {
longflag = 1;
++f;
}
/* handle the size_t flag. */
if (*f == 'z' && (f[1] == 'd' || f[1] == 'u')) {
size_tflag = 1;
++f;
WRITE_BYTES(buffer);
break;
case '%':
writer.min_size++;
*s++ = '%';
break;
default:
if (*f == 0) {
/* fix min_size if we reached the end of the format string */
writer.min_size++;
}
switch (*f) {
case 'c':
{
int c = va_arg(vargs, int);
/* c has been checked for overflow in the first step */
*s++ = (unsigned char)c;
break;
}
case 'd':
if (longflag)
sprintf(s, "%ld", va_arg(vargs, long));
else if (size_tflag)
sprintf(s, "%" PY_FORMAT_SIZE_T "d",
va_arg(vargs, Py_ssize_t));
else
sprintf(s, "%d", va_arg(vargs, int));
s += strlen(s);
break;
case 'u':
if (longflag)
sprintf(s, "%lu",
va_arg(vargs, unsigned long));
else if (size_tflag)
sprintf(s, "%" PY_FORMAT_SIZE_T "u",
va_arg(vargs, size_t));
else
sprintf(s, "%u",
va_arg(vargs, unsigned int));
s += strlen(s);
break;
case 'i':
sprintf(s, "%i", va_arg(vargs, int));
s += strlen(s);
break;
case 'x':
sprintf(s, "%x", va_arg(vargs, int));
s += strlen(s);
break;
case 's':
p = va_arg(vargs, char*);
i = strlen(p);
if (n > 0 && i > n)
i = n;
Py_MEMCPY(s, p, i);
s += i;
break;
case 'p':
sprintf(s, "%p", va_arg(vargs, void*));
/* %p is ill-defined: ensure leading 0x. */
if (s[1] == 'X')
s[1] = 'x';
else if (s[1] != 'x') {
memmove(s+2, s, strlen(s)+1);
s[0] = '0';
s[1] = 'x';
}
s += strlen(s);
break;
case '%':
*s++ = '%';
break;
default:
strcpy(s, p);
s += strlen(s);
goto end;
}
} else
*s++ = *f;
/* invalid format string: copy unformatted string and exit */
WRITE_BYTES(p);
return _PyBytesWriter_Finish(&writer, s);
}
}
end:
_PyBytes_Resize(&string, s - PyBytes_AS_STRING(string));
return string;
#undef WRITE_BYTES
return _PyBytesWriter_Finish(&writer, s);
error:
_PyBytesWriter_Dealloc(&writer);
return NULL;
}
PyObject *