diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index 20547df3bb6..cd077ff0e0c 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -1055,6 +1055,37 @@ class DictTest(unittest.TestCase): support.check_free_after_iterating(self, lambda d: iter(d.values()), dict) support.check_free_after_iterating(self, lambda d: iter(d.items()), dict) + +class CAPITest(unittest.TestCase): + + # Test _PyDict_GetItem_KnownHash() + @support.cpython_only + def test_getitem_knownhash(self): + from _testcapi import dict_getitem_knownhash + + d = {'x': 1, 'y': 2, 'z': 3} + self.assertEqual(dict_getitem_knownhash(d, 'x', hash('x')), 1) + self.assertEqual(dict_getitem_knownhash(d, 'y', hash('y')), 2) + self.assertEqual(dict_getitem_knownhash(d, 'z', hash('z')), 3) + + # not a dict + self.assertRaises(SystemError, dict_getitem_knownhash, [], 1, hash(1)) + # key does not exist + self.assertRaises(KeyError, dict_getitem_knownhash, {}, 1, hash(1)) + + class Exc(Exception): pass + class BadEq: + def __eq__(self, other): + raise Exc + def __hash__(self): + return 7 + + k1, k2 = BadEq(), BadEq() + d = {k1: 1} + self.assertEqual(dict_getitem_knownhash(d, k1, hash(k1)), 1) + self.assertRaises(Exc, dict_getitem_knownhash, d, k2, hash(k2)) + + from test import mapping_tests class GeneralMappingTests(mapping_tests.BasicTestMappingProtocol): diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 9ed6f14bec3..66087982474 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -2300,6 +2300,8 @@ _count_elements(PyObject *self, PyObject *args) oldval = _PyDict_GetItem_KnownHash(mapping, key, hash); if (oldval == NULL) { + if (PyErr_Occurred()) + goto done; if (_PyDict_SetItem_KnownHash(mapping, key, one, hash) < 0) goto done; } else { diff --git a/Modules/_testcapimodule.c b/Modules/_testcapimodule.c index 74422a2f8db..ecbec972caf 100644 --- a/Modules/_testcapimodule.c +++ b/Modules/_testcapimodule.c @@ -238,6 +238,26 @@ test_dict_iteration(PyObject* self) return Py_None; } +static PyObject* +dict_getitem_knownhash(PyObject *self, PyObject *args) +{ + PyObject *mp, *key, *result; + Py_ssize_t hash; + + if (!PyArg_ParseTuple(args, "OOn:dict_getitem_knownhash", + &mp, &key, &hash)) { + return NULL; + } + + result = _PyDict_GetItem_KnownHash(mp, key, (Py_hash_t)hash); + if (result == NULL && !PyErr_Occurred()) { + _PyErr_SetKeyError(key); + return NULL; + } + + Py_XINCREF(result); + return result; +} /* Issue #4701: Check that PyObject_Hash implicitly calls * PyType_Ready if it hasn't already been called @@ -4003,6 +4023,7 @@ static PyMethodDef TestMethods[] = { {"test_datetime_capi", test_datetime_capi, METH_NOARGS}, {"test_list_api", (PyCFunction)test_list_api, METH_NOARGS}, {"test_dict_iteration", (PyCFunction)test_dict_iteration,METH_NOARGS}, + {"dict_getitem_knownhash", dict_getitem_knownhash, METH_VARARGS}, {"test_lazy_hash_inheritance", (PyCFunction)test_lazy_hash_inheritance,METH_NOARGS}, {"test_long_api", (PyCFunction)test_long_api, METH_NOARGS}, {"test_xincref_doesnt_leak",(PyCFunction)test_xincref_doesnt_leak, METH_NOARGS}, diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 3cbc3fdeb5a..d8ab91fb12d 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -723,8 +723,10 @@ top: Py_INCREF(startkey); cmp = PyObject_RichCompareBool(startkey, key, Py_EQ); Py_DECREF(startkey); - if (cmp < 0) + if (cmp < 0) { + *value_addr = NULL; return DKIX_ERROR; + } if (dk == mp->ma_keys && ep->me_key == startkey) { if (cmp > 0) { *value_addr = &ep->me_value; @@ -1424,39 +1426,25 @@ PyDict_GetItem(PyObject *op, PyObject *key) return *value_addr; } +/* Same as PyDict_GetItemWithError() but with hash supplied by caller. + This returns NULL *with* an exception set if an exception occurred. + It returns NULL *without* an exception set if the key wasn't present. +*/ PyObject * _PyDict_GetItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash) { Py_ssize_t ix; PyDictObject *mp = (PyDictObject *)op; - PyThreadState *tstate; PyObject **value_addr; - if (!PyDict_Check(op)) + if (!PyDict_Check(op)) { + PyErr_BadInternalCall(); return NULL; - - /* We can arrive here with a NULL tstate during initialization: try - running "python -Wi" for an example related to string interning. - Let's just hope that no exception occurs then... This must be - _PyThreadState_Current and not PyThreadState_GET() because in debug - mode, the latter complains if tstate is NULL. */ - tstate = _PyThreadState_UncheckedGet(); - if (tstate != NULL && tstate->curexc_type != NULL) { - /* preserve the existing exception */ - PyObject *err_type, *err_value, *err_tb; - PyErr_Fetch(&err_type, &err_value, &err_tb); - ix = (mp->ma_keys->dk_lookup)(mp, key, hash, &value_addr, NULL); - /* ignore errors */ - PyErr_Restore(err_type, err_value, err_tb); - if (ix == DKIX_EMPTY) - return NULL; } - else { - ix = (mp->ma_keys->dk_lookup)(mp, key, hash, &value_addr, NULL); - if (ix == DKIX_EMPTY) { - PyErr_Clear(); - return NULL; - } + + ix = (mp->ma_keys->dk_lookup)(mp, key, hash, &value_addr, NULL); + if (ix < 0) { + return NULL; } return *value_addr; } @@ -2431,8 +2419,16 @@ dict_merge(PyObject *a, PyObject *b, int override) int err = 0; Py_INCREF(key); Py_INCREF(value); - if (override == 1 || _PyDict_GetItem_KnownHash(a, key, hash) == NULL) + if (override == 1) err = insertdict(mp, key, hash, value); + else if (_PyDict_GetItem_KnownHash(a, key, hash) == NULL) { + if (PyErr_Occurred()) { + Py_DECREF(value); + Py_DECREF(key); + return -1; + } + err = insertdict(mp, key, hash, value); + } else if (override != 0) { _PyErr_SetKeyError(key); Py_DECREF(value);