[3.13] gh-122311: Add more tests for pickle (GH-122376) (GH-122377)

(cherry picked from commit bc93923a2d)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Miss Islington (bot) 2024-07-28 10:56:49 +02:00 committed by GitHub
parent c9b7e2d097
commit d113359341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 660 additions and 128 deletions

View File

@ -144,6 +144,14 @@ class E(C):
def __getinitargs__(self): def __getinitargs__(self):
return () return ()
import __main__
__main__.C = C
C.__module__ = "__main__"
__main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
# Simple mutable object. # Simple mutable object.
class Object: class Object:
pass pass
@ -157,14 +165,6 @@ class K:
# Shouldn't support the recursion itself # Shouldn't support the recursion itself
return K, (self.value,) return K, (self.value,)
import __main__
__main__.C = C
C.__module__ = "__main__"
__main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
class myint(int): class myint(int):
def __init__(self, x): def __init__(self, x):
self.str = str(x) self.str = str(x)
@ -1179,6 +1179,124 @@ class AbstractUnpickleTests:
self.assertIs(type(unpickled), collections.UserDict) self.assertIs(type(unpickled), collections.UserDict)
self.assertEqual(unpickled, collections.UserDict({1: 2})) self.assertEqual(unpickled, collections.UserDict({1: 2}))
def test_load_global(self):
self.assertIs(self.loads(b'cbuiltins\nstr\n.'), str)
self.assertIs(self.loads(b'cmath\nlog\n.'), math.log)
self.assertIs(self.loads(b'cos.path\njoin\n.'), os.path.join)
self.assertIs(self.loads(b'\x80\x04cbuiltins\nstr.upper\n.'), str.upper)
with support.swap_item(sys.modules, 'mödule', types.SimpleNamespace(glöbal=42)):
self.assertEqual(self.loads(b'\x80\x04cm\xc3\xb6dule\ngl\xc3\xb6bal\n.'), 42)
self.assertRaises(UnicodeDecodeError, self.loads, b'c\xff\nlog\n.')
self.assertRaises(UnicodeDecodeError, self.loads, b'cmath\n\xff\n.')
self.assertRaises(self.truncated_errors, self.loads, b'c\nlog\n.')
self.assertRaises(self.truncated_errors, self.loads, b'cmath\n\n.')
self.assertRaises(self.truncated_errors, self.loads, b'\x80\x04cmath\n\n.')
def test_load_stack_global(self):
self.assertIs(self.loads(b'\x8c\x08builtins\x8c\x03str\x93.'), str)
self.assertIs(self.loads(b'\x8c\x04math\x8c\x03log\x93.'), math.log)
self.assertIs(self.loads(b'\x8c\x07os.path\x8c\x04join\x93.'),
os.path.join)
self.assertIs(self.loads(b'\x80\x04\x8c\x08builtins\x8c\x09str.upper\x93.'),
str.upper)
with support.swap_item(sys.modules, 'mödule', types.SimpleNamespace(glöbal=42)):
self.assertEqual(self.loads(b'\x80\x04\x8c\x07m\xc3\xb6dule\x8c\x07gl\xc3\xb6bal\x93.'), 42)
self.assertRaises(UnicodeDecodeError, self.loads, b'\x8c\x01\xff\x8c\x03log\x93.')
self.assertRaises(UnicodeDecodeError, self.loads, b'\x8c\x04math\x8c\x01\xff\x93.')
self.assertRaises(ValueError, self.loads, b'\x8c\x00\x8c\x03log\x93.')
self.assertRaises(AttributeError, self.loads, b'\x8c\x04math\x8c\x00\x93.')
self.assertRaises(AttributeError, self.loads, b'\x80\x04\x8c\x04math\x8c\x00\x93.')
self.assertRaises(pickle.UnpicklingError, self.loads, b'N\x8c\x03log\x93.')
self.assertRaises(pickle.UnpicklingError, self.loads, b'\x8c\x04mathN\x93.')
self.assertRaises(pickle.UnpicklingError, self.loads, b'\x80\x04\x8c\x04mathN\x93.')
def test_find_class(self):
unpickler = self.unpickler(io.BytesIO())
unpickler_nofix = self.unpickler(io.BytesIO(), fix_imports=False)
unpickler4 = self.unpickler(io.BytesIO(b'\x80\x04N.'))
unpickler4.load()
self.assertIs(unpickler.find_class('__builtin__', 'str'), str)
self.assertRaises(ModuleNotFoundError,
unpickler_nofix.find_class, '__builtin__', 'str')
self.assertIs(unpickler.find_class('builtins', 'str'), str)
self.assertIs(unpickler_nofix.find_class('builtins', 'str'), str)
self.assertIs(unpickler.find_class('math', 'log'), math.log)
self.assertIs(unpickler.find_class('os.path', 'join'), os.path.join)
self.assertIs(unpickler.find_class('os.path', 'join'), os.path.join)
self.assertIs(unpickler4.find_class('builtins', 'str.upper'), str.upper)
with self.assertRaises(AttributeError):
unpickler.find_class('builtins', 'str.upper')
with self.assertRaises(AttributeError):
unpickler.find_class('math', 'spam')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', 'spam')
with self.assertRaises(AttributeError):
unpickler.find_class('math', 'log.spam')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', 'log.spam')
with self.assertRaises(AttributeError):
unpickler.find_class('math', 'log.<locals>.spam')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', 'log.<locals>.spam')
with self.assertRaises(AttributeError):
unpickler.find_class('math', '')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', '')
self.assertRaises(ModuleNotFoundError, unpickler.find_class, 'spam', 'log')
self.assertRaises(ValueError, unpickler.find_class, '', 'log')
self.assertRaises(TypeError, unpickler.find_class, None, 'log')
self.assertRaises(TypeError, unpickler.find_class, 'math', None)
self.assertRaises((TypeError, AttributeError), unpickler4.find_class, 'math', None)
def test_custom_find_class(self):
def loads(data):
class Unpickler(self.unpickler):
def find_class(self, module_name, global_name):
return (module_name, global_name)
return Unpickler(io.BytesIO(data)).load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def loads(data):
class Unpickler(self.unpickler):
@staticmethod
def find_class(module_name, global_name):
return (module_name, global_name)
return Unpickler(io.BytesIO(data)).load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def loads(data):
class Unpickler(self.unpickler):
@classmethod
def find_class(cls, module_name, global_name):
return (module_name, global_name)
return Unpickler(io.BytesIO(data)).load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def loads(data):
class Unpickler(self.unpickler):
pass
def find_class(module_name, global_name):
return (module_name, global_name)
unpickler = Unpickler(io.BytesIO(data))
unpickler.find_class = find_class
return unpickler.load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def test_bad_reduce(self): def test_bad_reduce(self):
self.assertEqual(self.loads(b'cbuiltins\nint\n)R.'), 0) self.assertEqual(self.loads(b'cbuiltins\nint\n)R.'), 0)
self.check_unpickling_error(TypeError, b'N)R.') self.check_unpickling_error(TypeError, b'N)R.')
@ -1443,6 +1561,474 @@ class AbstractUnpickleTests:
[ToBeUnpickled] * 2) [ToBeUnpickled] * 2)
class AbstractPicklingErrorTests:
# Subclass must define self.dumps, self.pickler.
def test_bad_reduce_result(self):
obj = REX([print, ()])
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj = REX((print,))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj = REX((print, (), None, None, None, None, None))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_bad_reconstructor(self):
obj = REX((42, ()))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_reconstructor(self):
obj = REX((UnpickleableCallable(), ()))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_reconstructor_args(self):
obj = REX((print, []))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_reconstructor_args(self):
obj = REX((print, (1, 2, UNPICKLEABLE)))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_newobj_args(self):
obj = REX((copyreg.__newobj__, ()))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((IndexError, pickle.PicklingError)) as cm:
self.dumps(obj, proto)
obj = REX((copyreg.__newobj__, [REX]))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((IndexError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_bad_newobj_class(self):
obj = REX((copyreg.__newobj__, (NoNew(),)))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_wrong_newobj_class(self):
obj = REX((copyreg.__newobj__, (str,)))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_newobj_class(self):
class LocalREX(REX): pass
obj = LocalREX((copyreg.__newobj__, (LocalREX,)))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((pickle.PicklingError, AttributeError)):
self.dumps(obj, proto)
def test_unpickleable_newobj_args(self):
obj = REX((copyreg.__newobj__, (REX, 1, 2, UNPICKLEABLE)))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_newobj_ex_args(self):
obj = REX((copyreg.__newobj_ex__, ()))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((ValueError, pickle.PicklingError)):
self.dumps(obj, proto)
obj = REX((copyreg.__newobj_ex__, 42))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj = REX((copyreg.__newobj_ex__, (REX, 42, {})))
is_py = self.pickler is pickle._Pickler
for proto in protocols[2:4] if is_py else protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
obj = REX((copyreg.__newobj_ex__, (REX, (), [])))
for proto in protocols[2:4] if is_py else protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_bad_newobj_ex__class(self):
obj = REX((copyreg.__newobj_ex__, (NoNew(), (), {})))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_wrong_newobj_ex_class(self):
if self.pickler is not pickle._Pickler:
self.skipTest('only verified in the Python implementation')
obj = REX((copyreg.__newobj_ex__, (str, (), {})))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_newobj_ex_class(self):
class LocalREX(REX): pass
obj = LocalREX((copyreg.__newobj_ex__, (LocalREX, (), {})))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((pickle.PicklingError, AttributeError)):
self.dumps(obj, proto)
def test_unpickleable_newobj_ex_args(self):
obj = REX((copyreg.__newobj_ex__, (REX, (1, 2, UNPICKLEABLE), {})))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_newobj_ex_kwargs(self):
obj = REX((copyreg.__newobj_ex__, (REX, (), {'a': UNPICKLEABLE})))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_state(self):
obj = REX_state(UNPICKLEABLE)
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_state_setter(self):
if self.pickler is pickle._Pickler:
self.skipTest('only verified in the C implementation')
obj = REX((print, (), 'state', None, None, 42))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_state_setter(self):
obj = REX((print, (), 'state', None, None, UnpickleableCallable()))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_state_with_state_setter(self):
obj = REX((print, (), UNPICKLEABLE, None, None, print))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_object_list_items(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
obj = REX((list, (), None, 42))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
if self.pickler is not pickle._Pickler:
# Python implementation is less strict and also accepts iterables.
obj = REX((list, (), None, []))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_unpickleable_object_list_items(self):
obj = REX_six([1, 2, UNPICKLEABLE])
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_object_dict_items(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
obj = REX((dict, (), None, None, 42))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
for proto in protocols:
obj = REX((dict, (), None, None, iter([('a',)])))
with self.subTest(proto=proto):
with self.assertRaises((ValueError, TypeError)):
self.dumps(obj, proto)
if self.pickler is not pickle._Pickler:
# Python implementation is less strict and also accepts iterables.
obj = REX((dict, (), None, None, []))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_unpickleable_object_dict_items(self):
obj = REX_seven({'a': UNPICKLEABLE})
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_list_items(self):
obj = [1, [2, 3, UNPICKLEABLE]]
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
for n in [0, 1, 1000, 1005]:
obj = [*range(n), UNPICKLEABLE]
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_tuple_items(self):
obj = (1, (2, 3, UNPICKLEABLE))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
obj = (*range(10), UNPICKLEABLE)
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_dict_items(self):
obj = {'a': {'b': UNPICKLEABLE}}
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
for n in [0, 1, 1000, 1005]:
obj = dict.fromkeys(range(n))
obj['a'] = UNPICKLEABLE
for proto in protocols:
with self.subTest(proto=proto, n=n):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_set_items(self):
obj = {UNPICKLEABLE}
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_frozenset_items(self):
obj = frozenset({frozenset({UNPICKLEABLE})})
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_global_lookup_error(self):
# Global name does not exist
obj = REX('spam')
obj.__module__ = __name__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = 'nonexisting'
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = ''
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((ValueError, pickle.PicklingError)):
self.dumps(obj, proto)
obj.__module__ = None
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_nonencodable_global_name_error(self):
for proto in protocols[:4]:
with self.subTest(proto=proto):
name = 'nonascii\xff' if proto < 3 else 'nonencodable\udbff'
obj = REX(name)
obj.__module__ = __name__
with support.swap_item(globals(), name, obj):
with self.assertRaises((UnicodeEncodeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_nonencodable_module_name_error(self):
for proto in protocols[:4]:
with self.subTest(proto=proto):
name = 'nonascii\xff' if proto < 3 else 'nonencodable\udbff'
obj = REX('test')
obj.__module__ = name
mod = types.SimpleNamespace(test=obj)
with support.swap_item(sys.modules, name, mod):
with self.assertRaises((UnicodeEncodeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_nested_lookup_error(self):
# Nested name does not exist
obj = REX('AbstractPickleTests.spam')
obj.__module__ = __name__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = None
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_wrong_object_lookup_error(self):
# Name is bound to different object
obj = REX('AbstractPickleTests')
obj.__module__ = __name__
AbstractPickleTests.ham = []
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = None
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_local_lookup_error(self):
# Test that whichmodule() errors out cleanly when looking up
# an assumed globally-reachable object fails.
def f():
pass
# Since the function is local, lookup will fail
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((AttributeError, pickle.PicklingError)):
self.dumps(f, proto)
# Same without a __module__ attribute (exercises a different path
# in _pickle.c).
del f.__module__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((AttributeError, pickle.PicklingError)):
self.dumps(f, proto)
# Yet a different path.
f.__name__ = f.__qualname__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((AttributeError, pickle.PicklingError)):
self.dumps(f, proto)
def test_reduce_ex_None(self):
c = REX_None()
with self.assertRaises(TypeError):
self.dumps(c)
def test_reduce_None(self):
c = R_None()
with self.assertRaises(TypeError):
self.dumps(c)
@no_tracing
def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
for proto in range(2):
with support.infinite_recursion(25):
self.assertRaises(RuntimeError, self.dumps, x, proto)
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(x, proto)
def test_picklebuffer_error(self):
# PickleBuffer forbidden with protocol < 5
pb = pickle.PickleBuffer(b"foobar")
for proto in range(0, 5):
with self.subTest(proto=proto):
with self.assertRaises(pickle.PickleError):
self.dumps(pb, proto)
def test_non_continuous_buffer(self):
if self.pickler is pickle._Pickler:
self.skipTest('CRASHES (see gh-122306)')
for proto in protocols[5:]:
with self.subTest(proto=proto):
pb = pickle.PickleBuffer(memoryview(b"foobar")[::2])
with self.assertRaises(pickle.PicklingError):
self.dumps(pb, proto)
def test_buffer_callback_error(self):
def buffer_callback(buffers):
raise CustomError
pb = pickle.PickleBuffer(b"foobar")
with self.assertRaises(CustomError):
self.dumps(pb, 5, buffer_callback=buffer_callback)
def test_evil_pickler_mutating_collection(self):
# https://github.com/python/cpython/issues/92930
global Clearer
class Clearer:
pass
def check(collection):
class EvilPickler(self.pickler):
def persistent_id(self, obj):
if isinstance(obj, Clearer):
collection.clear()
return None
pickler = EvilPickler(io.BytesIO(), proto)
try:
pickler.dump(collection)
except RuntimeError as e:
expected = "changed size during iteration"
self.assertIn(expected, str(e))
for proto in protocols:
check([Clearer()])
check([Clearer(), Clearer()])
check({Clearer()})
check({Clearer(), Clearer()})
check({Clearer(): 1})
check({Clearer(): 1, Clearer(): 2})
check({1: Clearer(), 2: Clearer()})
class AbstractPickleTests: class AbstractPickleTests:
# Subclass must define self.dumps, self.loads. # Subclass must define self.dumps, self.loads.
@ -2453,55 +3039,12 @@ class AbstractPickleTests:
y = self.loads(s) y = self.loads(s)
self.assertEqual(y._reduce_called, 1) self.assertEqual(y._reduce_called, 1)
def test_reduce_ex_None(self):
c = REX_None()
with self.assertRaises(TypeError):
self.dumps(c)
def test_reduce_None(self):
c = R_None()
with self.assertRaises(TypeError):
self.dumps(c)
def test_pickle_setstate_None(self): def test_pickle_setstate_None(self):
c = C_None_setstate() c = C_None_setstate()
p = self.dumps(c) p = self.dumps(c)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
self.loads(p) self.loads(p)
@no_tracing
def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
for proto in range(2):
with support.infinite_recursion(25):
self.assertRaises(RuntimeError, self.dumps, x, proto)
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(x, proto)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
class C(object):
def __reduce__(self):
# 4th item is not an iterator
return list, (), None, [], None
class D(object):
def __reduce__(self):
# 5th item is not an iterator
return dict, (), None, None, []
# Python implementation is less strict and also accepts iterables.
for proto in protocols:
try:
self.dumps(C(), proto)
except pickle.PicklingError:
pass
try:
self.dumps(D(), proto)
except pickle.PicklingError:
pass
def test_many_puts_and_gets(self): def test_many_puts_and_gets(self):
# Test that internal data structures correctly deal with lots of # Test that internal data structures correctly deal with lots of
# puts/gets. # puts/gets.
@ -2950,27 +3493,6 @@ class AbstractPickleTests:
self.assertIn(('c%s\n%s' % (mod, name)).encode(), pickled) self.assertIn(('c%s\n%s' % (mod, name)).encode(), pickled)
self.assertIs(type(self.loads(pickled)), type(val)) self.assertIs(type(self.loads(pickled)), type(val))
def test_local_lookup_error(self):
# Test that whichmodule() errors out cleanly when looking up
# an assumed globally-reachable object fails.
def f():
pass
# Since the function is local, lookup will fail
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
# Same without a __module__ attribute (exercises a different path
# in _pickle.c).
del f.__module__
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
# Yet a different path.
f.__name__ = f.__qualname__
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
# #
# PEP 574 tests below # PEP 574 tests below
# #
@ -3081,20 +3603,6 @@ class AbstractPickleTests:
self.assertIs(type(new), type(obj)) self.assertIs(type(new), type(obj))
self.assertEqual(new, obj) self.assertEqual(new, obj)
def test_picklebuffer_error(self):
# PickleBuffer forbidden with protocol < 5
pb = pickle.PickleBuffer(b"foobar")
for proto in range(0, 5):
with self.assertRaises(pickle.PickleError):
self.dumps(pb, proto)
def test_buffer_callback_error(self):
def buffer_callback(buffers):
1/0
pb = pickle.PickleBuffer(b"foobar")
with self.assertRaises(ZeroDivisionError):
self.dumps(pb, 5, buffer_callback=buffer_callback)
def test_buffers_error(self): def test_buffers_error(self):
pb = pickle.PickleBuffer(b"foobar") pb = pickle.PickleBuffer(b"foobar")
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1): for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
@ -3186,37 +3694,6 @@ class AbstractPickleTests:
expected = "changed size during iteration" expected = "changed size during iteration"
self.assertIn(expected, str(e)) self.assertIn(expected, str(e))
def test_evil_pickler_mutating_collection(self):
# https://github.com/python/cpython/issues/92930
if not hasattr(self, "pickler"):
raise self.skipTest(f"{type(self)} has no associated pickler type")
global Clearer
class Clearer:
pass
def check(collection):
class EvilPickler(self.pickler):
def persistent_id(self, obj):
if isinstance(obj, Clearer):
collection.clear()
return None
pickler = EvilPickler(io.BytesIO(), proto)
try:
pickler.dump(collection)
except RuntimeError as e:
expected = "changed size during iteration"
self.assertIn(expected, str(e))
for proto in protocols:
check([Clearer()])
check([Clearer(), Clearer()])
check({Clearer()})
check({Clearer(), Clearer()})
check({Clearer(): 1})
check({Clearer(): 1, Clearer(): 2})
check({1: Clearer(), 2: Clearer()})
class BigmemPickleTests: class BigmemPickleTests:
@ -3347,6 +3824,18 @@ class BigmemPickleTests:
# Test classes for reduce_ex # Test classes for reduce_ex
class R:
def __init__(self, reduce=None):
self.reduce = reduce
def __reduce__(self, proto):
return self.reduce
class REX:
def __init__(self, reduce_ex=None):
self.reduce_ex = reduce_ex
def __reduce_ex__(self, proto):
return self.reduce_ex
class REX_one(object): class REX_one(object):
"""No __reduce_ex__ here, but inheriting it from object""" """No __reduce_ex__ here, but inheriting it from object"""
_reduce_called = 0 _reduce_called = 0
@ -3437,6 +3926,19 @@ class C_None_setstate:
__setstate__ = None __setstate__ = None
class CustomError(Exception):
pass
class Unpickleable:
def __reduce__(self):
raise CustomError
UNPICKLEABLE = Unpickleable()
class UnpickleableCallable(Unpickleable):
def __call__(self, *args, **kwargs):
pass
# Test classes for newobj # Test classes for newobj
@ -3505,6 +4007,12 @@ class BadGetattr:
def __getattr__(self, key): def __getattr__(self, key):
self.foo self.foo
class NoNew:
def __getattribute__(self, name):
if name == '__new__':
raise AttributeError
return super().__getattribute__(name)
class AbstractPickleModuleTests: class AbstractPickleModuleTests:
@ -3577,7 +4085,7 @@ class AbstractPickleModuleTests:
raise OSError raise OSError
@property @property
def bad_property(self): def bad_property(self):
1/0 raise CustomError
# File without read and readline # File without read and readline
class F: class F:
@ -3598,23 +4106,23 @@ class AbstractPickleModuleTests:
class F: class F:
read = bad_property read = bad_property
readline = raises_oserror readline = raises_oserror
self.assertRaises(ZeroDivisionError, self.Unpickler, F()) self.assertRaises(CustomError, self.Unpickler, F())
# File with bad readline # File with bad readline
class F: class F:
readline = bad_property readline = bad_property
read = raises_oserror read = raises_oserror
self.assertRaises(ZeroDivisionError, self.Unpickler, F()) self.assertRaises(CustomError, self.Unpickler, F())
# File with bad readline, no read # File with bad readline, no read
class F: class F:
readline = bad_property readline = bad_property
self.assertRaises(ZeroDivisionError, self.Unpickler, F()) self.assertRaises(CustomError, self.Unpickler, F())
# File with bad read, no readline # File with bad read, no readline
class F: class F:
read = bad_property read = bad_property
self.assertRaises((AttributeError, ZeroDivisionError), self.Unpickler, F()) self.assertRaises((AttributeError, CustomError), self.Unpickler, F())
# File with bad peek # File with bad peek
class F: class F:
@ -3623,7 +4131,7 @@ class AbstractPickleModuleTests:
readline = raises_oserror readline = raises_oserror
try: try:
self.Unpickler(F()) self.Unpickler(F())
except ZeroDivisionError: except CustomError:
pass pass
# File with bad readinto # File with bad readinto
@ -3633,7 +4141,7 @@ class AbstractPickleModuleTests:
readline = raises_oserror readline = raises_oserror
try: try:
self.Unpickler(F()) self.Unpickler(F())
except ZeroDivisionError: except CustomError:
pass pass
def test_pickler_bad_file(self): def test_pickler_bad_file(self):
@ -3646,8 +4154,8 @@ class AbstractPickleModuleTests:
class F: class F:
@property @property
def write(self): def write(self):
1/0 raise CustomError
self.assertRaises(ZeroDivisionError, self.Pickler, F()) self.assertRaises(CustomError, self.Pickler, F())
def check_dumps_loads_oob_buffers(self, dumps, loads): def check_dumps_loads_oob_buffers(self, dumps, loads):
# No need to do the full gamut of tests here, just enough to # No need to do the full gamut of tests here, just enough to
@ -3755,9 +4263,15 @@ class AbstractIdentityPersistentPicklerTests:
def test_protocol0_is_ascii_only(self): def test_protocol0_is_ascii_only(self):
non_ascii_str = "\N{EMPTY SET}" non_ascii_str = "\N{EMPTY SET}"
self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0) with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(non_ascii_str, 0)
self.assertEqual(str(cm.exception),
'persistent IDs in protocol 0 must be ASCII strings')
pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.' pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
self.assertRaises(pickle.UnpicklingError, self.loads, pickled) with self.assertRaises(pickle.UnpicklingError) as cm:
self.loads(pickled)
self.assertEqual(str(cm.exception),
'persistent IDs in protocol 0 must be ASCII strings')
class AbstractPicklerUnpicklerObjectTests: class AbstractPicklerUnpicklerObjectTests:

View File

@ -16,6 +16,7 @@ from test.support import import_helper
from test.pickletester import AbstractHookTests from test.pickletester import AbstractHookTests
from test.pickletester import AbstractUnpickleTests from test.pickletester import AbstractUnpickleTests
from test.pickletester import AbstractPicklingErrorTests
from test.pickletester import AbstractPickleTests from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests from test.pickletester import AbstractPersistentPicklerTests
@ -55,6 +56,18 @@ class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
return u.load() return u.load()
class PyPicklingErrorTests(AbstractPicklingErrorTests, unittest.TestCase):
pickler = pickle._Pickler
def dumps(self, arg, proto=None, **kwargs):
f = io.BytesIO()
p = self.pickler(f, proto, **kwargs)
p.dump(arg)
f.seek(0)
return bytes(f.read())
class PyPicklerTests(AbstractPickleTests, unittest.TestCase): class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
pickler = pickle._Pickler pickler = pickle._Pickler
@ -88,6 +101,8 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
return pickle.loads(buf, **kwds) return pickle.loads(buf, **kwds)
test_framed_write_sizes_with_delayed_writer = None test_framed_write_sizes_with_delayed_writer = None
test_find_class = None
test_custom_find_class = None
class PersistentPicklerUnpicklerMixin(object): class PersistentPicklerUnpicklerMixin(object):
@ -267,6 +282,9 @@ if has_c_implementation:
bad_stack_errors = (pickle.UnpicklingError,) bad_stack_errors = (pickle.UnpicklingError,)
truncated_errors = (pickle.UnpicklingError,) truncated_errors = (pickle.UnpicklingError,)
class CPicklingErrorTests(PyPicklingErrorTests):
pickler = _pickle.Pickler
class CPicklerTests(PyPicklerTests): class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler pickler = _pickle.Pickler
unpickler = _pickle.Unpickler unpickler = _pickle.Unpickler