Skip to content

Commit 33ec88a

Browse files
authored
bpo-43977: Make sure that tp_flags for pattern matching are inherited correctly. (GH-25813)
1 parent 9387fac commit 33ec88a

File tree

5 files changed

+69
-8
lines changed

5 files changed

+69
-8
lines changed

Lib/test/test_collections.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,12 @@ def insert(self, index, value):
19671967
self.assertEqual(len(mss), len(mss2))
19681968
self.assertEqual(list(mss), list(mss2))
19691969

1970+
def test_illegal_patma_flags(self):
1971+
with self.assertRaises(TypeError):
1972+
class Both(Collection):
1973+
__abc_tpflags__ = (Sequence.__flags__ | Mapping.__flags__)
1974+
1975+
19701976

19711977
################################################################################
19721978
### Counter

Lib/test/test_patma.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,6 +2979,47 @@ def f(x):
29792979
self.assertEqual(f((False, range(10, 20), True)), alts[4])
29802980

29812981

2982+
class TestInheritance(unittest.TestCase):
2983+
2984+
def test_multiple_inheritance(self):
2985+
class C:
2986+
pass
2987+
class S1(collections.UserList, collections.abc.Mapping):
2988+
pass
2989+
class S2(C, collections.UserList, collections.abc.Mapping):
2990+
pass
2991+
class S3(list, C, collections.abc.Mapping):
2992+
pass
2993+
class S4(collections.UserList, dict, C):
2994+
pass
2995+
class M1(collections.UserDict, collections.abc.Sequence):
2996+
pass
2997+
class M2(C, collections.UserDict, collections.abc.Sequence):
2998+
pass
2999+
class M3(collections.UserDict, C, list):
3000+
pass
3001+
class M4(dict, collections.abc.Sequence, C):
3002+
pass
3003+
def f(x):
3004+
match x:
3005+
case []:
3006+
return "seq"
3007+
case {}:
3008+
return "map"
3009+
def g(x):
3010+
match x:
3011+
case {}:
3012+
return "map"
3013+
case []:
3014+
return "seq"
3015+
for Seq in (S1, S2, S3, S4):
3016+
self.assertEqual(f(Seq()), "seq")
3017+
self.assertEqual(g(Seq()), "seq")
3018+
for Map in (M1, M2, M3, M4):
3019+
self.assertEqual(f(Map()), "map")
3020+
self.assertEqual(g(Map()), "map")
3021+
3022+
29823023
class PerfPatma(TestPatma):
29833024

29843025
def assertEqual(*_, **__):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Prevent classes being both a sequence and a mapping when pattern matching.

Modules/_abc.c

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,10 @@ _abc__abc_init(PyObject *module, PyObject *self)
467467
if (val == -1 && PyErr_Occurred()) {
468468
return NULL;
469469
}
470+
if ((val & COLLECTION_FLAGS) == COLLECTION_FLAGS) {
471+
PyErr_SetString(PyExc_TypeError, "__abc_tpflags__ cannot be both Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING");
472+
return NULL;
473+
}
470474
((PyTypeObject *)self)->tp_flags |= (val & COLLECTION_FLAGS);
471475
}
472476
if (_PyDict_DelItemId(cls->tp_dict, &PyId___abc_tpflags__) < 0) {
@@ -527,9 +531,12 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
527531
/* Invalidate negative cache */
528532
get_abc_state(module)->abc_invalidation_counter++;
529533

530-
if (PyType_Check(subclass) && PyType_Check(self) &&
531-
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE))
534+
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
535+
if (PyType_Check(self) &&
536+
!PyType_HasFeature((PyTypeObject *)subclass, Py_TPFLAGS_IMMUTABLETYPE) &&
537+
((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS)
532538
{
539+
((PyTypeObject *)subclass)->tp_flags &= ~COLLECTION_FLAGS;
533540
((PyTypeObject *)subclass)->tp_flags |= (((PyTypeObject *)self)->tp_flags & COLLECTION_FLAGS);
534541
}
535542
Py_INCREF(subclass);

Objects/typeobject.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5713,12 +5713,6 @@ inherit_special(PyTypeObject *type, PyTypeObject *base)
57135713
if (PyType_HasFeature(base, _Py_TPFLAGS_MATCH_SELF)) {
57145714
type->tp_flags |= _Py_TPFLAGS_MATCH_SELF;
57155715
}
5716-
if (PyType_HasFeature(base, Py_TPFLAGS_SEQUENCE)) {
5717-
type->tp_flags |= Py_TPFLAGS_SEQUENCE;
5718-
}
5719-
if (PyType_HasFeature(base, Py_TPFLAGS_MAPPING)) {
5720-
type->tp_flags |= Py_TPFLAGS_MAPPING;
5721-
}
57225716
}
57235717

57245718
static int
@@ -5936,6 +5930,7 @@ inherit_slots(PyTypeObject *type, PyTypeObject *base)
59365930
static int add_operators(PyTypeObject *);
59375931
static int add_tp_new_wrapper(PyTypeObject *type);
59385932

5933+
#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
59395934

59405935
static int
59415936
type_ready_checks(PyTypeObject *type)
@@ -5962,6 +5957,10 @@ type_ready_checks(PyTypeObject *type)
59625957
_PyObject_ASSERT((PyObject *)type, type->tp_as_async->am_send != NULL);
59635958
}
59645959

5960+
/* Consistency checks for pattern matching
5961+
* Py_TPFLAGS_SEQUENCE and Py_TPFLAGS_MAPPING are mutually exclusive */
5962+
_PyObject_ASSERT((PyObject *)type, (type->tp_flags & COLLECTION_FLAGS) != COLLECTION_FLAGS);
5963+
59655964
if (type->tp_name == NULL) {
59665965
PyErr_Format(PyExc_SystemError,
59675966
"Type does not define the tp_name field.");
@@ -6156,6 +6155,12 @@ type_ready_inherit_as_structs(PyTypeObject *type, PyTypeObject *base)
61566155
}
61576156
}
61586157

6158+
static void
6159+
inherit_patma_flags(PyTypeObject *type, PyTypeObject *base) {
6160+
if ((type->tp_flags & COLLECTION_FLAGS) == 0) {
6161+
type->tp_flags |= base->tp_flags & COLLECTION_FLAGS;
6162+
}
6163+
}
61596164

61606165
static int
61616166
type_ready_inherit(PyTypeObject *type)
@@ -6175,6 +6180,7 @@ type_ready_inherit(PyTypeObject *type)
61756180
if (inherit_slots(type, (PyTypeObject *)b) < 0) {
61766181
return -1;
61776182
}
6183+
inherit_patma_flags(type, (PyTypeObject *)b);
61786184
}
61796185
}
61806186

0 commit comments

Comments
 (0)