mirror of https://github.com/python/cpython
Issue #18594: Make the C code more closely match the pure python code.
This commit is contained in:
parent
5b22dd87aa
commit
cb1d96f782
|
@ -818,6 +818,24 @@ class TestCollectionABCs(ABCTestCase):
|
|||
### Counter
|
||||
################################################################################
|
||||
|
||||
class CounterSubclassWithSetItem(Counter):
|
||||
# Test a counter subclass that overrides __setitem__
|
||||
def __init__(self, *args, **kwds):
|
||||
self.called = False
|
||||
Counter.__init__(self, *args, **kwds)
|
||||
def __setitem__(self, key, value):
|
||||
self.called = True
|
||||
Counter.__setitem__(self, key, value)
|
||||
|
||||
class CounterSubclassWithGet(Counter):
|
||||
# Test a counter subclass that overrides get()
|
||||
def __init__(self, *args, **kwds):
|
||||
self.called = False
|
||||
Counter.__init__(self, *args, **kwds)
|
||||
def get(self, key, default):
|
||||
self.called = True
|
||||
return Counter.get(self, key, default)
|
||||
|
||||
class TestCounter(unittest.TestCase):
|
||||
|
||||
def test_basics(self):
|
||||
|
@ -1022,6 +1040,12 @@ class TestCounter(unittest.TestCase):
|
|||
self.assertEqual(m,
|
||||
OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
|
||||
|
||||
# test fidelity to the pure python version
|
||||
c = CounterSubclassWithSetItem('abracadabra')
|
||||
self.assertTrue(c.called)
|
||||
c = CounterSubclassWithGet('abracadabra')
|
||||
self.assertTrue(c.called)
|
||||
|
||||
|
||||
################################################################################
|
||||
### OrderedDict
|
||||
|
|
|
@ -1689,17 +1689,17 @@ Count elements in the iterable, updating the mappping");
|
|||
static PyObject *
|
||||
_count_elements(PyObject *self, PyObject *args)
|
||||
{
|
||||
_Py_IDENTIFIER(__getitem__);
|
||||
_Py_IDENTIFIER(get);
|
||||
_Py_IDENTIFIER(__setitem__);
|
||||
PyObject *it, *iterable, *mapping, *oldval;
|
||||
PyObject *newval = NULL;
|
||||
PyObject *key = NULL;
|
||||
PyObject *zero = NULL;
|
||||
PyObject *one = NULL;
|
||||
PyObject *mapping_get = NULL;
|
||||
PyObject *mapping_getitem;
|
||||
PyObject *bound_get = NULL;
|
||||
PyObject *mapping_get;
|
||||
PyObject *dict_get;
|
||||
PyObject *mapping_setitem;
|
||||
PyObject *dict_getitem;
|
||||
PyObject *dict_setitem;
|
||||
|
||||
if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
|
||||
|
@ -1713,15 +1713,16 @@ _count_elements(PyObject *self, PyObject *args)
|
|||
if (one == NULL)
|
||||
goto done;
|
||||
|
||||
mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__);
|
||||
dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__);
|
||||
/* Only take the fast path when get() and __setitem__()
|
||||
* have not been overridden.
|
||||
*/
|
||||
mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get);
|
||||
dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get);
|
||||
mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__);
|
||||
dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__);
|
||||
|
||||
if (mapping_getitem != NULL &&
|
||||
mapping_getitem == dict_getitem &&
|
||||
mapping_setitem != NULL &&
|
||||
mapping_setitem == dict_setitem) {
|
||||
if (mapping_get != NULL && mapping_get == dict_get &&
|
||||
mapping_setitem != NULL && mapping_setitem == dict_setitem) {
|
||||
while (1) {
|
||||
key = PyIter_Next(it);
|
||||
if (key == NULL)
|
||||
|
@ -1741,8 +1742,8 @@ _count_elements(PyObject *self, PyObject *args)
|
|||
Py_DECREF(key);
|
||||
}
|
||||
} else {
|
||||
mapping_get = PyObject_GetAttrString(mapping, "get");
|
||||
if (mapping_get == NULL)
|
||||
bound_get = PyObject_GetAttrString(mapping, "get");
|
||||
if (bound_get == NULL)
|
||||
goto done;
|
||||
|
||||
zero = PyLong_FromLong(0);
|
||||
|
@ -1753,7 +1754,7 @@ _count_elements(PyObject *self, PyObject *args)
|
|||
key = PyIter_Next(it);
|
||||
if (key == NULL)
|
||||
break;
|
||||
oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL);
|
||||
oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL);
|
||||
if (oldval == NULL)
|
||||
break;
|
||||
newval = PyNumber_Add(oldval, one);
|
||||
|
@ -1771,7 +1772,7 @@ done:
|
|||
Py_DECREF(it);
|
||||
Py_XDECREF(key);
|
||||
Py_XDECREF(newval);
|
||||
Py_XDECREF(mapping_get);
|
||||
Py_XDECREF(bound_get);
|
||||
Py_XDECREF(zero);
|
||||
Py_XDECREF(one);
|
||||
if (PyErr_Occurred())
|
||||
|
|
Loading…
Reference in New Issue