180 lines
5.6 KiB
Python
180 lines
5.6 KiB
Python
"""Unit tests for collections.defaultdict."""
|
|
|
|
import os
|
|
import copy
|
|
import tempfile
|
|
import unittest
|
|
from test import test_support
|
|
|
|
from collections import defaultdict
|
|
|
|
def foobar():
|
|
return list
|
|
|
|
class TestDefaultDict(unittest.TestCase):
|
|
|
|
def test_basic(self):
|
|
d1 = defaultdict()
|
|
self.assertEqual(d1.default_factory, None)
|
|
d1.default_factory = list
|
|
d1[12].append(42)
|
|
self.assertEqual(d1, {12: [42]})
|
|
d1[12].append(24)
|
|
self.assertEqual(d1, {12: [42, 24]})
|
|
d1[13]
|
|
d1[14]
|
|
self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
|
|
self.assertTrue(d1[12] is not d1[13] is not d1[14])
|
|
d2 = defaultdict(list, foo=1, bar=2)
|
|
self.assertEqual(d2.default_factory, list)
|
|
self.assertEqual(d2, {"foo": 1, "bar": 2})
|
|
self.assertEqual(d2["foo"], 1)
|
|
self.assertEqual(d2["bar"], 2)
|
|
self.assertEqual(d2[42], [])
|
|
self.assertTrue("foo" in d2)
|
|
self.assertTrue("foo" in d2.keys())
|
|
self.assertTrue("bar" in d2)
|
|
self.assertTrue("bar" in d2.keys())
|
|
self.assertTrue(42 in d2)
|
|
self.assertTrue(42 in d2.keys())
|
|
self.assertTrue(12 not in d2)
|
|
self.assertTrue(12 not in d2.keys())
|
|
d2.default_factory = None
|
|
self.assertEqual(d2.default_factory, None)
|
|
try:
|
|
d2[15]
|
|
except KeyError, err:
|
|
self.assertEqual(err.args, (15,))
|
|
else:
|
|
self.fail("d2[15] didn't raise KeyError")
|
|
self.assertRaises(TypeError, defaultdict, 1)
|
|
|
|
def test_missing(self):
|
|
d1 = defaultdict()
|
|
self.assertRaises(KeyError, d1.__missing__, 42)
|
|
d1.default_factory = list
|
|
self.assertEqual(d1.__missing__(42), [])
|
|
|
|
def test_repr(self):
|
|
d1 = defaultdict()
|
|
self.assertEqual(d1.default_factory, None)
|
|
self.assertEqual(repr(d1), "defaultdict(None, {})")
|
|
self.assertEqual(eval(repr(d1)), d1)
|
|
d1[11] = 41
|
|
self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
|
|
d2 = defaultdict(int)
|
|
self.assertEqual(d2.default_factory, int)
|
|
d2[12] = 42
|
|
self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
|
|
def foo(): return 43
|
|
d3 = defaultdict(foo)
|
|
self.assertTrue(d3.default_factory is foo)
|
|
d3[13]
|
|
self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
|
|
|
|
def test_print(self):
|
|
d1 = defaultdict()
|
|
def foo(): return 42
|
|
d2 = defaultdict(foo, {1: 2})
|
|
# NOTE: We can't use tempfile.[Named]TemporaryFile since this
|
|
# code must exercise the tp_print C code, which only gets
|
|
# invoked for *real* files.
|
|
tfn = tempfile.mktemp()
|
|
try:
|
|
f = open(tfn, "w+")
|
|
try:
|
|
print >>f, d1
|
|
print >>f, d2
|
|
f.seek(0)
|
|
self.assertEqual(f.readline(), repr(d1) + "\n")
|
|
self.assertEqual(f.readline(), repr(d2) + "\n")
|
|
finally:
|
|
f.close()
|
|
finally:
|
|
os.remove(tfn)
|
|
|
|
def test_copy(self):
|
|
d1 = defaultdict()
|
|
d2 = d1.copy()
|
|
self.assertEqual(type(d2), defaultdict)
|
|
self.assertEqual(d2.default_factory, None)
|
|
self.assertEqual(d2, {})
|
|
d1.default_factory = list
|
|
d3 = d1.copy()
|
|
self.assertEqual(type(d3), defaultdict)
|
|
self.assertEqual(d3.default_factory, list)
|
|
self.assertEqual(d3, {})
|
|
d1[42]
|
|
d4 = d1.copy()
|
|
self.assertEqual(type(d4), defaultdict)
|
|
self.assertEqual(d4.default_factory, list)
|
|
self.assertEqual(d4, {42: []})
|
|
d4[12]
|
|
self.assertEqual(d4, {42: [], 12: []})
|
|
|
|
# Issue 6637: Copy fails for empty default dict
|
|
d = defaultdict()
|
|
d['a'] = 42
|
|
e = d.copy()
|
|
self.assertEqual(e['a'], 42)
|
|
|
|
def test_shallow_copy(self):
|
|
d1 = defaultdict(foobar, {1: 1})
|
|
d2 = copy.copy(d1)
|
|
self.assertEqual(d2.default_factory, foobar)
|
|
self.assertEqual(d2, d1)
|
|
d1.default_factory = list
|
|
d2 = copy.copy(d1)
|
|
self.assertEqual(d2.default_factory, list)
|
|
self.assertEqual(d2, d1)
|
|
|
|
def test_deep_copy(self):
|
|
d1 = defaultdict(foobar, {1: [1]})
|
|
d2 = copy.deepcopy(d1)
|
|
self.assertEqual(d2.default_factory, foobar)
|
|
self.assertEqual(d2, d1)
|
|
self.assertTrue(d1[1] is not d2[1])
|
|
d1.default_factory = list
|
|
d2 = copy.deepcopy(d1)
|
|
self.assertEqual(d2.default_factory, list)
|
|
self.assertEqual(d2, d1)
|
|
|
|
def test_keyerror_without_factory(self):
|
|
d1 = defaultdict()
|
|
try:
|
|
d1[(1,)]
|
|
except KeyError, err:
|
|
self.assertEqual(err.args[0], (1,))
|
|
else:
|
|
self.fail("expected KeyError")
|
|
|
|
def test_recursive_repr(self):
|
|
# Issue2045: stack overflow when default_factory is a bound method
|
|
class sub(defaultdict):
|
|
def __init__(self):
|
|
self.default_factory = self._factory
|
|
def _factory(self):
|
|
return []
|
|
d = sub()
|
|
self.assertTrue(repr(d).startswith(
|
|
"defaultdict(<bound method sub._factory of defaultdict(..."))
|
|
|
|
# NOTE: printing a subclass of a builtin type does not call its
|
|
# tp_print slot. So this part is essentially the same test as above.
|
|
tfn = tempfile.mktemp()
|
|
try:
|
|
f = open(tfn, "w+")
|
|
try:
|
|
print >>f, d
|
|
finally:
|
|
f.close()
|
|
finally:
|
|
os.remove(tfn)
|
|
|
|
|
|
def test_main():
|
|
test_support.run_unittest(TestDefaultDict)
|
|
|
|
if __name__ == "__main__":
|
|
test_main()
|