mirror of https://github.com/python/cpython
Check for a common user error with defaultdict().
This commit is contained in:
parent
3156316823
commit
5a0217efea
|
@ -47,6 +47,7 @@ class TestDefaultDict(unittest.TestCase):
|
||||||
self.assertEqual(err.args, (15,))
|
self.assertEqual(err.args, (15,))
|
||||||
else:
|
else:
|
||||||
self.fail("d2[15] didn't raise KeyError")
|
self.fail("d2[15] didn't raise KeyError")
|
||||||
|
self.assertRaises(TypeError, defaultdict, 1)
|
||||||
|
|
||||||
def test_missing(self):
|
def test_missing(self):
|
||||||
d1 = defaultdict()
|
d1 = defaultdict()
|
||||||
|
@ -60,10 +61,10 @@ class TestDefaultDict(unittest.TestCase):
|
||||||
self.assertEqual(repr(d1), "defaultdict(None, {})")
|
self.assertEqual(repr(d1), "defaultdict(None, {})")
|
||||||
d1[11] = 41
|
d1[11] = 41
|
||||||
self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
|
self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
|
||||||
d2 = defaultdict(0)
|
d2 = defaultdict(int)
|
||||||
self.assertEqual(d2.default_factory, 0)
|
self.assertEqual(d2.default_factory, int)
|
||||||
d2[12] = 42
|
d2[12] = 42
|
||||||
self.assertEqual(repr(d2), "defaultdict(0, {12: 42})")
|
self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
|
||||||
def foo(): return 43
|
def foo(): return 43
|
||||||
d3 = defaultdict(foo)
|
d3 = defaultdict(foo)
|
||||||
self.assert_(d3.default_factory is foo)
|
self.assert_(d3.default_factory is foo)
|
||||||
|
|
|
@ -1252,8 +1252,14 @@ defdict_init(PyObject *self, PyObject *args, PyObject *kwds)
|
||||||
newargs = PyTuple_New(0);
|
newargs = PyTuple_New(0);
|
||||||
else {
|
else {
|
||||||
Py_ssize_t n = PyTuple_GET_SIZE(args);
|
Py_ssize_t n = PyTuple_GET_SIZE(args);
|
||||||
if (n > 0)
|
if (n > 0) {
|
||||||
newdefault = PyTuple_GET_ITEM(args, 0);
|
newdefault = PyTuple_GET_ITEM(args, 0);
|
||||||
|
if (!PyCallable_Check(newdefault)) {
|
||||||
|
PyErr_SetString(PyExc_TypeError,
|
||||||
|
"first argument must be callable");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
newargs = PySequence_GetSlice(args, 1, n);
|
newargs = PySequence_GetSlice(args, 1, n);
|
||||||
}
|
}
|
||||||
if (newargs == NULL)
|
if (newargs == NULL)
|
||||||
|
|
Loading…
Reference in New Issue