diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 2ebeff65e3f..77df31b622d 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -15,6 +15,12 @@ def check_pass_thru(): raise PassThru yield 1 +class BadCmp: + def __hash__(self): + return 1 + def __cmp__(self, other): + raise RuntimeError + class TestJointOps(unittest.TestCase): # Tests common to both set and frozenset @@ -227,6 +233,17 @@ class TestJointOps(unittest.TestCase): f.add(s) f.discard(s) + def test_badcmp(self): + s = self.thetype([BadCmp()]) + # Detect comparison errors during insertion and lookup + self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()]) + self.assertRaises(RuntimeError, s.__contains__, BadCmp()) + # Detect errors during mutating operations + if hasattr(s, 'add'): + self.assertRaises(RuntimeError, s.add, BadCmp()) + self.assertRaises(RuntimeError, s.discard, BadCmp()) + self.assertRaises(RuntimeError, s.remove, BadCmp()) + class TestSet(TestJointOps): thetype = set diff --git a/Objects/setobject.c b/Objects/setobject.c index dbfa79cc3e7..c4cd562d582 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -44,11 +44,8 @@ probe indices are computed as explained in Objects/dictobject.c. All arithmetic on hash should ignore overflow. -The lookup function always succeeds and nevers return NULL. This simplifies -and speeds client functions which do won't have to test for and handle -errors. To meet that requirement, any errors generated by a user defined -__cmp__() function are simply cleared and ignored. -Previously outstanding exceptions are maintained. +Unlike the dictionary implementation, the lookkey functions can return +NULL if the rich comparison returns an error. */ static setentry * @@ -60,10 +57,7 @@ set_lookkey(PySetObject *so, PyObject *key, register long hash) register unsigned int mask = so->mask; setentry *table = so->table; register setentry *entry; - register int restore_error; - register int checked_error; register int cmp; - PyObject *err_type, *err_value, *err_tb; PyObject *startkey; i = hash & mask; @@ -71,31 +65,23 @@ set_lookkey(PySetObject *so, PyObject *key, register long hash) if (entry->key == NULL || entry->key == key) return entry; - restore_error = checked_error = 0; if (entry->key == dummy) freeslot = entry; else { if (entry->hash == hash) { - /* error can't have been checked yet */ - checked_error = 1; - if (_PyErr_OCCURRED()) { - restore_error = 1; - PyErr_Fetch(&err_type, &err_value, &err_tb); - } startkey = entry->key; cmp = PyObject_RichCompareBool(startkey, key, Py_EQ); if (cmp < 0) - PyErr_Clear(); + return NULL; if (table == so->table && entry->key == startkey) { if (cmp > 0) - goto Done; + return entry; } else { /* The compare did major nasty stuff to the * set: start over. */ - entry = set_lookkey(so, key, hash); - goto Done; + return set_lookkey(so, key, hash); } } freeslot = NULL; @@ -114,18 +100,10 @@ set_lookkey(PySetObject *so, PyObject *key, register long hash) if (entry->key == key) break; if (entry->hash == hash && entry->key != dummy) { - if (!checked_error) { - checked_error = 1; - if (_PyErr_OCCURRED()) { - restore_error = 1; - PyErr_Fetch(&err_type, &err_value, - &err_tb); - } - } startkey = entry->key; cmp = PyObject_RichCompareBool(startkey, key, Py_EQ); if (cmp < 0) - PyErr_Clear(); + return NULL; if (table == so->table && entry->key == startkey) { if (cmp > 0) break; @@ -134,29 +112,19 @@ set_lookkey(PySetObject *so, PyObject *key, register long hash) /* The compare did major nasty stuff to the * set: start over. */ - entry = set_lookkey(so, key, hash); - break; + return set_lookkey(so, key, hash); } } else if (entry->key == dummy && freeslot == NULL) freeslot = entry; } - -Done: - if (restore_error) - PyErr_Restore(err_type, err_value, err_tb); return entry; } /* * Hacked up version of set_lookkey which can assume keys are always strings; - * this assumption allows testing for errors during PyObject_Compare() to - * be dropped; string-string comparisons never raise exceptions. This also - * means we don't need to go through PyObject_Compare(); we can always use - * _PyString_Eq directly. - * - * This is valuable because the general-case error handling in set_lookkey() is - * expensive, and sets with pure-string keys may be very common. + * This means we can always use _PyString_Eq directly and not have to check to + * see if the comparison altered the table. */ static setentry * set_lookkey_string(PySetObject *so, PyObject *key, register long hash) @@ -210,7 +178,7 @@ Internal routine to insert a new key into the table. Used both by the internal resize routine and by the public insert routine. Eats a reference to key. */ -static void +static int set_insert_key(register PySetObject *so, PyObject *key, long hash) { register setentry *entry; @@ -218,6 +186,8 @@ set_insert_key(register PySetObject *so, PyObject *key, long hash) assert(so->lookup != NULL); entry = so->lookup(so, key, hash); + if (entry == NULL) + return -1; if (entry->key == NULL) { /* UNUSED */ so->fill++; @@ -234,6 +204,7 @@ set_insert_key(register PySetObject *so, PyObject *key, long hash) /* ACTIVE */ Py_DECREF(key); } + return 0; } /* @@ -317,7 +288,11 @@ set_table_resize(PySetObject *so, int minused) } else { /* ACTIVE */ --i; - set_insert_key(so, entry->key, entry->hash); + if(set_insert_key(so, entry->key, entry->hash) == -1) { + if (is_oldtable_malloced) + PyMem_DEL(oldtable); + return -1; + } } } @@ -336,7 +311,8 @@ set_add_entry(register PySetObject *so, setentry *entry) assert(so->fill <= so->mask); /* at least one empty slot */ n_used = so->used; Py_INCREF(entry->key); - set_insert_key(so, entry->key, entry->hash); + if (set_insert_key(so, entry->key, entry->hash) == -1) + return -1; if (!(so->used > n_used && so->fill*3 >= (so->mask+1)*2)) return 0; return set_table_resize(so, so->used>50000 ? so->used*2 : so->used*4); @@ -357,7 +333,10 @@ set_add_key(register PySetObject *so, PyObject *key) assert(so->fill <= so->mask); /* at least one empty slot */ n_used = so->used; Py_INCREF(key); - set_insert_key(so, key, hash); + if (set_insert_key(so, key, hash) == -1) { + Py_DECREF(key); + return -1; + } if (!(so->used > n_used && so->fill*3 >= (so->mask+1)*2)) return 0; return set_table_resize(so, so->used>50000 ? so->used*2 : so->used*4); @@ -372,6 +351,8 @@ set_discard_entry(PySetObject *so, setentry *oldentry) PyObject *old_key; entry = (so->lookup)(so, oldentry->key, oldentry->hash); + if (entry == NULL) + return -1; if (entry->key == NULL || entry->key == dummy) return DISCARD_NOTFOUND; old_key = entry->key; @@ -397,6 +378,8 @@ set_discard_key(PySetObject *so, PyObject *key) return -1; } entry = (so->lookup)(so, key, hash); + if (entry == NULL) + return -1; if (entry->key == NULL || entry->key == dummy) return DISCARD_NOTFOUND; old_key = entry->key; @@ -601,7 +584,10 @@ set_merge(PySetObject *so, PyObject *otherset) if (entry->key != NULL && entry->key != dummy) { Py_INCREF(entry->key); - set_insert_key(so, entry->key, entry->hash); + if (set_insert_key(so, entry->key, entry->hash) == -1) { + Py_DECREF(entry->key); + return -1; + } } } return 0; @@ -611,6 +597,7 @@ static int set_contains_key(PySetObject *so, PyObject *key) { long hash; + setentry *entry; if (!PyString_CheckExact(key) || (hash = ((PyStringObject *) key)->ob_shash) == -1) { @@ -618,7 +605,10 @@ set_contains_key(PySetObject *so, PyObject *key) if (hash == -1) return -1; } - key = (so->lookup)(so, key, hash)->key; + entry = (so->lookup)(so, key, hash); + if (entry == NULL) + return -1; + key = entry->key; return key != NULL && key != dummy; } @@ -626,8 +616,12 @@ static int set_contains_entry(PySetObject *so, setentry *entry) { PyObject *key; + setentry *lu_entry; - key = (so->lookup)(so, entry->key, entry->hash)->key; + lu_entry = (so->lookup)(so, entry->key, entry->hash); + if (lu_entry == NULL) + return -1; + key = lu_entry->key; return key != NULL && key != dummy; } @@ -2096,4 +2090,6 @@ test_c_api(PySetObject *so) Py_RETURN_TRUE; } +#undef assertRaises + #endif