Check for a common user error with defaultdict().

This commit is contained in:
Raymond Hettinger 2007-02-07 21:40:49 +00:00
parent 05d59e2df7
commit 113776c411
3 changed files with 13 additions and 4 deletions

View File

@ -47,6 +47,7 @@ class TestDefaultDict(unittest.TestCase):
self.assertEqual(err.args, (15,)) self.assertEqual(err.args, (15,))
else: else:
self.fail("d2[15] didn't raise KeyError") self.fail("d2[15] didn't raise KeyError")
self.assertRaises(TypeError, defaultdict, 1)
def test_missing(self): def test_missing(self):
d1 = defaultdict() d1 = defaultdict()
@ -60,10 +61,10 @@ class TestDefaultDict(unittest.TestCase):
self.assertEqual(repr(d1), "defaultdict(None, {})") self.assertEqual(repr(d1), "defaultdict(None, {})")
d1[11] = 41 d1[11] = 41
self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
d2 = defaultdict(0) d2 = defaultdict(int)
self.assertEqual(d2.default_factory, 0) self.assertEqual(d2.default_factory, int)
d2[12] = 42 d2[12] = 42
self.assertEqual(repr(d2), "defaultdict(0, {12: 42})") self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
def foo(): return 43 def foo(): return 43
d3 = defaultdict(foo) d3 = defaultdict(foo)
self.assert_(d3.default_factory is foo) self.assert_(d3.default_factory is foo)

View File

@ -103,6 +103,8 @@ Core and builtins
Extension Modules Extension Modules
----------------- -----------------
- collections.defaultdict() now verifies that the factory function is callable.
- Bug #1486663: don't reject keyword arguments for subclasses of builtin - Bug #1486663: don't reject keyword arguments for subclasses of builtin
types. types.

View File

@ -1252,8 +1252,14 @@ defdict_init(PyObject *self, PyObject *args, PyObject *kwds)
newargs = PyTuple_New(0); newargs = PyTuple_New(0);
else { else {
Py_ssize_t n = PyTuple_GET_SIZE(args); Py_ssize_t n = PyTuple_GET_SIZE(args);
if (n > 0) if (n > 0) {
newdefault = PyTuple_GET_ITEM(args, 0); newdefault = PyTuple_GET_ITEM(args, 0);
if (!PyCallable_Check(newdefault)) {
PyErr_SetString(PyExc_TypeError,
"first argument must be callable");
return -1;
}
}
newargs = PySequence_GetSlice(args, 1, n); newargs = PySequence_GetSlice(args, 1, n);
} }
if (newargs == NULL) if (newargs == NULL)