From 2e3e00f0ba48cf82bdb6523c54d7004a28e68f27 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 5 Jun 2024 08:03:02 -0700 Subject: [PATCH 1/4] gh-120108: Fix deepcopying of AST trees with .parent attributes --- Lib/test/test_ast.py | 89 ++++++++++++++++--- ...-06-05-08-02-46.gh-issue-120108.4U9BL8.rst | 2 + Parser/asdl_c.py | 28 +++--- Python/Python-ast.c | 28 +++--- 4 files changed, 109 insertions(+), 38 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 18b2f7ffca6083..e998d4ad64f593 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -1,5 +1,6 @@ import ast import builtins +import copy import dis import enum import os @@ -20,7 +21,7 @@ from test.support.ast_helper import ASTTestMixin def to_tuple(t): - if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis: + if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis: return t elif isinstance(t, list): return [to_tuple(e) for e in t] @@ -123,7 +124,10 @@ def to_tuple(t): # Global "global v", # Expr + "b'x'", + "'x'", "1", + "1.9", # Pass, "pass", # Break @@ -666,15 +670,6 @@ def test_no_fields(self): x = ast.Sub() self.assertEqual(x._fields, ()) - def test_pickling(self): - import pickle - - for protocol in range(pickle.HIGHEST_PROTOCOL + 1): - for ast in (compile(i, "?", "exec", 0x400) for i in exec_tests): - with self.subTest(ast=ast, protocol=protocol): - ast2 = pickle.loads(pickle.dumps(ast, protocol)) - self.assertEqual(to_tuple(ast2), to_tuple(ast)) - def test_invalid_sum(self): pos = dict(lineno=2, col_offset=3) m = ast.Module([ast.Expr(ast.expr(**pos), **pos)], []) @@ -1026,6 +1021,80 @@ def test_none_checks(self) -> None: self.assert_none_check(node, attr, source) +class CopyTests(unittest.TestCase): + """Test copying and pickling AST nodes.""" + + def test_pickling(self): + import pickle + + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + for code in exec_tests: + with self.subTest(code=code, protocol=protocol): + tree = compile(code, "?", "exec", 0x400) + print(ast.dump(tree), tree.__dict__) + ast2 = pickle.loads(pickle.dumps(tree, protocol)) + self.assertEqual(to_tuple(ast2), to_tuple(tree)) + + def test_copy_with_parents(self): + # gh-120108 + code = """ + ('',) + while i < n: + if ch == '': + ch = format[i] + if ch == '': + if freplace is None: + '' % getattr(object) + elif ch == '': + if zreplace is None: + if hasattr: + offset = object.utcoffset() + if offset is not None: + if offset.days < 0: + offset = -offset + h = divmod(timedelta(hours=0)) + if u: + zreplace = '' % (sign,) + elif s: + zreplace = '' % (sign,) + else: + zreplace = '' % (sign,) + elif ch == '': + if Zreplace is None: + Zreplace = '' + if hasattr(object): + s = object.tzname() + if s is not None: + Zreplace = s.replace('') + newformat.append(Zreplace) + else: + push('') + else: + push(ch) + + """ + tree = ast.parse(textwrap.dedent(code)) + for node in ast.walk(tree): + for child in ast.iter_child_nodes(node): + child.parent = node + try: + with support.infinite_recursion(200): + tree2 = copy.deepcopy(tree) + finally: + # Singletons like ast.Load() are shared; make sure we don't + # leave them mutated after this test. + for node in ast.walk(tree): + if hasattr(node, "parent"): + del node.parent + + for node in ast.walk(tree2): + for child in ast.iter_child_nodes(node): + if hasattr(child, "parent") and not isinstance(child, ( + ast.expr_context, ast.boolop, ast.unaryop, ast.cmpop, ast.operator, + )): + self.assertEqual(to_tuple(child.parent), to_tuple(node)) + + class ASTHelpers_Test(unittest.TestCase): maxDiff = None diff --git a/Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst b/Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst new file mode 100644 index 00000000000000..e310695656255d --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-06-05-08-02-46.gh-issue-120108.4U9BL8.rst @@ -0,0 +1,2 @@ +Fix calling :func:`copy.deepcopy` on :mod:`ast` trees that have been +modified to have references to parent nodes. Patch by Jelle Zijlstra. diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index 9961d23629abc5..aa97b04c2c38e3 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -1065,16 +1065,22 @@ def visitModule(self, mod): } PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL, - *remaining_dict = NULL, *positional_args = NULL; + *positional_args = NULL; if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { return NULL; } PyObject *result = NULL; if (dict) { - // Serialize the fields as positional args if possible, because if we - // serialize them as a dict, during unpickling they are set only *after* - // the object is constructed, which will now trigger a DeprecationWarning - // if the AST type has required fields. + // Unpickling (or copying) works as follows: + // - Construct the object with only positional arguments + // - Set the fields from the dict + // We have two constraints: + // - We must set all the required fields in the initial constructor call, + // or the unpickling or deepcopying of the object will trigger DeprecationWarnings. + // - We must not include child nodes in the positional args, because + // that may trigger runaway recursion during copying (gh-120108). + // To satisfy both constraints, we set all the fields to None in the + // initial list of positional args, and then set the fields from the dict. if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { goto cleanup; } @@ -1084,11 +1090,6 @@ def visitModule(self, mod): Py_DECREF(dict); goto cleanup; } - remaining_dict = PyDict_Copy(dict); - Py_DECREF(dict); - if (!remaining_dict) { - goto cleanup; - } positional_args = PyList_New(0); if (!positional_args) { goto cleanup; @@ -1099,7 +1100,7 @@ def visitModule(self, mod): goto cleanup; } PyObject *value; - int rc = PyDict_Pop(remaining_dict, name, &value); + int rc = PyDict_GetItemRef(dict, name, &value); Py_DECREF(name); if (rc < 0) { goto cleanup; @@ -1107,7 +1108,7 @@ def visitModule(self, mod): if (!value) { break; } - rc = PyList_Append(positional_args, value); + rc = PyList_Append(positional_args, Py_None); Py_DECREF(value); if (rc < 0) { goto cleanup; @@ -1118,7 +1119,7 @@ def visitModule(self, mod): goto cleanup; } result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple, - remaining_dict); + dict); } else { result = Py_BuildValue("O()N", Py_TYPE(self), dict); @@ -1130,7 +1131,6 @@ def visitModule(self, mod): cleanup: Py_XDECREF(fields); Py_XDECREF(remaining_fields); - Py_XDECREF(remaining_dict); Py_XDECREF(positional_args); return result; } diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 7aa1c5119d8f28..8d10772839a392 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -5264,16 +5264,22 @@ ast_type_reduce(PyObject *self, PyObject *unused) } PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL, - *remaining_dict = NULL, *positional_args = NULL; + *positional_args = NULL; if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { return NULL; } PyObject *result = NULL; if (dict) { - // Serialize the fields as positional args if possible, because if we - // serialize them as a dict, during unpickling they are set only *after* - // the object is constructed, which will now trigger a DeprecationWarning - // if the AST type has required fields. + // Unpickling (or copying) works as follows: + // - Construct the object with only positional arguments + // - Set the fields from the dict + // We have two constraints: + // - We must set all the required fields in the initial constructor call, + // or the unpickling or deepcopying of the object will trigger DeprecationWarnings. + // - We must not include child nodes in the positional args, because + // that may trigger runaway recursion during copying (gh-120108). + // To satisfy both constraints, we set all the fields to None in the + // initial list of positional args, and then set the fields from the dict. if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) { goto cleanup; } @@ -5283,11 +5289,6 @@ ast_type_reduce(PyObject *self, PyObject *unused) Py_DECREF(dict); goto cleanup; } - remaining_dict = PyDict_Copy(dict); - Py_DECREF(dict); - if (!remaining_dict) { - goto cleanup; - } positional_args = PyList_New(0); if (!positional_args) { goto cleanup; @@ -5298,7 +5299,7 @@ ast_type_reduce(PyObject *self, PyObject *unused) goto cleanup; } PyObject *value; - int rc = PyDict_Pop(remaining_dict, name, &value); + int rc = PyDict_GetItemRef(dict, name, &value); Py_DECREF(name); if (rc < 0) { goto cleanup; @@ -5306,7 +5307,7 @@ ast_type_reduce(PyObject *self, PyObject *unused) if (!value) { break; } - rc = PyList_Append(positional_args, value); + rc = PyList_Append(positional_args, Py_None); Py_DECREF(value); if (rc < 0) { goto cleanup; @@ -5317,7 +5318,7 @@ ast_type_reduce(PyObject *self, PyObject *unused) goto cleanup; } result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple, - remaining_dict); + dict); } else { result = Py_BuildValue("O()N", Py_TYPE(self), dict); @@ -5329,7 +5330,6 @@ ast_type_reduce(PyObject *self, PyObject *unused) cleanup: Py_XDECREF(fields); Py_XDECREF(remaining_fields); - Py_XDECREF(remaining_dict); Py_XDECREF(positional_args); return result; } From 1599c90316bbf09bafd4d7cd2c6ab52270850503 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 5 Jun 2024 08:19:32 -0700 Subject: [PATCH 2/4] drop extra tests --- Lib/test/test_ast.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index e998d4ad64f593..c20afffdd25595 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -124,10 +124,7 @@ def to_tuple(t): # Global "global v", # Expr - "b'x'", - "'x'", "1", - "1.9", # Pass, "pass", # Break From 8aaa543a973551f7a78b90ff0c3c6864addab974 Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Wed, 5 Jun 2024 08:22:18 -0700 Subject: [PATCH 3/4] fixes --- Lib/test/test_ast.py | 1 - Parser/asdl_c.py | 2 +- Python/Python-ast.c | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index c20afffdd25595..c675dc64145886 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -1028,7 +1028,6 @@ def test_pickling(self): for code in exec_tests: with self.subTest(code=code, protocol=protocol): tree = compile(code, "?", "exec", 0x400) - print(ast.dump(tree), tree.__dict__) ast2 = pickle.loads(pickle.dumps(tree, protocol)) self.assertEqual(to_tuple(ast2), to_tuple(tree)) diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index aa97b04c2c38e3..d9002232ef0b96 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -1118,7 +1118,7 @@ def visitModule(self, mod): if (!args_tuple) { goto cleanup; } - result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple, + result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict); } else { diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 8d10772839a392..4714caea3b6e8a 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -5317,7 +5317,7 @@ ast_type_reduce(PyObject *self, PyObject *unused) if (!args_tuple) { goto cleanup; } - result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple, + result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict); } else { From 58dd30ca67915836e015aea029aacd2ed97f22bf Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 10 Jun 2024 23:23:15 -0600 Subject: [PATCH 4/4] Apply changes from review --- Parser/asdl_c.py | 7 ++----- Python/Python-ast.c | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index d9002232ef0b96..e338656a5b1eb9 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -1064,8 +1064,7 @@ def visitModule(self, mod): return NULL; } - PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL, - *positional_args = NULL; + PyObject *dict = NULL, *fields = NULL, *positional_args = NULL; if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { return NULL; } @@ -1118,8 +1117,7 @@ def visitModule(self, mod): if (!args_tuple) { goto cleanup; } - result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, - dict); + result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict); } else { result = Py_BuildValue("O()N", Py_TYPE(self), dict); @@ -1130,7 +1128,6 @@ def visitModule(self, mod): } cleanup: Py_XDECREF(fields); - Py_XDECREF(remaining_fields); Py_XDECREF(positional_args); return result; } diff --git a/Python/Python-ast.c b/Python/Python-ast.c index 4714caea3b6e8a..01ffea1869350b 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -5263,8 +5263,7 @@ ast_type_reduce(PyObject *self, PyObject *unused) return NULL; } - PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL, - *positional_args = NULL; + PyObject *dict = NULL, *fields = NULL, *positional_args = NULL; if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) { return NULL; } @@ -5317,8 +5316,7 @@ ast_type_reduce(PyObject *self, PyObject *unused) if (!args_tuple) { goto cleanup; } - result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, - dict); + result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict); } else { result = Py_BuildValue("O()N", Py_TYPE(self), dict); @@ -5329,7 +5327,6 @@ ast_type_reduce(PyObject *self, PyObject *unused) } cleanup: Py_XDECREF(fields); - Py_XDECREF(remaining_fields); Py_XDECREF(positional_args); return result; }