gh-119004: fix a crash in equality testing between `OrderedDict` (#121329)

This commit is contained in:
Bénédikt Tran 2024-09-24 01:44:36 +02:00 committed by GitHub
parent e80dd3035f
commit 38a887dc3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 145 additions and 11 deletions

View File

@ -1169,8 +1169,11 @@ Some differences from :class:`dict` still remain:
In addition to the usual mapping methods, ordered dictionaries also support In addition to the usual mapping methods, ordered dictionaries also support
reverse iteration using :func:`reversed`. reverse iteration using :func:`reversed`.
.. _collections_OrderedDict__eq__:
Equality tests between :class:`OrderedDict` objects are order-sensitive Equality tests between :class:`OrderedDict` objects are order-sensitive
and are implemented as ``list(od1.items())==list(od2.items())``. and are roughly equivalent to ``list(od1.items())==list(od2.items())``.
Equality tests between :class:`OrderedDict` objects and other Equality tests between :class:`OrderedDict` objects and other
:class:`~collections.abc.Mapping` objects are order-insensitive like regular :class:`~collections.abc.Mapping` objects are order-insensitive like regular
dictionaries. This allows :class:`OrderedDict` objects to be substituted dictionaries. This allows :class:`OrderedDict` objects to be substituted

View File

@ -2,7 +2,9 @@ import builtins
import contextlib import contextlib
import copy import copy
import gc import gc
import operator
import pickle import pickle
import re
from random import randrange, shuffle from random import randrange, shuffle
import struct import struct
import sys import sys
@ -740,11 +742,44 @@ class OrderedDictTests:
# when it's mutated and returned from __next__: # when it's mutated and returned from __next__:
self.assertTrue(gc.is_tracked(next(it))) self.assertTrue(gc.is_tracked(next(it)))
class _TriggerSideEffectOnEqual:
count = 0 # number of calls to __eq__
trigger = 1 # count value when to trigger side effect
def __eq__(self, other):
if self.__class__.count == self.__class__.trigger:
self.side_effect()
self.__class__.count += 1
return True
def __hash__(self):
# all instances represent the same key
return -1
def side_effect(self):
raise NotImplementedError
class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase): class PurePythonOrderedDictTests(OrderedDictTests, unittest.TestCase):
module = py_coll module = py_coll
OrderedDict = py_coll.OrderedDict OrderedDict = py_coll.OrderedDict
def test_issue119004_attribute_error(self):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
# This causes an AttributeError due to the linked list being changed
msg = re.escape("'NoneType' object has no attribute 'key'")
self.assertRaisesRegex(AttributeError, msg, operator.eq, dict1, dict2)
self.assertEqual(Key.count, 2)
self.assertDictEqual(dict1, dict.fromkeys((0, 4.2)))
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
class CPythonBuiltinDictTests(unittest.TestCase): class CPythonBuiltinDictTests(unittest.TestCase):
"""Builtin dict preserves insertion order. """Builtin dict preserves insertion order.
@ -765,8 +800,85 @@ for method in (
del method del method
class CPythonOrderedDictSideEffects:
def check_runtime_error_issue119004(self, dict1, dict2):
msg = re.escape("OrderedDict mutated during iteration")
self.assertRaisesRegex(RuntimeError, msg, operator.eq, dict1, dict2)
def test_issue119004_change_size_by_clear(self):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
dict1.clear()
dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
self.check_runtime_error_issue119004(dict1, dict2)
self.assertEqual(Key.count, 2)
self.assertDictEqual(dict1, {})
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key(self):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
self.check_runtime_error_issue119004(dict1, dict2)
self.assertEqual(Key.count, 2)
self.assertDictEqual(dict1, dict.fromkeys((0, 4.2)))
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_clear(self):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
dict1.clear()
dict1['a'] = dict1['b'] = 'c'
dict1 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
self.check_runtime_error_issue119004(dict1, dict2)
self.assertEqual(Key.count, 2)
self.assertDictEqual(dict1, dict.fromkeys(('a', 'b'), 'c'))
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_linked_list_by_delete_key(self):
class Key(_TriggerSideEffectOnEqual):
def side_effect(self):
del dict1[TODEL]
dict1['a'] = 'c'
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
self.check_runtime_error_issue119004(dict1, dict2)
self.assertEqual(Key.count, 2)
self.assertDictEqual(dict1, {0: None, 'a': 'c', 4.2: None})
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
def test_issue119004_change_size_by_delete_key_in_dict_eq(self):
class Key(_TriggerSideEffectOnEqual):
trigger = 0
def side_effect(self):
del dict1[TODEL]
TODEL = Key()
dict1 = self.OrderedDict(dict.fromkeys((0, TODEL, 4.2)))
dict2 = self.OrderedDict(dict.fromkeys((0, Key(), 4.2)))
self.assertEqual(Key.count, 0)
# the side effect is in dict.__eq__ and modifies the length
self.assertNotEqual(dict1, dict2)
self.assertEqual(Key.count, 2)
self.assertDictEqual(dict1, dict.fromkeys((0, 4.2)))
self.assertDictEqual(dict2, dict.fromkeys((0, Key(), 4.2)))
@unittest.skipUnless(c_coll, 'requires the C version of the collections module') @unittest.skipUnless(c_coll, 'requires the C version of the collections module')
class CPythonOrderedDictTests(OrderedDictTests, unittest.TestCase): class CPythonOrderedDictTests(OrderedDictTests,
CPythonOrderedDictSideEffects,
unittest.TestCase):
module = c_coll module = c_coll
OrderedDict = c_coll.OrderedDict OrderedDict = c_coll.OrderedDict

View File

@ -0,0 +1,2 @@
Fix a crash in :ref:`OrderedDict.__eq__ <collections_OrderedDict__eq__>`
when operands are mutated during the check. Patch by Bénédikt Tran.

View File

@ -796,6 +796,7 @@ _odict_clear_nodes(PyODictObject *od)
_odictnode_DEALLOC(node); _odictnode_DEALLOC(node);
node = next; node = next;
} }
od->od_state++;
} }
/* There isn't any memory management of nodes past this point. */ /* There isn't any memory management of nodes past this point. */
@ -806,24 +807,40 @@ _odict_keys_equal(PyODictObject *a, PyODictObject *b)
{ {
_ODictNode *node_a, *node_b; _ODictNode *node_a, *node_b;
// keep operands' state to detect undesired mutations
const size_t state_a = a->od_state;
const size_t state_b = b->od_state;
node_a = _odict_FIRST(a); node_a = _odict_FIRST(a);
node_b = _odict_FIRST(b); node_b = _odict_FIRST(b);
while (1) { while (1) {
if (node_a == NULL && node_b == NULL) if (node_a == NULL && node_b == NULL) {
/* success: hit the end of each at the same time */ /* success: hit the end of each at the same time */
return 1; return 1;
else if (node_a == NULL || node_b == NULL) }
else if (node_a == NULL || node_b == NULL) {
/* unequal length */ /* unequal length */
return 0; return 0;
}
else { else {
int res = PyObject_RichCompareBool( PyObject *key_a = Py_NewRef(_odictnode_KEY(node_a));
(PyObject *)_odictnode_KEY(node_a), PyObject *key_b = Py_NewRef(_odictnode_KEY(node_b));
(PyObject *)_odictnode_KEY(node_b), int res = PyObject_RichCompareBool(key_a, key_b, Py_EQ);
Py_EQ); Py_DECREF(key_a);
if (res < 0) Py_DECREF(key_b);
if (res < 0) {
return res; return res;
else if (res == 0) }
else if (a->od_state != state_a || b->od_state != state_b) {
PyErr_SetString(PyExc_RuntimeError,
"OrderedDict mutated during iteration");
return -1;
}
else if (res == 0) {
// This check comes after the check on the state
// in order for the exception to be set correctly.
return 0; return 0;
}
/* otherwise it must match, so move on to the next one */ /* otherwise it must match, so move on to the next one */
node_a = _odictnode_NEXT(node_a); node_a = _odictnode_NEXT(node_a);