From a762af74b2de734c44f7dc00358325d4485e2530 Mon Sep 17 00:00:00 2001 From: Eric Snow Date: Mon, 1 Jun 2015 22:59:08 -0600 Subject: [PATCH] Issue #24347: Set KeyError if PyDict_GetItemWithError returns NULL. --- Include/odictobject.h | 2 ++ Lib/test/test_collections.py | 18 +++++++++++ Misc/NEWS | 2 ++ Objects/odictobject.c | 62 +++++++++++++++++++++++++++--------- 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/Include/odictobject.h b/Include/odictobject.h index 7930d2876d3..c1d9592a1db 100644 --- a/Include/odictobject.h +++ b/Include/odictobject.h @@ -30,6 +30,8 @@ PyAPI_FUNC(int) PyODict_DelItem(PyObject *od, PyObject *key); /* wrappers around PyDict* functions */ #define PyODict_GetItem(od, key) PyDict_GetItem((PyObject *)od, key) +#define PyODict_GetItemWithError(od, key) \ + PyDict_GetItemWithError((PyObject *)od, key) #define PyODict_Contains(od, key) PyDict_Contains((PyObject *)od, key) #define PyODict_Size(od) PyDict_Size((PyObject *)od) #define PyODict_GetItemString(od, key) \ diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py index 3f02129c1df..931ac0ff0fe 100644 --- a/Lib/test/test_collections.py +++ b/Lib/test/test_collections.py @@ -2037,6 +2037,24 @@ class CPythonOrderedDictTests(OrderedDictTests, unittest.TestCase): del od[colliding] self.assertEqual(list(od.items()), [(key, ...), ('after', ...)]) + def test_issue24347(self): + OrderedDict = self.module.OrderedDict + + class Key: + def __hash__(self): + return randrange(100000) + + od = OrderedDict() + for i in range(100): + key = Key() + od[key] = i + + # These should not crash. + with self.assertRaises(KeyError): + repr(od) + with self.assertRaises(KeyError): + od.copy() + class PurePythonGeneralMappingTests(mapping_tests.BasicTestMappingProtocol): diff --git a/Misc/NEWS b/Misc/NEWS index 2de47f713a8..6260d797b7a 100644 --- a/Misc/NEWS +++ b/Misc/NEWS @@ -15,6 +15,8 @@ Core and Builtins Library ------- +- Issue #24347: Set KeyError if PyDict_GetItemWithError returns NULL. + What's New in Python 3.5.0 beta 2? ================================== diff --git a/Objects/odictobject.c b/Objects/odictobject.c index e7a368f70d7..53155b5f527 100644 --- a/Objects/odictobject.c +++ b/Objects/odictobject.c @@ -511,7 +511,7 @@ struct _odictnode { (node->key) /* borrowed reference */ #define _odictnode_VALUE(node, od) \ - PyODict_GetItem((PyObject *)od, _odictnode_KEY(node)) + PyODict_GetItemWithError((PyObject *)od, _odictnode_KEY(node)) /* If needed we could also have _odictnode_HASH. */ #define _odictnode_PREV(node) (node->prev) #define _odictnode_NEXT(node) (node->next) @@ -1313,10 +1313,14 @@ odict_copy(register PyODictObject *od) if (PyODict_CheckExact(od)) { _odict_FOREACH(od, node) { - int res = PyODict_SetItem((PyObject *)od_copy, - _odictnode_KEY(node), - _odictnode_VALUE(node, od)); - if (res != 0) + PyObject *key = _odictnode_KEY(node); + PyObject *value = _odictnode_VALUE(node, od); + if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); + goto fail; + } + if (PyODict_SetItem((PyObject *)od_copy, key, value) != 0) goto fail; } } @@ -1538,7 +1542,6 @@ odict_repr(PyODictObject *self) Py_ssize_t count = -1; PyObject *pieces = NULL, *result = NULL, *cls = NULL; PyObject *classname = NULL, *format = NULL, *args = NULL; - _ODictNode *node; i = Py_ReprEnter((PyObject *)self); if (i != 0) { @@ -1551,13 +1554,21 @@ odict_repr(PyODictObject *self) } if (PyODict_CheckExact(self)) { + _ODictNode *node; pieces = PyList_New(PyODict_SIZE(self)); if (pieces == NULL) goto Done; _odict_FOREACH(self, node) { - PyObject *pair = PyTuple_Pack(2, _odictnode_KEY(node), - _odictnode_VALUE(node, self)); + PyObject *pair; + PyObject *key = _odictnode_KEY(node); + PyObject *value = _odictnode_VALUE(node, self); + if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + pair = PyTuple_Pack(2, key, value); if (pair == NULL) goto Done; @@ -1813,7 +1824,7 @@ static void odictiter_dealloc(odictiterobject *di) { _PyObject_GC_UNTRACK(di); - Py_DECREF(di->di_odict); + Py_XDECREF(di->di_odict); Py_XDECREF(di->di_current); if (di->kind & (_odict_ITER_KEYS | _odict_ITER_VALUES)) { Py_DECREF(di->di_result); @@ -1830,16 +1841,21 @@ odictiter_traverse(odictiterobject *di, visitproc visit, void *arg) return 0; } +/* In order to protect against modifications during iteration, we track + * the current key instead of the current node. */ static PyObject * odictiter_nextkey(odictiterobject *di) { - PyObject *key; + PyObject *key = NULL; _ODictNode *node; int reversed = di->kind & _odict_ITER_REVERSED; - /* Get the key. */ - if (di->di_current == NULL) + if (di->di_odict == NULL) return NULL; + if (di->di_current == NULL) + goto done; /* We're already done. */ + + /* Get the key. */ node = _odict_find_node(di->di_odict, di->di_current); if (node == NULL) { /* Must have been deleted. */ @@ -1860,6 +1876,10 @@ odictiter_nextkey(odictiterobject *di) } return key; + +done: + Py_CLEAR(di->di_odict); + return key; } static PyObject * @@ -1882,8 +1902,10 @@ odictiter_iternext(odictiterobject *di) value = PyODict_GetItem((PyObject *)di->di_odict, key); /* borrowed */ if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); Py_DECREF(key); - return NULL; + goto done; } Py_INCREF(value); @@ -1899,7 +1921,7 @@ odictiter_iternext(odictiterobject *di) if (result == NULL) { Py_DECREF(key); Py_DECREF(value); - return NULL; + goto done; } } @@ -1911,10 +1933,20 @@ odictiter_iternext(odictiterobject *di) /* Handle the values case. */ else { value = PyODict_GetItem((PyObject *)di->di_odict, key); - Py_XINCREF(value); Py_DECREF(key); + if (value == NULL) { + if (!PyErr_Occurred()) + PyErr_SetObject(PyExc_KeyError, key); + goto done; + } + Py_INCREF(value); return value; } + +done: + Py_CLEAR(di->di_current); + Py_CLEAR(di->di_odict); + return NULL; } /* No need for tp_clear because odictiterobject is not mutable. */