gh-82951: Fix serializing by name in pickle protocols < 4 (GH-122149)

Serializing objects with complex __qualname__ (such as unbound methods and
nested classes) by name no longer involves serializing parent objects by value
in pickle protocols < 4.
This commit is contained in:
Serhiy Storchaka 2024-07-25 11:45:19 +03:00 committed by GitHub
parent ca0f7c447c
commit dc07f65a53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 82 additions and 26 deletions

View File

@ -1110,11 +1110,35 @@ class _Pickler:
self.save(module_name) self.save(module_name)
self.save(name) self.save(name)
write(STACK_GLOBAL) write(STACK_GLOBAL)
elif parent is not module: elif '.' in name:
self.save_reduce(getattr, (parent, lastname)) # In protocol < 4, objects with multi-part __qualname__
elif self.proto >= 3: # are represented as
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' + # getattr(getattr(..., attrname1), attrname2).
bytes(name, "utf-8") + b'\n') dotted_path = name.split('.')
name = dotted_path.pop(0)
save = self.save
for attrname in dotted_path:
save(getattr)
if self.proto < 2:
write(MARK)
self._save_toplevel_by_name(module_name, name)
for attrname in dotted_path:
save(attrname)
if self.proto < 2:
write(TUPLE)
else:
write(TUPLE2)
write(REDUCE)
else:
self._save_toplevel_by_name(module_name, name)
self.memoize(obj)
def _save_toplevel_by_name(self, module_name, name):
if self.proto >= 3:
# Non-ASCII identifiers are supported only with protocols >= 3.
self.write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
bytes(name, "utf-8") + b'\n')
else: else:
if self.fix_imports: if self.fix_imports:
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
@ -1124,15 +1148,13 @@ class _Pickler:
elif module_name in r_import_mapping: elif module_name in r_import_mapping:
module_name = r_import_mapping[module_name] module_name = r_import_mapping[module_name]
try: try:
write(GLOBAL + bytes(module_name, "ascii") + b'\n' + self.write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
bytes(name, "ascii") + b'\n') bytes(name, "ascii") + b'\n')
except UnicodeEncodeError: except UnicodeEncodeError:
raise PicklingError( raise PicklingError(
"can't pickle global identifier '%s.%s' using " "can't pickle global identifier '%s.%s' using "
"pickle protocol %i" % (module, name, self.proto)) from None "pickle protocol %i" % (module, name, self.proto)) from None
self.memoize(obj)
def save_type(self, obj): def save_type(self, obj):
if obj is type(None): if obj is type(None):
return self.save_reduce(type, (None,), obj=obj) return self.save_reduce(type, (None,), obj=obj)

View File

@ -2818,6 +2818,18 @@ class AbstractPickleTests:
self.assertIs(unpickled, Recursive) self.assertIs(unpickled, Recursive)
del Recursive.mod # break reference loop del Recursive.mod # break reference loop
def test_recursive_nested_names2(self):
global Recursive
class Recursive:
pass
Recursive.ref = Recursive
Recursive.__qualname__ = 'Recursive.ref'
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.subTest(proto=proto):
unpickled = self.loads(self.dumps(Recursive, proto))
self.assertIs(unpickled, Recursive)
del Recursive.ref # break reference loop
def test_py_methods(self): def test_py_methods(self):
global PyMethodsTest global PyMethodsTest
class PyMethodsTest: class PyMethodsTest:

View File

@ -0,0 +1,3 @@
Serializing objects with complex ``__qualname__`` (such as unbound methods
and nested classes) by name no longer involves serializing parent objects by
value in pickle protocols < 4.

View File

