"""Unit tests for collections.defaultdict.""" import os import copy import pickle import tempfile import unittest from collections import defaultdict def foobar(): return list class TestDefaultDict(unittest.TestCase): def test_basic(self): d1 = defaultdict() self.assertEqual(d1.default_factory, None) d1.default_factory = list d1[12].append(42) self.assertEqual(d1, {12: [42]}) d1[12].append(24) self.assertEqual(d1, {12: [42, 24]}) d1[13] d1[14] self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) self.assertTrue(d1[12] is not d1[13] is not d1[14]) d2 = defaultdict(list, foo=1, bar=2) self.assertEqual(d2.default_factory, list) self.assertEqual(d2, {"foo": 1, "bar": 2}) self.assertEqual(d2["foo"], 1) self.assertEqual(d2["bar"], 2) self.assertEqual(d2[42], []) self.assertIn("foo", d2) self.assertIn("foo", d2.keys()) self.assertIn("bar", d2) self.assertIn("bar", d2.keys()) self.assertIn(42, d2) self.assertIn(42, d2.keys()) self.assertNotIn(12, d2) self.assertNotIn(12, d2.keys()) d2.default_factory = None self.assertEqual(d2.default_factory, None) try: d2[15] except KeyError as err: self.assertEqual(err.args, (15,)) else: self.fail("d2[15] didn't raise KeyError") self.assertRaises(TypeError, defaultdict, 1) def test_missing(self): d1 = defaultdict() self.assertRaises(KeyError, d1.__missing__, 42) d1.default_factory = list self.assertEqual(d1.__missing__(42), []) def test_repr(self): d1 = defaultdict() self.assertEqual(d1.default_factory, None) self.assertEqual(repr(d1), "defaultdict(None, {})") self.assertEqual(eval(repr(d1)), d1) d1[11] = 41 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") d2 = defaultdict(int) self.assertEqual(d2.default_factory, int) d2[12] = 42 self.assertRegex(repr(d2), r"defaultdict\(, {12: 42}\)") def foo(): return 43 d3 = defaultdict(foo) self.assertTrue(d3.default_factory is foo) d3[13] self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) def test_print(self): d1 = defaultdict() def foo(): return 42 d2 = defaultdict(foo, {1: 2}) # NOTE: We can't use tempfile.[Named]TemporaryFile since this # code must exercise the tp_print C code, which only gets # invoked for *real* files. tfn = tempfile.mktemp() try: f = open(tfn, "w+") try: print(d1, file=f) print(d2, file=f) f.seek(0) self.assertEqual(f.readline(), repr(d1) + "\n") self.assertEqual(f.readline(), repr(d2) + "\n") finally: f.close() finally: os.remove(tfn) def test_copy(self): d1 = defaultdict() d2 = d1.copy() self.assertEqual(type(d2), defaultdict) self.assertEqual(d2.default_factory, None) self.assertEqual(d2, {}) d1.default_factory = list d3 = d1.copy() self.assertEqual(type(d3), defaultdict) self.assertEqual(d3.default_factory, list) self.assertEqual(d3, {}) d1[42] d4 = d1.copy() self.assertEqual(type(d4), defaultdict) self.assertEqual(d4.default_factory, list) self.assertEqual(d4, {42: []}) d4[12] self.assertEqual(d4, {42: [], 12: []}) # Issue 6637: Copy fails for empty default dict d = defaultdict() d['a'] = 42 e = d.copy() self.assertEqual(e['a'], 42) def test_shallow_copy(self): d1 = defaultdict(foobar, {1: 1}) d2 = copy.copy(d1) self.assertEqual(d2.default_factory, foobar) self.assertEqual(d2, d1) d1.default_factory = list d2 = copy.copy(d1) self.assertEqual(d2.default_factory, list) self.assertEqual(d2, d1) def test_deep_copy(self): d1 = defaultdict(foobar, {1: [1]}) d2 = copy.deepcopy(d1) self.assertEqual(d2.default_factory, foobar) self.assertEqual(d2, d1) self.assertTrue(d1[1] is not d2[1]) d1.default_factory = list d2 = copy.deepcopy(d1) self.assertEqual(d2.default_factory, list) self.assertEqual(d2, d1) def test_keyerror_without_factory(self): d1 = defaultdict() try: d1[(1,)] except KeyError as err: self.assertEqual(err.args[0], (1,)) else: self.fail("expected KeyError") def test_recursive_repr(self): # Issue2045: stack overflow when default_factory is a bound method class sub(defaultdict): def __init__(self): self.default_factory = self._factory def _factory(self): return [] d = sub() self.assertRegex(repr(d), r"defaultdict\(, \{\}\)") # NOTE: printing a subclass of a builtin type does not call its # tp_print slot. So this part is essentially the same test as above. tfn = tempfile.mktemp() try: f = open(tfn, "w+") try: print(d, file=f) finally: f.close() finally: os.remove(tfn) def test_callable_arg(self): self.assertRaises(TypeError, defaultdict, {}) def test_pickleing(self): d = defaultdict(int) d[1] for proto in range(pickle.HIGHEST_PROTOCOL + 1): s = pickle.dumps(d, proto) o = pickle.loads(s) self.assertEqual(d, o) if __name__ == "__main__": unittest.main()