mirror of https://github.com/python/cpython
bpo-43977: Properly update the tp_flags of existing subclasses when their parents are registered (GH-26864)
This commit is contained in:
parent
22e7effad5
commit
ca2009d72a
|
@ -772,17 +772,20 @@ iterations of the loop.
|
||||||
|
|
||||||
.. opcode:: MATCH_MAPPING
|
.. opcode:: MATCH_MAPPING
|
||||||
|
|
||||||
If TOS is an instance of :class:`collections.abc.Mapping`, push ``True`` onto
|
If TOS is an instance of :class:`collections.abc.Mapping` (or, more technically: if
|
||||||
the stack. Otherwise, push ``False``.
|
it has the :const:`Py_TPFLAGS_MAPPING` flag set in its
|
||||||
|
:c:member:`~PyTypeObject.tp_flags`), push ``True`` onto the stack. Otherwise, push
|
||||||
|
``False``.
|
||||||
|
|
||||||
.. versionadded:: 3.10
|
.. versionadded:: 3.10
|
||||||
|
|
||||||
|
|
||||||
.. opcode:: MATCH_SEQUENCE
|
.. opcode:: MATCH_SEQUENCE
|
||||||
|
|
||||||
If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an
|
If TOS is an instance of :class:`collections.abc.Sequence` and is *not* an instance
|
||||||
instance of :class:`str`/:class:`bytes`/:class:`bytearray`, push ``True``
|
of :class:`str`/:class:`bytes`/:class:`bytearray` (or, more technically: if it has
|
||||||
onto the stack. Otherwise, push ``False``.
|
the :const:`Py_TPFLAGS_SEQUENCE` flag set in its :c:member:`~PyTypeObject.tp_flags`),
|
||||||
|
push ``True`` onto the stack. Otherwise, push ``False``.
|
||||||
|
|
||||||
.. versionadded:: 3.10
|
.. versionadded:: 3.10
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,43 @@ class TestCompiler(unittest.TestCase):
|
||||||
|
|
||||||
class TestInheritance(unittest.TestCase):
|
class TestInheritance(unittest.TestCase):
|
||||||
|
|
||||||
def test_multiple_inheritance(self):
|
@staticmethod
|
||||||
|
def check_sequence_then_mapping(x):
|
||||||
|
match x:
|
||||||
|
case [*_]:
|
||||||
|
return "seq"
|
||||||
|
case {}:
|
||||||
|
return "map"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_mapping_then_sequence(x):
|
||||||
|
match x:
|
||||||
|
case {}:
|
||||||
|
return "map"
|
||||||
|
case [*_]:
|
||||||
|
return "seq"
|
||||||
|
|
||||||
|
def test_multiple_inheritance_mapping(self):
|
||||||
|
class C:
|
||||||
|
pass
|
||||||
|
class M1(collections.UserDict, collections.abc.Sequence):
|
||||||
|
pass
|
||||||
|
class M2(C, collections.UserDict, collections.abc.Sequence):
|
||||||
|
pass
|
||||||
|
class M3(collections.UserDict, C, list):
|
||||||
|
pass
|
||||||
|
class M4(dict, collections.abc.Sequence, C):
|
||||||
|
pass
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(M1()), "map")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(M2()), "map")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(M3()), "map")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(M4()), "map")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(M1()), "map")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(M2()), "map")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(M3()), "map")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(M4()), "map")
|
||||||
|
|
||||||
|
def test_multiple_inheritance_sequence(self):
|
||||||
class C:
|
class C:
|
||||||
pass
|
pass
|
||||||
class S1(collections.UserList, collections.abc.Mapping):
|
class S1(collections.UserList, collections.abc.Mapping):
|
||||||
|
@ -35,32 +71,60 @@ class TestInheritance(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
class S4(collections.UserList, dict, C):
|
class S4(collections.UserList, dict, C):
|
||||||
pass
|
pass
|
||||||
class M1(collections.UserDict, collections.abc.Sequence):
|
self.assertEqual(self.check_sequence_then_mapping(S1()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(S2()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(S3()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(S4()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(S1()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(S2()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(S3()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(S4()), "seq")
|
||||||
|
|
||||||
|
def test_late_registration_mapping(self):
|
||||||
|
class Parent:
|
||||||
pass
|
pass
|
||||||
class M2(C, collections.UserDict, collections.abc.Sequence):
|
class ChildPre(Parent):
|
||||||
pass
|
pass
|
||||||
class M3(collections.UserDict, C, list):
|
class GrandchildPre(ChildPre):
|
||||||
pass
|
pass
|
||||||
class M4(dict, collections.abc.Sequence, C):
|
collections.abc.Mapping.register(Parent)
|
||||||
|
class ChildPost(Parent):
|
||||||
pass
|
pass
|
||||||
def f(x):
|
class GrandchildPost(ChildPost):
|
||||||
match x:
|
pass
|
||||||
case []:
|
self.assertEqual(self.check_sequence_then_mapping(Parent()), "map")
|
||||||
return "seq"
|
self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "map")
|
||||||
case {}:
|
self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "map")
|
||||||
return "map"
|
self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "map")
|
||||||
def g(x):
|
self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "map")
|
||||||
match x:
|
self.assertEqual(self.check_mapping_then_sequence(Parent()), "map")
|
||||||
case {}:
|
self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "map")
|
||||||
return "map"
|
self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "map")
|
||||||
case []:
|
self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "map")
|
||||||
return "seq"
|
self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "map")
|
||||||
for Seq in (S1, S2, S3, S4):
|
|
||||||
self.assertEqual(f(Seq()), "seq")
|
def test_late_registration_sequence(self):
|
||||||
self.assertEqual(g(Seq()), "seq")
|
class Parent:
|
||||||
for Map in (M1, M2, M3, M4):
|
pass
|
||||||
self.assertEqual(f(Map()), "map")
|
class ChildPre(Parent):
|
||||||
self.assertEqual(g(Map()), "map")
|
pass
|
||||||
|
class GrandchildPre(ChildPre):
|
||||||
|
pass
|
||||||
|
collections.abc.Sequence.register(Parent)
|
||||||
|
class ChildPost(Parent):
|
||||||
|
pass
|
||||||
|
class GrandchildPost(ChildPost):
|
||||||
|
pass
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(Parent()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(ChildPre()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(GrandchildPre()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(ChildPost()), "seq")
|
||||||
|
self.assertEqual(self.check_sequence_then_mapping(GrandchildPost()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(Parent()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(ChildPre()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(GrandchildPre()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(ChildPost()), "seq")
|
||||||
|
self.assertEqual(self.check_mapping_then_sequence(GrandchildPost()), "seq")
|
||||||
|
|
||||||
|
|
||||||
class TestPatma(unittest.TestCase):
|
class TestPatma(unittest.TestCase):
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
Set the proper :const:`Py_TPFLAGS_MAPPING` and :const:`Py_TPFLAGS_SEQUENCE`
|
||||||
|
flags for subclasses created before a parent has been registered as a
|
||||||
|
:class:`collections.abc.Mapping` or :class:`collections.abc.Sequence`.
|
|
@ -481,6 +481,32 @@ _abc__abc_init(PyObject *module, PyObject *self)
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
set_collection_flag_recursive(PyTypeObject *child, unsigned long flag)
|
||||||
|
{
|
||||||
|
assert(flag == Py_TPFLAGS_MAPPING || flag == Py_TPFLAGS_SEQUENCE);
|
||||||
|
if (PyType_HasFeature(child, Py_TPFLAGS_IMMUTABLETYPE) ||
|
||||||
|
(child->tp_flags & COLLECTION_FLAGS) == flag)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
child->tp_flags &= ~COLLECTION_FLAGS;
|
||||||
|
child->tp_flags |= flag;
|
||||||
|
PyObject *grandchildren = child->tp_subclasses;
|
||||||
|
if (grandchildren == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
assert(PyDict_CheckExact(grandchildren));
|
||||||
|
Py_ssize_t i = 0;
|
||||||
|
while (PyDict_Next(grandchildren, &i, NULL, &grandchildren)) {
|
||||||
|
assert(PyWeakref_CheckRef(grandchildren));
|
||||||
|
PyObject *grandchild = PyWeakref_GET_OBJECT(grandchildren);
|
||||||
|
if (PyType_Check(grandchild)) {
|
||||||
|
set_collection_flag_recursive((PyTypeObject *)grandchild, flag);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*[clinic input]
|
/*[clinic input]
|
||||||
_abc._abc_register
|
_abc._abc_register
|
||||||
|
|
||||||
|
@ -532,12 +558,11 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
|
||||||
get_abc_state(module)->abc_invalidation_counter++;
|
get_abc_state(module)->abc_invalidation_counter++;
|
||||||
|
|
||||||
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
|
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
|
||||||
if (PyType_Check(self) &&
|
if (PyType_Check(self)) {
|
||||||
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
|
unsigned long collection_flag = ((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS;
|
||||||
((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
|
if (collection_flag) {
|
||||||
{
|
set_collection_flag_recursive((PyTypeObject *)subclass, collection_flag);
|
||||||
((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
|
}
|
||||||
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
|
|
||||||
}
|
}
|
||||||
Py_INCREF(subclass);
|
Py_INCREF(subclass);
|
||||||
return subclass;
|
return subclass;
|
||||||
|
|
Loading…
Reference in New Issue