From 9728ead36181fb3f0a4b2e8a7291a3e0a702b952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Thu, 4 Jul 2024 05:10:54 +0200 Subject: [PATCH] gh-121141: add support for `copy.replace` to AST nodes (#121162) --- Doc/whatsnew/3.14.rst | 8 +- Lib/test/test_ast.py | 272 +++++++++++++++++ ...-06-29-15-21-12.gh-issue-121141.4evD6q.rst | 1 + Parser/asdl_c.py | 279 ++++++++++++++++++ Python/Python-ast.c | 279 ++++++++++++++++++ 5 files changed, 837 insertions(+), 2 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst diff --git a/Doc/whatsnew/3.14.rst b/Doc/whatsnew/3.14.rst index 9578ba0c9c9..d02c10ec9cf 100644 --- a/Doc/whatsnew/3.14.rst +++ b/Doc/whatsnew/3.14.rst @@ -89,8 +89,12 @@ Improved Modules ast --- -Added :func:`ast.compare` for comparing two ASTs. -(Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`.) +* Added :func:`ast.compare` for comparing two ASTs. + (Contributed by Batuhan Taskaya and Jeremy Hylton in :issue:`15987`.) + +* Add support for :func:`copy.replace` for AST nodes. + + (Contributed by Bénédikt Tran in :gh:`121141`.) os -- diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index fbd19620311..eb3aefd5c26 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -1149,6 +1149,25 @@ class AST_Tests(unittest.TestCase): class CopyTests(unittest.TestCase): """Test copying and pickling AST nodes.""" + @staticmethod + def iter_ast_classes(): + """Iterate over the (native) subclasses of ast.AST recursively. + + This excludes the special class ast.Index since its constructor + returns an integer. + """ + def do(cls): + if cls.__module__ != 'ast': + return + if cls is ast.Index: + return + + yield cls + for sub in cls.__subclasses__(): + yield from do(sub) + + yield from do(ast.AST) + def test_pickling(self): import pickle @@ -1218,6 +1237,259 @@ class CopyTests(unittest.TestCase): )): self.assertEqual(to_tuple(child.parent), to_tuple(node)) + def test_replace_interface(self): + for klass in self.iter_ast_classes(): + with self.subTest(klass=klass): + self.assertTrue(hasattr(klass, '__replace__')) + + fields = set(klass._fields) + with self.subTest(klass=klass, fields=fields): + node = klass(**dict.fromkeys(fields)) + # forbid positional arguments in replace() + self.assertRaises(TypeError, copy.replace, node, 1) + self.assertRaises(TypeError, node.__replace__, 1) + + def test_replace_native(self): + for klass in self.iter_ast_classes(): + fields = set(klass._fields) + attributes = set(klass._attributes) + + with self.subTest(klass=klass, fields=fields, attributes=attributes): + # use of object() to ensure that '==' and 'is' + # behave similarly in ast.compare(node, repl) + old_fields = {field: object() for field in fields} + old_attrs = {attr: object() for attr in attributes} + + # check shallow copy + node = klass(**old_fields) + repl = copy.replace(node) + self.assertTrue(ast.compare(node, repl, compare_attributes=True)) + # check when passing using attributes (they may be optional!) + node = klass(**old_fields, **old_attrs) + repl = copy.replace(node) + self.assertTrue(ast.compare(node, repl, compare_attributes=True)) + + for field in fields: + # check when we sometimes have attributes and sometimes not + for init_attrs in [{}, old_attrs]: + node = klass(**old_fields, **init_attrs) + # only change a single field (do not change attributes) + new_value = object() + repl = copy.replace(node, **{field: new_value}) + for f in fields: + old_value = old_fields[f] + # assert that there is no side-effect + self.assertIs(getattr(node, f), old_value) + # check the changes + if f != field: + self.assertIs(getattr(repl, f), old_value) + else: + self.assertIs(getattr(repl, f), new_value) + self.assertFalse(ast.compare(node, repl, compare_attributes=True)) + + for attribute in attributes: + node = klass(**old_fields, **old_attrs) + # only change a single attribute (do not change fields) + new_attr = object() + repl = copy.replace(node, **{attribute: new_attr}) + for a in attributes: + old_attr = old_attrs[a] + # assert that there is no side-effect + self.assertIs(getattr(node, a), old_attr) + # check the changes + if a != attribute: + self.assertIs(getattr(repl, a), old_attr) + else: + self.assertIs(getattr(repl, a), new_attr) + self.assertFalse(ast.compare(node, repl, compare_attributes=True)) + + def test_replace_accept_known_class_fields(self): + nid, ctx = object(), object() + + node = ast.Name(id=nid, ctx=ctx) + self.assertIs(node.id, nid) + self.assertIs(node.ctx, ctx) + + new_nid = object() + repl = copy.replace(node, id=new_nid) + # assert that there is no side-effect + self.assertIs(node.id, nid) + self.assertIs(node.ctx, ctx) + # check the changes + self.assertIs(repl.id, new_nid) + self.assertIs(repl.ctx, node.ctx) # no changes + + def test_replace_accept_known_class_attributes(self): + node = ast.parse('x').body[0].value + self.assertEqual(node.id, 'x') + self.assertEqual(node.lineno, 1) + + # constructor allows any type so replace() should do the same + lineno = object() + repl = copy.replace(node, lineno=lineno) + # assert that there is no side-effect + self.assertEqual(node.lineno, 1) + # check the changes + self.assertEqual(repl.id, node.id) + self.assertEqual(repl.ctx, node.ctx) + self.assertEqual(repl.lineno, lineno) + + _, _, state = node.__reduce__() + self.assertEqual(state['id'], 'x') + self.assertEqual(state['ctx'], node.ctx) + self.assertEqual(state['lineno'], 1) + + _, _, state = repl.__reduce__() + self.assertEqual(state['id'], 'x') + self.assertEqual(state['ctx'], node.ctx) + self.assertEqual(state['lineno'], lineno) + + def test_replace_accept_known_custom_class_fields(self): + class MyNode(ast.AST): + _fields = ('name', 'data') + __annotations__ = {'name': str, 'data': object} + __match_args__ = ('name', 'data') + + name, data = 'name', object() + + node = MyNode(name, data) + self.assertIs(node.name, name) + self.assertIs(node.data, data) + # check shallow copy + repl = copy.replace(node) + # assert that there is no side-effect + self.assertIs(node.name, name) + self.assertIs(node.data, data) + # check the shallow copy + self.assertIs(repl.name, name) + self.assertIs(repl.data, data) + + node = MyNode(name, data) + repl_data = object() + # replace custom but known field + repl = copy.replace(node, data=repl_data) + # assert that there is no side-effect + self.assertIs(node.name, name) + self.assertIs(node.data, data) + # check the changes + self.assertIs(repl.name, node.name) + self.assertIs(repl.data, repl_data) + + def test_replace_accept_known_custom_class_attributes(self): + class MyNode(ast.AST): + x = 0 + y = 1 + _attributes = ('x', 'y') + + node = MyNode() + self.assertEqual(node.x, 0) + self.assertEqual(node.y, 1) + + y = object() + # custom attributes are currently not supported and raise a warning + # because the allowed attributes are hard-coded ! + msg = ( + "MyNode.__init__ got an unexpected keyword argument 'y'. " + "Support for arbitrary keyword arguments is deprecated and " + "will be removed in Python 3.15" + ) + with self.assertWarnsRegex(DeprecationWarning, re.escape(msg)): + repl = copy.replace(node, y=y) + # assert that there is no side-effect + self.assertEqual(node.x, 0) + self.assertEqual(node.y, 1) + # check the changes + self.assertEqual(repl.x, 0) + self.assertEqual(repl.y, y) + + def test_replace_ignore_known_custom_instance_fields(self): + node = ast.parse('x').body[0].value + node.extra = extra = object() # add instance 'extra' field + context = node.ctx + + # assert initial values + self.assertIs(node.id, 'x') + self.assertIs(node.ctx, context) + self.assertIs(node.extra, extra) + # shallow copy, but drops extra fields + repl = copy.replace(node) + # assert that there is no side-effect + self.assertIs(node.id, 'x') + self.assertIs(node.ctx, context) + self.assertIs(node.extra, extra) + # verify that the 'extra' field is not kept + self.assertIs(repl.id, 'x') + self.assertIs(repl.ctx, context) + self.assertRaises(AttributeError, getattr, repl, 'extra') + + # change known native field + repl = copy.replace(node, id='y') + # assert that there is no side-effect + self.assertIs(node.id, 'x') + self.assertIs(node.ctx, context) + self.assertIs(node.extra, extra) + # verify that the 'extra' field is not kept + self.assertIs(repl.id, 'y') + self.assertIs(repl.ctx, context) + self.assertRaises(AttributeError, getattr, repl, 'extra') + + def test_replace_reject_missing_field(self): + # case: warn if deleted field is not replaced + node = ast.parse('x').body[0].value + context = node.ctx + del node.id + + self.assertRaises(AttributeError, getattr, node, 'id') + self.assertIs(node.ctx, context) + msg = "Name.__replace__ missing 1 keyword argument: 'id'." + with self.assertRaisesRegex(TypeError, re.escape(msg)): + copy.replace(node) + # assert that there is no side-effect + self.assertRaises(AttributeError, getattr, node, 'id') + self.assertIs(node.ctx, context) + + # case: do not raise if deleted field is replaced + node = ast.parse('x').body[0].value + context = node.ctx + del node.id + + self.assertRaises(AttributeError, getattr, node, 'id') + self.assertIs(node.ctx, context) + repl = copy.replace(node, id='y') + # assert that there is no side-effect + self.assertRaises(AttributeError, getattr, node, 'id') + self.assertIs(node.ctx, context) + self.assertIs(repl.id, 'y') + self.assertIs(repl.ctx, context) + + def test_replace_reject_known_custom_instance_fields_commits(self): + node = ast.parse('x').body[0].value + node.extra = extra = object() # add instance 'extra' field + context = node.ctx + + # explicit rejection of known instance fields + self.assertTrue(hasattr(node, 'extra')) + msg = "Name.__replace__ got an unexpected keyword argument 'extra'." + with self.assertRaisesRegex(TypeError, re.escape(msg)): + copy.replace(node, extra=1) + # assert that there is no side-effect + self.assertIs(node.id, 'x') + self.assertIs(node.ctx, context) + self.assertIs(node.extra, extra) + + def test_replace_reject_unknown_instance_fields(self): + node = ast.parse('x').body[0].value + context = node.ctx + + # explicit rejection of unknown extra fields + self.assertRaises(AttributeError, getattr, node, 'unknown') + msg = "Name.__replace__ got an unexpected keyword argument 'unknown'." + with self.assertRaisesRegex(TypeError, re.escape(msg)): + copy.replace(node, unknown=1) + # assert that there is no side-effect + self.assertIs(node.id, 'x') + self.assertIs(node.ctx, context) + self.assertRaises(AttributeError, getattr, node, 'unknown') class ASTHelpers_Test(unittest.TestCase): maxDiff = None diff --git a/Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst b/Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst new file mode 100644 index 00000000000..f2dc621050f --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-06-29-15-21-12.gh-issue-121141.4evD6q.rst @@ -0,0 +1 @@ +Add support for :func:`copy.replace` to AST nodes. Patch by Bénédikt Tran. diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index e338656a5b1..f3667801782 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -1132,6 +1132,279 @@ cleanup: return result; } +/* + * Perform the following validations: + * + * - All keyword arguments are known 'fields' or 'attributes'. + * - No field or attribute would be left unfilled after copy.replace(). + * + * On success, this returns 1. Otherwise, set a TypeError + * exception and returns -1 (no exception is set if some + * other internal errors occur). + * + * Parameters + * + * self The AST node instance. + * dict The AST node instance dictionary (self.__dict__). + * fields The list of fields (self._fields). + * attributes The list of attributes (self._attributes). + * kwargs Keyword arguments passed to ast_type_replace(). + * + * The 'dict', 'fields', 'attributes' and 'kwargs' arguments can be NULL. + * + * Note: this function can be removed in 3.15 since the verification + * will be done inside the constructor. + */ +static inline int +ast_type_replace_check(PyObject *self, + PyObject *dict, + PyObject *fields, + PyObject *attributes, + PyObject *kwargs) +{ + // While it is possible to make some fast paths that would avoid + // allocating objects on the stack, this would cost us readability. + // For instance, if 'fields' and 'attributes' are both empty, and + // 'kwargs' is not empty, we could raise a TypeError immediately. + PyObject *expecting = PySet_New(fields); + if (expecting == NULL) { + return -1; + } + if (attributes) { + if (_PySet_Update(expecting, attributes) < 0) { + Py_DECREF(expecting); + return -1; + } + } + // Any keyword argument that is neither a field nor attribute is rejected. + // We first need to check whether a keyword argument is accepted or not. + // If all keyword arguments are accepted, we compute the required fields + // and attributes. A field or attribute is not needed if: + // + // 1) it is given in 'kwargs', or + // 2) it already exists on 'self'. + if (kwargs) { + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(kwargs, &pos, &key, &value)) { + int rc = PySet_Discard(expecting, key); + if (rc < 0) { + Py_DECREF(expecting); + return -1; + } + if (rc == 0) { + PyErr_Format(PyExc_TypeError, + "%.400s.__replace__ got an unexpected keyword " + "argument '%U'.", Py_TYPE(self)->tp_name, key); + Py_DECREF(expecting); + return -1; + } + } + } + // check that the remaining fields or attributes would be filled + if (dict) { + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(dict, &pos, &key, &value)) { + // Mark fields or attributes that are found on the instance + // as non-mandatory. If they are not given in 'kwargs', they + // will be shallow-coied; otherwise, they would be replaced + // (not in this function). + if (PySet_Discard(expecting, key) < 0) { + Py_DECREF(expecting); + return -1; + } + } + if (attributes) { + // Some attributes may or may not be present at runtime. + // In particular, now that we checked whether 'kwargs' + // is correct or not, we allow any attribute to be missing. + // + // Note that fields must still be entirely determined when + // calling the constructor later. + PyObject *unused = PyObject_CallMethodOneArg(expecting, + &_Py_ID(difference_update), + attributes); + if (unused == NULL) { + Py_DECREF(expecting); + return -1; + } + Py_DECREF(unused); + } + } + // Now 'expecting' contains the fields or attributes + // that would not be filled inside ast_type_replace(). + Py_ssize_t m = PySet_GET_SIZE(expecting); + if (m > 0) { + PyObject *names = PyList_New(m); + if (names == NULL) { + Py_DECREF(expecting); + return -1; + } + Py_ssize_t i = 0, pos = 0; + PyObject *item; + Py_hash_t hash; + while (_PySet_NextEntry(expecting, &pos, &item, &hash)) { + PyObject *name = PyObject_Repr(item); + if (name == NULL) { + Py_DECREF(expecting); + Py_DECREF(names); + return -1; + } + // steal the reference 'name' + PyList_SET_ITEM(names, i++, name); + } + Py_DECREF(expecting); + if (PyList_Sort(names) < 0) { + Py_DECREF(names); + return -1; + } + PyObject *sep = PyUnicode_FromString(", "); + if (sep == NULL) { + Py_DECREF(names); + return -1; + } + PyObject *str_names = PyUnicode_Join(sep, names); + Py_DECREF(sep); + Py_DECREF(names); + if (str_names == NULL) { + return -1; + } + PyErr_Format(PyExc_TypeError, + "%.400s.__replace__ missing %ld keyword argument%s: %U.", + Py_TYPE(self)->tp_name, m, m == 1 ? "" : "s", str_names); + Py_DECREF(str_names); + return -1; + } + else { + Py_DECREF(expecting); + return 1; + } +} + +/* + * Python equivalent: + * + * for key in keys: + * if hasattr(self, key): + * payload[key] = getattr(self, key) + * + * The 'keys' argument is a sequence corresponding to + * the '_fields' or the '_attributes' of an AST node. + * + * This returns -1 if an error occurs and 0 otherwise. + * + * Parameters + * + * payload A dictionary to fill. + * keys A sequence of keys or NULL for an empty sequence. + * dict The AST node instance dictionary (must not be NULL). + */ +static inline int +ast_type_replace_update_payload(PyObject *payload, + PyObject *keys, + PyObject *dict) +{ + assert(dict != NULL); + if (keys == NULL) { + return 0; + } + Py_ssize_t n = PySequence_Size(keys); + if (n == -1) { + return -1; + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *key = PySequence_GetItem(keys, i); + if (key == NULL) { + return -1; + } + PyObject *value; + if (PyDict_GetItemRef(dict, key, &value) < 0) { + Py_DECREF(key); + return -1; + } + if (value == NULL) { + Py_DECREF(key); + // If a field or attribute is not present at runtime, it should + // be explicitly given in 'kwargs'. If not, the constructor will + // issue a warning (which becomes an error in 3.15). + continue; + } + int rc = PyDict_SetItem(payload, key, value); + Py_DECREF(key); + Py_DECREF(value); + if (rc < 0) { + return -1; + } + } + return 0; +} + +/* copy.replace() support (shallow copy) */ +static PyObject * +ast_type_replace(PyObject *self, PyObject *args, PyObject *kwargs) +{ + if (!_PyArg_NoPositional("__replace__", args)) { + return NULL; + } + + struct ast_state *state = get_ast_state(); + if (state == NULL) { + return NULL; + } + + PyObject *result = NULL; + // known AST class fields and attributes + PyObject *fields = NULL, *attributes = NULL; + // current instance dictionary + PyObject *dict = NULL; + // constructor positional and keyword arguments + PyObject *empty_tuple = NULL, *payload = NULL; + + PyObject *type = (PyObject *)Py_TYPE(self); + if (PyObject_GetOptionalAttr(type, state->_fields, &fields) < 0) { + goto cleanup; + } + if (PyObject_GetOptionalAttr(type, state->_attributes, &attributes) < 0) { + goto cleanup; + } + if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { + goto cleanup; + } + if (ast_type_replace_check(self, dict, fields, attributes, kwargs) < 0) { + goto cleanup; + } + empty_tuple = PyTuple_New(0); + if (empty_tuple == NULL) { + goto cleanup; + } + payload = PyDict_New(); + if (payload == NULL) { + goto cleanup; + } + if (dict) { // in case __dict__ is missing (for some obscure reason) + // copy the instance's fields (possibly NULL) + if (ast_type_replace_update_payload(payload, fields, dict) < 0) { + goto cleanup; + } + // copy the instance's attributes (possibly NULL) + if (ast_type_replace_update_payload(payload, attributes, dict) < 0) { + goto cleanup; + } + } + if (kwargs && PyDict_Update(payload, kwargs) < 0) { + goto cleanup; + } + result = PyObject_Call(type, empty_tuple, payload); +cleanup: + Py_XDECREF(payload); + Py_XDECREF(empty_tuple); + Py_XDECREF(dict); + Py_XDECREF(attributes); + Py_XDECREF(fields); + return result; +} + static PyMemberDef ast_type_members[] = { {"__dictoffset__", Py_T_PYSSIZET, offsetof(AST_object, dict), Py_READONLY}, {NULL} /* Sentinel */ @@ -1139,6 +1412,10 @@ static PyMemberDef ast_type_members[] = { static PyMethodDef ast_type_methods[] = { {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, + {"__replace__", _PyCFunction_CAST(ast_type_replace), METH_VARARGS | METH_KEYWORDS, + PyDoc_STR("__replace__($self, /, **fields)\\n--\\n\\n" + "Return a copy of the AST node with new values " + "for the specified fields.")}, {NULL} }; @@ -1773,7 +2050,9 @@ def generate_module_def(mod, metadata, f, internal_h): #include "pycore_ceval.h" // _Py_EnterRecursiveCall #include "pycore_lock.h" // _PyOnceFlag #include "pycore_interp.h" // _PyInterpreterState.ast + #include "pycore_modsupport.h" // _PyArg_NoPositional() #include "pycore_pystate.h" // _PyInterpreterState_GET() + #include "pycore_setobject.h" // _PySet_NextEntry(), _PySet_Update() #include "pycore_unionobject.h" // _Py_union_type_or #include "structmember.h" #include diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 01ffea18693..cca2ee409e7 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -6,7 +6,9 @@ #include "pycore_ceval.h" // _Py_EnterRecursiveCall #include "pycore_lock.h" // _PyOnceFlag #include "pycore_interp.h" // _PyInterpreterState.ast +#include "pycore_modsupport.h" // _PyArg_NoPositional() #include "pycore_pystate.h" // _PyInterpreterState_GET() +#include "pycore_setobject.h" // _PySet_NextEntry(), _PySet_Update() #include "pycore_unionobject.h" // _Py_union_type_or #include "structmember.h" #include @@ -5331,6 +5333,279 @@ cleanup: return result; } +/* + * Perform the following validations: + * + * - All keyword arguments are known 'fields' or 'attributes'. + * - No field or attribute would be left unfilled after copy.replace(). + * + * On success, this returns 1. Otherwise, set a TypeError + * exception and returns -1 (no exception is set if some + * other internal errors occur). + * + * Parameters + * + * self The AST node instance. + * dict The AST node instance dictionary (self.__dict__). + * fields The list of fields (self._fields). + * attributes The list of attributes (self._attributes). + * kwargs Keyword arguments passed to ast_type_replace(). + * + * The 'dict', 'fields', 'attributes' and 'kwargs' arguments can be NULL. + * + * Note: this function can be removed in 3.15 since the verification + * will be done inside the constructor. + */ +static inline int +ast_type_replace_check(PyObject *self, + PyObject *dict, + PyObject *fields, + PyObject *attributes, + PyObject *kwargs) +{ + // While it is possible to make some fast paths that would avoid + // allocating objects on the stack, this would cost us readability. + // For instance, if 'fields' and 'attributes' are both empty, and + // 'kwargs' is not empty, we could raise a TypeError immediately. + PyObject *expecting = PySet_New(fields); + if (expecting == NULL) { + return -1; + } + if (attributes) { + if (_PySet_Update(expecting, attributes) < 0) { + Py_DECREF(expecting); + return -1; + } + } + // Any keyword argument that is neither a field nor attribute is rejected. + // We first need to check whether a keyword argument is accepted or not. + // If all keyword arguments are accepted, we compute the required fields + // and attributes. A field or attribute is not needed if: + // + // 1) it is given in 'kwargs', or + // 2) it already exists on 'self'. + if (kwargs) { + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(kwargs, &pos, &key, &value)) { + int rc = PySet_Discard(expecting, key); + if (rc < 0) { + Py_DECREF(expecting); + return -1; + } + if (rc == 0) { + PyErr_Format(PyExc_TypeError, + "%.400s.__replace__ got an unexpected keyword " + "argument '%U'.", Py_TYPE(self)->tp_name, key); + Py_DECREF(expecting); + return -1; + } + } + } + // check that the remaining fields or attributes would be filled + if (dict) { + Py_ssize_t pos = 0; + PyObject *key, *value; + while (PyDict_Next(dict, &pos, &key, &value)) { + // Mark fields or attributes that are found on the instance + // as non-mandatory. If they are not given in 'kwargs', they + // will be shallow-coied; otherwise, they would be replaced + // (not in this function). + if (PySet_Discard(expecting, key) < 0) { + Py_DECREF(expecting); + return -1; + } + } + if (attributes) { + // Some attributes may or may not be present at runtime. + // In particular, now that we checked whether 'kwargs' + // is correct or not, we allow any attribute to be missing. + // + // Note that fields must still be entirely determined when + // calling the constructor later. + PyObject *unused = PyObject_CallMethodOneArg(expecting, + &_Py_ID(difference_update), + attributes); + if (unused == NULL) { + Py_DECREF(expecting); + return -1; + } + Py_DECREF(unused); + } + } + // Now 'expecting' contains the fields or attributes + // that would not be filled inside ast_type_replace(). + Py_ssize_t m = PySet_GET_SIZE(expecting); + if (m > 0) { + PyObject *names = PyList_New(m); + if (names == NULL) { + Py_DECREF(expecting); + return -1; + } + Py_ssize_t i = 0, pos = 0; + PyObject *item; + Py_hash_t hash; + while (_PySet_NextEntry(expecting, &pos, &item, &hash)) { + PyObject *name = PyObject_Repr(item); + if (name == NULL) { + Py_DECREF(expecting); + Py_DECREF(names); + return -1; + } + // steal the reference 'name' + PyList_SET_ITEM(names, i++, name); + } + Py_DECREF(expecting); + if (PyList_Sort(names) < 0) { + Py_DECREF(names); + return -1; + } + PyObject *sep = PyUnicode_FromString(", "); + if (sep == NULL) { + Py_DECREF(names); + return -1; + } + PyObject *str_names = PyUnicode_Join(sep, names); + Py_DECREF(sep); + Py_DECREF(names); + if (str_names == NULL) { + return -1; + } + PyErr_Format(PyExc_TypeError, + "%.400s.__replace__ missing %ld keyword argument%s: %U.", + Py_TYPE(self)->tp_name, m, m == 1 ? "" : "s", str_names); + Py_DECREF(str_names); + return -1; + } + else { + Py_DECREF(expecting); + return 1; + } +} + +/* + * Python equivalent: + * + * for key in keys: + * if hasattr(self, key): + * payload[key] = getattr(self, key) + * + * The 'keys' argument is a sequence corresponding to + * the '_fields' or the '_attributes' of an AST node. + * + * This returns -1 if an error occurs and 0 otherwise. + * + * Parameters + * + * payload A dictionary to fill. + * keys A sequence of keys or NULL for an empty sequence. + * dict The AST node instance dictionary (must not be NULL). + */ +static inline int +ast_type_replace_update_payload(PyObject *payload, + PyObject *keys, + PyObject *dict) +{ + assert(dict != NULL); + if (keys == NULL) { + return 0; + } + Py_ssize_t n = PySequence_Size(keys); + if (n == -1) { + return -1; + } + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *key = PySequence_GetItem(keys, i); + if (key == NULL) { + return -1; + } + PyObject *value; + if (PyDict_GetItemRef(dict, key, &value) < 0) { + Py_DECREF(key); + return -1; + } + if (value == NULL) { + Py_DECREF(key); + // If a field or attribute is not present at runtime, it should + // be explicitly given in 'kwargs'. If not, the constructor will + // issue a warning (which becomes an error in 3.15). + continue; + } + int rc = PyDict_SetItem(payload, key, value); + Py_DECREF(key); + Py_DECREF(value); + if (rc < 0) { + return -1; + } + } + return 0; +} + +/* copy.replace() support (shallow copy) */ +static PyObject * +ast_type_replace(PyObject *self, PyObject *args, PyObject *kwargs) +{ + if (!_PyArg_NoPositional("__replace__", args)) { + return NULL; + } + + struct ast_state *state = get_ast_state(); + if (state == NULL) { + return NULL; + } + + PyObject *result = NULL; + // known AST class fields and attributes + PyObject *fields = NULL, *attributes = NULL; + // current instance dictionary + PyObject *dict = NULL; + // constructor positional and keyword arguments + PyObject *empty_tuple = NULL, *payload = NULL; + + PyObject *type = (PyObject *)Py_TYPE(self); + if (PyObject_GetOptionalAttr(type, state->_fields, &fields) < 0) { + goto cleanup; + } + if (PyObject_GetOptionalAttr(type, state->_attributes, &attributes) < 0) { + goto cleanup; + } + if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { + goto cleanup; + } + if (ast_type_replace_check(self, dict, fields, attributes, kwargs) < 0) { + goto cleanup; + } + empty_tuple = PyTuple_New(0); + if (empty_tuple == NULL) { + goto cleanup; + } + payload = PyDict_New(); + if (payload == NULL) { + goto cleanup; + } + if (dict) { // in case __dict__ is missing (for some obscure reason) + // copy the instance's fields (possibly NULL) + if (ast_type_replace_update_payload(payload, fields, dict) < 0) { + goto cleanup; + } + // copy the instance's attributes (possibly NULL) + if (ast_type_replace_update_payload(payload, attributes, dict) < 0) { + goto cleanup; + } + } + if (kwargs && PyDict_Update(payload, kwargs) < 0) { + goto cleanup; + } + result = PyObject_Call(type, empty_tuple, payload); +cleanup: + Py_XDECREF(payload); + Py_XDECREF(empty_tuple); + Py_XDECREF(dict); + Py_XDECREF(attributes); + Py_XDECREF(fields); + return result; +} + static PyMemberDef ast_type_members[] = { {"__dictoffset__", Py_T_PYSSIZET, offsetof(AST_object, dict), Py_READONLY}, {NULL} /* Sentinel */ @@ -5338,6 +5613,10 @@ static PyMemberDef ast_type_members[] = { static PyMethodDef ast_type_methods[] = { {"__reduce__", ast_type_reduce, METH_NOARGS, NULL}, + {"__replace__", _PyCFunction_CAST(ast_type_replace), METH_VARARGS | METH_KEYWORDS, + PyDoc_STR("__replace__($self, /, **fields)\n--\n\n" + "Return a copy of the AST node with new values " + "for the specified fields.")}, {NULL} };