Issue #17711: Fixed unpickling by the persistent ID with protocol 0.

Original patch by Alexandre Vassalotti.
This commit is contained in:
Serhiy Storchaka 2016-07-17 11:24:17 +03:00
parent 6fd76bceda
commit dec25afab1
5 changed files with 89 additions and 22 deletions

View File

@ -529,7 +529,11 @@ class _Pickler:
self.save(pid, save_persistent_id=False)
self.write(BINPERSID)
else:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
try:
self.write(PERSID + str(pid).encode("ascii") + b'\n')
except UnicodeEncodeError:
raise PicklingError(
"persistent IDs in protocol 0 must be ASCII strings")
def save_reduce(self, func, args, state=None, listitems=None,
dictitems=None, obj=None):
@ -1075,7 +1079,11 @@ class _Unpickler:
dispatch[FRAME[0]] = load_frame
def load_persid(self):
pid = self.readline()[:-1].decode("ascii")
try:
pid = self.readline()[:-1].decode("ascii")
except UnicodeDecodeError:
raise UnpicklingError(
"persistent IDs in protocol 0 must be ASCII strings")
self.append(self.persistent_load(pid))
dispatch[PERSID[0]] = load_persid

View File

@ -2629,6 +2629,35 @@ class AbstractPersistentPicklerTests(unittest.TestCase):
self.assertEqual(self.load_false_count, 1)
class AbstractIdentityPersistentPicklerTests(unittest.TestCase):
def persistent_id(self, obj):
return obj
def persistent_load(self, pid):
return pid
def _check_return_correct_type(self, obj, proto):
unpickled = self.loads(self.dumps(obj, proto))
self.assertIsInstance(unpickled, type(obj))
self.assertEqual(unpickled, obj)
def test_return_correct_type(self):
for proto in protocols:
# Protocol 0 supports only ASCII strings.
if proto == 0:
self._check_return_correct_type("abc", 0)
else:
for obj in [b"abc\n", "abc\n", -1, -1.1 * 0.1, str]:
self._check_return_correct_type(obj, proto)
def test_protocol0_is_ascii_only(self):
non_ascii_str = "\N{EMPTY SET}"
self.assertRaises(pickle.PicklingError, self.dumps, non_ascii_str, 0)
pickled = pickle.PERSID + non_ascii_str.encode('utf-8') + b'\n.'
self.assertRaises(pickle.UnpicklingError, self.loads, pickled)
class AbstractPicklerUnpicklerObjectTests(unittest.TestCase):
pickler_class = None

View File

@ -14,6 +14,7 @@ from test.pickletester import AbstractUnpickleTests
from test.pickletester import AbstractPickleTests
from test.pickletester import AbstractPickleModuleTests
from test.pickletester import AbstractPersistentPicklerTests
from test.pickletester import AbstractIdentityPersistentPicklerTests
from test.pickletester import AbstractPicklerUnpicklerObjectTests
from test.pickletester import AbstractDispatchTableTests
from test.pickletester import BigmemPickleTests
@ -82,10 +83,7 @@ class InMemoryPickleTests(AbstractPickleTests, AbstractUnpickleTests,
return pickle.loads(buf, **kwds)
class PyPersPicklerTests(AbstractPersistentPicklerTests):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
class PersistentPicklerUnpicklerMixin(object):
def dumps(self, arg, proto=None):
class PersPickler(self.pickler):
@ -94,8 +92,7 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
f = io.BytesIO()
p = PersPickler(f, proto)
p.dump(arg)
f.seek(0)
return f.read()
return f.getvalue()
def loads(self, buf, **kwds):
class PersUnpickler(self.unpickler):
@ -106,6 +103,20 @@ class PyPersPicklerTests(AbstractPersistentPicklerTests):
return u.load()
class PyPersPicklerTests(AbstractPersistentPicklerTests,
PersistentPicklerUnpicklerMixin):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
class PyIdPersPicklerTests(AbstractIdentityPersistentPicklerTests,
PersistentPicklerUnpicklerMixin):
pickler = pickle._Pickler
unpickler = pickle._Unpickler
class PyPicklerUnpicklerObjectTests(AbstractPicklerUnpicklerObjectTests):
pickler_class = pickle._Pickler
@ -144,6 +155,10 @@ if has_c_implementation:
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
class CIdPersPicklerTests(PyIdPersPicklerTests):
pickler = _pickle.Pickler
unpickler = _pickle.Unpickler
class CDumpPickle_LoadPickle(PyPicklerTests):
pickler = _pickle.Pickler
unpickler = pickle._Unpickler
@ -409,11 +424,13 @@ class CompatPickleTests(unittest.TestCase):
def test_main():
tests = [PickleTests, PyUnpicklerTests, PyPicklerTests, PyPersPicklerTests,
tests = [PickleTests, PyUnpicklerTests, PyPicklerTests,
PyPersPicklerTests, PyIdPersPicklerTests,
PyDispatchTableTests, PyChainDispatchTableTests,
CompatPickleTests]
if has_c_implementation:
tests.extend([CUnpicklerTests, CPicklerTests, CPersPicklerTests,
tests.extend([CUnpicklerTests, CPicklerTests,
CPersPicklerTests, CIdPersPicklerTests,
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
PyPicklerUnpicklerObjectTests,
CPicklerUnpicklerObjectTests,

View File

@ -24,6 +24,9 @@ Core and Builtins
Library
-------
- Issue #17711: Fixed unpickling by the persistent ID with protocol 0.
Original patch by Alexandre Vassalotti.
- Issue #27522: Avoid an unintentional reference cycle in email.feedparser.
- Issue #26844: Fix error message for imp.find_module() to refer to 'path'

View File

@ -3406,26 +3406,30 @@ save_pers(PicklerObject *self, PyObject *obj, PyObject *func)
goto error;
}
else {
PyObject *pid_str = NULL;
char *pid_ascii_bytes;
Py_ssize_t size;
PyObject *pid_str;
pid_str = PyObject_Str(pid);
if (pid_str == NULL)
goto error;
/* XXX: Should it check whether the persistent id only contains
ASCII characters? And what if the pid contains embedded
/* XXX: Should it check whether the pid contains embedded
newlines? */
pid_ascii_bytes = _PyUnicode_AsStringAndSize(pid_str, &size);
Py_DECREF(pid_str);
if (pid_ascii_bytes == NULL)
if (!PyUnicode_IS_ASCII(pid_str)) {
PyErr_SetString(_Pickle_GetGlobalState()->PicklingError,
"persistent IDs in protocol 0 must be "
"ASCII strings");
Py_DECREF(pid_str);
goto error;
}
if (_Pickler_Write(self, &persid_op, 1) < 0 ||
_Pickler_Write(self, pid_ascii_bytes, size) < 0 ||
_Pickler_Write(self, "\n", 1) < 0)
_Pickler_Write(self, PyUnicode_DATA(pid_str),
PyUnicode_GET_LENGTH(pid_str)) < 0 ||
_Pickler_Write(self, "\n", 1) < 0) {
Py_DECREF(pid_str);
goto error;
}
Py_DECREF(pid_str);
}
status = 1;
}
@ -5389,9 +5393,15 @@ load_persid(UnpicklerObject *self)
if (len < 1)
return bad_readline();
pid = PyBytes_FromStringAndSize(s, len - 1);
if (pid == NULL)
pid = PyUnicode_DecodeASCII(s, len - 1, "strict");
if (pid == NULL) {
if (PyErr_ExceptionMatches(PyExc_UnicodeDecodeError)) {
PyErr_SetString(_Pickle_GetGlobalState()->UnpicklingError,
"persistent IDs in protocol 0 must be "
"ASCII strings");
}
return -1;
}
/* This does not leak since _Pickle_FastCall() steals the reference
to pid first. */