Issue #17810: Implement PEP 3154, pickle protocol 4.

Most of the work is by Alexandre.
This commit is contained in:
Antoine Pitrou 2013-11-23 18:59:12 +01:00
parent 95401c5f6b
commit c9dc4a2a8a
12 changed files with 3132 additions and 1006 deletions

View File

@ -459,12 +459,29 @@ implementation of this behaviour::
Classes can alter the default behaviour by providing one or several special
methods:
.. method:: object.__getnewargs_ex__()
In protocols 4 and newer, classes that implements the
:meth:`__getnewargs_ex__` method can dictate the values passed to the
:meth:`__new__` method upon unpickling. The method must return a pair
``(args, kwargs)`` where *args* is a tuple of positional arguments
and *kwargs* a dictionary of named arguments for constructing the
object. Those will be passed to the :meth:`__new__` method upon
unpickling.
You should implement this method if the :meth:`__new__` method of your
class requires keyword-only arguments. Otherwise, it is recommended for
compatibility to implement :meth:`__getnewargs__`.
.. method:: object.__getnewargs__()
In protocol 2 and newer, classes that implements the :meth:`__getnewargs__`
method can dictate the values passed to the :meth:`__new__` method upon
unpickling. This is often needed for classes whose :meth:`__new__` method
requires arguments.
This method serve a similar purpose as :meth:`__getnewargs_ex__` but
for protocols 2 and newer. It must return a tuple of arguments `args`
which will be passed to the :meth:`__new__` method upon unpickling.
In protocols 4 and newer, :meth:`__getnewargs__` will not be called if
:meth:`__getnewargs_ex__` is defined.
.. method:: object.__getstate__()
@ -496,10 +513,10 @@ the methods :meth:`__getstate__` and :meth:`__setstate__`.
At unpickling time, some methods like :meth:`__getattr__`,
:meth:`__getattribute__`, or :meth:`__setattr__` may be called upon the
instance. In case those methods rely on some internal invariant being true,
the type should implement :meth:`__getnewargs__` to establish such an
invariant; otherwise, neither :meth:`__new__` nor :meth:`__init__` will be
called.
instance. In case those methods rely on some internal invariant being
true, the type should implement :meth:`__getnewargs__` or
:meth:`__getnewargs_ex__` to establish such an invariant; otherwise,
neither :meth:`__new__` nor :meth:`__init__` will be called.
.. index:: pair: copy; protocol
@ -511,7 +528,7 @@ objects. [#]_
Although powerful, implementing :meth:`__reduce__` directly in your classes is
error prone. For this reason, class designers should use the high-level
interface (i.e., :meth:`__getnewargs__`, :meth:`__getstate__` and
interface (i.e., :meth:`__getnewargs_ex__`, :meth:`__getstate__` and
:meth:`__setstate__`) whenever possible. We will show, however, cases where
using :meth:`__reduce__` is the only option or leads to more efficient pickling
or both.

View File

@ -109,6 +109,7 @@ New expected features for Python implementations:
Significantly Improved Library Modules:
* Single-dispatch generic functions in :mod:`functoools` (:pep:`443`)
* New :mod:`pickle` protocol 4 (:pep:`3154`)
* SHA-3 (Keccak) support for :mod:`hashlib`.
* TLSv1.1 and TLSv1.2 support for :mod:`ssl`.
* :mod:`multiprocessing` now has option to avoid using :func:`os.fork`
@ -285,6 +286,20 @@ described in the PEP. Existing importers should be updated to implement
the new methods.
Pickle protocol 4
=================
The new :mod:`pickle` protocol addresses a number of issues that were present
in previous protocols, such as the serialization of nested classes, very
large strings and containers, or classes whose :meth:`__new__` method takes
keyword-only arguments. It also brings a couple efficiency improvements.
.. seealso::
:pep:`3154` - Pickle protocol 4
PEP written by Antoine Pitrou and implemented by Alexandre Vassalotti.
Other Language Changes
======================

View File

@ -87,6 +87,12 @@ def _reduce_ex(self, proto):
def __newobj__(cls, *args):
return cls.__new__(cls, *args)
def __newobj_ex__(cls, args, kwargs):
"""Used by pickle protocol 4, instead of __newobj__ to allow classes with
keyword-only arguments to be pickled correctly.
"""
return cls.__new__(cls, *args, **kwargs)
def _slotnames(cls):
"""Return a list of slot names for a given class.

View File

@ -23,7 +23,7 @@ Misc variables:
"""
from types import FunctionType, BuiltinFunctionType
from types import FunctionType, BuiltinFunctionType, ModuleType
from copyreg import dispatch_table
from copyreg import _extension_registry, _inverted_registry, _extension_cache
from itertools import islice
@ -42,17 +42,18 @@ __all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler",
bytes_types = (bytes, bytearray)
# These are purely informational; no code uses these.
format_version = "3.0" # File format version we write
format_version = "4.0" # File format version we write
compatible_formats = ["1.0", # Original protocol 0
"1.1", # Protocol 0 with INST added
"1.2", # Original protocol 1
"1.3", # Protocol 1 with BINFLOAT added
"2.0", # Protocol 2
"3.0", # Protocol 3
"4.0", # Protocol 4
] # Old format versions we can read
# This is the highest protocol number we know how to read.
HIGHEST_PROTOCOL = 3
HIGHEST_PROTOCOL = 4
# The protocol we write by default. May be less than HIGHEST_PROTOCOL.
# We intentionally write a protocol that Python 2.x cannot read;
@ -164,7 +165,196 @@ _tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3]
BINBYTES = b'B' # push bytes; counted binary string argument
SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes
__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$",x)])
# Protocol 4
SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes
BINUNICODE8 = b'\x8d' # push very long string
BINBYTES8 = b'\x8e' # push very long bytes string
EMPTY_SET = b'\x8f' # push empty set on the stack
ADDITEMS = b'\x90' # modify set by adding topmost stack items
FROZENSET = b'\x91' # build frozenset from topmost stack items
NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments
STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks
MEMOIZE = b'\x94' # store top of the stack in memo
FRAME = b'\x95' # indicate the beginning of a new frame
__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)])
class _Framer:
_FRAME_SIZE_TARGET = 64 * 1024
def __init__(self, file_write):
self.file_write = file_write
self.current_frame = None
def _commit_frame(self):
f = self.current_frame
with f.getbuffer() as data:
n = len(data)
write = self.file_write
write(FRAME)
write(pack("<Q", n))
write(data)
f.seek(0)
f.truncate()
def start_framing(self):
self.current_frame = io.BytesIO()
def end_framing(self):
if self.current_frame is not None:
self._commit_frame()
self.current_frame = None
def write(self, data):
f = self.current_frame
if f is None:
return self.file_write(data)
else:
n = len(data)
if f.tell() >= self._FRAME_SIZE_TARGET:
self._commit_frame()
return f.write(data)
class _Unframer:
def __init__(self, file_read, file_readline, file_tell=None):
self.file_read = file_read
self.file_readline = file_readline
self.file_tell = file_tell
self.framing_enabled = False
self.current_frame = None
self.frame_start = None
def read(self, n):
if n == 0:
return b''
_file_read = self.file_read
if not self.framing_enabled:
return _file_read(n)
f = self.current_frame
if f is not None:
data = f.read(n)
if data:
if len(data) < n:
raise UnpicklingError(
"pickle exhausted before end of frame")
return data
frame_opcode = _file_read(1)
if frame_opcode != FRAME:
raise UnpicklingError(
"expected a FRAME opcode, got {} instead".format(frame_opcode))
frame_size, = unpack("<Q", _file_read(8))
if frame_size > sys.maxsize:
raise ValueError("frame size > sys.maxsize: %d" % frame_size)
if self.file_tell is not None:
self.frame_start = self.file_tell()
f = self.current_frame = io.BytesIO(_file_read(frame_size))
self.readline = f.readline
data = f.read(n)
assert len(data) == n, (len(data), n)
return data
def readline(self):
if not self.framing_enabled:
return self.file_readline()
else:
return self.current_frame.readline()
def tell(self):
if self.file_tell is None:
return None
elif self.current_frame is None:
return self.file_tell()
else:
return self.frame_start + self.current_frame.tell()
# Tools used for pickling.
def _getattribute(obj, name, allow_qualname=False):
dotted_path = name.split(".")
if not allow_qualname and len(dotted_path) > 1:
raise AttributeError("Can't get qualified attribute {!r} on {!r}; " +
"use protocols >= 4 to enable support"
.format(name, obj))
for subpath in dotted_path:
if subpath == '<locals>':
raise AttributeError("Can't get local attribute {!r} on {!r}"
.format(name, obj))
try:
obj = getattr(obj, subpath)
except AttributeError:
raise AttributeError("Can't get attribute {!r} on {!r}"
.format(name, obj))
return obj
def whichmodule(obj, name, allow_qualname=False):
"""Find the module an object belong to."""
module_name = getattr(obj, '__module__', None)
if module_name is not None:
return module_name
for module_name, module in sys.modules.items():
if module_name == '__main__' or module is None:
continue
try:
if _getattribute(module, name, allow_qualname) is obj:
return module_name
except AttributeError:
pass
return '__main__'
def encode_long(x):
r"""Encode a long to a two's complement little-endian binary string.
Note that 0 is a special case, returning an empty string, to save a
byte in the LONG1 pickling context.
>>> encode_long(0)
b''
>>> encode_long(255)
b'\xff\x00'
>>> encode_long(32767)
b'\xff\x7f'
>>> encode_long(-256)
b'\x00\xff'
>>> encode_long(-32768)
b'\x00\x80'
>>> encode_long(-128)
b'\x80'
>>> encode_long(127)
b'\x7f'
>>>
"""
if x == 0:
return b''
nbytes = (x.bit_length() >> 3) + 1
result = x.to_bytes(nbytes, byteorder='little', signed=True)
if x < 0 and nbytes > 1:
if result[-1] == 0xff and (result[-2] & 0x80) != 0:
result = result[:-1]
return result
def decode_long(data):
r"""Decode a long from a two's complement little-endian binary string.
>>> decode_long(b'')
0
>>> decode_long(b"\xff\x00")
255
>>> decode_long(b"\xff\x7f")
32767
>>> decode_long(b"\x00\xff")
-256
>>> decode_long(b"\x00\x80")
-32768
>>> decode_long(b"\x80")
-128
>>> decode_long(b"\x7f")
127
"""
return int.from_bytes(data, byteorder='little', signed=True)
# Pickling machinery
@ -174,9 +364,9 @@ class _Pickler:
"""This takes a binary file for writing a pickle data stream.
The optional protocol argument tells the pickler to use the
given protocol; supported protocols are 0, 1, 2, 3. The default
protocol is 3; a backward-incompatible protocol designed for
Python 3.0.
given protocol; supported protocols are 0, 1, 2, 3 and 4. The
default protocol is 3; a backward-incompatible protocol designed for
Python 3.
Specifying a negative protocol version selects the highest
protocol version supported. The higher the protocol used, the
@ -189,8 +379,8 @@ class _Pickler:
meets this interface.
If fix_imports is True and protocol is less than 3, pickle will try to
map the new Python 3.x names to the old module names used in Python
2.x, so that the pickle data stream is readable with Python 2.x.
map the new Python 3 names to the old module names used in Python 2,
so that the pickle data stream is readable with Python 2.
"""
if protocol is None:
protocol = DEFAULT_PROTOCOL
@ -199,7 +389,7 @@ class _Pickler:
elif not 0 <= protocol <= HIGHEST_PROTOCOL:
raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL)
try:
self.write = file.write
self._file_write = file.write
except AttributeError:
raise TypeError("file must have a 'write' attribute")
self.memo = {}
@ -223,13 +413,22 @@ class _Pickler:
"""Write a pickled representation of obj to the open file."""
# Check whether Pickler was initialized correctly. This is
# only needed to mimic the behavior of _pickle.Pickler.dump().
if not hasattr(self, "write"):
if not hasattr(self, "_file_write"):
raise PicklingError("Pickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,))
if self.proto >= 2:
self.write(PROTO + pack("<B", self.proto))
self._file_write(PROTO + pack("<B", self.proto))
if self.proto >= 4:
framer = _Framer(self._file_write)
framer.start_framing()
self.write = framer.write
else:
framer = None
self.write = self._file_write
self.save(obj)
self.write(STOP)
if framer is not None:
framer.end_framing()
def memoize(self, obj):
"""Store an object in the memo."""
@ -249,19 +448,21 @@ class _Pickler:
if self.fast:
return
assert id(obj) not in self.memo
memo_len = len(self.memo)
self.write(self.put(memo_len))
self.memo[id(obj)] = memo_len, obj
idx = len(self.memo)
self.write(self.put(idx))
self.memo[id(obj)] = idx, obj
# Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i.
def put(self, i):
if self.bin:
if i < 256:
return BINPUT + pack("<B", i)
def put(self, idx):
if self.proto >= 4:
return MEMOIZE
elif self.bin:
if idx < 256:
return BINPUT + pack("<B", idx)
else:
return LONG_BINPUT + pack("<I", i)
return PUT + repr(i).encode("ascii") + b'\n'
return LONG_BINPUT + pack("<I", idx)
else:
return PUT + repr(idx).encode("ascii") + b'\n'
# Return a GET (BINGET, LONG_BINGET) opcode string, with argument i.
def get(self, i):
@ -349,24 +550,33 @@ class _Pickler:
else:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
def save_reduce(self, func, args, state=None,
listitems=None, dictitems=None, obj=None):
def save_reduce(self, func, args, state=None, listitems=None,
dictitems=None, obj=None):
# This API is called by some subclasses
# Assert that args is a tuple
if not isinstance(args, tuple):
raise PicklingError("args from save_reduce() should be a tuple")
# Assert that func is callable
raise PicklingError("args from save_reduce() must be a tuple")
if not callable(func):
raise PicklingError("func from save_reduce() should be callable")
raise PicklingError("func from save_reduce() must be callable")
save = self.save
write = self.write
# Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
# A __reduce__ implementation can direct protocol 2 to
func_name = getattr(func, "__name__", "")
if self.proto >= 4 and func_name == "__newobj_ex__":
cls, args, kwargs = args
if not hasattr(cls, "__new__"):
raise PicklingError("args[0] from {} args has no __new__"
.format(func_name))
if obj is not None and cls is not obj.__class__:
raise PicklingError("args[0] from {} args has the wrong class"
.format(func_name))
save(cls)
save(args)
save(kwargs)
write(NEWOBJ_EX)
elif self.proto >= 2 and func_name == "__newobj__":
# A __reduce__ implementation can direct protocol 2 or newer to
# use the more efficient NEWOBJ opcode, while still
# allowing protocol 0 and 1 to work normally. For this to
# work, the function returned by __reduce__ should be
@ -409,6 +619,12 @@ class _Pickler:
write(REDUCE)
if obj is not None:
# If the object is already in the memo, this means it is
# recursive. In this case, throw away everything we put on the
# stack, and fetch the object back from the memo.
if id(obj) in self.memo:
write(POP + self.get(self.memo[id(obj)][0]))
else:
self.memoize(obj)
# More new special cases (that work with older protocols as
@ -493,8 +709,10 @@ class _Pickler:
(str(obj, 'latin1'), 'latin1'), obj=obj)
return
n = len(obj)
if n < 256:
if n <= 0xff:
self.write(SHORT_BINBYTES + pack("<B", n) + obj)
elif n > 0xffffffff and self.proto >= 4:
self.write(BINBYTES8 + pack("<Q", n) + obj)
else:
self.write(BINBYTES + pack("<I", n) + obj)
self.memoize(obj)
@ -504,11 +722,17 @@ class _Pickler:
if self.bin:
encoded = obj.encode('utf-8', 'surrogatepass')
n = len(encoded)
if n <= 0xff and self.proto >= 4:
self.write(SHORT_BINUNICODE + pack("<B", n) + encoded)
elif n > 0xffffffff and self.proto >= 4:
self.write(BINUNICODE8 + pack("<Q", n) + encoded)
else:
self.write(BINUNICODE + pack("<I", n) + encoded)
else:
obj = obj.replace("\\", "\\u005c")
obj = obj.replace("\n", "\\u000a")
self.write(UNICODE + obj.encode('raw-unicode-escape') + b'\n')
self.write(UNICODE + obj.encode('raw-unicode-escape') +
b'\n')
self.memoize(obj)
dispatch[str] = save_str
@ -647,33 +871,79 @@ class _Pickler:
if n < self._BATCHSIZE:
return
def save_set(self, obj):
save = self.save
write = self.write
if self.proto < 4:
self.save_reduce(set, (list(obj),), obj=obj)
return
write(EMPTY_SET)
self.memoize(obj)
it = iter(obj)
while True:
batch = list(islice(it, self._BATCHSIZE))
n = len(batch)
if n > 0:
write(MARK)
for item in batch:
save(item)
write(ADDITEMS)
if n < self._BATCHSIZE:
return
dispatch[set] = save_set
def save_frozenset(self, obj):
save = self.save
write = self.write
if self.proto < 4:
self.save_reduce(frozenset, (list(obj),), obj=obj)
return
write(MARK)
for item in obj:
save(item)
if id(obj) in self.memo:
# If the object is already in the memo, this means it is
# recursive. In this case, throw away everything we put on the
# stack, and fetch the object back from the memo.
write(POP_MARK + self.get(self.memo[id(obj)][0]))
return
write(FROZENSET)
self.memoize(obj)
dispatch[frozenset] = save_frozenset
def save_global(self, obj, name=None):
write = self.write
memo = self.memo
if name is None and self.proto >= 4:
name = getattr(obj, '__qualname__', None)
if name is None:
name = obj.__name__
module = getattr(obj, "__module__", None)
if module is None:
module = whichmodule(obj, name)
module_name = whichmodule(obj, name, allow_qualname=self.proto >= 4)
try:
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
__import__(module_name, level=0)
module = sys.modules[module_name]
obj2 = _getattribute(module, name, allow_qualname=self.proto >= 4)
except (ImportError, KeyError, AttributeError):
raise PicklingError(
"Can't pickle %r: it's not found as %s.%s" %
(obj, module, name))
(obj, module_name, name))
else:
if klass is not obj:
if obj2 is not obj:
raise PicklingError(
"Can't pickle %r: it's not the same object as %s.%s" %
(obj, module, name))
(obj, module_name, name))
if self.proto >= 2:
code = _extension_registry.get((module, name))
code = _extension_registry.get((module_name, name))
if code:
assert code > 0
if code <= 0xff:
@ -684,17 +954,23 @@ class _Pickler:
write(EXT4 + pack("<i", code))
return
# Non-ASCII identifiers are supported only with protocols >= 3.
if self.proto >= 3:
write(GLOBAL + bytes(module, "utf-8") + b'\n' +
if self.proto >= 4:
self.save(module_name)
self.save(name)
write(STACK_GLOBAL)
elif self.proto >= 3:
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else:
if self.fix_imports:
if (module, name) in _compat_pickle.REVERSE_NAME_MAPPING:
module, name = _compat_pickle.REVERSE_NAME_MAPPING[(module, name)]
if module in _compat_pickle.REVERSE_IMPORT_MAPPING:
module = _compat_pickle.REVERSE_IMPORT_MAPPING[module]
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
if (module_name, name) in r_name_mapping:
module_name, name = r_name_mapping[(module_name, name)]
if module_name in r_import_mapping:
module_name = r_import_mapping[module_name]
try:
write(GLOBAL + bytes(module, "ascii") + b'\n' +
write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
@ -703,39 +979,15 @@ class _Pickler:
self.memoize(obj)
dispatch[FunctionType] = save_global
dispatch[BuiltinFunctionType] = save_global
dispatch[type] = save_global
# A cache for whichmodule(), mapping a function object to the name of
# the module in which the function was found.
classmap = {} # called classmap for backwards compatibility
def whichmodule(func, funcname):
"""Figure out the module in which a function occurs.
Search sys.modules for the module.
Cache in classmap.
Return a module name.
If the function cannot be found, return "__main__".
"""
# Python functions should always get an __module__ from their globals.
mod = getattr(func, "__module__", None)
if mod is not None:
return mod
if func in classmap:
return classmap[func]
for name, module in list(sys.modules.items()):
if module is None:
continue # skip dummy package entries
if name != '__main__' and getattr(module, funcname, None) is func:
break
def save_method(self, obj):
if obj.__self__ is None or type(obj.__self__) is ModuleType:
self.save_global(obj)
else:
name = '__main__'
classmap[func] = name
return name
self.save_reduce(getattr, (obj.__self__, obj.__name__), obj=obj)
dispatch[FunctionType] = save_global
dispatch[BuiltinFunctionType] = save_method
dispatch[type] = save_global
# Unpickling machinery
@ -764,8 +1016,8 @@ class _Unpickler:
instances pickled by Python 2.x; these default to 'ASCII' and
'strict', respectively.
"""
self.readline = file.readline
self.read = file.read
self._file_readline = file.readline
self._file_read = file.read
self.memo = {}
self.encoding = encoding
self.errors = errors
@ -779,12 +1031,16 @@ class _Unpickler:
"""
# Check whether Unpickler was initialized correctly. This is
# only needed to mimic the behavior of _pickle.Unpickler.dump().
if not hasattr(self, "read"):
if not hasattr(self, "_file_read"):
raise UnpicklingError("Unpickler.__init__() was not called by "
"%s.__init__()" % (self.__class__.__name__,))
self._unframer = _Unframer(self._file_read, self._file_readline)
self.read = self._unframer.read
self.readline = self._unframer.readline
self.mark = object() # any new unique object
self.stack = []
self.append = self.stack.append
self.proto = 0
read = self.read
dispatch = self.dispatch
try:
@ -822,6 +1078,8 @@ class _Unpickler:
if not 0 <= proto <= HIGHEST_PROTOCOL:
raise ValueError("unsupported pickle protocol: %d" % proto)
self.proto = proto
if proto >= 4:
self._unframer.framing_enabled = True
dispatch[PROTO[0]] = load_proto
def load_persid(self):
@ -940,6 +1198,14 @@ class _Unpickler:
self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[BINUNICODE[0]] = load_binunicode
def load_binunicode8(self):
len, = unpack('<Q', self.read(8))
if len > maxsize:
raise UnpicklingError("BINUNICODE8 exceeds system's maximum size "
"of %d bytes" % maxsize)
self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[BINUNICODE8[0]] = load_binunicode8
def load_short_binstring(self):
len = self.read(1)[0]
data = self.read(len)
@ -952,6 +1218,11 @@ class _Unpickler:
self.append(self.read(len))
dispatch[SHORT_BINBYTES[0]] = load_short_binbytes
def load_short_binunicode(self):
len = self.read(1)[0]
self.append(str(self.read(len), 'utf-8', 'surrogatepass'))
dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode
def load_tuple(self):
k = self.marker()
self.stack[k:] = [tuple(self.stack[k+1:])]
@ -981,6 +1252,15 @@ class _Unpickler:
self.append({})
dispatch[EMPTY_DICT[0]] = load_empty_dictionary
def load_empty_set(self):
self.append(set())
dispatch[EMPTY_SET[0]] = load_empty_set
def load_frozenset(self):
k = self.marker()
self.stack[k:] = [frozenset(self.stack[k+1:])]
dispatch[FROZENSET[0]] = load_frozenset
def load_list(self):
k = self.marker()
self.stack[k:] = [self.stack[k+1:]]
@ -1029,11 +1309,19 @@ class _Unpickler:
def load_newobj(self):
args = self.stack.pop()
cls = self.stack[-1]
cls = self.stack.pop()
obj = cls.__new__(cls, *args)
self.stack[-1] = obj
self.append(obj)
dispatch[NEWOBJ[0]] = load_newobj
def load_newobj_ex(self):
kwargs = self.stack.pop()
args = self.stack.pop()
cls = self.stack.pop()
obj = cls.__new__(cls, *args, **kwargs)
self.append(obj)
dispatch[NEWOBJ_EX[0]] = load_newobj_ex
def load_global(self):
module = self.readline()[:-1].decode("utf-8")
name = self.readline()[:-1].decode("utf-8")
@ -1041,6 +1329,14 @@ class _Unpickler:
self.append(klass)
dispatch[GLOBAL[0]] = load_global
def load_stack_global(self):
name = self.stack.pop()
module = self.stack.pop()
if type(name) is not str or type(module) is not str:
raise UnpicklingError("STACK_GLOBAL requires str")
self.append(self.find_class(module, name))
dispatch[STACK_GLOBAL[0]] = load_stack_global
def load_ext1(self):
code = self.read(1)[0]
self.get_extension(code)
@ -1080,9 +1376,8 @@ class _Unpickler:
if module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0)
mod = sys.modules[module]
klass = getattr(mod, name)
return klass
return _getattribute(sys.modules[module], name,
allow_qualname=self.proto >= 4)
def load_reduce(self):
stack = self.stack
@ -1146,6 +1441,11 @@ class _Unpickler:
self.memo[i] = self.stack[-1]
dispatch[LONG_BINPUT[0]] = load_long_binput
def load_memoize(self):
memo = self.memo
memo[len(memo)] = self.stack[-1]
dispatch[MEMOIZE[0]] = load_memoize
def load_append(self):
stack = self.stack
value = stack.pop()
@ -1185,6 +1485,20 @@ class _Unpickler:
del stack[mark:]
dispatch[SETITEMS[0]] = load_setitems
def load_additems(self):
stack = self.stack
mark = self.marker()
set_obj = stack[mark - 1]
items = stack[mark + 1:]
if isinstance(set_obj, set):
set_obj.update(items)
else:
add = set_obj.add
for item in items:
add(item)
del stack[mark:]
dispatch[ADDITEMS[0]] = load_additems
def load_build(self):
stack = self.stack
state = stack.pop()
@ -1218,86 +1532,46 @@ class _Unpickler:
raise _Stop(value)
dispatch[STOP[0]] = load_stop
# Encode/decode ints.
def encode_long(x):
r"""Encode a long to a two's complement little-endian binary string.
Note that 0 is a special case, returning an empty string, to save a
byte in the LONG1 pickling context.
>>> encode_long(0)
b''
>>> encode_long(255)
b'\xff\x00'
>>> encode_long(32767)
b'\xff\x7f'
>>> encode_long(-256)
b'\x00\xff'
>>> encode_long(-32768)
b'\x00\x80'
>>> encode_long(-128)
b'\x80'
>>> encode_long(127)
b'\x7f'
>>>
"""
if x == 0:
return b''
nbytes = (x.bit_length() >> 3) + 1
result = x.to_bytes(nbytes, byteorder='little', signed=True)
if x < 0 and nbytes > 1:
if result[-1] == 0xff and (result[-2] & 0x80) != 0:
result = result[:-1]
return result
def decode_long(data):
r"""Decode an int from a two's complement little-endian binary string.
>>> decode_long(b'')
0
>>> decode_long(b"\xff\x00")
255
>>> decode_long(b"\xff\x7f")
32767
>>> decode_long(b"\x00\xff")
-256
>>> decode_long(b"\x00\x80")
-32768
>>> decode_long(b"\x80")
-128
>>> decode_long(b"\x7f")
127
"""
return int.from_bytes(data, byteorder='little', signed=True)
# Shorthands
def dump(obj, file, protocol=None, *, fix_imports=True):
Pickler(file, protocol, fix_imports=fix_imports).dump(obj)
def _dump(obj, file, protocol=None, *, fix_imports=True):
_Pickler(file, protocol, fix_imports=fix_imports).dump(obj)
def dumps(obj, protocol=None, *, fix_imports=True):
def _dumps(obj, protocol=None, *, fix_imports=True):
f = io.BytesIO()
Pickler(f, protocol, fix_imports=fix_imports).dump(obj)
_Pickler(f, protocol, fix_imports=fix_imports).dump(obj)
res = f.getvalue()
assert isinstance(res, bytes_types)
return res
def load(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
return Unpickler(file, fix_imports=fix_imports,
def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict"):
return _Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load()
def loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):
def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict"):
if isinstance(s, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(s)
return Unpickler(file, fix_imports=fix_imports,
return _Unpickler(file, fix_imports=fix_imports,
encoding=encoding, errors=errors).load()
# Use the faster _pickle if possible
try:
from _pickle import *
from _pickle import (
PickleError,
PicklingError,
UnpicklingError,
Pickler,
Unpickler,
dump,
dumps,
load,
loads
)
except ImportError:
Pickler, Unpickler = _Pickler, _Unpickler
dump, dumps, load, loads = _dump, _dumps, _load, _loads
# Doctest
def _test():

View File

@ -11,6 +11,7 @@ dis(pickle, out=None, memo=None, indentlevel=4)
'''
import codecs
import io
import pickle
import re
import sys
@ -168,6 +169,7 @@ UP_TO_NEWLINE = -1
TAKEN_FROM_ARGUMENT1 = -2 # num bytes is 1-byte unsigned int
TAKEN_FROM_ARGUMENT4 = -3 # num bytes is 4-byte signed little-endian int
TAKEN_FROM_ARGUMENT4U = -4 # num bytes is 4-byte unsigned little-endian int
TAKEN_FROM_ARGUMENT8U = -5 # num bytes is 8-byte unsigned little-endian int
class ArgumentDescriptor(object):
__slots__ = (
@ -175,7 +177,7 @@ class ArgumentDescriptor(object):
'name',
# length of argument, in bytes; an int; UP_TO_NEWLINE and
# TAKEN_FROM_ARGUMENT{1,4} are negative values for variable-length
# TAKEN_FROM_ARGUMENT{1,4,8} are negative values for variable-length
# cases
'n',
@ -196,7 +198,8 @@ class ArgumentDescriptor(object):
n in (UP_TO_NEWLINE,
TAKEN_FROM_ARGUMENT1,
TAKEN_FROM_ARGUMENT4,
TAKEN_FROM_ARGUMENT4U))
TAKEN_FROM_ARGUMENT4U,
TAKEN_FROM_ARGUMENT8U))
self.n = n
self.reader = reader
@ -288,6 +291,27 @@ uint4 = ArgumentDescriptor(
doc="Four-byte unsigned integer, little-endian.")
def read_uint8(f):
r"""
>>> import io
>>> read_uint8(io.BytesIO(b'\xff\x00\x00\x00\x00\x00\x00\x00'))
255
>>> read_uint8(io.BytesIO(b'\xff' * 8)) == 2**64-1
True
"""
data = f.read(8)
if len(data) == 8:
return _unpack("<Q", data)[0]
raise ValueError("not enough data in stream to read uint8")
uint8 = ArgumentDescriptor(
name='uint8',
n=8,
reader=read_uint8,
doc="Eight-byte unsigned integer, little-endian.")
def read_stringnl(f, decode=True, stripquotes=True):
r"""
>>> import io
@ -381,6 +405,36 @@ stringnl_noescape_pair = ArgumentDescriptor(
a single blank separating the two strings.
""")
def read_string1(f):
r"""
>>> import io
>>> read_string1(io.BytesIO(b"\x00"))
''
>>> read_string1(io.BytesIO(b"\x03abcdef"))
'abc'
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return data.decode("latin-1")
raise ValueError("expected %d bytes in a string1, but only %d remain" %
(n, len(data)))
string1 = ArgumentDescriptor(
name="string1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_string1,
doc="""A counted string.
The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many
bytes.
""")
def read_string4(f):
r"""
>>> import io
@ -415,28 +469,28 @@ string4 = ArgumentDescriptor(
""")
def read_string1(f):
def read_bytes1(f):
r"""
>>> import io
>>> read_string1(io.BytesIO(b"\x00"))
''
>>> read_string1(io.BytesIO(b"\x03abcdef"))
'abc'
>>> read_bytes1(io.BytesIO(b"\x00"))
b''
>>> read_bytes1(io.BytesIO(b"\x03abcdef"))
b'abc'
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return data.decode("latin-1")
raise ValueError("expected %d bytes in a string1, but only %d remain" %
return data
raise ValueError("expected %d bytes in a bytes1, but only %d remain" %
(n, len(data)))
string1 = ArgumentDescriptor(
name="string1",
bytes1 = ArgumentDescriptor(
name="bytes1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_string1,
doc="""A counted string.
reader=read_bytes1,
doc="""A counted bytes string.
The first argument is a 1-byte unsigned int giving the number
of bytes in the string, and the second argument is that many
@ -486,6 +540,7 @@ def read_bytes4(f):
"""
n = read_uint4(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("bytes4 byte count > sys.maxsize: %d" % n)
data = f.read(n)
@ -505,6 +560,39 @@ bytes4 = ArgumentDescriptor(
""")
def read_bytes8(f):
r"""
>>> import io
>>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x00\x00abc"))
b''
>>> read_bytes8(io.BytesIO(b"\x03\x00\x00\x00\x00\x00\x00\x00abcdef"))
b'abc'
>>> read_bytes8(io.BytesIO(b"\x00\x00\x00\x00\x00\x00\x03\x00abcdef"))
Traceback (most recent call last):
...
ValueError: expected 844424930131968 bytes in a bytes8, but only 6 remain
"""
n = read_uint8(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("bytes8 byte count > sys.maxsize: %d" % n)
data = f.read(n)
if len(data) == n:
return data
raise ValueError("expected %d bytes in a bytes8, but only %d remain" %
(n, len(data)))
bytes8 = ArgumentDescriptor(
name="bytes8",
n=TAKEN_FROM_ARGUMENT8U,
reader=read_bytes8,
doc="""A counted bytes string.
The first argument is a 8-byte little-endian unsigned int giving
the number of bytes, and the second argument is that many bytes.
""")
def read_unicodestringnl(f):
r"""
>>> import io
@ -530,6 +618,46 @@ unicodestringnl = ArgumentDescriptor(
escape sequences.
""")
def read_unicodestring1(f):
r"""
>>> import io
>>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8')
>>> enc
b'abcd\xea\xaf\x8d'
>>> n = bytes([len(enc)]) # little-endian 1-byte length
>>> t = read_unicodestring1(io.BytesIO(n + enc + b'junk'))
>>> s == t
True
>>> read_unicodestring1(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last):
...
ValueError: expected 7 bytes in a unicodestring1, but only 6 remain
"""
n = read_uint1(f)
assert n >= 0
data = f.read(n)
if len(data) == n:
return str(data, 'utf-8', 'surrogatepass')
raise ValueError("expected %d bytes in a unicodestring1, but only %d "
"remain" % (n, len(data)))
unicodestring1 = ArgumentDescriptor(
name="unicodestring1",
n=TAKEN_FROM_ARGUMENT1,
reader=read_unicodestring1,
doc="""A counted Unicode string.
The first argument is a 1-byte little-endian signed int
giving the number of bytes in the string, and the second
argument-- the UTF-8 encoding of the Unicode string --
contains that many bytes.
""")
def read_unicodestring4(f):
r"""
>>> import io
@ -549,6 +677,7 @@ def read_unicodestring4(f):
"""
n = read_uint4(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("unicodestring4 byte count > sys.maxsize: %d" % n)
data = f.read(n)
@ -570,6 +699,47 @@ unicodestring4 = ArgumentDescriptor(
""")
def read_unicodestring8(f):
r"""
>>> import io
>>> s = 'abcd\uabcd'
>>> enc = s.encode('utf-8')
>>> enc
b'abcd\xea\xaf\x8d'
>>> n = bytes([len(enc)]) + bytes(7) # little-endian 8-byte length
>>> t = read_unicodestring8(io.BytesIO(n + enc + b'junk'))
>>> s == t
True
>>> read_unicodestring8(io.BytesIO(n + enc[:-1]))
Traceback (most recent call last):
...
ValueError: expected 7 bytes in a unicodestring8, but only 6 remain
"""
n = read_uint8(f)
assert n >= 0
if n > sys.maxsize:
raise ValueError("unicodestring8 byte count > sys.maxsize: %d" % n)
data = f.read(n)
if len(data) == n:
return str(data, 'utf-8', 'surrogatepass')
raise ValueError("expected %d bytes in a unicodestring8, but only %d "
"remain" % (n, len(data)))
unicodestring8 = ArgumentDescriptor(
name="unicodestring8",
n=TAKEN_FROM_ARGUMENT8U,
reader=read_unicodestring8,
doc="""A counted Unicode string.
The first argument is a 8-byte little-endian signed int
giving the number of bytes in the string, and the second
argument-- the UTF-8 encoding of the Unicode string --
contains that many bytes.
""")
def read_decimalnl_short(f):
r"""
>>> import io
@ -859,6 +1029,16 @@ pydict = StackObject(
obtype=dict,
doc="A Python dict object.")
pyset = StackObject(
name="set",
obtype=set,
doc="A Python set object.")
pyfrozenset = StackObject(
name="frozenset",
obtype=set,
doc="A Python frozenset object.")
anyobject = StackObject(
name='any',
obtype=object,
@ -1142,6 +1322,19 @@ opcodes = [
literally as the string content.
"""),
I(name='BINBYTES8',
code='\x8e',
arg=bytes8,
stack_before=[],
stack_after=[pybytes],
proto=4,
doc="""Push a Python bytes object.
There are two arguments: the first is a 8-byte unsigned int giving
the number of bytes in the string, and the second is that many bytes,
which are taken literally as the string content.
"""),
# Ways to spell None.
I(name='NONE',
@ -1190,6 +1383,19 @@ opcodes = [
until the next newline character.
"""),
I(name='SHORT_BINUNICODE',
code='\x8c',
arg=unicodestring1,
stack_before=[],
stack_after=[pyunicode],
proto=4,
doc="""Push a Python Unicode string object.
There are two arguments: the first is a 1-byte little-endian signed int
giving the number of bytes in the string. The second is that many
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
I(name='BINUNICODE',
code='X',
arg=unicodestring4,
@ -1203,6 +1409,19 @@ opcodes = [
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
I(name='BINUNICODE8',
code='\x8d',
arg=unicodestring8,
stack_before=[],
stack_after=[pyunicode],
proto=4,
doc="""Push a Python Unicode string object.
There are two arguments: the first is a 8-byte little-endian signed int
giving the number of bytes in the string. The second is that many
bytes, and is the UTF-8 encoding of the Unicode string.
"""),
# Ways to spell floats.
I(name='FLOAT',
@ -1428,6 +1647,54 @@ opcodes = [
1, 2, ..., n, and in that order.
"""),
# Ways to build sets
I(name='EMPTY_SET',
code='\x8f',
arg=None,
stack_before=[],
stack_after=[pyset],
proto=4,
doc="Push an empty set."),
I(name='ADDITEMS',
code='\x90',
arg=None,
stack_before=[pyset, markobject, stackslice],
stack_after=[pyset],
proto=4,
doc="""Add an arbitrary number of items to an existing set.
The slice of the stack following the topmost markobject is taken as
a sequence of items, added to the set immediately under the topmost
markobject. Everything at and after the topmost markobject is popped,
leaving the mutated set at the top of the stack.
Stack before: ... pyset markobject item_1 ... item_n
Stack after: ... pyset
where pyset has been modified via pyset.add(item_i) = item_i for i in
1, 2, ..., n, and in that order.
"""),
# Way to build frozensets
I(name='FROZENSET',
code='\x91',
arg=None,
stack_before=[markobject, stackslice],
stack_after=[pyfrozenset],
proto=4,
doc="""Build a frozenset out of the topmost slice, after markobject.
All the stack entries following the topmost markobject are placed into
a single Python frozenset, which single frozenset object replaces all
of the stack from the topmost markobject onward. For example,
Stack before: ... markobject 1 2 3
Stack after: ... frozenset({1, 2, 3})
"""),
# Stack manipulation.
I(name='POP',
@ -1549,6 +1816,18 @@ opcodes = [
unsigned little-endian integer following.
"""),
I(name='MEMOIZE',
code='\x94',
arg=None,
stack_before=[anyobject],
stack_after=[anyobject],
proto=4,
doc="""Store the stack top into the memo. The stack is not popped.
The index of the memo location to write is the number of
elements currently present in the memo.
"""),
# Access the extension registry (predefined objects). Akin to the GET
# family.
@ -1614,6 +1893,15 @@ opcodes = [
stack, so unpickling subclasses can override this form of lookup.
"""),
I(name='STACK_GLOBAL',
code='\x93',
arg=None,
stack_before=[pyunicode, pyunicode],
stack_after=[anyobject],
proto=0,
doc="""Push a global object (module.attr) on the stack.
"""),
# Ways to build objects of classes pickle doesn't know about directly
# (user-defined classes). I despair of documenting this accurately
# and comprehensibly -- you really have to read the pickle code to
@ -1770,6 +2058,21 @@ opcodes = [
onto the stack.
"""),
I(name='NEWOBJ_EX',
code='\x92',
arg=None,
stack_before=[anyobject, anyobject, anyobject],
stack_after=[anyobject],
proto=4,
doc="""Build an object instance.
The stack before should be thought of as containing a class
object followed by an argument tuple and by a keyword argument dict
(the dict being the stack top). Call these cls and args. They are
popped off the stack, and the value returned by
cls.__new__(cls, *args, *kwargs) is pushed back onto the stack.
"""),
# Machine control.
I(name='PROTO',
@ -1797,6 +2100,20 @@ opcodes = [
empty then.
"""),
# Framing support.
I(name='FRAME',
code='\x95',
arg=uint8,
stack_before=[],
stack_after=[],
proto=4,
doc="""Indicate the beginning of a new frame.
The unpickler may use this opcode to safely prefetch data from its
underlying stream.
"""),
# Ways to deal with persistent IDs.
I(name='PERSID',
@ -1903,6 +2220,38 @@ del assure_pickle_consistency
##############################################################################
# A pickle opcode generator.
def _genops(data, yield_end_pos=False):
if isinstance(data, bytes_types):
data = io.BytesIO(data)
if hasattr(data, "tell"):
getpos = data.tell
else:
getpos = lambda: None
while True:
pos = getpos()
code = data.read(1)
opcode = code2op.get(code.decode("latin-1"))
if opcode is None:
if code == b"":
raise ValueError("pickle exhausted before seeing STOP")
else:
raise ValueError("at position %s, opcode %r unknown" % (
"<unknown>" if pos is None else pos,
code))
if opcode.arg is None:
arg = None
else:
arg = opcode.arg.reader(data)
if yield_end_pos:
yield opcode, arg, pos, getpos()
else:
yield opcode, arg, pos
if code == b'.':
assert opcode.name == 'STOP'
break
def genops(pickle):
"""Generate all the opcodes in a pickle.
@ -1926,62 +2275,47 @@ def genops(pickle):
used. Else (the pickle doesn't have a tell(), and it's not obvious how
to query its current position) pos is None.
"""
if isinstance(pickle, bytes_types):
import io
pickle = io.BytesIO(pickle)
if hasattr(pickle, "tell"):
getpos = pickle.tell
else:
getpos = lambda: None
while True:
pos = getpos()
code = pickle.read(1)
opcode = code2op.get(code.decode("latin-1"))
if opcode is None:
if code == b"":
raise ValueError("pickle exhausted before seeing STOP")
else:
raise ValueError("at position %s, opcode %r unknown" % (
pos is None and "<unknown>" or pos,
code))
if opcode.arg is None:
arg = None
else:
arg = opcode.arg.reader(pickle)
yield opcode, arg, pos
if code == b'.':
assert opcode.name == 'STOP'
break
return _genops(pickle)
##############################################################################
# A pickle optimizer.
def optimize(p):
'Optimize a pickle string by removing unused PUT opcodes'
gets = set() # set of args used by a GET opcode
puts = [] # (arg, startpos, stoppos) for the PUT opcodes
prevpos = None # set to pos if previous opcode was a PUT
for opcode, arg, pos in genops(p):
if prevpos is not None:
puts.append((prevarg, prevpos, pos))
prevpos = None
not_a_put = object()
gets = { not_a_put } # set of args used by a GET opcode
opcodes = [] # (startpos, stoppos, putid)
proto = 0
for opcode, arg, pos, end_pos in _genops(p, yield_end_pos=True):
if 'PUT' in opcode.name:
prevarg, prevpos = arg, pos
elif 'GET' in opcode.name:
opcodes.append((pos, end_pos, arg))
elif 'FRAME' in opcode.name:
pass
else:
if 'GET' in opcode.name:
gets.add(arg)
elif opcode.name == 'PROTO':
assert pos == 0, pos
proto = arg
opcodes.append((pos, end_pos, not_a_put))
prevpos, prevarg = pos, None
# Copy the pickle string except for PUTS without a corresponding GET
s = []
i = 0
for arg, start, stop in puts:
j = stop if (arg in gets) else start
s.append(p[i:j])
i = stop
s.append(p[i:])
return b''.join(s)
# Copy the opcodes except for PUTS without a corresponding GET
out = io.BytesIO()
opcodes = iter(opcodes)
if proto >= 2:
# Write the PROTO header before any framing
start, stop, _ = next(opcodes)
out.write(p[start:stop])
buf = pickle._Framer(out.write)
if proto >= 4:
buf.start_framing()
for start, stop, putid in opcodes:
if putid in gets:
buf.write(p[start:stop])
if proto >= 4:
buf.end_framing()
return out.getvalue()
##############################################################################
# A symbolic pickle disassembler.
@ -2081,17 +2415,20 @@ def dis(pickle, out=None, memo=None, indentlevel=4, annotate=0):
errormsg = markmsg = "no MARK exists on stack"
# Check for correct memo usage.
if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT"):
if opcode.name in ("PUT", "BINPUT", "LONG_BINPUT", "MEMOIZE"):
if opcode.name == "MEMOIZE":
memo_idx = len(memo)
else:
assert arg is not None
if arg in memo:
memo_idx = arg
if memo_idx in memo:
errormsg = "memo key %r already defined" % arg
elif not stack:
errormsg = "stack is empty -- can't store into memo"
elif stack[-1] is markobject:
errormsg = "can't store markobject in the memo"
else:
memo[arg] = stack[-1]
memo[memo_idx] = stack[-1]
elif opcode.name in ("GET", "BINGET", "LONG_BINGET"):
if arg in memo:
assert len(after) == 1

View File

@ -1,9 +1,10 @@
import copyreg
import io
import unittest
import pickle
import pickletools
import random
import sys
import copyreg
import unittest
import weakref
from http.cookies import SimpleCookie
@ -95,6 +96,9 @@ class E(C):
def __getinitargs__(self):
return ()
class H(object):
pass
import __main__
__main__.C = C
C.__module__ = "__main__"
@ -102,6 +106,8 @@ __main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
__main__.H = H
H.__module__ = "__main__"
class myint(int):
def __init__(self, x):
@ -428,6 +434,7 @@ def create_data():
x.append(5)
return x
class AbstractPickleTests(unittest.TestCase):
# Subclass must define self.dumps, self.loads.
@ -436,23 +443,41 @@ class AbstractPickleTests(unittest.TestCase):
def setUp(self):
pass
def assert_is_copy(self, obj, objcopy, msg=None):
"""Utility method to verify if two objects are copies of each others.
"""
if msg is None:
msg = "{!r} is not a copy of {!r}".format(obj, objcopy)
self.assertEqual(obj, objcopy, msg=msg)
self.assertIs(type(obj), type(objcopy), msg=msg)
if hasattr(obj, '__dict__'):
self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg)
self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg)
if hasattr(obj, '__slots__'):
self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg)
for slot in obj.__slots__:
self.assertEqual(
hasattr(obj, slot), hasattr(objcopy, slot), msg=msg)
self.assertEqual(getattr(obj, slot, None),
getattr(objcopy, slot, None), msg=msg)
def test_misc(self):
# test various datatypes not tested by testdata
for proto in protocols:
x = myint(4)
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
x = (1, ())
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
x = initarg(1, x)
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
# XXX test __reduce__ protocol?
@ -461,16 +486,16 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(expected, proto)
got = self.loads(s)
self.assertEqual(expected, got)
self.assert_is_copy(expected, got)
def test_load_from_data0(self):
self.assertEqual(self._testdata, self.loads(DATA0))
self.assert_is_copy(self._testdata, self.loads(DATA0))
def test_load_from_data1(self):
self.assertEqual(self._testdata, self.loads(DATA1))
self.assert_is_copy(self._testdata, self.loads(DATA1))
def test_load_from_data2(self):
self.assertEqual(self._testdata, self.loads(DATA2))
self.assert_is_copy(self._testdata, self.loads(DATA2))
def test_load_classic_instance(self):
# See issue5180. Test loading 2.x pickles that
@ -492,7 +517,7 @@ class AbstractPickleTests(unittest.TestCase):
b"X\n"
b"p0\n"
b"(dp1\nb.").replace(b'X', xname)
self.assertEqual(X(*args), self.loads(pickle0))
self.assert_is_copy(X(*args), self.loads(pickle0))
# Protocol 1 (binary mode pickle)
"""
@ -509,7 +534,7 @@ class AbstractPickleTests(unittest.TestCase):
pickle1 = (b'(c__main__\n'
b'X\n'
b'q\x00oq\x01}q\x02b.').replace(b'X', xname)
self.assertEqual(X(*args), self.loads(pickle1))
self.assert_is_copy(X(*args), self.loads(pickle1))
# Protocol 2 (pickle2 = b'\x80\x02' + pickle1)
"""
@ -527,7 +552,7 @@ class AbstractPickleTests(unittest.TestCase):
pickle2 = (b'\x80\x02(c__main__\n'
b'X\n'
b'q\x00oq\x01}q\x02b.').replace(b'X', xname)
self.assertEqual(X(*args), self.loads(pickle2))
self.assert_is_copy(X(*args), self.loads(pickle2))
# There are gratuitous differences between pickles produced by
# pickle and cPickle, largely because cPickle starts PUT indices at
@ -552,6 +577,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
self.assertIsInstance(x, list)
self.assertEqual(len(x), 1)
self.assertTrue(x is x[0])
@ -561,6 +587,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(t, proto)
x = self.loads(s)
self.assertIsInstance(x, tuple)
self.assertEqual(len(x), 1)
self.assertEqual(len(x[0]), 1)
self.assertTrue(x is x[0][0])
@ -571,15 +598,39 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(d, proto)
x = self.loads(s)
self.assertIsInstance(x, dict)
self.assertEqual(list(x.keys()), [1])
self.assertTrue(x[1] is x)
def test_recursive_set(self):
h = H()
y = set({h})
h.attr = y
for proto in protocols:
s = self.dumps(y, proto)
x = self.loads(s)
self.assertIsInstance(x, set)
self.assertIs(list(x)[0].attr, x)
self.assertEqual(len(x), 1)
def test_recursive_frozenset(self):
h = H()
y = frozenset({h})
h.attr = y
for proto in protocols:
s = self.dumps(y, proto)
x = self.loads(s)
self.assertIsInstance(x, frozenset)
self.assertIs(list(x)[0].attr, x)
self.assertEqual(len(x), 1)
def test_recursive_inst(self):
i = C()
i.attr = i
for proto in protocols:
s = self.dumps(i, proto)
x = self.loads(s)
self.assertIsInstance(x, C)
self.assertEqual(dir(x), dir(i))
self.assertIs(x.attr, x)
@ -592,6 +643,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(l, proto)
x = self.loads(s)
self.assertIsInstance(x, list)
self.assertEqual(len(x), 1)
self.assertEqual(dir(x[0]), dir(i))
self.assertEqual(list(x[0].attr.keys()), [1])
@ -599,7 +651,8 @@ class AbstractPickleTests(unittest.TestCase):
def test_get(self):
self.assertRaises(KeyError, self.loads, b'g0\np0')
self.assertEqual(self.loads(b'((Kdtp0\nh\x00l.))'), [(100,), (100,)])
self.assert_is_copy([(100,), (100,)],
self.loads(b'((Kdtp0\nh\x00l.))'))
def test_unicode(self):
endcases = ['', '<\\u>', '<\\\u1234>', '<\n>',
@ -610,26 +663,26 @@ class AbstractPickleTests(unittest.TestCase):
for u in endcases:
p = self.dumps(u, proto)
u2 = self.loads(p)
self.assertEqual(u2, u)
self.assert_is_copy(u, u2)
def test_unicode_high_plane(self):
t = '\U00012345'
for proto in protocols:
p = self.dumps(t, proto)
t2 = self.loads(p)
self.assertEqual(t2, t)
self.assert_is_copy(t, t2)
def test_bytes(self):
for proto in protocols:
for s in b'', b'xyz', b'xyz'*100:
p = self.dumps(s, proto)
self.assertEqual(self.loads(p), s)
self.assert_is_copy(s, self.loads(p))
for s in [bytes([i]) for i in range(256)]:
p = self.dumps(s, proto)
self.assertEqual(self.loads(p), s)
self.assert_is_copy(s, self.loads(p))
for s in [bytes([i, i]) for i in range(256)]:
p = self.dumps(s, proto)
self.assertEqual(self.loads(p), s)
self.assert_is_copy(s, self.loads(p))
def test_ints(self):
import sys
@ -639,14 +692,14 @@ class AbstractPickleTests(unittest.TestCase):
for expected in (-n, n):
s = self.dumps(expected, proto)
n2 = self.loads(s)
self.assertEqual(expected, n2)
self.assert_is_copy(expected, n2)
n = n >> 1
def test_maxint64(self):
maxint64 = (1 << 63) - 1
data = b'I' + str(maxint64).encode("ascii") + b'\n.'
got = self.loads(data)
self.assertEqual(got, maxint64)
self.assert_is_copy(maxint64, got)
# Try too with a bogus literal.
data = b'I' + str(maxint64).encode("ascii") + b'JUNK\n.'
@ -661,7 +714,7 @@ class AbstractPickleTests(unittest.TestCase):
for n in npos, -npos:
pickle = self.dumps(n, proto)
got = self.loads(pickle)
self.assertEqual(n, got)
self.assert_is_copy(n, got)
# Try a monster. This is quadratic-time in protos 0 & 1, so don't
# bother with those.
nbase = int("deadbeeffeedface", 16)
@ -669,7 +722,7 @@ class AbstractPickleTests(unittest.TestCase):
for n in nbase, -nbase:
p = self.dumps(n, 2)
got = self.loads(p)
self.assertEqual(n, got)
self.assert_is_copy(n, got)
def test_float(self):
test_values = [0.0, 4.94e-324, 1e-310, 7e-308, 6.626e-34, 0.1, 0.5,
@ -679,7 +732,7 @@ class AbstractPickleTests(unittest.TestCase):
for value in test_values:
pickle = self.dumps(value, proto)
got = self.loads(pickle)
self.assertEqual(value, got)
self.assert_is_copy(value, got)
@run_with_locale('LC_ALL', 'de_DE', 'fr_FR')
def test_float_format(self):
@ -711,6 +764,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(a, proto)
b = self.loads(s)
self.assertEqual(a, b)
self.assertIs(type(a), type(b))
def test_structseq(self):
import time
@ -720,48 +774,48 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(t, proto)
u = self.loads(s)
self.assertEqual(t, u)
self.assert_is_copy(t, u)
if hasattr(os, "stat"):
t = os.stat(os.curdir)
s = self.dumps(t, proto)
u = self.loads(s)
self.assertEqual(t, u)
self.assert_is_copy(t, u)
if hasattr(os, "statvfs"):
t = os.statvfs(os.curdir)
s = self.dumps(t, proto)
u = self.loads(s)
self.assertEqual(t, u)
self.assert_is_copy(t, u)
def test_ellipsis(self):
for proto in protocols:
s = self.dumps(..., proto)
u = self.loads(s)
self.assertEqual(..., u)
self.assertIs(..., u)
def test_notimplemented(self):
for proto in protocols:
s = self.dumps(NotImplemented, proto)
u = self.loads(s)
self.assertEqual(NotImplemented, u)
self.assertIs(NotImplemented, u)
# Tests for protocol 2
def test_proto(self):
build_none = pickle.NONE + pickle.STOP
for proto in protocols:
expected = build_none
pickled = self.dumps(None, proto)
if proto >= 2:
expected = pickle.PROTO + bytes([proto]) + expected
p = self.dumps(None, proto)
self.assertEqual(p, expected)
proto_header = pickle.PROTO + bytes([proto])
self.assertTrue(pickled.startswith(proto_header))
else:
self.assertEqual(count_opcode(pickle.PROTO, pickled), 0)
oob = protocols[-1] + 1 # a future protocol
build_none = pickle.NONE + pickle.STOP
badpickle = pickle.PROTO + bytes([oob]) + build_none
try:
self.loads(badpickle)
except ValueError as detail:
self.assertTrue(str(detail).startswith(
"unsupported pickle protocol"))
except ValueError as err:
self.assertIn("unsupported pickle protocol", str(err))
else:
self.fail("expected bad protocol number to raise ValueError")
@ -770,7 +824,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2)
def test_long4(self):
@ -778,7 +832,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2)
def test_short_tuples(self):
@ -816,9 +870,9 @@ class AbstractPickleTests(unittest.TestCase):
for x in a, b, c, d, e:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y, (proto, x, s, y))
expected = expected_opcode[proto, len(x)]
self.assertEqual(opcode_in_pickle(expected, s), True)
self.assert_is_copy(x, y)
expected = expected_opcode[min(proto, 3), len(x)]
self.assertTrue(opcode_in_pickle(expected, s))
def test_singletons(self):
# Map (proto, singleton) to expected opcode.
@ -842,8 +896,8 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto)
y = self.loads(s)
self.assertTrue(x is y, (proto, x, s, y))
expected = expected_opcode[proto, x]
self.assertEqual(opcode_in_pickle(expected, s), True)
expected = expected_opcode[min(proto, 3), x]
self.assertTrue(opcode_in_pickle(expected, s))
def test_newobj_tuple(self):
x = MyTuple([1, 2, 3])
@ -852,8 +906,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(tuple(x), tuple(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assert_is_copy(x, y)
def test_newobj_list(self):
x = MyList([1, 2, 3])
@ -862,8 +915,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assert_is_copy(x, y)
def test_newobj_generic(self):
for proto in protocols:
@ -874,6 +926,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto)
y = self.loads(s)
detail = (proto, C, B, x, y, type(y))
self.assert_is_copy(x, y) # XXX revisit
self.assertEqual(B(x), B(y), detail)
self.assertEqual(x.__dict__, y.__dict__, detail)
@ -912,11 +965,10 @@ class AbstractPickleTests(unittest.TestCase):
s1 = self.dumps(x, 1)
self.assertIn(__name__.encode("utf-8"), s1)
self.assertIn(b"MyList", s1)
self.assertEqual(opcode_in_pickle(opcode, s1), False)
self.assertFalse(opcode_in_pickle(opcode, s1))
y = self.loads(s1)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assert_is_copy(x, y)
# Dump using protocol 2 for test.
s2 = self.dumps(x, 2)
@ -925,9 +977,7 @@ class AbstractPickleTests(unittest.TestCase):
self.assertEqual(opcode_in_pickle(opcode, s2), True, repr(s2))
y = self.loads(s2)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assert_is_copy(x, y)
finally:
e.restore()
@ -951,7 +1001,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
num_appends = count_opcode(pickle.APPENDS, s)
self.assertEqual(num_appends, proto > 0)
@ -960,7 +1010,7 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
num_appends = count_opcode(pickle.APPENDS, s)
if proto == 0:
self.assertEqual(num_appends, 0)
@ -974,7 +1024,7 @@ class AbstractPickleTests(unittest.TestCase):
s = self.dumps(x, proto)
self.assertIsInstance(s, bytes_types)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s)
self.assertEqual(num_setitems, proto > 0)
@ -983,22 +1033,49 @@ class AbstractPickleTests(unittest.TestCase):
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assertEqual(x, y)
self.assert_is_copy(x, y)
num_setitems = count_opcode(pickle.SETITEMS, s)
if proto == 0:
self.assertEqual(num_setitems, 0)
else:
self.assertTrue(num_setitems >= 2)
def test_set_chunking(self):
n = 10 # too small to chunk
x = set(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assert_is_copy(x, y)
num_additems = count_opcode(pickle.ADDITEMS, s)
if proto < 4:
self.assertEqual(num_additems, 0)
else:
self.assertEqual(num_additems, 1)
n = 2500 # expect at least two chunks when proto >= 4
x = set(range(n))
for proto in protocols:
s = self.dumps(x, proto)
y = self.loads(s)
self.assert_is_copy(x, y)
num_additems = count_opcode(pickle.ADDITEMS, s)
if proto < 4:
self.assertEqual(num_additems, 0)
else:
self.assertGreaterEqual(num_additems, 2)
def test_simple_newobj(self):
x = object.__new__(SimpleNewObj) # avoid __init__
x.abc = 666
for proto in protocols:
s = self.dumps(x, proto)
self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s), proto >= 2)
self.assertEqual(opcode_in_pickle(pickle.NEWOBJ, s),
2 <= proto < 4)
self.assertEqual(opcode_in_pickle(pickle.NEWOBJ_EX, s),
proto >= 4)
y = self.loads(s) # will raise TypeError if __init__ called
self.assertEqual(y.abc, 666)
self.assertEqual(x.__dict__, y.__dict__)
self.assert_is_copy(x, y)
def test_newobj_list_slots(self):
x = SlotList([1, 2, 3])
@ -1006,10 +1083,7 @@ class AbstractPickleTests(unittest.TestCase):
x.bar = "hello"
s = self.dumps(x, 2)
y = self.loads(s)
self.assertEqual(list(x), list(y))
self.assertEqual(x.__dict__, y.__dict__)
self.assertEqual(x.foo, y.foo)
self.assertEqual(x.bar, y.bar)
self.assert_is_copy(x, y)
def test_reduce_overrides_default_reduce_ex(self):
for proto in protocols:
@ -1058,11 +1132,10 @@ class AbstractPickleTests(unittest.TestCase):
@no_tracing
def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
for proto in 0, 1:
for proto in protocols:
self.assertRaises(RuntimeError, self.dumps, x, proto)
# protocol 2 don't raise a RuntimeError.
d = self.dumps(x, 2)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
@ -1095,11 +1168,10 @@ class AbstractPickleTests(unittest.TestCase):
obj = [dict(large_dict), dict(large_dict), dict(large_dict)]
for proto in protocols:
with self.subTest(proto=proto):
dumped = self.dumps(obj, proto)
loaded = self.loads(dumped)
self.assertEqual(loaded, obj,
"Failed protocol %d: %r != %r"
% (proto, obj, loaded))
self.assert_is_copy(obj, loaded)
def test_attribute_name_interning(self):
# Test that attribute names of pickled objects are interned when
@ -1155,11 +1227,14 @@ class AbstractPickleTests(unittest.TestCase):
def test_int_pickling_efficiency(self):
# Test compacity of int representation (see issue #12744)
for proto in protocols:
sizes = [len(self.dumps(2**n, proto)) for n in range(70)]
with self.subTest(proto=proto):
pickles = [self.dumps(2**n, proto) for n in range(70)]
sizes = list(map(len, pickles))
# the size function is monotonic
self.assertEqual(sorted(sizes), sizes)
if proto >= 2:
self.assertLessEqual(sizes[-1], 14)
for p in pickles:
self.assertFalse(opcode_in_pickle(pickle.LONG, p))
def check_negative_32b_binXXX(self, dumped):
if sys.maxsize > 2**32:
@ -1242,6 +1317,137 @@ class AbstractPickleTests(unittest.TestCase):
else:
self._check_pickling_with_opcode(obj, pickle.SETITEMS, proto)
# Exercise framing (proto >= 4) for significant workloads
FRAME_SIZE_TARGET = 64 * 1024
def test_framing_many_objects(self):
obj = list(range(10**5))
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
pickled = self.dumps(obj, proto)
unpickled = self.loads(pickled)
self.assertEqual(obj, unpickled)
# Test the framing heuristic is sane,
# assuming a given frame size target.
bytes_per_frame = (len(pickled) /
pickled.count(b'\x00\x00\x00\x00\x00'))
self.assertGreater(bytes_per_frame,
self.FRAME_SIZE_TARGET / 2)
self.assertLessEqual(bytes_per_frame,
self.FRAME_SIZE_TARGET * 1)
def test_framing_large_objects(self):
N = 1024 * 1024
obj = [b'x' * N, b'y' * N, b'z' * N]
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
pickled = self.dumps(obj, proto)
unpickled = self.loads(pickled)
self.assertEqual(obj, unpickled)
# At least one frame was emitted per large bytes object.
n_frames = pickled.count(b'\x00\x00\x00\x00\x00')
self.assertGreaterEqual(n_frames, len(obj))
def test_nested_names(self):
global Nested
class Nested:
class A:
class B:
class C:
pass
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for obj in [Nested.A, Nested.A.B, Nested.A.B.C]:
with self.subTest(proto=proto, obj=obj):
unpickled = self.loads(self.dumps(obj, proto))
self.assertIs(obj, unpickled)
def test_py_methods(self):
global PyMethodsTest
class PyMethodsTest:
@staticmethod
def cheese():
return "cheese"
@classmethod
def wine(cls):
assert cls is PyMethodsTest
return "wine"
def biscuits(self):
assert isinstance(self, PyMethodsTest)
return "biscuits"
class Nested:
"Nested class"
@staticmethod
def ketchup():
return "ketchup"
@classmethod
def maple(cls):
assert cls is PyMethodsTest.Nested
return "maple"
def pie(self):
assert isinstance(self, PyMethodsTest.Nested)
return "pie"
py_methods = (
PyMethodsTest.cheese,
PyMethodsTest.wine,
PyMethodsTest().biscuits,
PyMethodsTest.Nested.ketchup,
PyMethodsTest.Nested.maple,
PyMethodsTest.Nested().pie
)
py_unbound_methods = (
(PyMethodsTest.biscuits, PyMethodsTest),
(PyMethodsTest.Nested.pie, PyMethodsTest.Nested)
)
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for method in py_methods:
with self.subTest(proto=proto, method=method):
unpickled = self.loads(self.dumps(method, proto))
self.assertEqual(method(), unpickled())
for method, cls in py_unbound_methods:
obj = cls()
with self.subTest(proto=proto, method=method):
unpickled = self.loads(self.dumps(method, proto))
self.assertEqual(method(obj), unpickled(obj))
def test_c_methods(self):
global Subclass
class Subclass(tuple):
class Nested(str):
pass
c_methods = (
# bound built-in method
("abcd".index, ("c",)),
# unbound built-in method
(str.index, ("abcd", "c")),
# bound "slot" method
([1, 2, 3].__len__, ()),
# unbound "slot" method
(list.__len__, ([1, 2, 3],)),
# bound "coexist" method
({1, 2}.__contains__, (2,)),
# unbound "coexist" method
(set.__contains__, ({1, 2}, 2)),
# built-in class method
(dict.fromkeys, (("a", 1), ("b", 2))),
# built-in static method
(bytearray.maketrans, (b"abc", b"xyz")),
# subclass methods
(Subclass([1,2,2]).count, (2,)),
(Subclass.count, (Subclass([1,2,2]), 2)),
(Subclass.Nested("sweet").count, ("e",)),
(Subclass.Nested.count, (Subclass.Nested("sweet"), "e")),
)
for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
for method, args in c_methods:
with self.subTest(proto=proto, method=method):
unpickled = self.loads(self.dumps(method, proto))
self.assertEqual(method(*args), unpickled(*args))
class BigmemPickleTests(unittest.TestCase):
@ -1252,6 +1458,7 @@ class BigmemPickleTests(unittest.TestCase):
data = 1 << (8 * size)
try:
for proto in protocols:
with self.subTest(proto=proto):
if proto < 2:
continue
with self.assertRaises((ValueError, OverflowError)):
@ -1268,12 +1475,13 @@ class BigmemPickleTests(unittest.TestCase):
data = b"abcd" * (size // 4)
try:
for proto in protocols:
with self.subTest(proto=proto):
if proto < 3:
continue
try:
pickled = self.dumps(data, protocol=proto)
self.assertTrue(b"abcd" in pickled[:15])
self.assertTrue(b"abcd" in pickled[-15:])
self.assertTrue(b"abcd" in pickled[:19])
self.assertTrue(b"abcd" in pickled[-18:])
finally:
pickled = None
finally:
@ -1284,6 +1492,7 @@ class BigmemPickleTests(unittest.TestCase):
data = b"a" * size
try:
for proto in protocols:
with self.subTest(proto=proto):
if proto < 3:
continue
with self.assertRaises((ValueError, OverflowError)):
@ -1299,27 +1508,38 @@ class BigmemPickleTests(unittest.TestCase):
data = "abcd" * (size // 4)
try:
for proto in protocols:
with self.subTest(proto=proto):
try:
pickled = self.dumps(data, protocol=proto)
self.assertTrue(b"abcd" in pickled[:15])
self.assertTrue(b"abcd" in pickled[-15:])
self.assertTrue(b"abcd" in pickled[:19])
self.assertTrue(b"abcd" in pickled[-18:])
finally:
pickled = None
finally:
data = None
# BINUNICODE (protocols 1, 2 and 3) cannot carry more than
# 2**32 - 1 bytes of utf-8 encoded unicode.
# BINUNICODE (protocols 1, 2 and 3) cannot carry more than 2**32 - 1 bytes
# of utf-8 encoded unicode. BINUNICODE8 (protocol 4) supports these huge
# unicode strings however.
@bigmemtest(size=_4G, memuse=1 + ascii_char_size, dry_run=False)
@bigmemtest(size=_4G, memuse=2 + ascii_char_size, dry_run=False)
def test_huge_str_64b(self, size):
data = "a" * size
data = "abcd" * (size // 4)
try:
for proto in protocols:
with self.subTest(proto=proto):
if proto == 0:
continue
if proto < 4:
with self.assertRaises((ValueError, OverflowError)):
self.dumps(data, protocol=proto)
else:
try:
pickled = self.dumps(data, protocol=proto)
self.assertTrue(b"abcd" in pickled[:19])
self.assertTrue(b"abcd" in pickled[-18:])
finally:
pickled = None
finally:
data = None
@ -1363,8 +1583,8 @@ class REX_five(object):
return object.__reduce__(self)
class REX_six(object):
"""This class is used to check the 4th argument (list iterator) of the reduce
protocol.
"""This class is used to check the 4th argument (list iterator) of
the reduce protocol.
"""
def __init__(self, items=None):
self.items = items if items is not None else []
@ -1376,8 +1596,8 @@ class REX_six(object):
return type(self), (), None, iter(self.items), None
class REX_seven(object):
"""This class is used to check the 5th argument (dict iterator) of the reduce
protocol.
"""This class is used to check the 5th argument (dict iterator) of
the reduce protocol.
"""
def __init__(self, table=None):
self.table = table if table is not None else {}
@ -1415,10 +1635,16 @@ class MyList(list):
class MyDict(dict):
sample = {"a": 1, "b": 2}
class MySet(set):
sample = {"a", "b"}
class MyFrozenSet(frozenset):
sample = frozenset({"a", "b"})
myclasses = [MyInt, MyFloat,
MyComplex,
MyStr, MyUnicode,
MyTuple, MyList, MyDict]
MyTuple, MyList, MyDict, MySet, MyFrozenSet]
class SlotList(MyList):
@ -1428,6 +1654,8 @@ class SimpleNewObj(object):
def __init__(self, a, b, c):
# raise an error, to make sure this isn't called
raise TypeError("SimpleNewObj.__init__() didn't expect to get called")
def __eq__(self, other):
return self.__dict__ == other.__dict__
class BadGetattr:
def __getattr__(self, key):
@ -1464,7 +1692,7 @@ class AbstractPickleModuleTests(unittest.TestCase):
def test_highest_protocol(self):
# Of course this needs to be changed when HIGHEST_PROTOCOL changes.
self.assertEqual(pickle.HIGHEST_PROTOCOL, 3)
self.assertEqual(pickle.HIGHEST_PROTOCOL, 4)
def test_callapi(self):
f = io.BytesIO()
@ -1645,6 +1873,7 @@ class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
def _check_multiple_unpicklings(self, ioclass):
for proto in protocols:
with self.subTest(proto=proto):
data1 = [(x, str(x)) for x in range(2000)] + [b"abcde", len]
f = ioclass()
pickler = self.pickler_class(f, protocol=proto)

View File

@ -1,8 +1,11 @@
import builtins
import copyreg
import gc
import itertools
import math
import pickle
import sys
import types
import math
import unittest
import weakref
@ -3153,176 +3156,6 @@ order (MRO) for bases """
self.assertEqual(e.a, 1)
self.assertEqual(can_delete_dict(e), can_delete_dict(ValueError()))
def test_pickles(self):
# Testing pickling and copying new-style classes and objects...
import pickle
def sorteditems(d):
L = list(d.items())
L.sort()
return L
global C
class C(object):
def __init__(self, a, b):
super(C, self).__init__()
self.a = a
self.b = b
def __repr__(self):
return "C(%r, %r)" % (self.a, self.b)
global C1
class C1(list):
def __new__(cls, a, b):
return super(C1, cls).__new__(cls)
def __getnewargs__(self):
return (self.a, self.b)
def __init__(self, a, b):
self.a = a
self.b = b
def __repr__(self):
return "C1(%r, %r)<%r>" % (self.a, self.b, list(self))
global C2
class C2(int):
def __new__(cls, a, b, val=0):
return super(C2, cls).__new__(cls, val)
def __getnewargs__(self):
return (self.a, self.b, int(self))
def __init__(self, a, b, val=0):
self.a = a
self.b = b
def __repr__(self):
return "C2(%r, %r)<%r>" % (self.a, self.b, int(self))
global C3
class C3(object):
def __init__(self, foo):
self.foo = foo
def __getstate__(self):
return self.foo
def __setstate__(self, foo):
self.foo = foo
global C4classic, C4
class C4classic: # classic
pass
class C4(C4classic, object): # mixed inheritance
pass
for bin in 0, 1:
for cls in C, C1, C2:
s = pickle.dumps(cls, bin)
cls2 = pickle.loads(s)
self.assertIs(cls2, cls)
a = C1(1, 2); a.append(42); a.append(24)
b = C2("hello", "world", 42)
s = pickle.dumps((a, b), bin)
x, y = pickle.loads(s)
self.assertEqual(x.__class__, a.__class__)
self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__))
self.assertEqual(y.__class__, b.__class__)
self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__))
self.assertEqual(repr(x), repr(a))
self.assertEqual(repr(y), repr(b))
# Test for __getstate__ and __setstate__ on new style class
u = C3(42)
s = pickle.dumps(u, bin)
v = pickle.loads(s)
self.assertEqual(u.__class__, v.__class__)
self.assertEqual(u.foo, v.foo)
# Test for picklability of hybrid class
u = C4()
u.foo = 42
s = pickle.dumps(u, bin)
v = pickle.loads(s)
self.assertEqual(u.__class__, v.__class__)
self.assertEqual(u.foo, v.foo)
# Testing copy.deepcopy()
import copy
for cls in C, C1, C2:
cls2 = copy.deepcopy(cls)
self.assertIs(cls2, cls)
a = C1(1, 2); a.append(42); a.append(24)
b = C2("hello", "world", 42)
x, y = copy.deepcopy((a, b))
self.assertEqual(x.__class__, a.__class__)
self.assertEqual(sorteditems(x.__dict__), sorteditems(a.__dict__))
self.assertEqual(y.__class__, b.__class__)
self.assertEqual(sorteditems(y.__dict__), sorteditems(b.__dict__))
self.assertEqual(repr(x), repr(a))
self.assertEqual(repr(y), repr(b))
def test_pickle_slots(self):
# Testing pickling of classes with __slots__ ...
import pickle
# Pickling of classes with __slots__ but without __getstate__ should fail
# (if using protocol 0 or 1)
global B, C, D, E
class B(object):
pass
for base in [object, B]:
class C(base):
__slots__ = ['a']
class D(C):
pass
try:
pickle.dumps(C(), 0)
except TypeError:
pass
else:
self.fail("should fail: pickle C instance - %s" % base)
try:
pickle.dumps(C(), 0)
except TypeError:
pass
else:
self.fail("should fail: pickle D instance - %s" % base)
# Give C a nice generic __getstate__ and __setstate__
class C(base):
__slots__ = ['a']
def __getstate__(self):
try:
d = self.__dict__.copy()
except AttributeError:
d = {}
for cls in self.__class__.__mro__:
for sn in cls.__dict__.get('__slots__', ()):
try:
d[sn] = getattr(self, sn)
except AttributeError:
pass
return d
def __setstate__(self, d):
for k, v in list(d.items()):
setattr(self, k, v)
class D(C):
pass
# Now it should work
x = C()
y = pickle.loads(pickle.dumps(x))
self.assertNotHasAttr(y, 'a')
x.a = 42
y = pickle.loads(pickle.dumps(x))
self.assertEqual(y.a, 42)
x = D()
x.a = 42
x.b = 100
y = pickle.loads(pickle.dumps(x))
self.assertEqual(y.a + y.b, 142)
# A subclass that adds a slot should also work
class E(C):
__slots__ = ['b']
x = E()
x.a = 42
x.b = "foo"
y = pickle.loads(pickle.dumps(x))
self.assertEqual(y.a, x.a)
self.assertEqual(y.b, x.b)
def test_binary_operator_override(self):
# Testing overrides of binary operations...
class I(int):
@ -4690,11 +4523,439 @@ class MiscTests(unittest.TestCase):
self.assertEqual(X.mykey2, 'from Base2')
class PicklingTests(unittest.TestCase):
def _check_reduce(self, proto, obj, args=(), kwargs={}, state=None,
listitems=None, dictitems=None):
if proto >= 4:
reduce_value = obj.__reduce_ex__(proto)
self.assertEqual(reduce_value[:3],
(copyreg.__newobj_ex__,
(type(obj), args, kwargs),
state))
if listitems is not None:
self.assertListEqual(list(reduce_value[3]), listitems)
else:
self.assertIsNone(reduce_value[3])
if dictitems is not None:
self.assertDictEqual(dict(reduce_value[4]), dictitems)
else:
self.assertIsNone(reduce_value[4])
elif proto >= 2:
reduce_value = obj.__reduce_ex__(proto)
self.assertEqual(reduce_value[:3],
(copyreg.__newobj__,
(type(obj),) + args,
state))
if listitems is not None:
self.assertListEqual(list(reduce_value[3]), listitems)
else:
self.assertIsNone(reduce_value[3])
if dictitems is not None:
self.assertDictEqual(dict(reduce_value[4]), dictitems)
else:
self.assertIsNone(reduce_value[4])
else:
base_type = type(obj).__base__
reduce_value = (copyreg._reconstructor,
(type(obj),
base_type,
None if base_type is object else base_type(obj)))
if state is not None:
reduce_value += (state,)
self.assertEqual(obj.__reduce_ex__(proto), reduce_value)
self.assertEqual(obj.__reduce__(), reduce_value)
def test_reduce(self):
protocols = range(pickle.HIGHEST_PROTOCOL + 1)
args = (-101, "spam")
kwargs = {'bacon': -201, 'fish': -301}
state = {'cheese': -401}
class C1:
def __getnewargs__(self):
return args
obj = C1()
for proto in protocols:
self._check_reduce(proto, obj, args)
for name, value in state.items():
setattr(obj, name, value)
for proto in protocols:
self._check_reduce(proto, obj, args, state=state)
class C2:
def __getnewargs__(self):
return "bad args"
obj = C2()
for proto in protocols:
if proto >= 2:
with self.assertRaises(TypeError):
obj.__reduce_ex__(proto)
class C3:
def __getnewargs_ex__(self):
return (args, kwargs)
obj = C3()
for proto in protocols:
if proto >= 4:
self._check_reduce(proto, obj, args, kwargs)
elif proto >= 2:
with self.assertRaises(ValueError):
obj.__reduce_ex__(proto)
class C4:
def __getnewargs_ex__(self):
return (args, "bad dict")
class C5:
def __getnewargs_ex__(self):
return ("bad tuple", kwargs)
class C6:
def __getnewargs_ex__(self):
return ()
class C7:
def __getnewargs_ex__(self):
return "bad args"
for proto in protocols:
for cls in C4, C5, C6, C7:
obj = cls()
if proto >= 2:
with self.assertRaises((TypeError, ValueError)):
obj.__reduce_ex__(proto)
class C8:
def __getnewargs_ex__(self):
return (args, kwargs)
obj = C8()
for proto in protocols:
if 2 <= proto < 4:
with self.assertRaises(ValueError):
obj.__reduce_ex__(proto)
class C9:
def __getnewargs_ex__(self):
return (args, {})
obj = C9()
for proto in protocols:
self._check_reduce(proto, obj, args)
class C10:
def __getnewargs_ex__(self):
raise IndexError
obj = C10()
for proto in protocols:
if proto >= 2:
with self.assertRaises(IndexError):
obj.__reduce_ex__(proto)
class C11:
def __getstate__(self):
return state
obj = C11()
for proto in protocols:
self._check_reduce(proto, obj, state=state)
class C12:
def __getstate__(self):
return "not dict"
obj = C12()
for proto in protocols:
self._check_reduce(proto, obj, state="not dict")
class C13:
def __getstate__(self):
raise IndexError
obj = C13()
for proto in protocols:
with self.assertRaises(IndexError):
obj.__reduce_ex__(proto)
if proto < 2:
with self.assertRaises(IndexError):
obj.__reduce__()
class C14:
__slots__ = tuple(state)
def __init__(self):
for name, value in state.items():
setattr(self, name, value)
obj = C14()
for proto in protocols:
if proto >= 2:
self._check_reduce(proto, obj, state=(None, state))
else:
with self.assertRaises(TypeError):
obj.__reduce_ex__(proto)
with self.assertRaises(TypeError):
obj.__reduce__()
class C15(dict):
pass
obj = C15({"quebec": -601})
for proto in protocols:
self._check_reduce(proto, obj, dictitems=dict(obj))
class C16(list):
pass
obj = C16(["yukon"])
for proto in protocols:
self._check_reduce(proto, obj, listitems=list(obj))
def _assert_is_copy(self, obj, objcopy, msg=None):
"""Utility method to verify if two objects are copies of each others.
"""
if msg is None:
msg = "{!r} is not a copy of {!r}".format(obj, objcopy)
if type(obj).__repr__ is object.__repr__:
# We have this limitation for now because we use the object's repr
# to help us verify that the two objects are copies. This allows
# us to delegate the non-generic verification logic to the objects
# themselves.
raise ValueError("object passed to _assert_is_copy must " +
"override the __repr__ method.")
self.assertIsNot(obj, objcopy, msg=msg)
self.assertIs(type(obj), type(objcopy), msg=msg)
if hasattr(obj, '__dict__'):
self.assertDictEqual(obj.__dict__, objcopy.__dict__, msg=msg)
self.assertIsNot(obj.__dict__, objcopy.__dict__, msg=msg)
if hasattr(obj, '__slots__'):
self.assertListEqual(obj.__slots__, objcopy.__slots__, msg=msg)
for slot in obj.__slots__:
self.assertEqual(
hasattr(obj, slot), hasattr(objcopy, slot), msg=msg)
self.assertEqual(getattr(obj, slot, None),
getattr(objcopy, slot, None), msg=msg)
self.assertEqual(repr(obj), repr(objcopy), msg=msg)
@staticmethod
def _generate_pickle_copiers():
"""Utility method to generate the many possible pickle configurations.
"""
class PickleCopier:
"This class copies object using pickle."
def __init__(self, proto, dumps, loads):
self.proto = proto
self.dumps = dumps
self.loads = loads
def copy(self, obj):
return self.loads(self.dumps(obj, self.proto))
def __repr__(self):
# We try to be as descriptive as possible here since this is
# the string which we will allow us to tell the pickle
# configuration we are using during debugging.
return ("PickleCopier(proto={}, dumps={}.{}, loads={}.{})"
.format(self.proto,
self.dumps.__module__, self.dumps.__qualname__,
self.loads.__module__, self.loads.__qualname__))
return (PickleCopier(*args) for args in
itertools.product(range(pickle.HIGHEST_PROTOCOL + 1),
{pickle.dumps, pickle._dumps},
{pickle.loads, pickle._loads}))
def test_pickle_slots(self):
# Tests pickling of classes with __slots__.
# Pickling of classes with __slots__ but without __getstate__ should
# fail (if using protocol 0 or 1)
global C
class C:
__slots__ = ['a']
with self.assertRaises(TypeError):
pickle.dumps(C(), 0)
global D
class D(C):
pass
with self.assertRaises(TypeError):
pickle.dumps(D(), 0)
class C:
"A class with __getstate__ and __setstate__ implemented."
__slots__ = ['a']
def __getstate__(self):
state = getattr(self, '__dict__', {}).copy()
for cls in type(self).__mro__:
for slot in cls.__dict__.get('__slots__', ()):
try:
state[slot] = getattr(self, slot)
except AttributeError:
pass
return state
def __setstate__(self, state):
for k, v in state.items():
setattr(self, k, v)
def __repr__(self):
return "%s()<%r>" % (type(self).__name__, self.__getstate__())
class D(C):
"A subclass of a class with slots."
pass
global E
class E(C):
"A subclass with an extra slot."
__slots__ = ['b']
# Now it should work
for pickle_copier in self._generate_pickle_copiers():
with self.subTest(pickle_copier=pickle_copier):
x = C()
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
x.a = 42
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
x = D()
x.a = 42
x.b = 100
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
x = E()
x.a = 42
x.b = "foo"
y = pickle_copier.copy(x)
self._assert_is_copy(x, y)
def test_reduce_copying(self):
# Tests pickling and copying new-style classes and objects.
global C1
class C1:
"The state of this class is copyable via its instance dict."
ARGS = (1, 2)
NEED_DICT_COPYING = True
def __init__(self, a, b):
super().__init__()
self.a = a
self.b = b
def __repr__(self):
return "C1(%r, %r)" % (self.a, self.b)
global C2
class C2(list):
"A list subclass copyable via __getnewargs__."
ARGS = (1, 2)
NEED_DICT_COPYING = False
def __new__(cls, a, b):
self = super().__new__(cls)
self.a = a
self.b = b
return self
def __init__(self, *args):
super().__init__()
# This helps testing that __init__ is not called during the
# unpickling process, which would cause extra appends.
self.append("cheese")
@classmethod
def __getnewargs__(cls):
return cls.ARGS
def __repr__(self):
return "C2(%r, %r)<%r>" % (self.a, self.b, list(self))
global C3
class C3(list):
"A list subclass copyable via __getstate__."
ARGS = (1, 2)
NEED_DICT_COPYING = False
def __init__(self, a, b):
self.a = a
self.b = b
# This helps testing that __init__ is not called during the
# unpickling process, which would cause extra appends.
self.append("cheese")
@classmethod
def __getstate__(cls):
return cls.ARGS
def __setstate__(self, state):
a, b = state
self.a = a
self.b = b
def __repr__(self):
return "C3(%r, %r)<%r>" % (self.a, self.b, list(self))
global C4
class C4(int):
"An int subclass copyable via __getnewargs__."
ARGS = ("hello", "world", 1)
NEED_DICT_COPYING = False
def __new__(cls, a, b, value):
self = super().__new__(cls, value)
self.a = a
self.b = b
return self
@classmethod
def __getnewargs__(cls):
return cls.ARGS
def __repr__(self):
return "C4(%r, %r)<%r>" % (self.a, self.b, int(self))
global C5
class C5(int):
"An int subclass copyable via __getnewargs_ex__."
ARGS = (1, 2)
KWARGS = {'value': 3}
NEED_DICT_COPYING = False
def __new__(cls, a, b, *, value=0):
self = super().__new__(cls, value)
self.a = a
self.b = b
return self
@classmethod
def __getnewargs_ex__(cls):
return (cls.ARGS, cls.KWARGS)
def __repr__(self):
return "C5(%r, %r)<%r>" % (self.a, self.b, int(self))
test_classes = (C1, C2, C3, C4, C5)
# Testing copying through pickle
pickle_copiers = self._generate_pickle_copiers()
for cls, pickle_copier in itertools.product(test_classes, pickle_copiers):
with self.subTest(cls=cls, pickle_copier=pickle_copier):
kwargs = getattr(cls, 'KWARGS', {})
obj = cls(*cls.ARGS, **kwargs)
proto = pickle_copier.proto
if 2 <= proto < 4 and hasattr(cls, '__getnewargs_ex__'):
with self.assertRaises(ValueError):
pickle_copier.dumps(obj, proto)
continue
objcopy = pickle_copier.copy(obj)
self._assert_is_copy(obj, objcopy)
# For test classes that supports this, make sure we didn't go
# around the reduce protocol by simply copying the attribute
# dictionary. We clear attributes using the previous copy to
# not mutate the original argument.
if proto >= 2 and not cls.NEED_DICT_COPYING:
objcopy.__dict__.clear()
objcopy2 = pickle_copier.copy(objcopy)
self._assert_is_copy(obj, objcopy2)
# Testing copying through copy.deepcopy()
for cls in test_classes:
with self.subTest(cls=cls):
kwargs = getattr(cls, 'KWARGS', {})
obj = cls(*cls.ARGS, **kwargs)
# XXX: We need to modify the copy module to support PEP 3154's
# reduce protocol 4.
if hasattr(cls, '__getnewargs_ex__'):
continue
objcopy = deepcopy(obj)
self._assert_is_copy(obj, objcopy)
# For test classes that supports this, make sure we didn't go
# around the reduce protocol by simply copying the attribute
# dictionary. We clear attributes using the previous copy to
# not mutate the original argument.
if not cls.NEED_DICT_COPYING:
objcopy.__dict__.clear()
objcopy2 = deepcopy(objcopy)
self._assert_is_copy(obj, objcopy2)
def test_main():
# Run all local test cases, with PTypesLongInitTest first.
support.run_unittest(PTypesLongInitTest, OperatorsTest,
ClassPropertiesAndMethods, DictProxyTests,
MiscTests)
MiscTests, PicklingTests)
if __name__ == "__main__":
test_main()

View File

@ -68,6 +68,8 @@ Core and Builtins
Library
-------
- Issue #17810: Implement PEP 3154, pickle protocol 4.
- Issue #19668: Added support for the cp1125 encoding.
- Issue #19689: Add ssl.create_default_context() factory function. It creates

File diff suppressed because it is too large Load Diff

View File

@ -69,6 +69,30 @@ PyMethod_New(PyObject *func, PyObject *self)
return (PyObject *)im;
}
static PyObject *
method_reduce(PyMethodObject *im)
{
PyObject *self = PyMethod_GET_SELF(im);
PyObject *func = PyMethod_GET_FUNCTION(im);
PyObject *builtins;
PyObject *getattr;
PyObject *funcname;
_Py_IDENTIFIER(getattr);
funcname = _PyObject_GetAttrId(func, &PyId___name__);
if (funcname == NULL) {
return NULL;
}
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
return Py_BuildValue("O(ON)", getattr, self, funcname);
}
static PyMethodDef method_methods[] = {
{"__reduce__", (PyCFunction)method_reduce, METH_NOARGS, NULL},
{NULL, NULL}
};
/* Descriptors for PyMethod attributes */
/* im_func and im_self are stored in the PyMethod object */
@ -367,7 +391,7 @@ PyTypeObject PyMethod_Type = {
offsetof(PyMethodObject, im_weakreflist), /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
method_methods, /* tp_methods */
method_memberlist, /* tp_members */
method_getset, /* tp_getset */
0, /* tp_base */

View File

@ -398,6 +398,24 @@ descr_get_qualname(PyDescrObject *descr)
return descr->d_qualname;
}
static PyObject *
descr_reduce(PyDescrObject *descr)
{
PyObject *builtins;
PyObject *getattr;
_Py_IDENTIFIER(getattr);
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
return Py_BuildValue("O(OO)", getattr, PyDescr_TYPE(descr),
PyDescr_NAME(descr));
}
static PyMethodDef descr_methods[] = {
{"__reduce__", (PyCFunction)descr_reduce, METH_NOARGS, NULL},
{NULL, NULL}
};
static PyMemberDef descr_members[] = {
{"__objclass__", T_OBJECT, offsetof(PyDescrObject, d_type), READONLY},
{"__name__", T_OBJECT, offsetof(PyDescrObject, d_name), READONLY},
@ -494,7 +512,7 @@ PyTypeObject PyMethodDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
descr_methods, /* tp_methods */
descr_members, /* tp_members */
method_getset, /* tp_getset */
0, /* tp_base */
@ -532,7 +550,7 @@ PyTypeObject PyClassMethodDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
descr_methods, /* tp_methods */
descr_members, /* tp_members */
method_getset, /* tp_getset */
0, /* tp_base */
@ -569,7 +587,7 @@ PyTypeObject PyMemberDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
descr_methods, /* tp_methods */
descr_members, /* tp_members */
member_getset, /* tp_getset */
0, /* tp_base */
@ -643,7 +661,7 @@ PyTypeObject PyWrapperDescr_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
descr_methods, /* tp_methods */
descr_members, /* tp_members */
wrapperdescr_getset, /* tp_getset */
0, /* tp_base */
@ -1085,6 +1103,23 @@ wrapper_repr(wrapperobject *wp)
wp->self);
}
static PyObject *
wrapper_reduce(wrapperobject *wp)
{
PyObject *builtins;
PyObject *getattr;
_Py_IDENTIFIER(getattr);
builtins = PyEval_GetBuiltins();
getattr = _PyDict_GetItemId(builtins, &PyId_getattr);
return Py_BuildValue("O(OO)", getattr, wp->self, PyDescr_NAME(wp->descr));
}
static PyMethodDef wrapper_methods[] = {
{"__reduce__", (PyCFunction)wrapper_reduce, METH_NOARGS, NULL},
{NULL, NULL}
};
static PyMemberDef wrapper_members[] = {
{"__self__", T_OBJECT, offsetof(wrapperobject, self), READONLY},
{0}
@ -1193,7 +1228,7 @@ PyTypeObject _PyMethodWrapper_Type = {
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
wrapper_methods, /* tp_methods */
wrapper_members, /* tp_members */
wrapper_getsets, /* tp_getset */
0, /* tp_base */

View File

@ -3405,149 +3405,428 @@ import_copyreg(void)
return cached_copyreg_module;
}
static PyObject *
slotnames(PyObject *cls)
Py_LOCAL(PyObject *)
_PyType_GetSlotNames(PyTypeObject *cls)
{
PyObject *clsdict;
PyObject *copyreg;
PyObject *slotnames;
_Py_IDENTIFIER(__slotnames__);
_Py_IDENTIFIER(_slotnames);
clsdict = ((PyTypeObject *)cls)->tp_dict;
slotnames = _PyDict_GetItemId(clsdict, &PyId___slotnames__);
if (slotnames != NULL && PyList_Check(slotnames)) {
assert(PyType_Check(cls));
/* Get the slot names from the cache in the class if possible. */
slotnames = _PyDict_GetItemIdWithError(cls->tp_dict, &PyId___slotnames__);
if (slotnames != NULL) {
if (slotnames != Py_None && !PyList_Check(slotnames)) {
PyErr_Format(PyExc_TypeError,
"%.200s.__slotnames__ should be a list or None, "
"not %.200s",
cls->tp_name, Py_TYPE(slotnames)->tp_name);
return NULL;
}
Py_INCREF(slotnames);
return slotnames;
}
else {
if (PyErr_Occurred()) {
return NULL;
}
/* The class does not have the slot names cached yet. */
}
copyreg = import_copyreg();
if (copyreg == NULL)
return NULL;
slotnames = _PyObject_CallMethodId(copyreg, &PyId__slotnames, "O", cls);
/* Use _slotnames function from the copyreg module to find the slots
by this class and its bases. This function will cache the result
in __slotnames__. */
slotnames = _PyObject_CallMethodIdObjArgs(copyreg, &PyId__slotnames,
cls, NULL);
Py_DECREF(copyreg);
if (slotnames != NULL &&
slotnames != Py_None &&
!PyList_Check(slotnames))
{
if (slotnames == NULL)
return NULL;
if (slotnames != Py_None && !PyList_Check(slotnames)) {
PyErr_SetString(PyExc_TypeError,
"copyreg._slotnames didn't return a list or None");
Py_DECREF(slotnames);
slotnames = NULL;
return NULL;
}
return slotnames;
}
static PyObject *
reduce_2(PyObject *obj)
Py_LOCAL(PyObject *)
_PyObject_GetState(PyObject *obj)
{
PyObject *cls, *getnewargs;
PyObject *args = NULL, *args2 = NULL;
PyObject *getstate = NULL, *state = NULL, *names = NULL;
PyObject *slots = NULL, *listitems = NULL, *dictitems = NULL;
PyObject *copyreg = NULL, *newobj = NULL, *res = NULL;
Py_ssize_t i, n;
_Py_IDENTIFIER(__getnewargs__);
PyObject *state;
PyObject *getstate;
_Py_IDENTIFIER(__getstate__);
_Py_IDENTIFIER(__newobj__);
cls = (PyObject *) Py_TYPE(obj);
getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__);
if (getnewargs != NULL) {
args = PyObject_CallObject(getnewargs, NULL);
Py_DECREF(getnewargs);
if (args != NULL && !PyTuple_Check(args)) {
PyErr_Format(PyExc_TypeError,
"__getnewargs__ should return a tuple, "
"not '%.200s'", Py_TYPE(args)->tp_name);
goto end;
}
}
else {
PyErr_Clear();
args = PyTuple_New(0);
}
if (args == NULL)
goto end;
getstate = _PyObject_GetAttrId(obj, &PyId___getstate__);
if (getstate != NULL) {
if (getstate == NULL) {
PyObject *slotnames;
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return NULL;
}
PyErr_Clear();
{
PyObject **dict;
dict = _PyObject_GetDictPtr(obj);
/* It is possible that the object's dict is not initialized
yet. In this case, we will return None for the state.
We also return None if the dict is empty to make the behavior
consistent regardless whether the dict was initialized or not.
This make unit testing easier. */
if (dict != NULL && *dict != NULL && PyDict_Size(*dict) > 0) {
state = *dict;
}
else {
state = Py_None;
}
Py_INCREF(state);
}
slotnames = _PyType_GetSlotNames(Py_TYPE(obj));
if (slotnames == NULL) {
Py_DECREF(state);
return NULL;
}
assert(slotnames == Py_None || PyList_Check(slotnames));
if (slotnames != Py_None && Py_SIZE(slotnames) > 0) {
PyObject *slots;
Py_ssize_t slotnames_size, i;
slots = PyDict_New();
if (slots == NULL) {
Py_DECREF(slotnames);
Py_DECREF(state);
return NULL;
}
slotnames_size = Py_SIZE(slotnames);
for (i = 0; i < slotnames_size; i++) {
PyObject *name, *value;
name = PyList_GET_ITEM(slotnames, i);
value = PyObject_GetAttr(obj, name);
if (value == NULL) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
goto error;
}
/* It is not an error if the attribute is not present. */
PyErr_Clear();
}
else {
int err = PyDict_SetItem(slots, name, value);
Py_DECREF(value);
if (err) {
goto error;
}
}
/* The list is stored on the class so it may mutates while we
iterate over it */
if (slotnames_size != Py_SIZE(slotnames)) {
PyErr_Format(PyExc_RuntimeError,
"__slotsname__ changed size during iteration");
goto error;
}
/* We handle errors within the loop here. */
if (0) {
error:
Py_DECREF(slotnames);
Py_DECREF(slots);
Py_DECREF(state);
return NULL;
}
}
/* If we found some slot attributes, pack them in a tuple along
the orginal attribute dictionary. */
if (PyDict_Size(slots) > 0) {
PyObject *state2;
state2 = PyTuple_Pack(2, state, slots);
Py_DECREF(state);
if (state2 == NULL) {
Py_DECREF(slotnames);
Py_DECREF(slots);
return NULL;
}
state = state2;
}
Py_DECREF(slots);
}
Py_DECREF(slotnames);
}
else { /* getstate != NULL */
state = PyObject_CallObject(getstate, NULL);
Py_DECREF(getstate);
if (state == NULL)
goto end;
return NULL;
}
return state;
}
Py_LOCAL(int)
_PyObject_GetNewArguments(PyObject *obj, PyObject **args, PyObject **kwargs)
{
PyObject *getnewargs, *getnewargs_ex;
_Py_IDENTIFIER(__getnewargs_ex__);
_Py_IDENTIFIER(__getnewargs__);
if (args == NULL || kwargs == NULL) {
PyErr_BadInternalCall();
return -1;
}
/* We first attempt to fetch the arguments for __new__ by calling
__getnewargs_ex__ on the object. */
getnewargs_ex = _PyObject_GetAttrId(obj, &PyId___getnewargs_ex__);
if (getnewargs_ex != NULL) {
PyObject *newargs = PyObject_CallObject(getnewargs_ex, NULL);
Py_DECREF(getnewargs_ex);
if (newargs == NULL) {
return -1;
}
if (!PyTuple_Check(newargs)) {
PyErr_Format(PyExc_TypeError,
"__getnewargs_ex__ should return a tuple, "
"not '%.200s'", Py_TYPE(newargs)->tp_name);
Py_DECREF(newargs);
return -1;
}
if (Py_SIZE(newargs) != 2) {
PyErr_Format(PyExc_ValueError,
"__getnewargs_ex__ should return a tuple of "
"length 2, not %zd", Py_SIZE(newargs));
Py_DECREF(newargs);
return -1;
}
*args = PyTuple_GET_ITEM(newargs, 0);
Py_INCREF(*args);
*kwargs = PyTuple_GET_ITEM(newargs, 1);
Py_INCREF(*kwargs);
Py_DECREF(newargs);
/* XXX We should perhaps allow None to be passed here. */
if (!PyTuple_Check(*args)) {
PyErr_Format(PyExc_TypeError,
"first item of the tuple returned by "
"__getnewargs_ex__ must be a tuple, not '%.200s'",
Py_TYPE(*args)->tp_name);
Py_CLEAR(*args);
Py_CLEAR(*kwargs);
return -1;
}
if (!PyDict_Check(*kwargs)) {
PyErr_Format(PyExc_TypeError,
"second item of the tuple returned by "
"__getnewargs_ex__ must be a dict, not '%.200s'",
Py_TYPE(*kwargs)->tp_name);
Py_CLEAR(*args);
Py_CLEAR(*kwargs);
return -1;
}
return 0;
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1;
}
else {
PyObject **dict;
PyErr_Clear();
dict = _PyObject_GetDictPtr(obj);
if (dict && *dict)
state = *dict;
else
state = Py_None;
Py_INCREF(state);
names = slotnames(cls);
if (names == NULL)
goto end;
if (names != Py_None && PyList_GET_SIZE(names) > 0) {
assert(PyList_Check(names));
slots = PyDict_New();
if (slots == NULL)
goto end;
n = 0;
/* Can't pre-compute the list size; the list
is stored on the class so accessible to other
threads, which may be run by DECREF */
for (i = 0; i < PyList_GET_SIZE(names); i++) {
PyObject *name, *value;
name = PyList_GET_ITEM(names, i);
value = PyObject_GetAttr(obj, name);
if (value == NULL)
}
/* The object does not have __getnewargs_ex__ so we fallback on using
__getnewargs__ instead. */
getnewargs = _PyObject_GetAttrId(obj, &PyId___getnewargs__);
if (getnewargs != NULL) {
*args = PyObject_CallObject(getnewargs, NULL);
Py_DECREF(getnewargs);
if (*args == NULL) {
return -1;
}
if (!PyTuple_Check(*args)) {
PyErr_Format(PyExc_TypeError,
"__getnewargs__ should return a tuple, "
"not '%.200s'", Py_TYPE(*args)->tp_name);
Py_CLEAR(*args);
return -1;
}
*kwargs = NULL;
return 0;
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
return -1;
}
PyErr_Clear();
else {
int err = PyDict_SetItem(slots, name,
value);
Py_DECREF(value);
if (err)
goto end;
n++;
}
}
if (n) {
state = Py_BuildValue("(NO)", state, slots);
if (state == NULL)
goto end;
}
}
/* The object does not have __getnewargs_ex__ and __getnewargs__. This may
means __new__ does not takes any arguments on this object, or that the
object does not implement the reduce protocol for pickling or
copying. */
*args = NULL;
*kwargs = NULL;
return 0;
}
Py_LOCAL(int)
_PyObject_GetItemsIter(PyObject *obj, PyObject **listitems,
PyObject **dictitems)
{
if (listitems == NULL || dictitems == NULL) {
PyErr_BadInternalCall();
return -1;
}
if (!PyList_Check(obj)) {
listitems = Py_None;
Py_INCREF(listitems);
*listitems = Py_None;
Py_INCREF(*listitems);
}
else {
listitems = PyObject_GetIter(obj);
*listitems = PyObject_GetIter(obj);
if (listitems == NULL)
goto end;
return -1;
}
if (!PyDict_Check(obj)) {
dictitems = Py_None;
Py_INCREF(dictitems);
*dictitems = Py_None;
Py_INCREF(*dictitems);
}
else {
PyObject *items;
_Py_IDENTIFIER(items);
PyObject *items = _PyObject_CallMethodId(obj, &PyId_items, "");
if (items == NULL)
goto end;
dictitems = PyObject_GetIter(items);
Py_DECREF(items);
if (dictitems == NULL)
goto end;
items = _PyObject_CallMethodIdObjArgs(obj, &PyId_items, NULL);
if (items == NULL) {
Py_CLEAR(*listitems);
return -1;
}
*dictitems = PyObject_GetIter(items);
Py_DECREF(items);
if (*dictitems == NULL) {
Py_CLEAR(*listitems);
return -1;
}
}
assert(*listitems != NULL && *dictitems != NULL);
return 0;
}
static PyObject *
reduce_4(PyObject *obj)
{
PyObject *args = NULL, *kwargs = NULL;
PyObject *copyreg;
PyObject *newobj, *newargs, *state, *listitems, *dictitems;
PyObject *result;
_Py_IDENTIFIER(__newobj_ex__);
if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) {
return NULL;
}
if (args == NULL) {
args = PyTuple_New(0);
if (args == NULL)
return NULL;
}
if (kwargs == NULL) {
kwargs = PyDict_New();
if (kwargs == NULL)
return NULL;
}
copyreg = import_copyreg();
if (copyreg == NULL) {
Py_DECREF(args);
Py_DECREF(kwargs);
return NULL;
}
newobj = _PyObject_GetAttrId(copyreg, &PyId___newobj_ex__);
Py_DECREF(copyreg);
if (newobj == NULL) {
Py_DECREF(args);
Py_DECREF(kwargs);
return NULL;
}
newargs = PyTuple_Pack(3, Py_TYPE(obj), args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);
if (newargs == NULL) {
Py_DECREF(newobj);
return NULL;
}
state = _PyObject_GetState(obj);
if (state == NULL) {
Py_DECREF(newobj);
Py_DECREF(newargs);
return NULL;
}
if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0) {
Py_DECREF(newobj);
Py_DECREF(newargs);
Py_DECREF(state);
return NULL;
}
result = PyTuple_Pack(5, newobj, newargs, state, listitems, dictitems);
Py_DECREF(newobj);
Py_DECREF(newargs);
Py_DECREF(state);
Py_DECREF(listitems);
Py_DECREF(dictitems);
return result;
}
static PyObject *
reduce_2(PyObject *obj)
{
PyObject *cls;
PyObject *args = NULL, *args2 = NULL, *kwargs = NULL;
PyObject *state = NULL, *listitems = NULL, *dictitems = NULL;
PyObject *copyreg = NULL, *newobj = NULL, *res = NULL;
Py_ssize_t i, n;
_Py_IDENTIFIER(__newobj__);
if (_PyObject_GetNewArguments(obj, &args, &kwargs) < 0) {
return NULL;
}
if (args == NULL) {
assert(kwargs == NULL);
args = PyTuple_New(0);
if (args == NULL) {
return NULL;
}
}
else if (kwargs != NULL) {
if (PyDict_Size(kwargs) > 0) {
PyErr_SetString(PyExc_ValueError,
"must use protocol 4 or greater to copy this "
"object; since __getnewargs_ex__ returned "
"keyword arguments.");
Py_DECREF(args);
Py_DECREF(kwargs);
return NULL;
}
Py_CLEAR(kwargs);
}
state = _PyObject_GetState(obj);
if (state == NULL)
goto end;
if (_PyObject_GetItemsIter(obj, &listitems, &dictitems) < 0)
goto end;
copyreg = import_copyreg();
if (copyreg == NULL)
@ -3560,6 +3839,7 @@ reduce_2(PyObject *obj)
args2 = PyTuple_New(n+1);
if (args2 == NULL)
goto end;
cls = (PyObject *) Py_TYPE(obj);
Py_INCREF(cls);
PyTuple_SET_ITEM(args2, 0, cls);
for (i = 0; i < n; i++) {
@ -3573,9 +3853,7 @@ reduce_2(PyObject *obj)
end:
Py_XDECREF(args);
Py_XDECREF(args2);
Py_XDECREF(slots);
Py_XDECREF(state);
Py_XDECREF(names);
Py_XDECREF(listitems);
Py_XDECREF(dictitems);
Py_XDECREF(copyreg);
@ -3603,7 +3881,9 @@ _common_reduce(PyObject *self, int proto)
{
PyObject *copyreg, *res;
if (proto >= 2)
if (proto >= 4)
return reduce_4(self);
else if (proto >= 2)
return reduce_2(self);
copyreg = import_copyreg();