From eb8ac57af26c4eb96a8230eba7492ce5ceef7886 Mon Sep 17 00:00:00 2001 From: Brandt Bucher Date: Mon, 24 Feb 2020 19:47:34 -0800 Subject: [PATCH] bpo-36144: Dictionary Union (PEP 584) (#12088) --- Lib/collections/__init__.py | 20 ++++++ Lib/test/test_dict.py | 32 +++++++++ .../2019-03-02-23-03-34.bpo-36144.LRl4LS.rst | 2 + Objects/dictobject.c | 71 ++++++++++++++----- 4 files changed, 107 insertions(+), 18 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2019-03-02-23-03-34.bpo-36144.LRl4LS.rst diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 178cdb1fa5b..1aa7d10ad22 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -994,6 +994,26 @@ class UserDict(_collections_abc.MutableMapping): # Now, add the methods in dicts but not in MutableMapping def __repr__(self): return repr(self.data) + + def __or__(self, other): + if isinstance(other, UserDict): + return self.__class__(self.data | other.data) + if isinstance(other, dict): + return self.__class__(self.data | other) + return NotImplemented + def __ror__(self, other): + if isinstance(other, UserDict): + return self.__class__(other.data | self.data) + if isinstance(other, dict): + return self.__class__(other | self.data) + return NotImplemented + def __ior__(self, other): + if isinstance(other, UserDict): + self.data |= other.data + else: + self.data |= other + return self + def __copy__(self): inst = self.__class__.__new__(self.__class__) inst.__dict__.update(self.__dict__) diff --git a/Lib/test/test_dict.py b/Lib/test/test_dict.py index de483ab5521..d5a3d9e8945 100644 --- a/Lib/test/test_dict.py +++ b/Lib/test/test_dict.py @@ -37,6 +37,38 @@ class DictTest(unittest.TestCase): dictliteral = '{' + ', '.join(formatted_items) + '}' self.assertEqual(eval(dictliteral), dict(items)) + def test_merge_operator(self): + + a = {0: 0, 1: 1, 2: 1} + b = {1: 1, 2: 2, 3: 3} + + c = a.copy() + c |= b + + self.assertEqual(a | b, {0: 0, 1: 1, 2: 2, 3: 3}) + self.assertEqual(c, {0: 0, 1: 1, 2: 2, 3: 3}) + + c = b.copy() + c |= a + + self.assertEqual(b | a, {1: 1, 2: 1, 3: 3, 0: 0}) + self.assertEqual(c, {1: 1, 2: 1, 3: 3, 0: 0}) + + c = a.copy() + c |= [(1, 1), (2, 2), (3, 3)] + + self.assertEqual(c, {0: 0, 1: 1, 2: 2, 3: 3}) + + self.assertIs(a.__or__(None), NotImplemented) + self.assertIs(a.__or__(()), NotImplemented) + self.assertIs(a.__or__("BAD"), NotImplemented) + self.assertIs(a.__or__(""), NotImplemented) + + self.assertRaises(TypeError, a.__ior__, None) + self.assertEqual(a.__ior__(()), {0: 0, 1: 1, 2: 1}) + self.assertRaises(ValueError, a.__ior__, "BAD") + self.assertEqual(a.__ior__(""), {0: 0, 1: 1, 2: 1}) + def test_bool(self): self.assertIs(not {}, True) self.assertTrue({1: 2}) diff --git a/Misc/NEWS.d/next/Core and Builtins/2019-03-02-23-03-34.bpo-36144.LRl4LS.rst b/Misc/NEWS.d/next/Core and Builtins/2019-03-02-23-03-34.bpo-36144.LRl4LS.rst new file mode 100644 index 00000000000..7d6d076ea7d --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2019-03-02-23-03-34.bpo-36144.LRl4LS.rst @@ -0,0 +1,2 @@ +:class:`dict` (and :class:`collections.UserDict`) objects now support PEP 584's merge (``|``) and update (``|=``) operators. +Patch by Brandt Bucher. diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 86ac4ef4816..4aa927afd9c 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -2320,6 +2320,25 @@ dict_fromkeys_impl(PyTypeObject *type, PyObject *iterable, PyObject *value) return _PyDict_FromKeys((PyObject *)type, iterable, value); } +/* Single-arg dict update; used by dict_update_common and operators. */ +static int +dict_update_arg(PyObject *self, PyObject *arg) +{ + if (PyDict_CheckExact(arg)) { + return PyDict_Merge(self, arg, 1); + } + _Py_IDENTIFIER(keys); + PyObject *func; + if (_PyObject_LookupAttrId(arg, &PyId_keys, &func) < 0) { + return -1; + } + if (func != NULL) { + Py_DECREF(func); + return PyDict_Merge(self, arg, 1); + } + return PyDict_MergeFromSeq2(self, arg, 1); +} + static int dict_update_common(PyObject *self, PyObject *args, PyObject *kwds, const char *methname) @@ -2331,23 +2350,7 @@ dict_update_common(PyObject *self, PyObject *args, PyObject *kwds, result = -1; } else if (arg != NULL) { - if (PyDict_CheckExact(arg)) { - result = PyDict_Merge(self, arg, 1); - } - else { - _Py_IDENTIFIER(keys); - PyObject *func; - if (_PyObject_LookupAttrId(arg, &PyId_keys, &func) < 0) { - result = -1; - } - else if (func != NULL) { - Py_DECREF(func); - result = PyDict_Merge(self, arg, 1); - } - else { - result = PyDict_MergeFromSeq2(self, arg, 1); - } - } + result = dict_update_arg(self, arg); } if (result == 0 && kwds != NULL) { @@ -3169,6 +3172,33 @@ dict_sizeof(PyDictObject *mp, PyObject *Py_UNUSED(ignored)) return PyLong_FromSsize_t(_PyDict_SizeOf(mp)); } +static PyObject * +dict_or(PyObject *self, PyObject *other) +{ + if (!PyDict_Check(self) || !PyDict_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + PyObject *new = PyDict_Copy(self); + if (new == NULL) { + return NULL; + } + if (dict_update_arg(new, other)) { + Py_DECREF(new); + return NULL; + } + return new; +} + +static PyObject * +dict_ior(PyObject *self, PyObject *other) +{ + if (dict_update_arg(self, other)) { + return NULL; + } + Py_INCREF(self); + return self; +} + PyDoc_STRVAR(getitem__doc__, "x.__getitem__(y) <==> x[y]"); PyDoc_STRVAR(sizeof__doc__, @@ -3274,6 +3304,11 @@ static PySequenceMethods dict_as_sequence = { 0, /* sq_inplace_repeat */ }; +static PyNumberMethods dict_as_number = { + .nb_or = dict_or, + .nb_inplace_or = dict_ior, +}; + static PyObject * dict_new(PyTypeObject *type, PyObject *args, PyObject *kwds) { @@ -3335,7 +3370,7 @@ PyTypeObject PyDict_Type = { 0, /* tp_setattr */ 0, /* tp_as_async */ (reprfunc)dict_repr, /* tp_repr */ - 0, /* tp_as_number */ + &dict_as_number, /* tp_as_number */ &dict_as_sequence, /* tp_as_sequence */ &dict_as_mapping, /* tp_as_mapping */ PyObject_HashNotImplemented, /* tp_hash */