From 3ac6741f7922b7fdf05f9ba231c6eeff73580f8a Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Sat, 10 Feb 2007 18:55:06 +0000 Subject: [PATCH] Implement __contains__ for dict_keys and dict_items. (Not for dict_values, where it can't be done faster than the default implementation which just iterates the elements.) --- Lib/test/test_dictviews.py | 22 ++++++++++ Objects/dictobject.c | 82 ++++++++++++++++++++++++++------------ 2 files changed, 78 insertions(+), 26 deletions(-) diff --git a/Lib/test/test_dictviews.py b/Lib/test/test_dictviews.py index c0c1da1bf25..4c436f7eae2 100644 --- a/Lib/test/test_dictviews.py +++ b/Lib/test/test_dictviews.py @@ -3,17 +3,39 @@ from test import test_support class DictSetTest(unittest.TestCase): + def test_constructors_not_callable(self): + kt = type({}.KEYS()) + self.assertRaises(TypeError, kt, {}) + self.assertRaises(TypeError, kt) + it = type({}.ITEMS()) + self.assertRaises(TypeError, it, {}) + self.assertRaises(TypeError, it) + vt = type({}.VALUES()) + self.assertRaises(TypeError, vt, {}) + self.assertRaises(TypeError, vt) + def test_dict_keys(self): d = {1: 10, "a": "ABC"} keys = d.KEYS() self.assertEqual(set(keys), {1, "a"}) self.assertEqual(len(keys), 2) + self.assert_(1 in keys) + self.assert_("a" in keys) + self.assert_(10 not in keys) + self.assert_("Z" not in keys) def test_dict_items(self): d = {1: 10, "a": "ABC"} items = d.ITEMS() self.assertEqual(set(items), {(1, 10), ("a", "ABC")}) self.assertEqual(len(items), 2) + self.assert_((1, 10) in items) + self.assert_(("a", "ABC") in items) + self.assert_((1, 11) not in items) + self.assert_(1 not in items) + self.assert_(() not in items) + self.assert_((1,) not in items) + self.assert_((1, 2, 3) not in items) def test_dict_values(self): d = {1: 10, "a": "ABC"} diff --git a/Objects/dictobject.c b/Objects/dictobject.c index e2e98db02f1..ec14fcbc316 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2336,37 +2336,40 @@ PyTypeObject PyDictIterItem_Type = { }; +/***********************************************/ /* View objects for keys(), items(), values(). */ +/***********************************************/ + /* While this is incomplete, we use KEYS(), ITEMS(), VALUES(). */ /* The instance lay-out is the same for all three; but the type differs. */ typedef struct { PyObject_HEAD - dictobject *ds_dict; + dictobject *dv_dict; } dictviewobject; static void -dictview_dealloc(dictviewobject *ds) +dictview_dealloc(dictviewobject *dv) { - Py_XDECREF(ds->ds_dict); - PyObject_Del(ds); + Py_XDECREF(dv->dv_dict); + PyObject_Del(dv); } static Py_ssize_t -dictview_len(dictviewobject *ds) +dictview_len(dictviewobject *dv) { Py_ssize_t len = 0; - if (ds->ds_dict != NULL) - len = ds->ds_dict->ma_used; + if (dv->dv_dict != NULL) + len = dv->dv_dict->ma_used; return len; } static PyObject * dictview_new(PyObject *dict, PyTypeObject *type) { - dictviewobject *ds; + dictviewobject *dv; if (dict == NULL) { PyErr_BadInternalCall(); return NULL; @@ -2378,23 +2381,31 @@ dictview_new(PyObject *dict, PyTypeObject *type) type->tp_name, dict->ob_type->tp_name); return NULL; } - ds = PyObject_New(dictviewobject, type); - if (ds == NULL) + dv = PyObject_New(dictviewobject, type); + if (dv == NULL) return NULL; Py_INCREF(dict); - ds->ds_dict = (dictobject *)dict; - return (PyObject *)ds; + dv->dv_dict = (dictobject *)dict; + return (PyObject *)dv; } -/* dict_keys */ +/*** dict_keys ***/ static PyObject * -dictkeys_iter(dictviewobject *ds) +dictkeys_iter(dictviewobject *dv) { - if (ds->ds_dict == NULL) { + if (dv->dv_dict == NULL) { Py_RETURN_NONE; } - return dictiter_new(ds->ds_dict, &PyDictIterKey_Type); + return dictiter_new(dv->dv_dict, &PyDictIterKey_Type); +} + +static int +dictkeys_contains(dictviewobject *dv, PyObject *obj) +{ + if (dv->dv_dict == NULL) + return 0; + return PyDict_Contains((PyObject *)dv->dv_dict, obj); } static PySequenceMethods dictkeys_as_sequence = { @@ -2405,7 +2416,7 @@ static PySequenceMethods dictkeys_as_sequence = { 0, /* sq_slice */ 0, /* sq_ass_item */ 0, /* sq_ass_slice */ - (objobjproc)0, /* sq_contains */ + (objobjproc)dictkeys_contains, /* sq_contains */ }; static PyMethodDef dictkeys_methods[] = { @@ -2452,15 +2463,34 @@ dictkeys_new(PyObject *dict) return dictview_new(dict, &PyDictKeys_Type); } -/* dict_items */ +/*** dict_items ***/ static PyObject * -dictitems_iter(dictviewobject *ds) +dictitems_iter(dictviewobject *dv) { - if (ds->ds_dict == NULL) { + if (dv->dv_dict == NULL) { Py_RETURN_NONE; } - return dictiter_new(ds->ds_dict, &PyDictIterItem_Type); + return dictiter_new(dv->dv_dict, &PyDictIterItem_Type); +} + +static int +dictitems_contains(dictviewobject *dv, PyObject *obj) +{ + PyObject *key, *value, *found; + if (dv->dv_dict == NULL) + return 0; + if (!PyTuple_Check(obj) || PyTuple_GET_SIZE(obj) != 2) + return 0; + key = PyTuple_GET_ITEM(obj, 0); + value = PyTuple_GET_ITEM(obj, 1); + found = PyDict_GetItem((PyObject *)dv->dv_dict, key); + if (found == NULL) { + if (PyErr_Occurred()) + return -1; + return 0; + } + return PyObject_RichCompareBool(value, found, Py_EQ); } static PySequenceMethods dictitems_as_sequence = { @@ -2471,7 +2501,7 @@ static PySequenceMethods dictitems_as_sequence = { 0, /* sq_slice */ 0, /* sq_ass_item */ 0, /* sq_ass_slice */ - (objobjproc)0, /* sq_contains */ + (objobjproc)dictitems_contains, /* sq_contains */ }; static PyMethodDef dictitems_methods[] = { @@ -2518,15 +2548,15 @@ dictitems_new(PyObject *dict) return dictview_new(dict, &PyDictItems_Type); } -/* dict_values */ +/*** dict_values ***/ static PyObject * -dictvalues_iter(dictviewobject *ds) +dictvalues_iter(dictviewobject *dv) { - if (ds->ds_dict == NULL) { + if (dv->dv_dict == NULL) { Py_RETURN_NONE; } - return dictiter_new(ds->ds_dict, &PyDictIterValue_Type); + return dictiter_new(dv->dv_dict, &PyDictIterValue_Type); } static PySequenceMethods dictvalues_as_sequence = {