mirror of https://github.com/python/cpython
gh-89811: Check for valid tp_version_tag in specializer (GH-113558)
This commit is contained in:
parent
c65ae26f2b
commit
f653caa5a8
|
@ -1,5 +1,6 @@
|
|||
""" Tests for the internal type cache in CPython. """
|
||||
import unittest
|
||||
import dis
|
||||
from test import support
|
||||
from test.support import import_helper
|
||||
try:
|
||||
|
@ -8,8 +9,11 @@ except ImportError:
|
|||
_clear_type_cache = None
|
||||
|
||||
# Skip this test if the _testcapi module isn't available.
|
||||
type_get_version = import_helper.import_module('_testcapi').type_get_version
|
||||
type_assign_version = import_helper.import_module('_testcapi').type_assign_version
|
||||
_testcapi = import_helper.import_module("_testcapi")
|
||||
type_get_version = _testcapi.type_get_version
|
||||
type_assign_specific_version_unsafe = _testcapi.type_assign_specific_version_unsafe
|
||||
type_assign_version = _testcapi.type_assign_version
|
||||
type_modified = _testcapi.type_modified
|
||||
|
||||
|
||||
@support.cpython_only
|
||||
|
@ -56,6 +60,183 @@ class TypeCacheTests(unittest.TestCase):
|
|||
self.assertNotEqual(type_get_version(C), 0)
|
||||
self.assertNotEqual(type_get_version(C), c_ver)
|
||||
|
||||
def test_type_assign_specific_version(self):
|
||||
"""meta-test for type_assign_specific_version_unsafe"""
|
||||
class C:
|
||||
pass
|
||||
|
||||
type_assign_version(C)
|
||||
orig_version = type_get_version(C)
|
||||
self.assertNotEqual(orig_version, 0)
|
||||
|
||||
type_modified(C)
|
||||
type_assign_specific_version_unsafe(C, orig_version + 5)
|
||||
type_assign_version(C) # this should do nothing
|
||||
|
||||
new_version = type_get_version(C)
|
||||
self.assertEqual(new_version, orig_version + 5)
|
||||
|
||||
_clear_type_cache()
|
||||
|
||||
|
||||
@support.cpython_only
|
||||
class TypeCacheWithSpecializationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
_clear_type_cache()
|
||||
|
||||
def _assign_and_check_valid_version(self, user_type):
|
||||
type_modified(user_type)
|
||||
type_assign_version(user_type)
|
||||
self.assertNotEqual(type_get_version(user_type), 0)
|
||||
|
||||
def _assign_and_check_version_0(self, user_type):
|
||||
type_modified(user_type)
|
||||
type_assign_specific_version_unsafe(user_type, 0)
|
||||
self.assertEqual(type_get_version(user_type), 0)
|
||||
|
||||
def _all_opnames(self, func):
|
||||
return set(instr.opname for instr in dis.Bytecode(func, adaptive=True))
|
||||
|
||||
def _check_specialization(self, func, arg, opname, *, should_specialize):
|
||||
self.assertIn(opname, self._all_opnames(func))
|
||||
|
||||
for _ in range(100):
|
||||
func(arg)
|
||||
|
||||
if should_specialize:
|
||||
self.assertNotIn(opname, self._all_opnames(func))
|
||||
else:
|
||||
self.assertIn(opname, self._all_opnames(func))
|
||||
|
||||
def test_class_load_attr_specialization_user_type(self):
|
||||
class A:
|
||||
def foo(self):
|
||||
pass
|
||||
|
||||
self._assign_and_check_valid_version(A)
|
||||
|
||||
def load_foo_1(type_):
|
||||
type_.foo
|
||||
|
||||
self._check_specialization(load_foo_1, A, "LOAD_ATTR", should_specialize=True)
|
||||
del load_foo_1
|
||||
|
||||
self._assign_and_check_version_0(A)
|
||||
|
||||
def load_foo_2(type_):
|
||||
return type_.foo
|
||||
|
||||
self._check_specialization(load_foo_2, A, "LOAD_ATTR", should_specialize=False)
|
||||
|
||||
def test_class_load_attr_specialization_static_type(self):
|
||||
self._assign_and_check_valid_version(str)
|
||||
self._assign_and_check_valid_version(bytes)
|
||||
|
||||
def get_capitalize_1(type_):
|
||||
return type_.capitalize
|
||||
|
||||
self._check_specialization(get_capitalize_1, str, "LOAD_ATTR", should_specialize=True)
|
||||
self.assertEqual(get_capitalize_1(str)('hello'), 'Hello')
|
||||
self.assertEqual(get_capitalize_1(bytes)(b'hello'), b'Hello')
|
||||
del get_capitalize_1
|
||||
|
||||
# Permanently overflow the static type version counter, and force str and bytes
|
||||
# to have tp_version_tag == 0
|
||||
for _ in range(2**16):
|
||||
type_modified(str)
|
||||
type_assign_version(str)
|
||||
type_modified(bytes)
|
||||
type_assign_version(bytes)
|
||||
|
||||
self.assertEqual(type_get_version(str), 0)
|
||||
self.assertEqual(type_get_version(bytes), 0)
|
||||
|
||||
def get_capitalize_2(type_):
|
||||
return type_.capitalize
|
||||
|
||||
self._check_specialization(get_capitalize_2, str, "LOAD_ATTR", should_specialize=False)
|
||||
self.assertEqual(get_capitalize_2(str)('hello'), 'Hello')
|
||||
self.assertEqual(get_capitalize_2(bytes)(b'hello'), b'Hello')
|
||||
|
||||
def test_property_load_attr_specialization_user_type(self):
|
||||
class G:
|
||||
@property
|
||||
def x(self):
|
||||
return 9
|
||||
|
||||
self._assign_and_check_valid_version(G)
|
||||
|
||||
def load_x_1(instance):
|
||||
instance.x
|
||||
|
||||
self._check_specialization(load_x_1, G(), "LOAD_ATTR", should_specialize=True)
|
||||
del load_x_1
|
||||
|
||||
self._assign_and_check_version_0(G)
|
||||
|
||||
def load_x_2(instance):
|
||||
instance.x
|
||||
|
||||
self._check_specialization(load_x_2, G(), "LOAD_ATTR", should_specialize=False)
|
||||
|
||||
def test_store_attr_specialization_user_type(self):
|
||||
class B:
|
||||
__slots__ = ("bar",)
|
||||
|
||||
self._assign_and_check_valid_version(B)
|
||||
|
||||
def store_bar_1(type_):
|
||||
type_.bar = 10
|
||||
|
||||
self._check_specialization(store_bar_1, B(), "STORE_ATTR", should_specialize=True)
|
||||
del store_bar_1
|
||||
|
||||
self._assign_and_check_version_0(B)
|
||||
|
||||
def store_bar_2(type_):
|
||||
type_.bar = 10
|
||||
|
||||
self._check_specialization(store_bar_2, B(), "STORE_ATTR", should_specialize=False)
|
||||
|
||||
def test_class_call_specialization_user_type(self):
|
||||
class F:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
self._assign_and_check_valid_version(F)
|
||||
|
||||
def call_class_1(type_):
|
||||
type_()
|
||||
|
||||
self._check_specialization(call_class_1, F, "CALL", should_specialize=True)
|
||||
del call_class_1
|
||||
|
||||
self._assign_and_check_version_0(F)
|
||||
|
||||
def call_class_2(type_):
|
||||
type_()
|
||||
|
||||
self._check_specialization(call_class_2, F, "CALL", should_specialize=False)
|
||||
|
||||
def test_to_bool_specialization_user_type(self):
|
||||
class H:
|
||||
pass
|
||||
|
||||
self._assign_and_check_valid_version(H)
|
||||
|
||||
def to_bool_1(instance):
|
||||
not instance
|
||||
|
||||
self._check_specialization(to_bool_1, H(), "TO_BOOL", should_specialize=True)
|
||||
del to_bool_1
|
||||
|
||||
self._assign_and_check_version_0(H)
|
||||
|
||||
def to_bool_2(instance):
|
||||
not instance
|
||||
|
||||
self._check_specialization(to_bool_2, H(), "TO_BOOL", should_specialize=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
Check for a valid ``tp_version_tag`` before performing bytecode specializations that
|
||||
rely on this value being usable.
|
|
@ -2409,6 +2409,32 @@ type_get_version(PyObject *self, PyObject *type)
|
|||
return res;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
type_modified(PyObject *self, PyObject *type)
|
||||
{
|
||||
if (!PyType_Check(type)) {
|
||||
PyErr_SetString(PyExc_TypeError, "argument must be a type");
|
||||
return NULL;
|
||||
}
|
||||
PyType_Modified((PyTypeObject *)type);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
// Circumvents standard version assignment machinery - use with caution and only on
|
||||
// short-lived heap types
|
||||
static PyObject *
|
||||
type_assign_specific_version_unsafe(PyObject *self, PyObject *args)
|
||||
{
|
||||
PyTypeObject *type;
|
||||
unsigned int version;
|
||||
if (!PyArg_ParseTuple(args, "Oi:type_assign_specific_version_unsafe", &type, &version)) {
|
||||
return NULL;
|
||||
}
|
||||
assert(!PyType_HasFeature(type, Py_TPFLAGS_IMMUTABLETYPE));
|
||||
type->tp_version_tag = version;
|
||||
type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject *
|
||||
type_assign_version(PyObject *self, PyObject *type)
|
||||
|
@ -3342,6 +3368,9 @@ static PyMethodDef TestMethods[] = {
|
|||
{"test_py_is_macros", test_py_is_macros, METH_NOARGS},
|
||||
{"test_py_is_funcs", test_py_is_funcs, METH_NOARGS},
|
||||
{"type_get_version", type_get_version, METH_O, PyDoc_STR("type->tp_version_tag")},
|
||||
{"type_modified", type_modified, METH_O, PyDoc_STR("PyType_Modified")},
|
||||
{"type_assign_specific_version_unsafe", type_assign_specific_version_unsafe, METH_VARARGS,
|
||||
PyDoc_STR("forcefully assign type->tp_version_tag")},
|
||||
{"type_assign_version", type_assign_version, METH_O, PyDoc_STR("PyUnstable_Type_AssignVersionTag")},
|
||||
{"type_get_tp_bases", type_get_tp_bases, METH_O},
|
||||
{"type_get_tp_mro", type_get_tp_mro, METH_O},
|
||||
|
|
|
@ -586,6 +586,7 @@ _PyCode_Quicken(PyCodeObject *code)
|
|||
static int function_kind(PyCodeObject *code);
|
||||
static bool function_check_args(PyObject *o, int expected_argcount, int opcode);
|
||||
static uint32_t function_get_version(PyObject *o, int opcode);
|
||||
static uint32_t type_get_version(PyTypeObject *t, int opcode);
|
||||
|
||||
static int
|
||||
specialize_module_load_attr(
|
||||
|
@ -874,6 +875,9 @@ _Py_Specialize_LoadAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
|
|||
PyObject *descr = NULL;
|
||||
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 0);
|
||||
assert(descr != NULL || kind == ABSENT || kind == GETSET_OVERRIDDEN);
|
||||
if (type_get_version(type, LOAD_ATTR) == 0) {
|
||||
goto fail;
|
||||
}
|
||||
switch(kind) {
|
||||
case OVERRIDING:
|
||||
SPECIALIZATION_FAIL(LOAD_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
|
||||
|
@ -1057,6 +1061,9 @@ _Py_Specialize_StoreAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
|
|||
}
|
||||
PyObject *descr;
|
||||
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 1);
|
||||
if (type_get_version(type, STORE_ATTR) == 0) {
|
||||
goto fail;
|
||||
}
|
||||
switch(kind) {
|
||||
case OVERRIDING:
|
||||
SPECIALIZATION_FAIL(STORE_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
|
||||
|
@ -1183,6 +1190,9 @@ specialize_class_load_attr(PyObject *owner, _Py_CODEUNIT *instr,
|
|||
PyObject *descr = NULL;
|
||||
DescriptorClassification kind = 0;
|
||||
kind = analyze_descriptor((PyTypeObject *)owner, name, &descr, 0);
|
||||
if (type_get_version((PyTypeObject *)owner, LOAD_ATTR) == 0) {
|
||||
return -1;
|
||||
}
|
||||
switch (kind) {
|
||||
case METHOD:
|
||||
case NON_DESCRIPTOR:
|
||||
|
@ -1455,6 +1465,18 @@ function_get_version(PyObject *o, int opcode)
|
|||
return version;
|
||||
}
|
||||
|
||||
/* Returning 0 indicates a failure. */
|
||||
static uint32_t
|
||||
type_get_version(PyTypeObject *t, int opcode)
|
||||
{
|
||||
uint32_t version = t->tp_version_tag;
|
||||
if (version == 0) {
|
||||
SPECIALIZATION_FAIL(opcode, SPEC_FAIL_OUT_OF_VERSIONS);
|
||||
return 0;
|
||||
}
|
||||
return version;
|
||||
}
|
||||
|
||||
void
|
||||
_Py_Specialize_BinarySubscr(
|
||||
PyObject *container, PyObject *sub, _Py_CODEUNIT *instr)
|
||||
|
@ -1726,6 +1748,9 @@ specialize_class_call(PyObject *callable, _Py_CODEUNIT *instr, int nargs)
|
|||
}
|
||||
if (tp->tp_new == PyBaseObject_Type.tp_new) {
|
||||
PyFunctionObject *init = get_init_for_simple_managed_python_class(tp);
|
||||
if (type_get_version(tp, CALL) == 0) {
|
||||
return -1;
|
||||
}
|
||||
if (init != NULL) {
|
||||
if (((PyCodeObject *)init->func_code)->co_argcount != nargs+1) {
|
||||
SPECIALIZATION_FAIL(CALL, SPEC_FAIL_WRONG_NUMBER_ARGUMENTS);
|
||||
|
@ -2466,7 +2491,10 @@ _Py_Specialize_ToBool(PyObject *value, _Py_CODEUNIT *instr)
|
|||
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OUT_OF_VERSIONS);
|
||||
goto failure;
|
||||
}
|
||||
uint32_t version = Py_TYPE(value)->tp_version_tag;
|
||||
uint32_t version = type_get_version(Py_TYPE(value), TO_BOOL);
|
||||
if (version == 0) {
|
||||
goto failure;
|
||||
}
|
||||
instr->op.code = TO_BOOL_ALWAYS_TRUE;
|
||||
write_u32(cache->version, version);
|
||||
assert(version);
|
||||
|
|
Loading…
Reference in New Issue