From de61d4bd4db868ce49a729a283763b94f2fda961 Mon Sep 17 00:00:00 2001 From: Sam Gross Date: Tue, 6 Feb 2024 11:36:23 -0500 Subject: [PATCH] gh-112066: Add `PyDict_SetDefaultRef` function. (#112123) The `PyDict_SetDefaultRef` function is similar to `PyDict_SetDefault`, but returns a strong reference through the optional `**result` pointer instead of a borrowed reference. Co-authored-by: Petr Viktorin --- Doc/c-api/dict.rst | 20 ++++ Doc/whatsnew/3.13.rst | 6 ++ Include/cpython/dictobject.h | 10 ++ Lib/test/test_capi/test_dict.py | 22 +++++ ...-11-15-13-47-48.gh-issue-112066.22WsqR.rst | 5 + Modules/_testcapi/dict.c | 26 ++++++ Objects/dictobject.c | 91 +++++++++++++++---- 7 files changed, 160 insertions(+), 20 deletions(-) create mode 100644 Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst diff --git a/Doc/c-api/dict.rst b/Doc/c-api/dict.rst index 8471c98d044..03f3d28187b 100644 --- a/Doc/c-api/dict.rst +++ b/Doc/c-api/dict.rst @@ -174,6 +174,26 @@ Dictionary Objects .. versionadded:: 3.4 +.. c:function:: int PyDict_SetDefaultRef(PyObject *p, PyObject *key, PyObject *default_value, PyObject **result) + + Inserts *default_value* into the dictionary *p* with a key of *key* if the + key is not already present in the dictionary. If *result* is not ``NULL``, + then *\*result* is set to a :term:`strong reference` to either + *default_value*, if the key was not present, or the existing value, if *key* + was already present in the dictionary. + Returns ``1`` if the key was present and *default_value* was not inserted, + or ``0`` if the key was not present and *default_value* was inserted. + On failure, returns ``-1``, sets an exception, and sets ``*result`` + to ``NULL``. + + For clarity: if you have a strong reference to *default_value* before + calling this function, then after it returns, you hold a strong reference + to both *default_value* and *\*result* (if it's not ``NULL``). + These may refer to the same object: in that case you hold two separate + references to it. + .. versionadded:: 3.13 + + .. c:function:: int PyDict_Pop(PyObject *p, PyObject *key, PyObject **result) Remove *key* from dictionary *p* and optionally return the removed value. diff --git a/Doc/whatsnew/3.13.rst b/Doc/whatsnew/3.13.rst index 372757759b9..e034d34c5fb 100644 --- a/Doc/whatsnew/3.13.rst +++ b/Doc/whatsnew/3.13.rst @@ -1440,6 +1440,12 @@ New Features not needed. (Contributed by Victor Stinner in :gh:`106004`.) +* Added :c:func:`PyDict_SetDefaultRef`, which is similar to + :c:func:`PyDict_SetDefault` but returns a :term:`strong reference` instead of + a :term:`borrowed reference`. This function returns ``-1`` on error, ``0`` on + insertion, and ``1`` if the key was already present in the dictionary. + (Contributed by Sam Gross in :gh:`112066`.) + * Add :c:func:`PyDict_ContainsString` function: same as :c:func:`PyDict_Contains`, but *key* is specified as a :c:expr:`const char*` UTF-8 encoded bytes string, rather than a :c:expr:`PyObject*`. diff --git a/Include/cpython/dictobject.h b/Include/cpython/dictobject.h index 1720fe6f01e..35b6a822a0d 100644 --- a/Include/cpython/dictobject.h +++ b/Include/cpython/dictobject.h @@ -41,6 +41,16 @@ PyAPI_FUNC(PyObject *) _PyDict_GetItemStringWithError(PyObject *, const char *); PyAPI_FUNC(PyObject *) PyDict_SetDefault( PyObject *mp, PyObject *key, PyObject *defaultobj); +// Inserts `key` with a value `default_value`, if `key` is not already present +// in the dictionary. If `result` is not NULL, then the value associated +// with `key` is returned in `*result` (either the existing value, or the now +// inserted `default_value`). +// Returns: +// -1 on error +// 0 if `key` was not present and `default_value` was inserted +// 1 if `key` was present and `default_value` was not inserted +PyAPI_FUNC(int) PyDict_SetDefaultRef(PyObject *mp, PyObject *key, PyObject *default_value, PyObject **result); + /* Get the number of items of a dictionary. */ static inline Py_ssize_t PyDict_GET_SIZE(PyObject *op) { PyDictObject *mp; diff --git a/Lib/test/test_capi/test_dict.py b/Lib/test/test_capi/test_dict.py index 57a7238588e..cca6145bc90 100644 --- a/Lib/test/test_capi/test_dict.py +++ b/Lib/test/test_capi/test_dict.py @@ -339,6 +339,28 @@ class CAPITest(unittest.TestCase): # CRASHES setdefault({}, 'a', NULL) # CRASHES setdefault(NULL, 'a', 5) + def test_dict_setdefaultref(self): + setdefault = _testcapi.dict_setdefaultref + dct = {} + self.assertEqual(setdefault(dct, 'a', 5), 5) + self.assertEqual(dct, {'a': 5}) + self.assertEqual(setdefault(dct, 'a', 8), 5) + self.assertEqual(dct, {'a': 5}) + + dct2 = DictSubclass() + self.assertEqual(setdefault(dct2, 'a', 5), 5) + self.assertEqual(dct2, {'a': 5}) + self.assertEqual(setdefault(dct2, 'a', 8), 5) + self.assertEqual(dct2, {'a': 5}) + + self.assertRaises(TypeError, setdefault, {}, [], 5) # unhashable + self.assertRaises(SystemError, setdefault, UserDict(), 'a', 5) + self.assertRaises(SystemError, setdefault, [1], 0, 5) + self.assertRaises(SystemError, setdefault, 42, 'a', 5) + # CRASHES setdefault({}, NULL, 5) + # CRASHES setdefault({}, 'a', NULL) + # CRASHES setdefault(NULL, 'a', 5) + def test_mapping_keys_valuesitems(self): class BadMapping(dict): def keys(self): diff --git a/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst b/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst new file mode 100644 index 00000000000..ae2b8b2444d --- /dev/null +++ b/Misc/NEWS.d/next/C API/2023-11-15-13-47-48.gh-issue-112066.22WsqR.rst @@ -0,0 +1,5 @@ +Add :c:func:`PyDict_SetDefaultRef`: insert a key and value into a dictionary +if the key is not already present. This is similar to +:meth:`dict.setdefault`, but returns an integer value indicating if the key +was already present. It is also similar to :c:func:`PyDict_SetDefault`, but +returns a strong reference instead of a borrowed reference. diff --git a/Modules/_testcapi/dict.c b/Modules/_testcapi/dict.c index 42e056b7d07..fe03c24f75e 100644 --- a/Modules/_testcapi/dict.c +++ b/Modules/_testcapi/dict.c @@ -225,6 +225,31 @@ dict_setdefault(PyObject *self, PyObject *args) return PyDict_SetDefault(mapping, key, defaultobj); } +static PyObject * +dict_setdefaultref(PyObject *self, PyObject *args) +{ + PyObject *obj, *key, *default_value, *result = UNINITIALIZED_PTR; + if (!PyArg_ParseTuple(args, "OOO", &obj, &key, &default_value)) { + return NULL; + } + NULLABLE(obj); + NULLABLE(key); + NULLABLE(default_value); + switch (PyDict_SetDefaultRef(obj, key, default_value, &result)) { + case -1: + assert(result == NULL); + return NULL; + case 0: + assert(result == default_value); + return result; + case 1: + return result; + default: + Py_FatalError("PyDict_SetDefaultRef() returned invalid code"); + Py_UNREACHABLE(); + } +} + static PyObject * dict_delitem(PyObject *self, PyObject *args) { @@ -433,6 +458,7 @@ static PyMethodDef test_methods[] = { {"dict_delitem", dict_delitem, METH_VARARGS}, {"dict_delitemstring", dict_delitemstring, METH_VARARGS}, {"dict_setdefault", dict_setdefault, METH_VARARGS}, + {"dict_setdefaultref", dict_setdefaultref, METH_VARARGS}, {"dict_keys", dict_keys, METH_O}, {"dict_values", dict_values, METH_O}, {"dict_items", dict_items, METH_O}, diff --git a/Objects/dictobject.c b/Objects/dictobject.c index 4bb818b90a4..11b388d9f4a 100644 --- a/Objects/dictobject.c +++ b/Objects/dictobject.c @@ -3355,8 +3355,9 @@ dict_get_impl(PyDictObject *self, PyObject *key, PyObject *default_value) return Py_NewRef(val); } -PyObject * -PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) +static int +dict_setdefault_ref(PyObject *d, PyObject *key, PyObject *default_value, + PyObject **result, int incref_result) { PyDictObject *mp = (PyDictObject *)d; PyObject *value; @@ -3365,41 +3366,64 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) if (!PyDict_Check(d)) { PyErr_BadInternalCall(); - return NULL; + if (result) { + *result = NULL; + } + return -1; } if (!PyUnicode_CheckExact(key) || (hash = unicode_get_hash(key)) == -1) { hash = PyObject_Hash(key); - if (hash == -1) - return NULL; + if (hash == -1) { + if (result) { + *result = NULL; + } + return -1; + } } if (mp->ma_keys == Py_EMPTY_KEYS) { if (insert_to_emptydict(interp, mp, Py_NewRef(key), hash, - Py_NewRef(defaultobj)) < 0) { - return NULL; + Py_NewRef(default_value)) < 0) { + if (result) { + *result = NULL; + } + return -1; } - return defaultobj; + if (result) { + *result = incref_result ? Py_NewRef(default_value) : default_value; + } + return 0; } if (!PyUnicode_CheckExact(key) && DK_IS_UNICODE(mp->ma_keys)) { if (insertion_resize(interp, mp, 0) < 0) { - return NULL; + if (result) { + *result = NULL; + } + return -1; } } Py_ssize_t ix = _Py_dict_lookup(mp, key, hash, &value); - if (ix == DKIX_ERROR) - return NULL; + if (ix == DKIX_ERROR) { + if (result) { + *result = NULL; + } + return -1; + } if (ix == DKIX_EMPTY) { uint64_t new_version = _PyDict_NotifyEvent( - interp, PyDict_EVENT_ADDED, mp, key, defaultobj); + interp, PyDict_EVENT_ADDED, mp, key, default_value); mp->ma_keys->dk_version = 0; - value = defaultobj; + value = default_value; if (mp->ma_keys->dk_usable <= 0) { if (insertion_resize(interp, mp, 1) < 0) { - return NULL; + if (result) { + *result = NULL; + } + return -1; } } Py_ssize_t hashpos = find_empty_slot(mp->ma_keys, hash); @@ -3431,11 +3455,16 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) mp->ma_keys->dk_usable--; mp->ma_keys->dk_nentries++; assert(mp->ma_keys->dk_usable >= 0); + ASSERT_CONSISTENT(mp); + if (result) { + *result = incref_result ? Py_NewRef(value) : value; + } + return 0; } else if (value == NULL) { uint64_t new_version = _PyDict_NotifyEvent( - interp, PyDict_EVENT_ADDED, mp, key, defaultobj); - value = defaultobj; + interp, PyDict_EVENT_ADDED, mp, key, default_value); + value = default_value; assert(_PyDict_HasSplitTable(mp)); assert(mp->ma_values->values[ix] == NULL); MAINTAIN_TRACKING(mp, key, value); @@ -3443,10 +3472,33 @@ PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) _PyDictValues_AddToInsertionOrder(mp->ma_values, ix); mp->ma_used++; mp->ma_version_tag = new_version; + ASSERT_CONSISTENT(mp); + if (result) { + *result = incref_result ? Py_NewRef(value) : value; + } + return 0; } ASSERT_CONSISTENT(mp); - return value; + if (result) { + *result = incref_result ? Py_NewRef(value) : value; + } + return 1; +} + +int +PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value, + PyObject **result) +{ + return dict_setdefault_ref(d, key, default_value, result, 1); +} + +PyObject * +PyDict_SetDefault(PyObject *d, PyObject *key, PyObject *defaultobj) +{ + PyObject *result; + dict_setdefault_ref(d, key, defaultobj, &result, 0); + return result; } /*[clinic input] @@ -3467,9 +3519,8 @@ dict_setdefault_impl(PyDictObject *self, PyObject *key, /*[clinic end generated code: output=f8c1101ebf69e220 input=0f063756e815fd9d]*/ { PyObject *val; - - val = PyDict_SetDefault((PyObject *)self, key, default_value); - return Py_XNewRef(val); + PyDict_SetDefaultRef((PyObject *)self, key, default_value, &val); + return val; }