@ -3592,7 +3592,6 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
PyObject *module = NULL; PyObject *module = NULL;
PyObject *parent = NULL; PyObject *parent = NULL;
PyObject *dotted_path = NULL; PyObject *dotted_path = NULL;
PyObject *lastname = NULL;
PyObject *cls; PyObject *cls;
int status = 0; int status = 0;
@ -3633,10 +3632,7 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
obj, module_name); obj, module_name);
goto error; goto error;
} }
lastname = Py_NewRef(PyList_GET_ITEM(dotted_path,
PyList_GET_SIZE(dotted_path) - 1));
cls = get_deep_attribute(module, dotted_path, &parent); cls = get_deep_attribute(module, dotted_path, &parent);
Py_CLEAR(dotted_path);
if (cls == NULL) { if (cls == NULL) {
PyErr_Format(st->PicklingError, PyErr_Format(st->PicklingError,
"Can't pickle %R: attribute lookup %S on %S failed", "Can't pickle %R: attribute lookup %S on %S failed",
@ -3724,7 +3720,10 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
else { else {
gen_global: gen_global:
if (parent == module) { if (parent == module) {
Py_SETREF(global_name, Py_NewRef(lastname)); Py_SETREF(global_name,
Py_NewRef(PyList_GET_ITEM(dotted_path,
PyList_GET_SIZE(dotted_path) - 1)));
Py_CLEAR(dotted_path);
} }
if (self->proto >= 4) { if (self->proto >= 4) {
const char stack_global_op = STACK_GLOBAL; const char stack_global_op = STACK_GLOBAL;
@ -3737,20 +3736,30 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
if (_Pickler_Write(self, &stack_global_op, 1) < 0) if (_Pickler_Write(self, &stack_global_op, 1) < 0)
goto error; goto error;
} }
else if (parent != module) {
PyObject *reduce_value = Py_BuildValue("(O(OO))",
st->getattr, parent, lastname);
if (reduce_value == NULL)
goto error;
status = save_reduce(st, self, reduce_value, NULL);
Py_DECREF(reduce_value);
if (status < 0)
goto error;
}
else { else {
/* Generate a normal global opcode if we are using a pickle /* Generate a normal global opcode if we are using a pickle
protocol < 4, or if the object is not registered in the protocol < 4, or if the object is not registered in the
extension registry. */ extension registry.
Objects with multi-part __qualname__ are represented as
getattr(getattr(..., attrname1), attrname2). */
const char mark_op = MARK;
const char tupletwo_op = (self->proto < 2) ? TUPLE : TUPLE2;
const char reduce_op = REDUCE;
Py_ssize_t i;
if (dotted_path) {
if (PyList_GET_SIZE(dotted_path) > 1) {
Py_SETREF(global_name, Py_NewRef(PyList_GET_ITEM(dotted_path, 0)));
}
for (i = 1; i < PyList_GET_SIZE(dotted_path); i++) {
if (save(st, self, st->getattr, 0) < 0 ||
(self->proto < 2 && _Pickler_Write(self, &mark_op, 1) < 0))
{
goto error;
}
}
}
PyObject *encoded; PyObject *encoded;
PyObject *(*unicode_encoder)(PyObject *); PyObject *(*unicode_encoder)(PyObject *);
@ -3812,6 +3821,17 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
Py_DECREF(encoded); Py_DECREF(encoded);
if (_Pickler_Write(self, "\n", 1) < 0) if (_Pickler_Write(self, "\n", 1) < 0)
goto error; goto error;
if (dotted_path) {
for (i = 1; i < PyList_GET_SIZE(dotted_path); i++) {
if (save(st, self, PyList_GET_ITEM(dotted_path, i), 0) < 0 ||
_Pickler_Write(self, &tupletwo_op, 1) < 0 ||
_Pickler_Write(self, &reduce_op, 1) < 0)
{
goto error;
}
}
}
} }
/* Memoize the object. */ /* Memoize the object. */
if (memo_put(st, self, obj) < 0) if (memo_put(st, self, obj) < 0)
@ -3827,7 +3847,6 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
Py_XDECREF(module); Py_XDECREF(module);
Py_XDECREF(parent); Py_XDECREF(parent);
Py_XDECREF(dotted_path); Py_XDECREF(dotted_path);
Py_XDECREF(lastname);
return status; return status;
} }