gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035)

Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
This commit is contained in:
Yilei Yang 2023-12-25 09:36:59 -08:00 committed by GitHub
parent 3f5eb3e6c7
commit 48c49739f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 412 additions and 339 deletions

View File

@ -16,8 +16,6 @@ extern "C" {
struct ast_state {
_PyOnceFlag once;
int finalized;
int recursion_depth;
int recursion_limit;
PyObject *AST_type;
PyObject *Add_singleton;
PyObject *Add_type;

View File

@ -0,0 +1,7 @@
Use per AST-parser state rather than global state to track recursion depth
within the AST parser to prevent potential race condition due to
simultaneous parsing.
The issue primarily showed up in 3.11 by multithreaded users of
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
triggered prevented the race condition from occurring.

View File

@ -731,7 +731,7 @@ class SequenceConstructorVisitor(EmitVisitor):
class PyTypesDeclareVisitor(PickleVisitor):
def visitProduct(self, prod, name):
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
if prod.attributes:
self.emit("static const char * const %s_attributes[] = {" % name, 0)
for a in prod.attributes:
@ -752,7 +752,7 @@ class PyTypesDeclareVisitor(PickleVisitor):
ptype = "void*"
if is_simple(sum):
ptype = get_c_type(name)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
for t in sum.types:
self.visitConstructor(t, name)
@ -984,7 +984,8 @@ add_attributes(struct ast_state *state, PyObject *type, const char * const *attr
/* Conversion AST -> Python */
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
{
Py_ssize_t i, n = asdl_seq_LEN(seq);
PyObject *result = PyList_New(n);
@ -992,7 +993,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
if (!result)
return NULL;
for (i = 0; i < n; i++) {
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
if (!value) {
Py_DECREF(result);
return NULL;
@ -1002,7 +1003,7 @@ static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject*
return result;
}
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
{
PyObject *op = (PyObject*)o;
if (!op) {
@ -1014,7 +1015,7 @@ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
#define ast2obj_identifier ast2obj_object
#define ast2obj_string ast2obj_object
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
{
return PyLong_FromLong(b);
}
@ -1116,8 +1117,6 @@ static int add_ast_fields(struct ast_state *state)
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
state->recursion_depth = 0;
state->recursion_limit = 0;
return 0;
}
'''))
@ -1260,7 +1259,7 @@ class ObjVisitor(PickleVisitor):
def func_begin(self, name):
ctype = get_c_type(name)
self.emit("PyObject*", 0)
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
self.emit("{", 0)
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
self.emit("PyObject *result = NULL, *value = NULL;", 1)
@ -1268,17 +1267,17 @@ class ObjVisitor(PickleVisitor):
self.emit('if (!o) {', 1)
self.emit("Py_RETURN_NONE;", 2)
self.emit("}", 1)
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
self.emit("return NULL;", 2)
self.emit("}", 1)
def func_end(self):
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("return result;", 1)
self.emit("failed:", 0)
self.emit("state->recursion_depth--;", 1)
self.emit("vstate->recursion_depth--;", 1)
self.emit("Py_XDECREF(value);", 1)
self.emit("Py_XDECREF(result);", 1)
self.emit("return NULL;", 1)
@ -1296,7 +1295,7 @@ class ObjVisitor(PickleVisitor):
self.visitConstructor(t, i + 1, name)
self.emit("}", 1)
for a in sum.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
self.emit('goto failed;', 2)
@ -1304,7 +1303,7 @@ class ObjVisitor(PickleVisitor):
self.func_end()
def simpleSum(self, sum, name):
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
self.emit("{", 0)
self.emit("switch(o) {", 1)
for t in sum.types:
@ -1322,7 +1321,7 @@ class ObjVisitor(PickleVisitor):
for field in prod.fields:
self.visitField(field, name, 1, True)
for a in prod.attributes:
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
self.emit("if (!value) goto failed;", 1)
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
self.emit('goto failed;', 2)
@ -1363,7 +1362,7 @@ class ObjVisitor(PickleVisitor):
self.emit("for(i = 0; i < n; i++)", depth+1)
# This cannot fail, so no need for error handling
self.emit(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
field.type,
value
),
@ -1372,9 +1371,9 @@ class ObjVisitor(PickleVisitor):
)
self.emit("}", depth)
else:
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
else:
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
class PartingShots(StaticVisitor):
@ -1394,18 +1393,19 @@ PyObject* PyAST_mod2obj(mod_ty t)
if (!tstate) {
return NULL;
}
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
struct validator vstate;
vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
state->recursion_depth = starting_recursion_depth;
vstate.recursion_depth = starting_recursion_depth;
PyObject *result = ast2obj_mod(state, t);
PyObject *result = ast2obj_mod(state, &vstate, t);
/* Check that the recursion depth counting balanced correctly */
if (result && state->recursion_depth != starting_recursion_depth) {
if (result && vstate.recursion_depth != starting_recursion_depth) {
PyErr_Format(PyExc_SystemError,
"AST constructor recursion depth mismatch (before=%d, after=%d)",
starting_recursion_depth, state->recursion_depth);
starting_recursion_depth, vstate.recursion_depth);
return NULL;
}
return result;
@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
f.write('struct ast_state {\n')
f.write(' _PyOnceFlag once;\n')
f.write(' int finalized;\n')
f.write(' int recursion_depth;\n')
f.write(' int recursion_limit;\n')
for s in module_state:
f.write(' PyObject *' + s + ';\n')
f.write('};')
@ -1539,6 +1537,11 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include <stddef.h>
struct validator {
int recursion_depth; /* current recursion depth */
int recursion_limit; /* recursion limit */
};
// Forward declaration
static int init_types(struct ast_state *state);

689
Python/Python-ast.c generated

File diff suppressed because it is too large Load Diff