mirror of https://github.com/python/cpython
bpo-43977: Make sure that tp_flags for pattern matching are inherited correctly. (GH-25813)
This commit is contained in:
parent
9387fac100
commit
33ec88ac81
|
@ -1967,6 +1967,12 @@ class TestCollectionABCs(ABCTestCase):
|
||||||
self.assertEqual(len(mss), len(mss2))
|
self.assertEqual(len(mss), len(mss2))
|
||||||
self.assertEqual(list(mss), list(mss2))
|
self.assertEqual(list(mss), list(mss2))
|
||||||
|
|
||||||
|
def test_illegal_patma_flags(self):
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
class Both(Collection):
|
||||||
|
__abc_tpflags__ = (Sequence.__flags__ | Mapping.__flags__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
### Counter
|
### Counter
|
||||||
|
|
|
@ -2979,6 +2979,47 @@ class TestPatma(unittest.TestCase):
|
||||||
self.assertEqual(f((False, range(10, 20), True)), alts[4])
|
self.assertEqual(f((False, range(10, 20), True)), alts[4])
|
||||||
|
|
||||||
|
|
||||||
|
class TestInheritance(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_multiple_inheritance(self):
|
||||||
|
class C:
|
||||||
|
pass
|
||||||
|
class S1(collections.UserList, collections.abc.Mapping):
|
||||||
|
pass
|
||||||
|
class S2(C, collections.UserList, collections.abc.Mapping):
|
||||||
|
pass
|
||||||
|
class S3(list, C, collections.abc.Mapping):
|
||||||
|
pass
|
||||||
|
class S4(collections.UserList, dict, C):
|
||||||
|
pass
|
||||||
|
class M1(collections.UserDict, collections.abc.Sequence):
|
||||||
|
pass
|
||||||
|
class M2(C, collections.UserDict, collections.abc.Sequence):
|
||||||
|
pass
|
||||||
|
class M3(collections.UserDict, C, list):
|
||||||
|
pass
|
||||||
|
class M4(dict, collections.abc.Sequence, C):
|
||||||
|
pass
|
||||||
|
def f(x):
|
||||||
|
match x:
|
||||||
|
case []:
|
||||||
|
return "seq"
|
||||||
|
case {}:
|
||||||
|
return "map"
|
||||||
|
def g(x):
|
||||||
|
match x:
|
||||||
|
case {}:
|
||||||
|
return "map"
|
||||||
|
case []:
|
||||||
|
return "seq"
|
||||||
|
for Seq in (S1, S2, S3, S4):
|
||||||
|
self.assertEqual(f(Seq()), "seq")
|
||||||
|
self.assertEqual(g(Seq()), "seq")
|
||||||
|
for Map in (M1, M2, M3, M4):
|
||||||
|
self.assertEqual(f(Map()), "map")
|
||||||
|
self.assertEqual(g(Map()), "map")
|
||||||
|
|
||||||
|
|
||||||
class PerfPatma(TestPatma):
|
class PerfPatma(TestPatma):
|
||||||
|
|
||||||
def assertEqual(*_, **__):
|
def assertEqual(*_, **__):
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Prevent classes being both a sequence and a mapping when pattern matching.
|
|
@ -467,6 +467,10 @@ _abc__abc_init(PyObject *module, PyObject *self)
|
||||||
if (val == -1 && PyErr_Occurred()) {
|
if (val == -1 && PyErr_Occurred()) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
|
||||||
|
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
|
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
|
||||||
}
|
}
|
||||||
if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
|
if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
|
||||||
|
@ -527,9 +531,12 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
|
||||||
/* Invalidate negative cache */
|
/* Invalidate negative cache */
|
||||||
get_abc_state(module)->abc_invalidation_counter++;
|
get_abc_state(module)->abc_invalidation_counter++;
|
||||||
|
|
||||||
if (PyType_Check(subclass) && PyType_Check(self) &&
|
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
|
||||||
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE))
|
if (PyType_Check(self) &&
|
||||||
|
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
|
||||||
|
((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
|
||||||
{
|
{
|
||||||
|
((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
|
||||||
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
|
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
|
||||||
}
|
}
|
||||||
Py_INCREF(subclass);
|
Py_INCREF(subclass);
|
||||||
|
|
|
@ -5713,12 +5713,6 @@ inherit_special(PyTypeObject *type, PyTypeObject *base)
|
||||||
if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
|
if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
|
||||||
type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
|
type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
|
||||||
}
|
}
|
||||||
if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) {
|
|
||||||
type->tp_flags |= Py_TPFLAGS_SEQUENCE;
|
|
||||||
}
|
|
||||||
if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) {
|
|
||||||
type->tp_flags |= Py_TPFLAGS_MAPPING;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static int
|
static int
|
||||||
|
@ -5936,6 +5930,7 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
|
||||||
static int add_operators(PyTypeObject *);
|
static int add_operators(PyTypeObject *);
|
||||||
static int add_tp_new_wrapper(PyTypeObject *type);
|
static int add_tp_new_wrapper(PyTypeObject *type);
|
||||||
|
|
||||||
|
#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
|
||||||
|
|
||||||
static int
|
static int
|
||||||
type_ready_checks(PyTypeObject *type)
|
type_ready_checks(PyTypeObject *type)
|
||||||
|
@ -5962,6 +5957,10 @@ type_ready_checks(PyTypeObject *type)
|
||||||
_PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL);
|
_PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Consistency checks for pattern matching
|
||||||
|
* Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING are mutually exclusive */
|
||||||
|
_PyObject_ASSERT((PyObject *)type, (type->tp_flags & COLLECTION_FLAGS) != COLLECTION_FLAGS);
|
||||||
|
|
||||||
if (type->tp_name == NULL) {
|
if (type->tp_name == NULL) {
|
||||||
PyErr_Format(PyExc_SystemError,
|
PyErr_Format(PyExc_SystemError,
|
||||||
"Type does not define the tp_name field.");
|
"Type does not define the tp_name field.");
|
||||||
|
@ -6156,6 +6155,12 @@ type_ready_inherit_as_structs(PyTypeObject *type, PyTypeObject *base)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void
|
||||||
|
inherit_patma_flags(PyTypeObject *type, PyTypeObject *base) {
|
||||||
|
if ((type->tp_flags & COLLECTION_FLAGS) == 0) {
|
||||||
|
type->tp_flags |= base->tp_flags & COLLECTION_FLAGS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static int
|
static int
|
||||||
type_ready_inherit(PyTypeObject *type)
|
type_ready_inherit(PyTypeObject *type)
|
||||||
|
@ -6175,6 +6180,7 @@ type_ready_inherit(PyTypeObject *type)
|
||||||
if (inherit_slots(type, (PyTypeObject *)b) < 0) {
|
if (inherit_slots(type, (PyTypeObject *)b) < 0) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
inherit_patma_flags(type, (PyTypeObject *)b);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue