diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst index 8dcf9451d72..f4a383c8ea5 100644 --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -729,6 +729,10 @@ stack manipulations such as ``dup``, ``drop``, ``swap``, ``over``, ``pick``, initialized from the first argument to the constructor, if present, or to ``None``, if absent. + .. versionchanged:: 3.9 + Added merge (``|``) and update (``|=``) operators, specified in + :pep:`584`. + :class:`defaultdict` Examples ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/Lib/test/test_defaultdict.py b/Lib/test/test_defaultdict.py index b9f1fb9f23d..b48c649fce6 100644 --- a/Lib/test/test_defaultdict.py +++ b/Lib/test/test_defaultdict.py @@ -183,5 +183,43 @@ class TestDefaultDict(unittest.TestCase): o = pickle.loads(s) self.assertEqual(d, o) + def test_union(self): + i = defaultdict(int, {1: 1, 2: 2}) + s = defaultdict(str, {0: "zero", 1: "one"}) + + i_s = i | s + self.assertIs(i_s.default_factory, int) + self.assertDictEqual(i_s, {1: "one", 2: 2, 0: "zero"}) + self.assertEqual(list(i_s), [1, 2, 0]) + + s_i = s | i + self.assertIs(s_i.default_factory, str) + self.assertDictEqual(s_i, {0: "zero", 1: 1, 2: 2}) + self.assertEqual(list(s_i), [0, 1, 2]) + + i_ds = i | dict(s) + self.assertIs(i_ds.default_factory, int) + self.assertDictEqual(i_ds, {1: "one", 2: 2, 0: "zero"}) + self.assertEqual(list(i_ds), [1, 2, 0]) + + ds_i = dict(s) | i + self.assertIs(ds_i.default_factory, int) + self.assertDictEqual(ds_i, {0: "zero", 1: 1, 2: 2}) + self.assertEqual(list(ds_i), [0, 1, 2]) + + with self.assertRaises(TypeError): + i | list(s.items()) + with self.assertRaises(TypeError): + list(s.items()) | i + + # We inherit a fine |= from dict, so just a few sanity checks here: + i |= list(s.items()) + self.assertIs(i.default_factory, int) + self.assertDictEqual(i, {1: "one", 2: 2, 0: "zero"}) + self.assertEqual(list(i), [1, 2, 0]) + + with self.assertRaises(TypeError): + i |= None + if __name__ == "__main__": unittest.main() diff --git a/Misc/NEWS.d/next/Library/2020-02-29-15-54-08.bpo-36144.4GgTZs.rst b/Misc/NEWS.d/next/Library/2020-02-29-15-54-08.bpo-36144.4GgTZs.rst new file mode 100644 index 00000000000..416d5ac3a27 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-02-29-15-54-08.bpo-36144.4GgTZs.rst @@ -0,0 +1 @@ +:class:`collections.defaultdict` now implements ``|`` (:pep:`584`). diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c index 4d5d874b44d..d0a381deabf 100644 --- a/Modules/_collectionsmodule.c +++ b/Modules/_collectionsmodule.c @@ -1990,6 +1990,13 @@ defdict_missing(defdictobject *dd, PyObject *key) return value; } +static inline PyObject* +new_defdict(defdictobject *dd, PyObject *arg) +{ + return PyObject_CallFunctionObjArgs((PyObject*)Py_TYPE(dd), + dd->default_factory ? dd->default_factory : Py_None, arg, NULL); +} + PyDoc_STRVAR(defdict_copy_doc, "D.copy() -> a shallow copy of D."); static PyObject * @@ -1999,11 +2006,7 @@ defdict_copy(defdictobject *dd, PyObject *Py_UNUSED(ignored)) whose class constructor has the same signature. Subclasses that define a different constructor signature must override copy(). */ - - if (dd->default_factory == NULL) - return PyObject_CallFunctionObjArgs((PyObject*)Py_TYPE(dd), Py_None, dd, NULL); - return PyObject_CallFunctionObjArgs((PyObject*)Py_TYPE(dd), - dd->default_factory, dd, NULL); + return new_defdict(dd, (PyObject*)dd); } static PyObject * @@ -2127,6 +2130,42 @@ defdict_repr(defdictobject *dd) return result; } +static PyObject* +defdict_or(PyObject* left, PyObject* right) +{ + int left_is_self = PyObject_IsInstance(left, (PyObject*)&defdict_type); + if (left_is_self < 0) { + return NULL; + } + PyObject *self, *other; + if (left_is_self) { + self = left; + other = right; + } + else { + self = right; + other = left; + } + if (!PyDict_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + // Like copy(), this calls the object's class. + // Override __or__/__ror__ for subclasses with different constructors. + PyObject *new = new_defdict((defdictobject*)self, left); + if (!new) { + return NULL; + } + if (PyDict_Update(new, right)) { + Py_DECREF(new); + return NULL; + } + return new; +} + +static PyNumberMethods defdict_as_number = { + .nb_or = defdict_or, +}; + static int defdict_traverse(PyObject *self, visitproc visit, void *arg) { @@ -2198,7 +2237,7 @@ static PyTypeObject defdict_type = { 0, /* tp_setattr */ 0, /* tp_as_async */ (reprfunc)defdict_repr, /* tp_repr */ - 0, /* tp_as_number */ + &defdict_as_number, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */