diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index fdbfe19c91f..99d5c70e0a3 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -1660,6 +1660,39 @@ class TestVariousIteratorArgs(unittest.TestCase): self.assertRaises(TypeError, getattr(set('january'), methname), N(data)) self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data)) +class bad_eq: + def __eq__(self, other): + if be_bad: + set2.clear() + raise ZeroDivisionError + return self is other + def __hash__(self): + return 0 + +class bad_dict_clear: + def __eq__(self, other): + if be_bad: + dict2.clear() + return self is other + def __hash__(self): + return 0 + +class TestWeirdBugs(unittest.TestCase): + def test_8420_set_merge(self): + # This used to segfault + global be_bad, set2, dict2 + be_bad = False + set1 = {bad_eq()} + set2 = {bad_eq() for i in range(75)} + be_bad = True + self.assertRaises(ZeroDivisionError, set1.update, set2) + + be_bad = False + set1 = {bad_dict_clear()} + dict2 = {bad_dict_clear(): None} + be_bad = True + set1.symmetric_difference_update(dict2) + # Application tests (based on David Eppstein's graph recipes ==================================== def powerset(U): @@ -1804,6 +1837,7 @@ def test_main(verbose=None): TestIdentities, TestVariousIteratorArgs, TestGraphs, + TestWeirdBugs, ) support.run_unittest(*test_classes) diff --git a/Objects/setobject.c b/Objects/setobject.c index 30afc7c1eef..7aa1a7faee3 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -364,12 +364,14 @@ static int set_add_entry(register PySetObject *so, setentry *entry) { register Py_ssize_t n_used; + PyObject *key = entry->key; + long hash = entry->hash; assert(so->fill <= so->mask); /* at least one empty slot */ n_used = so->used; - Py_INCREF(entry->key); - if (set_insert_key(so, entry->key, (long) entry->hash) == -1) { - Py_DECREF(entry->key); + Py_INCREF(key); + if (set_insert_key(so, key, hash) == -1) { + Py_DECREF(key); return -1; } if (!(so->used > n_used && so->fill*3 >= (so->mask+1)*2)) @@ -637,6 +639,8 @@ static int set_merge(PySetObject *so, PyObject *otherset) { PySetObject *other; + PyObject *key; + long hash; register Py_ssize_t i; register setentry *entry; @@ -657,11 +661,13 @@ set_merge(PySetObject *so, PyObject *otherset) } for (i = 0; i <= other->mask; i++) { entry = &other->table[i]; - if (entry->key != NULL && - entry->key != dummy) { - Py_INCREF(entry->key); - if (set_insert_key(so, entry->key, (long) entry->hash) == -1) { - Py_DECREF(entry->key); + key = entry->key; + hash = entry->hash; + if (key != NULL && + key != dummy) { + Py_INCREF(key); + if (set_insert_key(so, key, hash) == -1) { + Py_DECREF(key); return -1; } } @@ -1642,15 +1648,22 @@ set_symmetric_difference_update(PySetObject *so, PyObject *other) while (_PyDict_Next(other, &pos, &key, &value, &hash)) { setentry an_entry; + Py_INCREF(key); an_entry.hash = hash; an_entry.key = key; + rv = set_discard_entry(so, &an_entry); - if (rv == -1) + if (rv == -1) { + Py_DECREF(key); return NULL; - if (rv == DISCARD_NOTFOUND) { - if (set_add_entry(so, &an_entry) == -1) - return NULL; } + if (rv == DISCARD_NOTFOUND) { + if (set_add_entry(so, &an_entry) == -1) { + Py_DECREF(key); + return NULL; + } + } + Py_DECREF(key); } Py_RETURN_NONE; }