Issue #18408: Fix _PyMem_DebugRealloc()

Don't mark old extra memory dead before calling realloc(). realloc() can fail
and realloc() must not touch the original buffer on failure.

So mark old extra memory dead only on success if the new buffer did not move
(has the same address).
This commit is contained in:
Victor Stinner 2013-07-09 00:44:43 +02:00
parent 9e6b4d715c
commit c4266360fc
1 changed files with 8 additions and 6 deletions

View File

@ -1780,7 +1780,7 @@ static void *
_PyMem_DebugRealloc(void *ctx, void *p, size_t nbytes) _PyMem_DebugRealloc(void *ctx, void *p, size_t nbytes)
{ {
debug_alloc_api_t *api = (debug_alloc_api_t *)ctx; debug_alloc_api_t *api = (debug_alloc_api_t *)ctx;
uchar *q = (uchar *)p; uchar *q = (uchar *)p, *oldq;
uchar *tail; uchar *tail;
size_t total; /* nbytes + 4*SST */ size_t total; /* nbytes + 4*SST */
size_t original_nbytes; size_t original_nbytes;
@ -1797,24 +1797,26 @@ _PyMem_DebugRealloc(void *ctx, void *p, size_t nbytes)
/* overflow: can't represent total as a size_t */ /* overflow: can't represent total as a size_t */
return NULL; return NULL;
if (nbytes < original_nbytes) {
/* shrinking: mark old extra memory dead */
memset(q + nbytes, DEADBYTE, original_nbytes - nbytes + 2*SST);
}
/* Resize and add decorations. We may get a new pointer here, in which /* Resize and add decorations. We may get a new pointer here, in which
* case we didn't get the chance to mark the old memory with DEADBYTE, * case we didn't get the chance to mark the old memory with DEADBYTE,
* but we live with that. * but we live with that.
*/ */
oldq = q;
q = (uchar *)api->alloc.realloc(api->alloc.ctx, q - 2*SST, total); q = (uchar *)api->alloc.realloc(api->alloc.ctx, q - 2*SST, total);
if (q == NULL) if (q == NULL)
return NULL; return NULL;
if (q == oldq && nbytes < original_nbytes) {
/* shrinking: mark old extra memory dead */
memset(q + nbytes, DEADBYTE, original_nbytes - nbytes);
}
write_size_t(q, nbytes); write_size_t(q, nbytes);
assert(q[SST] == (uchar)api->api_id); assert(q[SST] == (uchar)api->api_id);
for (i = 1; i < SST; ++i) for (i = 1; i < SST; ++i)
assert(q[SST + i] == FORBIDDENBYTE); assert(q[SST + i] == FORBIDDENBYTE);
q += 2*SST; q += 2*SST;
tail = q + nbytes; tail = q + nbytes;
memset(tail, FORBIDDENBYTE, SST); memset(tail, FORBIDDENBYTE, SST);
write_size_t(tail + SST, serialno); write_size_t(tail + SST, serialno);