Make dict.keys() and dict.items() comparable to sets, using == and !=.

(PEP 3106 requires subset comparisons too, those will come later if someone
really wants them. :-)
This commit is contained in:
Guido van Rossum 2007-02-12 02:23:40 +00:00
parent e19aad4c7b
commit d9214d1f2c
2 changed files with 107 additions and 4 deletions

View File

@ -17,18 +17,33 @@ class DictSetTest(unittest.TestCase):
def test_dict_keys(self):
d = {1: 10, "a": "ABC"}
keys = d.keys()
self.assertEqual(set(keys), {1, "a"})
self.assertEqual(len(keys), 2)
self.assertEqual(set(keys), {1, "a"})
self.assertEqual(keys, {1, "a"})
self.assertNotEqual(keys, {1, "a", "b"})
self.assertNotEqual(keys, {1, "b"})
self.assertNotEqual(keys, {1})
self.assertNotEqual(keys, 42)
self.assert_(1 in keys)
self.assert_("a" in keys)
self.assert_(10 not in keys)
self.assert_("Z" not in keys)
self.assertEqual(d.keys(), d.keys())
e = {1: 11, "a": "def"}
self.assertEqual(d.keys(), e.keys())
del e["a"]
self.assertNotEqual(d.keys(), e.keys())
def test_dict_items(self):
d = {1: 10, "a": "ABC"}
items = d.items()
self.assertEqual(set(items), {(1, 10), ("a", "ABC")})
self.assertEqual(len(items), 2)
self.assertEqual(set(items), {(1, 10), ("a", "ABC")})
self.assertEqual(items, {(1, 10), ("a", "ABC")})
self.assertNotEqual(items, {(1, 10), ("a", "ABC"), "junk"})
self.assertNotEqual(items, {(1, 10), ("a", "def")})
self.assertNotEqual(items, {(1, 10)})
self.assertNotEqual(items, 42)
self.assert_((1, 10) in items)
self.assert_(("a", "ABC") in items)
self.assert_((1, 11) not in items)
@ -36,6 +51,17 @@ class DictSetTest(unittest.TestCase):
self.assert_(() not in items)
self.assert_((1,) not in items)
self.assert_((1, 2, 3) not in items)
self.assertEqual(d.items(), d.items())
e = d.copy()
self.assertEqual(d.items(), e.items())
e["a"] = "def"
self.assertNotEqual(d.items(), e.items())
def test_dict_mixed_keys_items(self):
d = {(1, 1): 11, (2, 2): 22}
e = {1: 1, 2: 2}
self.assertEqual(d.keys(), e.items())
self.assertNotEqual(d.items(), e.keys())
def test_dict_values(self):
d = {1: 10, "a": "ABC"}

View File

@ -2399,6 +2399,83 @@ dictview_new(PyObject *dict, PyTypeObject *type)
return (PyObject *)dv;
}
/* Forward */
PyTypeObject PyDictKeys_Type;
PyTypeObject PyDictItems_Type;
PyTypeObject PyDictValues_Type;
#define PyDictKeys_Check(obj) ((obj)->ob_type == &PyDictKeys_Type)
#define PyDictItems_Check(obj) ((obj)->ob_type == &PyDictItems_Type)
#define PyDictValues_Check(obj) ((obj)->ob_type == &PyDictValues_Type)
/* This excludes Values, since they are not sets. */
# define PyDictViewSet_Check(obj) \
(PyDictKeys_Check(obj) || PyDictItems_Check(obj))
static int
all_contained_in(PyObject *self, PyObject *other)
{
PyObject *iter = PyObject_GetIter(self);
int ok = 1;
if (iter == NULL)
return -1;
for (;;) {
PyObject *next = PyIter_Next(iter);
if (next == NULL) {
if (PyErr_Occurred())
ok = -1;
break;
}
ok = PySequence_Contains(other, next);
Py_DECREF(next);
if (ok <= 0)
break;
}
Py_DECREF(iter);
return ok;
}
static PyObject *
dictview_richcompare(PyObject *self, PyObject *other, int op)
{
assert(self != NULL);
assert(PyDictViewSet_Check(self));
assert(other != NULL);
if ((op == Py_EQ || op == Py_NE) &&
(PyAnySet_Check(other) || PyDictViewSet_Check(other)))
{
Py_ssize_t len_self, len_other;
int ok;
PyObject *result;
len_self = PyObject_Size(self);
if (len_self < 0)
return NULL;
len_other = PyObject_Size(other);
if (len_other < 0)
return NULL;
if (len_self != len_other)
ok = 0;
else if (len_self == 0)
ok = 1;
else
ok = all_contained_in(self, other);
if (ok < 0)
return NULL;
if (ok == (op == Py_EQ))
result = Py_True;
else
result = Py_False;
Py_INCREF(result);
return result;
}
else {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
}
/*** dict_keys ***/
static PyObject *
@ -2459,7 +2536,7 @@ PyTypeObject PyDictKeys_Type = {
0, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
dictview_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
(getiterfunc)dictkeys_iter, /* tp_iter */
0, /* tp_iternext */
@ -2544,7 +2621,7 @@ PyTypeObject PyDictItems_Type = {
0, /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
dictview_richcompare, /* tp_richcompare */
0, /* tp_weaklistoffset */
(getiterfunc)dictitems_iter, /* tp_iter */
0, /* tp_iternext */