gh-126317: Simplify pickle code by using itertools.batched() (GH-126323)

This commit is contained in:
Lee Dong Wook 2024-11-02 23:07:32 +09:00 committed by GitHub
parent 10eeec2d4f
commit bd4be5e67d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 22 additions and 39 deletions

View File

@ -26,7 +26,7 @@ Misc variables:
from types import FunctionType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice
from itertools import batched
from functools import partial
import sys
from sys import maxsize
@ -1033,31 +1033,26 @@ class _Pickler:
write(APPEND)
return
it = iter(items)
start = 0
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
for batch in batched(items, self._BATCHSIZE):
batch_len = len(batch)
if batch_len != 1:
write(MARK)
for i, x in enumerate(tmp, start):
for i, x in enumerate(batch, start):
try:
save(x)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {i}')
raise
write(APPENDS)
elif n:
else:
try:
save(tmp[0])
save(batch[0])
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} item {start}')
raise
write(APPEND)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
start += n
start += batch_len
def save_dict(self, obj):
if self.bin:
@ -1086,13 +1081,10 @@ class _Pickler:
write(SETITEM)
return
it = iter(items)
while True:
tmp = list(islice(it, self._BATCHSIZE))
n = len(tmp)
if n > 1:
for batch in batched(items, self._BATCHSIZE):
if len(batch) != 1:
write(MARK)
for k, v in tmp:
for k, v in batch:
save(k)
try:
save(v)
@ -1100,8 +1092,8 @@ class _Pickler:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEMS)
elif n:
k, v = tmp[0]
else:
k, v = batch[0]
save(k)
try:
save(v)
@ -1109,9 +1101,6 @@ class _Pickler:
exc.add_note(f'when serializing {_T(obj)} item {k!r}')
raise
write(SETITEM)
# else tmp is empty, and we're done
if n < self._BATCHSIZE:
return
def save_set(self, obj):
save = self.save
@ -1124,21 +1113,15 @@ class _Pickler:
write(EMPTY_SET)
self.memoize(obj)
it = iter(obj)
while True:
batch = list(islice(it, self._BATCHSIZE))
n = len(batch)
if n > 0:
write(MARK)
try:
for item in batch:
save(item)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
if n < self._BATCHSIZE:
return
for batch in batched(obj, self._BATCHSIZE):
write(MARK)
try:
for item in batch:
save(item)
except BaseException as exc:
exc.add_note(f'when serializing {_T(obj)} element')
raise
write(ADDITEMS)
dispatch[set] = save_set
def save_frozenset(self, obj):