diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py index 607d182011b..3a5fce6e7ab 100644 --- a/Lib/test/test_dictviews.py +++ b/Lib/test/test_dictviews.py @@ -17,18 +17,33 @@ class DictSetTest(unittest.TestCase): def test_dict_keys(self): d = {1: 10, "a": "ABC"} keys = d.keys() - self.assertEqual(set(keys), {1, "a"}) self.assertEqual(len(keys), 2) + self.assertEqual(set(keys), {1, "a"}) + self.assertEqual(keys, {1, "a"}) + self.assertNotEqual(keys, {1, "a", "b"}) + self.assertNotEqual(keys, {1, "b"}) + self.assertNotEqual(keys, {1}) + self.assertNotEqual(keys, 42) self.assert_(1 in keys) self.assert_("a" in keys) self.assert_(10 not in keys) self.assert_("Z" not in keys) + self.assertEqual(d.keys(), d.keys()) + e = {1: 11, "a": "def"} + self.assertEqual(d.keys(), e.keys()) + del e["a"] + self.assertNotEqual(d.keys(), e.keys()) def test_dict_items(self): d = {1: 10, "a": "ABC"} items = d.items() - self.assertEqual(set(items), {(1, 10), ("a", "ABC")}) self.assertEqual(len(items), 2) + self.assertEqual(set(items), {(1, 10), ("a", "ABC")}) + self.assertEqual(items, {(1, 10), ("a", "ABC")}) + self.assertNotEqual(items, {(1, 10), ("a", "ABC"), "junk"}) + self.assertNotEqual(items, {(1, 10), ("a", "def")}) + self.assertNotEqual(items, {(1, 10)}) + self.assertNotEqual(items, 42) self.assert_((1, 10) in items) self.assert_(("a", "ABC") in items) self.assert_((1, 11) not in items) @@ -36,6 +51,17 @@ class DictSetTest(unittest.TestCase): self.assert_(() not in items) self.assert_((1,) not in items) self.assert_((1, 2, 3) not in items) + self.assertEqual(d.items(), d.items()) + e = d.copy() + self.assertEqual(d.items(), e.items()) + e["a"] = "def" + self.assertNotEqual(d.items(), e.items()) + + def test_dict_mixed_keys_items(self): + d = {(1, 1): 11, (2, 2): 22} + e = {1: 1, 2: 2} + self.assertEqual(d.keys(), e.items()) + self.assertNotEqual(d.items(), e.keys()) def test_dict_values(self): d = {1: 10, "a": "ABC"} diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 49907b4f48d..eeb8bd51f6f 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2399,6 +2399,83 @@ dictview_new(PyObject *dict, PyTypeObject *type) return (PyObject *)dv; } +/* Forward */ +PyTypeObject PyDictKeys_Type; +PyTypeObject PyDictItems_Type; +PyTypeObject PyDictValues_Type; + +#define PyDictKeys_Check(obj) ((obj)->ob_type == &PyDictKeys_Type) +#define PyDictItems_Check(obj) ((obj)->ob_type == &PyDictItems_Type) +#define PyDictValues_Check(obj) ((obj)->ob_type == &PyDictValues_Type) + +/* This excludes Values, since they are not sets. */ +# define PyDictViewSet_Check(obj) \ + (PyDictKeys_Check(obj) || PyDictItems_Check(obj)) + +static int +all_contained_in(PyObject *self, PyObject *other) +{ + PyObject *iter = PyObject_GetIter(self); + int ok = 1; + + if (iter == NULL) + return -1; + for (;;) { + PyObject *next = PyIter_Next(iter); + if (next == NULL) { + if (PyErr_Occurred()) + ok = -1; + break; + } + ok = PySequence_Contains(other, next); + Py_DECREF(next); + if (ok <= 0) + break; + } + Py_DECREF(iter); + return ok; +} + +static PyObject * +dictview_richcompare(PyObject *self, PyObject *other, int op) +{ + assert(self != NULL); + assert(PyDictViewSet_Check(self)); + assert(other != NULL); + if ((op == Py_EQ || op == Py_NE) && + (PyAnySet_Check(other) || PyDictViewSet_Check(other))) + { + Py_ssize_t len_self, len_other; + int ok; + PyObject *result; + + len_self = PyObject_Size(self); + if (len_self < 0) + return NULL; + len_other = PyObject_Size(other); + if (len_other < 0) + return NULL; + if (len_self != len_other) + ok = 0; + else if (len_self == 0) + ok = 1; + else + ok = all_contained_in(self, other); + if (ok < 0) + return NULL; + if (ok == (op == Py_EQ)) + result = Py_True; + else + result = Py_False; + Py_INCREF(result); + return result; + } + else { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } +} + /*** dict_keys ***/ static PyObject * @@ -2459,7 +2536,7 @@ PyTypeObject PyDictKeys_Type = { 0, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + dictview_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ (getiterfunc)dictkeys_iter, /* tp_iter */ 0, /* tp_iternext */ @@ -2544,7 +2621,7 @@ PyTypeObject PyDictItems_Type = { 0, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + dictview_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ (getiterfunc)dictitems_iter, /* tp_iter */ 0, /* tp_iternext */