[3.13] gh-122311: Add more tests for pickle (GH-122376) (GH-122377)

(cherry picked from commit bc93923a2d)

Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Miss Islington (bot) 2024-07-28 10:56:49 +02:00 committed by GitHub
parent c9b7e2d097
commit d113359341
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 660 additions and 128 deletions

View File

@ -144,6 +144,14 @@ class E(C):
def __getinitargs__(self):
return ()
import __main__
__main__.C = C
C.__module__ = "__main__"
__main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
# Simple mutable object.
class Object:
pass
@ -157,14 +165,6 @@ class K:
# Shouldn't support the recursion itself
return K, (self.value,)
import __main__
__main__.C = C
C.__module__ = "__main__"
__main__.D = D
D.__module__ = "__main__"
__main__.E = E
E.__module__ = "__main__"
class myint(int):
def __init__(self, x):
self.str = str(x)
@ -1179,6 +1179,124 @@ class AbstractUnpickleTests:
self.assertIs(type(unpickled), collections.UserDict)
self.assertEqual(unpickled, collections.UserDict({1: 2}))
def test_load_global(self):
self.assertIs(self.loads(b'cbuiltins\nstr\n.'), str)
self.assertIs(self.loads(b'cmath\nlog\n.'), math.log)
self.assertIs(self.loads(b'cos.path\njoin\n.'), os.path.join)
self.assertIs(self.loads(b'\x80\x04cbuiltins\nstr.upper\n.'), str.upper)
with support.swap_item(sys.modules, 'mödule', types.SimpleNamespace(glöbal=42)):
self.assertEqual(self.loads(b'\x80\x04cm\xc3\xb6dule\ngl\xc3\xb6bal\n.'), 42)
self.assertRaises(UnicodeDecodeError, self.loads, b'c\xff\nlog\n.')
self.assertRaises(UnicodeDecodeError, self.loads, b'cmath\n\xff\n.')
self.assertRaises(self.truncated_errors, self.loads, b'c\nlog\n.')
self.assertRaises(self.truncated_errors, self.loads, b'cmath\n\n.')
self.assertRaises(self.truncated_errors, self.loads, b'\x80\x04cmath\n\n.')
def test_load_stack_global(self):
self.assertIs(self.loads(b'\x8c\x08builtins\x8c\x03str\x93.'), str)
self.assertIs(self.loads(b'\x8c\x04math\x8c\x03log\x93.'), math.log)
self.assertIs(self.loads(b'\x8c\x07os.path\x8c\x04join\x93.'),
os.path.join)
self.assertIs(self.loads(b'\x80\x04\x8c\x08builtins\x8c\x09str.upper\x93.'),
str.upper)
with support.swap_item(sys.modules, 'mödule', types.SimpleNamespace(glöbal=42)):
self.assertEqual(self.loads(b'\x80\x04\x8c\x07m\xc3\xb6dule\x8c\x07gl\xc3\xb6bal\x93.'), 42)
self.assertRaises(UnicodeDecodeError, self.loads, b'\x8c\x01\xff\x8c\x03log\x93.')
self.assertRaises(UnicodeDecodeError, self.loads, b'\x8c\x04math\x8c\x01\xff\x93.')
self.assertRaises(ValueError, self.loads, b'\x8c\x00\x8c\x03log\x93.')
self.assertRaises(AttributeError, self.loads, b'\x8c\x04math\x8c\x00\x93.')
self.assertRaises(AttributeError, self.loads, b'\x80\x04\x8c\x04math\x8c\x00\x93.')
self.assertRaises(pickle.UnpicklingError, self.loads, b'N\x8c\x03log\x93.')
self.assertRaises(pickle.UnpicklingError, self.loads, b'\x8c\x04mathN\x93.')
self.assertRaises(pickle.UnpicklingError, self.loads, b'\x80\x04\x8c\x04mathN\x93.')
def test_find_class(self):
unpickler = self.unpickler(io.BytesIO())
unpickler_nofix = self.unpickler(io.BytesIO(), fix_imports=False)
unpickler4 = self.unpickler(io.BytesIO(b'\x80\x04N.'))
unpickler4.load()
self.assertIs(unpickler.find_class('__builtin__', 'str'), str)
self.assertRaises(ModuleNotFoundError,
unpickler_nofix.find_class, '__builtin__', 'str')
self.assertIs(unpickler.find_class('builtins', 'str'), str)
self.assertIs(unpickler_nofix.find_class('builtins', 'str'), str)
self.assertIs(unpickler.find_class('math', 'log'), math.log)
self.assertIs(unpickler.find_class('os.path', 'join'), os.path.join)
self.assertIs(unpickler.find_class('os.path', 'join'), os.path.join)
self.assertIs(unpickler4.find_class('builtins', 'str.upper'), str.upper)
with self.assertRaises(AttributeError):
unpickler.find_class('builtins', 'str.upper')
with self.assertRaises(AttributeError):
unpickler.find_class('math', 'spam')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', 'spam')
with self.assertRaises(AttributeError):
unpickler.find_class('math', 'log.spam')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', 'log.spam')
with self.assertRaises(AttributeError):
unpickler.find_class('math', 'log.<locals>.spam')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', 'log.<locals>.spam')
with self.assertRaises(AttributeError):
unpickler.find_class('math', '')
with self.assertRaises(AttributeError):
unpickler4.find_class('math', '')
self.assertRaises(ModuleNotFoundError, unpickler.find_class, 'spam', 'log')
self.assertRaises(ValueError, unpickler.find_class, '', 'log')
self.assertRaises(TypeError, unpickler.find_class, None, 'log')
self.assertRaises(TypeError, unpickler.find_class, 'math', None)
self.assertRaises((TypeError, AttributeError), unpickler4.find_class, 'math', None)
def test_custom_find_class(self):
def loads(data):
class Unpickler(self.unpickler):
def find_class(self, module_name, global_name):
return (module_name, global_name)
return Unpickler(io.BytesIO(data)).load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def loads(data):
class Unpickler(self.unpickler):
@staticmethod
def find_class(module_name, global_name):
return (module_name, global_name)
return Unpickler(io.BytesIO(data)).load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def loads(data):
class Unpickler(self.unpickler):
@classmethod
def find_class(cls, module_name, global_name):
return (module_name, global_name)
return Unpickler(io.BytesIO(data)).load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def loads(data):
class Unpickler(self.unpickler):
pass
def find_class(module_name, global_name):
return (module_name, global_name)
unpickler = Unpickler(io.BytesIO(data))
unpickler.find_class = find_class
return unpickler.load()
self.assertEqual(loads(b'cmath\nlog\n.'), ('math', 'log'))
self.assertEqual(loads(b'\x8c\x04math\x8c\x03log\x93.'), ('math', 'log'))
def test_bad_reduce(self):
self.assertEqual(self.loads(b'cbuiltins\nint\n)R.'), 0)
self.check_unpickling_error(TypeError, b'N)R.')
@ -1443,6 +1561,474 @@ class AbstractUnpickleTests:
[ToBeUnpickled] * 2)
class AbstractPicklingErrorTests:
# Subclass must define self.dumps, self.pickler.
def test_bad_reduce_result(self):
obj = REX([print, ()])
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj = REX((print,))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj = REX((print, (), None, None, None, None, None))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_bad_reconstructor(self):
obj = REX((42, ()))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_reconstructor(self):
obj = REX((UnpickleableCallable(), ()))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_reconstructor_args(self):
obj = REX((print, []))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_reconstructor_args(self):
obj = REX((print, (1, 2, UNPICKLEABLE)))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_newobj_args(self):
obj = REX((copyreg.__newobj__, ()))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((IndexError, pickle.PicklingError)) as cm:
self.dumps(obj, proto)
obj = REX((copyreg.__newobj__, [REX]))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((IndexError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_bad_newobj_class(self):
obj = REX((copyreg.__newobj__, (NoNew(),)))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_wrong_newobj_class(self):
obj = REX((copyreg.__newobj__, (str,)))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_newobj_class(self):
class LocalREX(REX): pass
obj = LocalREX((copyreg.__newobj__, (LocalREX,)))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((pickle.PicklingError, AttributeError)):
self.dumps(obj, proto)
def test_unpickleable_newobj_args(self):
obj = REX((copyreg.__newobj__, (REX, 1, 2, UNPICKLEABLE)))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_newobj_ex_args(self):
obj = REX((copyreg.__newobj_ex__, ()))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((ValueError, pickle.PicklingError)):
self.dumps(obj, proto)
obj = REX((copyreg.__newobj_ex__, 42))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj = REX((copyreg.__newobj_ex__, (REX, 42, {})))
is_py = self.pickler is pickle._Pickler
for proto in protocols[2:4] if is_py else protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
obj = REX((copyreg.__newobj_ex__, (REX, (), [])))
for proto in protocols[2:4] if is_py else protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_bad_newobj_ex__class(self):
obj = REX((copyreg.__newobj_ex__, (NoNew(), (), {})))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_wrong_newobj_ex_class(self):
if self.pickler is not pickle._Pickler:
self.skipTest('only verified in the Python implementation')
obj = REX((copyreg.__newobj_ex__, (str, (), {})))
for proto in protocols[2:]:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_newobj_ex_class(self):
class LocalREX(REX): pass
obj = LocalREX((copyreg.__newobj_ex__, (LocalREX, (), {})))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((pickle.PicklingError, AttributeError)):
self.dumps(obj, proto)
def test_unpickleable_newobj_ex_args(self):
obj = REX((copyreg.__newobj_ex__, (REX, (1, 2, UNPICKLEABLE), {})))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_newobj_ex_kwargs(self):
obj = REX((copyreg.__newobj_ex__, (REX, (), {'a': UNPICKLEABLE})))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_state(self):
obj = REX_state(UNPICKLEABLE)
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_state_setter(self):
if self.pickler is pickle._Pickler:
self.skipTest('only verified in the C implementation')
obj = REX((print, (), 'state', None, None, 42))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_unpickleable_state_setter(self):
obj = REX((print, (), 'state', None, None, UnpickleableCallable()))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_state_with_state_setter(self):
obj = REX((print, (), UNPICKLEABLE, None, None, print))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_object_list_items(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
obj = REX((list, (), None, 42))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
if self.pickler is not pickle._Pickler:
# Python implementation is less strict and also accepts iterables.
obj = REX((list, (), None, []))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_unpickleable_object_list_items(self):
obj = REX_six([1, 2, UNPICKLEABLE])
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_bad_object_dict_items(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
obj = REX((dict, (), None, None, 42))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
for proto in protocols:
obj = REX((dict, (), None, None, iter([('a',)])))
with self.subTest(proto=proto):
with self.assertRaises((ValueError, TypeError)):
self.dumps(obj, proto)
if self.pickler is not pickle._Pickler:
# Python implementation is less strict and also accepts iterables.
obj = REX((dict, (), None, None, []))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((TypeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_unpickleable_object_dict_items(self):
obj = REX_seven({'a': UNPICKLEABLE})
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_list_items(self):
obj = [1, [2, 3, UNPICKLEABLE]]
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
for n in [0, 1, 1000, 1005]:
obj = [*range(n), UNPICKLEABLE]
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_tuple_items(self):
obj = (1, (2, 3, UNPICKLEABLE))
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
obj = (*range(10), UNPICKLEABLE)
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_dict_items(self):
obj = {'a': {'b': UNPICKLEABLE}}
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
for n in [0, 1, 1000, 1005]:
obj = dict.fromkeys(range(n))
obj['a'] = UNPICKLEABLE
for proto in protocols:
with self.subTest(proto=proto, n=n):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_set_items(self):
obj = {UNPICKLEABLE}
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_unpickleable_frozenset_items(self):
obj = frozenset({frozenset({UNPICKLEABLE})})
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(CustomError):
self.dumps(obj, proto)
def test_global_lookup_error(self):
# Global name does not exist
obj = REX('spam')
obj.__module__ = __name__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = 'nonexisting'
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = ''
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((ValueError, pickle.PicklingError)):
self.dumps(obj, proto)
obj.__module__ = None
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_nonencodable_global_name_error(self):
for proto in protocols[:4]:
with self.subTest(proto=proto):
name = 'nonascii\xff' if proto < 3 else 'nonencodable\udbff'
obj = REX(name)
obj.__module__ = __name__
with support.swap_item(globals(), name, obj):
with self.assertRaises((UnicodeEncodeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_nonencodable_module_name_error(self):
for proto in protocols[:4]:
with self.subTest(proto=proto):
name = 'nonascii\xff' if proto < 3 else 'nonencodable\udbff'
obj = REX('test')
obj.__module__ = name
mod = types.SimpleNamespace(test=obj)
with support.swap_item(sys.modules, name, mod):
with self.assertRaises((UnicodeEncodeError, pickle.PicklingError)):
self.dumps(obj, proto)
def test_nested_lookup_error(self):
# Nested name does not exist
obj = REX('AbstractPickleTests.spam')
obj.__module__ = __name__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = None
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_wrong_object_lookup_error(self):
# Name is bound to different object
obj = REX('AbstractPickleTests')
obj.__module__ = __name__
AbstractPickleTests.ham = []
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
obj.__module__ = None
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises(pickle.PicklingError):
self.dumps(obj, proto)
def test_local_lookup_error(self):
# Test that whichmodule() errors out cleanly when looking up
# an assumed globally-reachable object fails.
def f():
pass
# Since the function is local, lookup will fail
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((AttributeError, pickle.PicklingError)):
self.dumps(f, proto)
# Same without a __module__ attribute (exercises a different path
# in _pickle.c).
del f.__module__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((AttributeError, pickle.PicklingError)):
self.dumps(f, proto)
# Yet a different path.
f.__name__ = f.__qualname__
for proto in protocols:
with self.subTest(proto=proto):
with self.assertRaises((AttributeError, pickle.PicklingError)):
self.dumps(f, proto)
def test_reduce_ex_None(self):
c = REX_None()
with self.assertRaises(TypeError):
self.dumps(c)
def test_reduce_None(self):
c = R_None()
with self.assertRaises(TypeError):
self.dumps(c)
@no_tracing
def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
for proto in range(2):
with support.infinite_recursion(25):
self.assertRaises(RuntimeError, self.dumps, x, proto)
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(x, proto)
def test_picklebuffer_error(self):
# PickleBuffer forbidden with protocol < 5
pb = pickle.PickleBuffer(b"foobar")
for proto in range(0, 5):
with self.subTest(proto=proto):
with self.assertRaises(pickle.PickleError):
self.dumps(pb, proto)
def test_non_continuous_buffer(self):
if self.pickler is pickle._Pickler:
self.skipTest('CRASHES (see gh-122306)')
for proto in protocols[5:]:
with self.subTest(proto=proto):
pb = pickle.PickleBuffer(memoryview(b"foobar")[::2])
with self.assertRaises(pickle.PicklingError):
self.dumps(pb, proto)
def test_buffer_callback_error(self):
def buffer_callback(buffers):
raise CustomError
pb = pickle.PickleBuffer(b"foobar")
with self.assertRaises(CustomError):
self.dumps(pb, 5, buffer_callback=buffer_callback)
def test_evil_pickler_mutating_collection(self):
# https://github.com/python/cpython/issues/92930
global Clearer
class Clearer:
pass
def check(collection):
class EvilPickler(self.pickler):
def persistent_id(self, obj):
if isinstance(obj, Clearer):
collection.clear()
return None
pickler = EvilPickler(io.BytesIO(), proto)
try:
pickler.dump(collection)
except RuntimeError as e:
expected = "changed size during iteration"
self.assertIn(expected, str(e))
for proto in protocols:
check([Clearer()])
check([Clearer(), Clearer()])
check({Clearer()})
check({Clearer(), Clearer()})
check({Clearer(): 1})
check({Clearer(): 1, Clearer(): 2})
check({1: Clearer(), 2: Clearer()})
class AbstractPickleTests:
# Subclass must define self.dumps, self.loads.
@ -2453,55 +3039,12 @@ class AbstractPickleTests:
y = self.loads(s)
self.assertEqual(y._reduce_called, 1)
def test_reduce_ex_None(self):
c = REX_None()
with self.assertRaises(TypeError):
self.dumps(c)
def test_reduce_None(self):
c = R_None()
with self.assertRaises(TypeError):
self.dumps(c)
def test_pickle_setstate_None(self):
c = C_None_setstate()
p = self.dumps(c)
with self.assertRaises(TypeError):
self.loads(p)
@no_tracing
def test_bad_getattr(self):
# Issue #3514: crash when there is an infinite loop in __getattr__
x = BadGetattr()
for proto in range(2):
with support.infinite_recursion(25):
self.assertRaises(RuntimeError, self.dumps, x, proto)
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
s = self.dumps(x, proto)
def test_reduce_bad_iterator(self):
# Issue4176: crash when 4th and 5th items of __reduce__()
# are not iterators
class C(object):
def __reduce__(self):
# 4th item is not an iterator
return list, (), None, [], None
class D(object):
def __reduce__(self):
# 5th item is not an iterator
return dict, (), None, None, []
# Python implementation is less strict and also accepts iterables.
for proto in protocols:
try:
self.dumps(C(), proto)
except pickle.PicklingError:
pass
try:
self.dumps(D(), proto)
except pickle.PicklingError:
pass
def test_many_puts_and_gets(self):
# Test that internal data structures correctly deal with lots of
# puts/gets.
@ -2950,27 +3493,6 @@ class AbstractPickleTests:
self.assertIn(('c%s\n%s' % (mod, name)).encode(), pickled)
self.assertIs(type(self.loads(pickled)), type(val))
def test_local_lookup_error(self):
# Test that whichmodule() errors out cleanly when looking up
# an assumed globally-reachable object fails.
def f():
pass
# Since the function is local, lookup will fail
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
# Same without a __module__ attribute (exercises a different path
# in _pickle.c).
del f.__module__
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
# Yet a different path.
f.__name__ = f.__qualname__
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises((AttributeError, pickle.PicklingError)):
pickletools.dis(self.dumps(f, proto))
#
# PEP 574 tests below
#
@ -3081,20 +3603,6 @@ class AbstractPickleTests:
self.assertIs(type(new), type(obj))
self.assertEqual(new, obj)
def test_picklebuffer_error(self):
# PickleBuffer forbidden with protocol < 5
pb = pickle.PickleBuffer(b"foobar")
for proto in range(0, 5):
with self.assertRaises(pickle.PickleError):
self.dumps(pb, proto)
def test_buffer_callback_error(self):
def buffer_callback(buffers):
1/0
pb = pickle.PickleBuffer(b"foobar")
with self.assertRaises(ZeroDivisionError):
self.dumps(pb, 5, buffer_callback=buffer_callback)
def test_buffers_error(self):
pb = pickle.PickleBuffer(b"foobar")
for proto in range(5, pickle.HIGHEST_PROTOCOL + 1):
@ -3186,37 +3694,6 @@ class AbstractPickleTests:
expected = "changed size during iteration"
self.assertIn(expected, str(e))
def test_evil_pickler_mutating_collection(self):
# https://github.com/python/cpython/issues/92930
if not hasattr(self, "pickler"):
raise self.skipTest(f"{type(self)} has no associated pickler type")
global Clearer
class Clearer:
pass
def check(collection):
class EvilPickler(self.pickler):
def persistent_id(self, obj):
if isinstance(obj, Clearer):
collection.clear()
return None
pickler = EvilPickler(io.BytesIO(), proto)
try:
pickler.dump(collection)
except RuntimeError as e:
expected = "changed size during iteration"
self.assertIn(expected, str(e))
for proto in protocols:
check([Clearer()])
check([Clearer(), Clearer()])
check({Clearer()})
check({Clearer(), Clearer()})
check({Clearer(): 1})
check({Clearer(): 1, Clearer(): 2})
check({1: Clearer(), 2: Clearer()})
class BigmemPickleTests:
@ -3347,6 +3824,18 @@ class BigmemPickleTests:
# Test classes for reduce_ex
class R:
def __init__(self, reduce=None):
self.reduce = reduce
def __reduce__(self, proto):
return self.reduce
class REX:
def __init__(self, reduce_ex=None):
self.reduce_ex = reduce_ex
def __reduce_ex__(self, proto):
return self.reduce_ex
class REX_one(object):
"""No __reduce_ex__ here, but inheriting it from object"""
_reduce_called = 0
@ -3437,6 +3926,19 @@ class C_None_setstate:
__setstate__ = None
class CustomError(Exception):
pass
class Unpickleable:
def __reduce__(self):
raise CustomError
UNPICKLEABLE = Unpickleable()
class UnpickleableCallable(Unpickleable):
def __call__(self, *args, **kwargs):
pass
# Test classes for newobj
@ -3505,6 +4007,12 @@ class BadGetattr:
def __getattr__(self, key):
self.foo
class NoNew:
def __getattribute__(self, name):
if name == '__new__':
raise AttributeError
return super().__getattribute__(name)
class AbstractPickleModuleTests:
@ -3577,7 +4085,7 @@ class AbstractPickleModuleTests:
raise OSError
@property
def bad_property(self):
1/0
raise CustomError
# File without read and readline
class F:
@ -3598,23 +4106,23 @@ class AbstractPickleModuleTests:
class F:
read = bad_property
readline = raises_oserror
self.assertRaises(ZeroDivisionError, self.Unpickler, F())
self.assertRaises(CustomError, self.Unpickler, F())
# File with bad readline
class F:
readline = bad_property
read = raises_oserror
self.assertRaises(ZeroDivisionError, self.Unpickler, F())
self.assertRaises(CustomError, self.Unpickler, F())
# File with bad readline, no read
class F:
readline = bad_property
self.assertRaises(ZeroDivisionError, self.Unpickler, F())
self.assertRaises(CustomError, self.Unpickler, F())
# File with bad read, no readline
class F:
read = bad_property
self.assertRaises((AttributeError, ZeroDivisionError), self.Unpickler, F())
self.assertRaises((AttributeError, CustomError), self.Unpickler, F())
# File with bad peek
class F:
@ -3623,7 +4131,7 @@ class AbstractPickleModuleTests:
readline = raises_oserror
try:
self.Unpickler(F())
except ZeroDivisionError:
except CustomError:
pass
# File with bad readinto
@ -3633,7 +4141,7 @@ class AbstractPickleModuleTests:
readline = raises_oserror
try:
self.Unpickler(F())
except ZeroDivisionError:
except CustomError:
pass
def test_pickler_bad_file(self):
@ -3646,8 +4154,8 @@ class AbstractPickleModuleTests:
class F:
@property
def write(self):
1/0
self.assertRaises(ZeroDivisionError, self.Pickler, F())
raise CustomError
self.assertRaises(CustomError, self.Pickler, F())
def check_dumps_loads_oob_buffers(self, dumps, loads):
# No need to do the full gamut of tests here, just enough to
@ -3755,9 +4263,15 @@ class AbstractIdentityPersistentPicklerTests:
def test_protocol0_is_ascii_only(self):
non_ascii_str = "\N{EMPTY SET}"
self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0)
with self.assertRaises(pickle.PicklingError) as cm:
self.dumps(non_ascii_str, 0)
self.assertEqual(str(cm.exception),
'persistent IDs in protocol 0 must be ASCII strings')
pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
self.assertRaises(pickle.UnpicklingError, self.loads, pickled)
with self.assertRaises(pickle.UnpicklingError) as cm:
self.loads(pickled)
self.assertEqual(str(cm.exception),
'persistent IDs in protocol 0 must be ASCII strings')
class AbstractPicklerUnpicklerObjectTests:

View File

@ -16,6 +16,7 @@ from test.support import import_helper
from test.pickletester import AbstractHookTests
from test.pickletester import AbstractUnpickleTests
from test.pickletester import AbstractPicklingErrorTests
from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
@ -55,6 +56,18 @@ class PyUnpicklerTests(AbstractUnpickleTests, unittest.TestCase):
return u.load()
class PyPicklingErrorTests(AbstractPicklingErrorTests, unittest.TestCase):
pickler = pickle._Pickler
def dumps(self, arg, proto=None, **kwargs):
f = io.BytesIO()
p = self.pickler(f, proto, **kwargs)
p.dump(arg)
f.seek(0)
return bytes(f.read())
class PyPicklerTests(AbstractPickleTests, unittest.TestCase):
pickler = pickle._Pickler
@ -88,6 +101,8 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
return pickle.loads(buf, **kwds)
test_framed_write_sizes_with_delayed_writer = None
test_find_class = None
test_custom_find_class = None
class PersistentPicklerUnpicklerMixin(object):
@ -267,6 +282,9 @@ if has_c_implementation:
bad_stack_errors = (pickle.UnpicklingError,)
truncated_errors = (pickle.UnpicklingError,)
class CPicklingErrorTests(PyPicklingErrorTests):
pickler = _pickle.Pickler
class CPicklerTests(PyPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler