bpo-43977: Make sure that tp_flags for pattern matching are inherited correctly. (GH-25813)

This commit is contained in:
Mark Shannon 2021-05-03 00:38:22 +01:00 committed by GitHub
parent 9387fac100
commit 33ec88ac81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 8 deletions

View File

@ -1967,6 +1967,12 @@ class TestCollectionABCs(ABCTestCase):
self.assertEqual(len(mss), len(mss2)) self.assertEqual(len(mss), len(mss2))
self.assertEqual(list(mss), list(mss2)) self.assertEqual(list(mss), list(mss2))
def test_illegal_patma_flags(self):
with self.assertRaises(TypeError):
class Both(Collection):
__abc_tpflags__ = (Sequence.__flags__ | Mapping.__flags__)
################################################################################ ################################################################################
### Counter ### Counter

View File

@ -2979,6 +2979,47 @@ class TestPatma(unittest.TestCase):
self.assertEqual(f((False, range(10, 20), True)), alts[4]) self.assertEqual(f((False, range(10, 20), True)), alts[4])
class TestInheritance(unittest.TestCase):
def test_multiple_inheritance(self):
class C:
pass
class S1(collections.UserList, collections.abc.Mapping):
pass
class S2(C, collections.UserList, collections.abc.Mapping):
pass
class S3(list, C, collections.abc.Mapping):
pass
class S4(collections.UserList, dict, C):
pass
class M1(collections.UserDict, collections.abc.Sequence):
pass
class M2(C, collections.UserDict, collections.abc.Sequence):
pass
class M3(collections.UserDict, C, list):
pass
class M4(dict, collections.abc.Sequence, C):
pass
def f(x):
match x:
case []:
return "seq"
case {}:
return "map"
def g(x):
match x:
case {}:
return "map"
case []:
return "seq"
for Seq in (S1, S2, S3, S4):
self.assertEqual(f(Seq()), "seq")
self.assertEqual(g(Seq()), "seq")
for Map in (M1, M2, M3, M4):
self.assertEqual(f(Map()), "map")
self.assertEqual(g(Map()), "map")
class PerfPatma(TestPatma): class PerfPatma(TestPatma):
def assertEqual(*_, **__): def assertEqual(*_, **__):

View File

@ -0,0 +1 @@
Prevent classes being both a sequence and a mapping when pattern matching.

View File

@ -467,6 +467,10 @@ _abc__abc_init(PyObject *module, PyObject *self)
if (val == -1 && PyErr_Occurred()) { if (val == -1 && PyErr_Occurred()) {
return NULL; return NULL;
} }
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
return NULL;
}
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS); ((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
} }
if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) { if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
@ -527,9 +531,12 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
/* Invalidate negative cache */ /* Invalidate negative cache */
get_abc_state(module)->abc_invalidation_counter++; get_abc_state(module)->abc_invalidation_counter++;
if (PyType_Check(subclass) && PyType_Check(self) && /* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE)) if (PyType_Check(self) &&
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
{ {
((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS); ((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
} }
Py_INCREF(subclass); Py_INCREF(subclass);

View File

@ -5713,12 +5713,6 @@ inherit_special(PyTypeObject *type, PyTypeObject *base)
if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) { if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
type->tp_flags |= _Py_TPFLAGS_MATCH_SELF; type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
} }
if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) {
type->tp_flags |= Py_TPFLAGS_SEQUENCE;
}
if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) {
type->tp_flags |= Py_TPFLAGS_MAPPING;
}
} }
static int static int
@ -5936,6 +5930,7 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
static int add_operators(PyTypeObject *); static int add_operators(PyTypeObject *);
static int add_tp_new_wrapper(PyTypeObject *type); static int add_tp_new_wrapper(PyTypeObject *type);
#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
static int static int
type_ready_checks(PyTypeObject *type) type_ready_checks(PyTypeObject *type)
@ -5962,6 +5957,10 @@ type_ready_checks(PyTypeObject *type)
_PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL); _PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL);
} }
/* Consistency checks for pattern matching
* Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING are mutually exclusive */
_PyObject_ASSERT((PyObject *)type, (type->tp_flags & COLLECTION_FLAGS) != COLLECTION_FLAGS);
if (type->tp_name == NULL) { if (type->tp_name == NULL) {
PyErr_Format(PyExc_SystemError, PyErr_Format(PyExc_SystemError,
"Type does not define the tp_name field."); "Type does not define the tp_name field.");
@ -6156,6 +6155,12 @@ type_ready_inherit_as_structs(PyTypeObject *type, PyTypeObject *base)
} }
} }
static void
inherit_patma_flags(PyTypeObject *type, PyTypeObject *base) {
if ((type->tp_flags & COLLECTION_FLAGS) == 0) {
type->tp_flags |= base->tp_flags & COLLECTION_FLAGS;
}
}
static int static int
type_ready_inherit(PyTypeObject *type) type_ready_inherit(PyTypeObject *type)
@ -6175,6 +6180,7 @@ type_ready_inherit(PyTypeObject *type)
if (inherit_slots(type, (PyTypeObject *)b) < 0) { if (inherit_slots(type, (PyTypeObject *)b) < 0) {
return -1; return -1;
} }
inherit_patma_flags(type, (PyTypeObject *)b);
} }
} }