diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index cf0b95d73c7..31930fc763a 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -30,6 +30,16 @@ def signature(part): """ return the signature of a partial object """ return (part.func, part.args, part.keywords, part.__dict__) +class MyTuple(tuple): + pass + +class BadTuple(tuple): + def __add__(self, other): + return list(self) + list(other) + +class MyDict(dict): + pass + class TestPartial: @@ -208,11 +218,84 @@ class TestPartialC(TestPartial, unittest.TestCase): for kwargs_repr in kwargs_reprs]) def test_pickle(self): - f = self.partial(signature, 'asdf', bar=True) - f.add_something_to__dict__ = True + f = self.partial(signature, ['asdf'], bar=[True]) + f.attr = [] for proto in range(pickle.HIGHEST_PROTOCOL + 1): f_copy = pickle.loads(pickle.dumps(f, proto)) - self.assertEqual(signature(f), signature(f_copy)) + self.assertEqual(signature(f_copy), signature(f)) + + def test_copy(self): + f = self.partial(signature, ['asdf'], bar=[True]) + f.attr = [] + f_copy = copy.copy(f) + self.assertEqual(signature(f_copy), signature(f)) + self.assertIs(f_copy.attr, f.attr) + self.assertIs(f_copy.args, f.args) + self.assertIs(f_copy.keywords, f.keywords) + + def test_deepcopy(self): + f = self.partial(signature, ['asdf'], bar=[True]) + f.attr = [] + f_copy = copy.deepcopy(f) + self.assertEqual(signature(f_copy), signature(f)) + self.assertIsNot(f_copy.attr, f.attr) + self.assertIsNot(f_copy.args, f.args) + self.assertIsNot(f_copy.args[0], f.args[0]) + self.assertIsNot(f_copy.keywords, f.keywords) + self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) + + def test_setstate(self): + f = self.partial(signature) + f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) + self.assertEqual(signature(f), + (capture, (1,), dict(a=10), dict(attr=[]))) + self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) + + f.__setstate__((capture, (1,), dict(a=10), None)) + self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) + self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) + + f.__setstate__((capture, (1,), None, None)) + #self.assertEqual(signature(f), (capture, (1,), {}, {})) + self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) + self.assertEqual(f(2), ((1, 2), {})) + self.assertEqual(f(), ((1,), {})) + + f.__setstate__((capture, (), {}, None)) + self.assertEqual(signature(f), (capture, (), {}, {})) + self.assertEqual(f(2, b=20), ((2,), {'b': 20})) + self.assertEqual(f(2), ((2,), {})) + self.assertEqual(f(), ((), {})) + + def test_setstate_errors(self): + f = self.partial(signature) + self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) + self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) + self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) + self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) + self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) + self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) + self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) + + def test_setstate_subclasses(self): + f = self.partial(signature) + f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) + s = signature(f) + self.assertEqual(s, (capture, (1,), dict(a=10), {})) + self.assertIs(type(s[1]), tuple) + self.assertIs(type(s[2]), dict) + r = f() + self.assertEqual(r, ((1,), {'a': 10})) + self.assertIs(type(r[0]), tuple) + self.assertIs(type(r[1]), dict) + + f.__setstate__((capture, BadTuple((1,)), {}, None)) + s = signature(f) + self.assertEqual(s, (capture, (1,), {}, {})) + self.assertIs(type(s[1]), tuple) + r = f(2) + self.assertEqual(r, ((1, 2), {})) + self.assertIs(type(r[0]), tuple) # Issue 6083: Reference counting bug def test_setstate_refcount(self): @@ -229,9 +312,7 @@ class TestPartialC(TestPartial, unittest.TestCase): raise IndexError f = self.partial(object) - self.assertRaisesRegex(SystemError, - "new style getargs format but argument is not a tuple", - f.__setstate__, BadSequence()) + self.assertRaises(TypeError, f.__setstate__, BadSequence()) class TestPartialPy(TestPartial, unittest.TestCase): diff --git a/Misc/NEWS b/Misc/NEWS index ba1f4de5f5c..3a71354a043 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -70,6 +70,11 @@ Core and Builtins Library ------- +- Issue #25945: Fixed a crash when unpickle the functools.partial object with + wrong state. Fixed a leak in failed functools.partial constructor. + "args" and "keywords" attributes of functools.partial have now always types + tuple and dict correspondingly. + - Issue #26202: copy.deepcopy() now correctly copies range() objects with non-atomic attributes. diff --git a/Modules/_functoolsmodule.c b/Modules/_functoolsmodule.c index 035d3d9c591..8da5eb374ae 100644 --- a/Modules/_functoolsmodule.c +++ b/Modules/_functoolsmodule.c @@ -34,7 +34,7 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw) return NULL; } - pargs = pkw = Py_None; + pargs = pkw = NULL; func = PyTuple_GET_ITEM(args, 0); if (Py_TYPE(func) == &partial_type && type == &partial_type) { partialobject *part = (partialobject *)func; @@ -42,6 +42,8 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw) pargs = part->args; pkw = part->kw; func = part->fn; + assert(PyTuple_Check(pargs)); + assert(PyDict_Check(pkw)); } } if (!PyCallable_Check(func)) { @@ -60,12 +62,10 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw) nargs = PyTuple_GetSlice(args, 1, PY_SSIZE_T_MAX); if (nargs == NULL) { - pto->args = NULL; - pto->kw = NULL; Py_DECREF(pto); return NULL; } - if (pargs == Py_None || PyTuple_GET_SIZE(pargs) == 0) { + if (pargs == NULL || PyTuple_GET_SIZE(pargs) == 0) { pto->args = nargs; Py_INCREF(nargs); } @@ -76,47 +76,36 @@ partial_new(PyTypeObject *type, PyObject *args, PyObject *kw) else { pto->args = PySequence_Concat(pargs, nargs); if (pto->args == NULL) { - pto->kw = NULL; + Py_DECREF(nargs); Py_DECREF(pto); return NULL; } + assert(PyTuple_Check(pto->args)); } Py_DECREF(nargs); - if (kw != NULL) { - if (pkw == Py_None) { - pto->kw = PyDict_Copy(kw); + if (pkw == NULL || PyDict_Size(pkw) == 0) { + if (kw == NULL) { + pto->kw = PyDict_New(); } else { - pto->kw = PyDict_Copy(pkw); - if (pto->kw != NULL) { - if (PyDict_Merge(pto->kw, kw, 1) != 0) { - Py_DECREF(pto); - return NULL; - } - } - } - if (pto->kw == NULL) { - Py_DECREF(pto); - return NULL; + Py_INCREF(kw); + pto->kw = kw; } } else { - if (pkw == Py_None) { - pto->kw = PyDict_New(); - if (pto->kw == NULL) { + pto->kw = PyDict_Copy(pkw); + if (kw != NULL && pto->kw != NULL) { + if (PyDict_Merge(pto->kw, kw, 1) != 0) { Py_DECREF(pto); return NULL; } } - else { - pto->kw = pkw; - Py_INCREF(pkw); - } } - - pto->weakreflist = NULL; - pto->dict = NULL; + if (pto->kw == NULL) { + Py_DECREF(pto); + return NULL; + } return (PyObject *)pto; } @@ -138,11 +127,11 @@ static PyObject * partial_call(partialobject *pto, PyObject *args, PyObject *kw) { PyObject *ret; - PyObject *argappl = NULL, *kwappl = NULL; + PyObject *argappl, *kwappl; assert (PyCallable_Check(pto->fn)); assert (PyTuple_Check(pto->args)); - assert (pto->kw == Py_None || PyDict_Check(pto->kw)); + assert (PyDict_Check(pto->kw)); if (PyTuple_GET_SIZE(pto->args) == 0) { argappl = args; @@ -154,11 +143,12 @@ partial_call(partialobject *pto, PyObject *args, PyObject *kw) argappl = PySequence_Concat(pto->args, args); if (argappl == NULL) return NULL; + assert(PyTuple_Check(argappl)); } - if (pto->kw == Py_None) { + if (PyDict_Size(pto->kw) == 0) { kwappl = kw; - Py_XINCREF(kw); + Py_XINCREF(kwappl); } else { kwappl = PyDict_Copy(pto->kw); if (kwappl == NULL) { @@ -217,6 +207,7 @@ partial_repr(partialobject *pto) PyObject *arglist; PyObject *tmp; Py_ssize_t i, n; + PyObject *key, *value; arglist = PyUnicode_FromString(""); if (arglist == NULL) { @@ -234,17 +225,14 @@ partial_repr(partialobject *pto) arglist = tmp; } /* Pack keyword arguments */ - assert (pto->kw == Py_None || PyDict_Check(pto->kw)); - if (pto->kw != Py_None) { - PyObject *key, *value; - for (i = 0; PyDict_Next(pto->kw, &i, &key, &value);) { - tmp = PyUnicode_FromFormat("%U, %U=%R", arglist, - key, value); - Py_DECREF(arglist); - if (tmp == NULL) - return NULL; - arglist = tmp; - } + assert (PyDict_Check(pto->kw)); + for (i = 0; PyDict_Next(pto->kw, &i, &key, &value);) { + tmp = PyUnicode_FromFormat("%U, %U=%R", arglist, + key, value); + Py_DECREF(arglist); + if (tmp == NULL) + return NULL; + arglist = tmp; } result = PyUnicode_FromFormat("%s(%R%U)", Py_TYPE(pto)->tp_name, pto->fn, arglist); @@ -271,25 +259,45 @@ static PyObject * partial_setstate(partialobject *pto, PyObject *state) { PyObject *fn, *fnargs, *kw, *dict; - if (!PyArg_ParseTuple(state, "OOOO", - &fn, &fnargs, &kw, &dict)) + + if (!PyTuple_Check(state) || + !PyArg_ParseTuple(state, "OOOO", &fn, &fnargs, &kw, &dict) || + !PyCallable_Check(fn) || + !PyTuple_Check(fnargs) || + (kw != Py_None && !PyDict_Check(kw))) + { + PyErr_SetString(PyExc_TypeError, "invalid partial state"); return NULL; - Py_XDECREF(pto->fn); - Py_XDECREF(pto->args); - Py_XDECREF(pto->kw); - Py_XDECREF(pto->dict); - pto->fn = fn; - pto->args = fnargs; - pto->kw = kw; - if (dict != Py_None) { - pto->dict = dict; - Py_INCREF(dict); - } else { - pto->dict = NULL; } + + if(!PyTuple_CheckExact(fnargs)) + fnargs = PySequence_Tuple(fnargs); + else + Py_INCREF(fnargs); + if (fnargs == NULL) + return NULL; + + if (kw == Py_None) + kw = PyDict_New(); + else if(!PyDict_CheckExact(kw)) + kw = PyDict_Copy(kw); + else + Py_INCREF(kw); + if (kw == NULL) { + Py_DECREF(fnargs); + return NULL; + } + Py_INCREF(fn); - Py_INCREF(fnargs); - Py_INCREF(kw); + if (dict == Py_None) + dict = NULL; + else + Py_INCREF(dict); + + Py_SETREF(pto->fn, fn); + Py_SETREF(pto->args, fnargs); + Py_SETREF(pto->kw, kw); + Py_SETREF(pto->dict, dict); Py_RETURN_NONE; }