Skip to content

gh-120108: Fix deepcopying of AST trees with .parent attributes #120114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 75 additions & 10 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import builtins
import copy
import dis
import enum
import os
Expand All @@ -20,7 +21,7 @@
from test.support.ast_helper import ASTTestMixin

def to_tuple(t):
if t is None or isinstance(t, (str, int, complex)) or t is Ellipsis:
if t is None or isinstance(t, (str, int, complex, float, bytes)) or t is Ellipsis:
return t
elif isinstance(t, list):
return [to_tuple(e) for e in t]
Expand Down Expand Up @@ -775,15 +776,6 @@ def test_no_fields(self):
x = ast.Sub()
self.assertEqual(x._fields, ())

def test_pickling(self):
import pickle

for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
for ast in (compile(i, "?", "exec", 0x400) for i in exec_tests):
with self.subTest(ast=ast, protocol=protocol):
ast2 = pickle.loads(pickle.dumps(ast, protocol))
self.assertEqual(to_tuple(ast2), to_tuple(ast))

def test_invalid_sum(self):
pos = dict(lineno=2, col_offset=3)
m = ast.Module([ast.Expr(ast.expr(**pos), **pos)], [])
Expand Down Expand Up @@ -1135,6 +1127,79 @@ def test_none_checks(self) -> None:
self.assert_none_check(node, attr, source)


class CopyTests(unittest.TestCase):
"""Test copying and pickling AST nodes."""

def test_pickling(self):
import pickle

for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
for code in exec_tests:
with self.subTest(code=code, protocol=protocol):
tree = compile(code, "?", "exec", 0x400)
ast2 = pickle.loads(pickle.dumps(tree, protocol))
self.assertEqual(to_tuple(ast2), to_tuple(tree))

def test_copy_with_parents(self):
# gh-120108
code = """
('',)
while i < n:
if ch == '':
ch = format[i]
if ch == '':
if freplace is None:
'' % getattr(object)
elif ch == '':
if zreplace is None:
if hasattr:
offset = object.utcoffset()
if offset is not None:
if offset.days < 0:
offset = -offset
h = divmod(timedelta(hours=0))
if u:
zreplace = '' % (sign,)
elif s:
zreplace = '' % (sign,)
else:
zreplace = '' % (sign,)
elif ch == '':
if Zreplace is None:
Zreplace = ''
if hasattr(object):
s = object.tzname()
if s is not None:
Zreplace = s.replace('')
newformat.append(Zreplace)
else:
push('')
else:
push(ch)

"""
tree = ast.parse(textwrap.dedent(code))
for node in ast.walk(tree):
for child in ast.iter_child_nodes(node):
child.parent = node
try:
with support.infinite_recursion(200):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe bisect the exact recursion limit being needed so that we exactly know whether something changed or not?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should guarantee the exact number of stack frames needed for this test; future changes may mean we need slightly more or fewer. The point of the test is to ensure the number of stack frames does not grow to an unreasonable level.

tree2 = copy.deepcopy(tree)
finally:
# Singletons like ast.Load() are shared; make sure we don't
# leave them mutated after this test.
for node in ast.walk(tree):
if hasattr(node, "parent"):
del node.parent

for node in ast.walk(tree2):
for child in ast.iter_child_nodes(node):
if hasattr(child, "parent") and not isinstance(child, (
ast.expr_context, ast.boolop, ast.unaryop, ast.cmpop, ast.operator,
)):
self.assertEqual(to_tuple(child.parent), to_tuple(node))


class ASTHelpers_Test(unittest.TestCase):
maxDiff = None

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix calling :func:`copy.deepcopy` on :mod:`ast` trees that have been
modified to have references to parent nodes. Patch by Jelle Zijlstra.
31 changes: 14 additions & 17 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,17 +1064,22 @@ def visitModule(self, mod):
return NULL;
}

PyObject *dict = NULL, *fields = NULL, *remaining_fields = NULL,
*remaining_dict = NULL, *positional_args = NULL;
PyObject *dict = NULL, *fields = NULL, *positional_args = NULL;
if (PyObject_GetOptionalAttr(self, state->__dict__, &dict) < 0) {
return NULL;
}
PyObject *result = NULL;
if (dict) {
// Serialize the fields as positional args if possible, because if we
// serialize them as a dict, during unpickling they are set only *after*
// the object is constructed, which will now trigger a DeprecationWarning
// if the AST type has required fields.
// Unpickling (or copying) works as follows:
// - Construct the object with only positional arguments
// - Set the fields from the dict
// We have two constraints:
// - We must set all the required fields in the initial constructor call,
// or the unpickling or deepcopying of the object will trigger DeprecationWarnings.
// - We must not include child nodes in the positional args, because
// that may trigger runaway recursion during copying (gh-120108).
// To satisfy both constraints, we set all the fields to None in the
// initial list of positional args, and then set the fields from the dict.
if (PyObject_GetOptionalAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
Expand All @@ -1084,11 +1089,6 @@ def visitModule(self, mod):
Py_DECREF(dict);
goto cleanup;
}
remaining_dict = PyDict_Copy(dict);
Py_DECREF(dict);
if (!remaining_dict) {
goto cleanup;
}
positional_args = PyList_New(0);
if (!positional_args) {
goto cleanup;
Expand All @@ -1099,15 +1099,15 @@ def visitModule(self, mod):
goto cleanup;
}
PyObject *value;
int rc = PyDict_Pop(remaining_dict, name, &value);
int rc = PyDict_GetItemRef(dict, name, &value);
Py_DECREF(name);
if (rc < 0) {
goto cleanup;
}
if (!value) {
break;
}
rc = PyList_Append(positional_args, value);
rc = PyList_Append(positional_args, Py_None);
Py_DECREF(value);
if (rc < 0) {
goto cleanup;
Expand All @@ -1117,8 +1117,7 @@ def visitModule(self, mod):
if (!args_tuple) {
goto cleanup;
}
result = Py_BuildValue("ONO", Py_TYPE(self), args_tuple,
remaining_dict);
result = Py_BuildValue("ONN", Py_TYPE(self), args_tuple, dict);
}
else {
result = Py_BuildValue("O()N", Py_TYPE(self), dict);
Expand All @@ -1129,8 +1128,6 @@ def visitModule(self, mod):
}
cleanup:
Py_XDECREF(fields);
Py_XDECREF(remaining_fields);
Py_XDECREF(remaining_dict);
Py_XDECREF(positional_args);
return result;
}
Expand Down
31 changes: 14 additions & 17 deletions Python/Python-ast.c

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading