diff --git a/Doc/library/collections.rst b/Doc/library/collections.rst index f4a383c8ea5..72d1052064a 100644 --- a/Doc/library/collections.rst +++ b/Doc/library/collections.rst @@ -1123,6 +1123,10 @@ anywhere a regular dictionary is used. passed to the :class:`OrderedDict` constructor and its :meth:`update` method. +.. versionchanged:: 3.9 + Added merge (``|``) and update (``|=``) operators, specified in :pep:`584`. + + :class:`OrderedDict` Examples and Recipes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 1aa7d10ad22..18255da1759 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -293,6 +293,24 @@ class OrderedDict(dict): return dict.__eq__(self, other) and all(map(_eq, self, other)) return dict.__eq__(self, other) + def __ior__(self, other): + self.update(other) + return self + + def __or__(self, other): + if not isinstance(other, dict): + return NotImplemented + new = self.__class__(self) + new.update(other) + return new + + def __ror__(self, other): + if not isinstance(other, dict): + return NotImplemented + new = self.__class__(other) + new.update(self) + return new + try: from _collections import OrderedDict diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py index eb0a8f4ba0d..fdea44e4d85 100644 --- a/Lib/test/test_ordered_dict.py +++ b/Lib/test/test_ordered_dict.py @@ -654,6 +654,49 @@ class OrderedDictTests: support.check_free_after_iterating(self, lambda d: iter(d.values()), self.OrderedDict) support.check_free_after_iterating(self, lambda d: iter(d.items()), self.OrderedDict) + def test_merge_operator(self): + OrderedDict = self.OrderedDict + + a = OrderedDict({0: 0, 1: 1, 2: 1}) + b = OrderedDict({1: 1, 2: 2, 3: 3}) + + c = a.copy() + d = a.copy() + c |= b + d |= list(b.items()) + expected = OrderedDict({0: 0, 1: 1, 2: 2, 3: 3}) + self.assertEqual(a | dict(b), expected) + self.assertEqual(a | b, expected) + self.assertEqual(c, expected) + self.assertEqual(d, expected) + + c = b.copy() + c |= a + expected = OrderedDict({1: 1, 2: 1, 3: 3, 0: 0}) + self.assertEqual(dict(b) | a, expected) + self.assertEqual(b | a, expected) + self.assertEqual(c, expected) + + self.assertIs(type(a | b), OrderedDict) + self.assertIs(type(dict(a) | b), OrderedDict) + self.assertIs(type(a | dict(b)), OrderedDict) + + expected = a.copy() + a |= () + a |= "" + self.assertEqual(a, expected) + + with self.assertRaises(TypeError): + a | None + with self.assertRaises(TypeError): + a | () + with self.assertRaises(TypeError): + a | "BAD" + with self.assertRaises(TypeError): + a | "" + with self.assertRaises(ValueError): + a |= "BAD" + class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2020-03-12-11-55-16.bpo-36144.9bxGH_.rst b/Misc/NEWS.d/next/Library/2020-03-12-11-55-16.bpo-36144.9bxGH_.rst new file mode 100644 index 00000000000..6cc35a21428 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2020-03-12-11-55-16.bpo-36144.9bxGH_.rst @@ -0,0 +1,2 @@ +:class:`collections.OrderedDict` now implements ``|`` and ``|=`` +(:pep:`584`). diff --git a/Objects/odictobject.c b/Objects/odictobject.c index 6813cddfddc..220ae92ec92 100644 --- a/Objects/odictobject.c +++ b/Objects/odictobject.c @@ -851,6 +851,57 @@ static PyMappingMethods odict_as_mapping = { }; +/* ---------------------------------------------- + * OrderedDict number methods + */ + +static int mutablemapping_update_arg(PyObject*, PyObject*); + +static PyObject * +odict_or(PyObject *left, PyObject *right) +{ + PyTypeObject *type; + PyObject *other; + if (PyODict_Check(left)) { + type = Py_TYPE(left); + other = right; + } + else { + type = Py_TYPE(right); + other = left; + } + if (!PyDict_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + PyObject *new = PyObject_CallOneArg((PyObject*)type, left); + if (!new) { + return NULL; + } + if (mutablemapping_update_arg(new, right) < 0) { + Py_DECREF(new); + return NULL; + } + return new; +} + +static PyObject * +odict_inplace_or(PyObject *self, PyObject *other) +{ + if (mutablemapping_update_arg(self, other) < 0) { + return NULL; + } + Py_INCREF(self); + return self; +} + +/* tp_as_number */ + +static PyNumberMethods odict_as_number = { + .nb_or = odict_or, + .nb_inplace_or = odict_inplace_or, +}; + + /* ---------------------------------------------- * OrderedDict methods */ @@ -1555,7 +1606,7 @@ PyTypeObject PyODict_Type = { 0, /* tp_setattr */ 0, /* tp_as_async */ (reprfunc)odict_repr, /* tp_repr */ - 0, /* tp_as_number */ + &odict_as_number, /* tp_as_number */ 0, /* tp_as_sequence */ &odict_as_mapping, /* tp_as_mapping */ 0, /* tp_hash */ @@ -2189,16 +2240,77 @@ Done: return 0; } +static int +mutablemapping_update_arg(PyObject *self, PyObject *arg) +{ + int res = 0; + if (PyDict_CheckExact(arg)) { + PyObject *items = PyDict_Items(arg); + if (items == NULL) { + return -1; + } + res = mutablemapping_add_pairs(self, items); + Py_DECREF(items); + return res; + } + _Py_IDENTIFIER(keys); + PyObject *func; + if (_PyObject_LookupAttrId(arg, &PyId_keys, &func) < 0) { + return -1; + } + if (func != NULL) { + PyObject *keys = _PyObject_CallNoArg(func); + Py_DECREF(func); + if (keys == NULL) { + return -1; + } + PyObject *iterator = PyObject_GetIter(keys); + Py_DECREF(keys); + if (iterator == NULL) { + return -1; + } + PyObject *key; + while (res == 0 && (key = PyIter_Next(iterator))) { + PyObject *value = PyObject_GetItem(arg, key); + if (value != NULL) { + res = PyObject_SetItem(self, key, value); + Py_DECREF(value); + } + else { + res = -1; + } + Py_DECREF(key); + } + Py_DECREF(iterator); + if (res != 0 || PyErr_Occurred()) { + return -1; + } + return 0; + } + if (_PyObject_LookupAttrId(arg, &PyId_items, &func) < 0) { + return -1; + } + if (func != NULL) { + PyObject *items = _PyObject_CallNoArg(func); + Py_DECREF(func); + if (items == NULL) { + return -1; + } + res = mutablemapping_add_pairs(self, items); + Py_DECREF(items); + return res; + } + res = mutablemapping_add_pairs(self, arg); + return res; +} + static PyObject * mutablemapping_update(PyObject *self, PyObject *args, PyObject *kwargs) { - int res = 0; - Py_ssize_t len; - _Py_IDENTIFIER(keys); - + int res; /* first handle args, if any */ assert(args == NULL || PyTuple_Check(args)); - len = (args != NULL) ? PyTuple_GET_SIZE(args) : 0; + Py_ssize_t len = (args != NULL) ? PyTuple_GET_SIZE(args) : 0; if (len > 1) { const char *msg = "update() takes at most 1 positional argument (%zd given)"; PyErr_Format(PyExc_TypeError, msg, len); @@ -2206,83 +2318,16 @@ mutablemapping_update(PyObject *self, PyObject *args, PyObject *kwargs) } if (len) { - PyObject *func; PyObject *other = PyTuple_GET_ITEM(args, 0); /* borrowed reference */ assert(other != NULL); Py_INCREF(other); - if (PyDict_CheckExact(other)) { - PyObject *items = PyDict_Items(other); - Py_DECREF(other); - if (items == NULL) - return NULL; - res = mutablemapping_add_pairs(self, items); - Py_DECREF(items); - if (res == -1) - return NULL; - goto handle_kwargs; - } - - if (_PyObject_LookupAttrId(other, &PyId_keys, &func) < 0) { - Py_DECREF(other); - return NULL; - } - if (func != NULL) { - PyObject *keys, *iterator, *key; - keys = _PyObject_CallNoArg(func); - Py_DECREF(func); - if (keys == NULL) { - Py_DECREF(other); - return NULL; - } - iterator = PyObject_GetIter(keys); - Py_DECREF(keys); - if (iterator == NULL) { - Py_DECREF(other); - return NULL; - } - while (res == 0 && (key = PyIter_Next(iterator))) { - PyObject *value = PyObject_GetItem(other, key); - if (value != NULL) { - res = PyObject_SetItem(self, key, value); - Py_DECREF(value); - } - else { - res = -1; - } - Py_DECREF(key); - } - Py_DECREF(other); - Py_DECREF(iterator); - if (res != 0 || PyErr_Occurred()) - return NULL; - goto handle_kwargs; - } - - if (_PyObject_LookupAttrId(other, &PyId_items, &func) < 0) { - Py_DECREF(other); - return NULL; - } - if (func != NULL) { - PyObject *items; - Py_DECREF(other); - items = _PyObject_CallNoArg(func); - Py_DECREF(func); - if (items == NULL) - return NULL; - res = mutablemapping_add_pairs(self, items); - Py_DECREF(items); - if (res == -1) - return NULL; - goto handle_kwargs; - } - - res = mutablemapping_add_pairs(self, other); + res = mutablemapping_update_arg(self, other); Py_DECREF(other); - if (res != 0) + if (res < 0) { return NULL; + } } - handle_kwargs: /* now handle kwargs */ assert(kwargs == NULL || PyDict_Check(kwargs)); if (kwargs != NULL && PyDict_GET_SIZE(kwargs)) {