Skip to content

gh-134935: Use RecursionError to check for circular references #134937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 18 additions & 31 deletions Lib/json/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,23 @@ def encode(self, o):
return encode_basestring_ascii(o)
else:
return encode_basestring(o)
# This doesn't pass the iterator directly to ''.join() because the
# exceptions aren't as detailed. The list call should be roughly
# equivalent to the PySequence_Fast that ''.join() would do.
chunks = self.iterencode(o, _one_shot=True)
if not isinstance(chunks, (list, tuple)):
chunks = list(chunks)

# There are tests for bad bool args
bool(self.check_circular)
try:
# This doesn't pass the iterator directly to ''.join() because the
# exceptions aren't as detailed. The list call should be roughly
# equivalent to the PySequence_Fast that ''.join() would do.
chunks = self.iterencode(o, _one_shot=True)
if not isinstance(chunks, (list, tuple)):
chunks = list(chunks)
except RecursionError as exc:
if self.check_circular:
err = ValueError("Circular reference detected")
if notes := getattr(exc, "__notes__", None):
err.__notes__ = notes
raise err
raise
return ''.join(chunks)

def iterencode(self, o, _one_shot=False):
Expand All @@ -212,10 +223,7 @@ def iterencode(self, o, _one_shot=False):
mysocket.write(chunk)

"""
if self.check_circular:
markers = {}
else:
markers = None
markers = None
if self.ensure_ascii:
_encoder = encode_basestring_ascii
else:
Expand Down Expand Up @@ -279,11 +287,6 @@ def _iterencode_list(lst, _current_indent_level):
if not lst:
yield '[]'
return
if markers is not None:
markerid = id(lst)
if markerid in markers:
raise ValueError("Circular reference detected")
markers[markerid] = lst
buf = '['
if _indent is not None:
_current_indent_level += 1
Expand Down Expand Up @@ -331,18 +334,11 @@ def _iterencode_list(lst, _current_indent_level):
_current_indent_level -= 1
yield '\n' + _indent * _current_indent_level
yield ']'
if markers is not None:
del markers[markerid]

def _iterencode_dict(dct, _current_indent_level):
if not dct:
yield '{}'
return
if markers is not None:
markerid = id(dct)
if markerid in markers:
raise ValueError("Circular reference detected")
markers[markerid] = dct
yield '{'
if _indent is not None:
_current_indent_level += 1
Expand Down Expand Up @@ -417,8 +413,6 @@ def _iterencode_dict(dct, _current_indent_level):
_current_indent_level -= 1
yield '\n' + _indent * _current_indent_level
yield '}'
if markers is not None:
del markers[markerid]

def _iterencode(o, _current_indent_level):
if isinstance(o, str):
Expand All @@ -440,11 +434,6 @@ def _iterencode(o, _current_indent_level):
elif isinstance(o, dict):
yield from _iterencode_dict(o, _current_indent_level)
else:
if markers is not None:
markerid = id(o)
if markerid in markers:
raise ValueError("Circular reference detected")
markers[markerid] = o
newobj = _default(o)
try:
yield from _iterencode(newobj, _current_indent_level)
Expand All @@ -453,6 +442,4 @@ def _iterencode(o, _current_indent_level):
except BaseException as exc:
exc.add_note(f'when serializing {type(o).__name__} object')
raise
if markers is not None:
del markers[markerid]
return _iterencode
15 changes: 8 additions & 7 deletions Lib/test/test_json/test_recursion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_listrecursion(self):
try:
self.dumps(x)
except ValueError as exc:
self.assertEqual(exc.__notes__, ["when serializing list item 0"])
self.assertEqual(exc.__notes__[:1], ["when serializing list item 0"])
else:
self.fail("didn't raise ValueError on list recursion")
x = []
Expand All @@ -22,7 +22,7 @@ def test_listrecursion(self):
try:
self.dumps(x)
except ValueError as exc:
self.assertEqual(exc.__notes__, ["when serializing list item 0"]*2)
self.assertEqual(exc.__notes__[:2], ["when serializing list item 0"]*2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what consequences of this change led to the need to limit how many notes we assert on?

what additional notes are now added?

else:
self.fail("didn't raise ValueError on alternating list recursion")
y = []
Expand All @@ -36,7 +36,7 @@ def test_dictrecursion(self):
try:
self.dumps(x)
except ValueError as exc:
self.assertEqual(exc.__notes__, ["when serializing dict item 'test'"])
self.assertEqual(exc.__notes__[:1], ["when serializing dict item 'test'"])
else:
self.fail("didn't raise ValueError on dict recursion")
x = {}
Expand All @@ -59,9 +59,10 @@ def default(self, o):
self.assertEqual(enc.encode(JSONTestObject), '"JSONTestObject"')
enc.recurse = True
try:
enc.encode(JSONTestObject)
with support.infinite_recursion(5000):
enc.encode(JSONTestObject)
except ValueError as exc:
self.assertEqual(exc.__notes__,
self.assertEqual(exc.__notes__[:2],
["when serializing list item 0",
"when serializing type object"])
else:
Expand Down Expand Up @@ -94,10 +95,10 @@ def test_highly_nested_objects_encoding(self):
l, d = [l], {'k':d}
with self.assertRaises(RecursionError):
with support.infinite_recursion(5000):
self.dumps(l)
self.dumps(l, check_circular=False)
with self.assertRaises(RecursionError):
with support.infinite_recursion(5000):
self.dumps(d)
self.dumps(d, check_circular=False)

@support.skip_emscripten_stack_overflow()
@support.skip_wasi_stack_overflow()
Expand Down
78 changes: 0 additions & 78 deletions Modules/_json.c
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,6 @@ encoder_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
if (s == NULL)
return NULL;

s->markers = Py_NewRef(markers);
s->defaultfn = Py_NewRef(defaultfn);
s->encoder = Py_NewRef(encoder);
s->indent = Py_NewRef(indent);
Expand Down Expand Up @@ -1521,33 +1520,13 @@ encoder_listencode_obj(PyEncoderObject *s, PyUnicodeWriter *writer,
return rv;
}
else {
PyObject *ident = NULL;
if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(obj);
if (ident == NULL)
return -1;
has_key = PyDict_Contains(s->markers, ident);
if (has_key) {
if (has_key != -1)
PyErr_SetString(PyExc_ValueError, "Circular reference detected");
Py_DECREF(ident);
return -1;
}
if (PyDict_SetItem(s->markers, ident, obj)) {
Py_DECREF(ident);
return -1;
}
}
newobj = PyObject_CallOneArg(s->defaultfn, obj);
if (newobj == NULL) {
Py_XDECREF(ident);
return -1;
}

if (_Py_EnterRecursiveCall(" while encoding a JSON object")) {
Py_DECREF(newobj);
Py_XDECREF(ident);
return -1;
}
rv = encoder_listencode_obj(s, writer, newobj, indent_level, indent_cache);
Expand All @@ -1556,16 +1535,8 @@ encoder_listencode_obj(PyEncoderObject *s, PyUnicodeWriter *writer,
Py_DECREF(newobj);
if (rv) {
_PyErr_FormatNote("when serializing %T object", obj);
Py_XDECREF(ident);
return -1;
}
if (ident != NULL) {
if (PyDict_DelItem(s->markers, ident)) {
Py_XDECREF(ident);
return -1;
}
Py_XDECREF(ident);
}
return rv;
}
}
Expand Down Expand Up @@ -1642,7 +1613,6 @@ encoder_listencode_dict(PyEncoderObject *s, PyUnicodeWriter *writer,
Py_ssize_t indent_level, PyObject *indent_cache)
{
/* Encode Python dict dct a JSON term */
PyObject *ident = NULL;
PyObject *items = NULL;
PyObject *key, *value;
bool first = true;
Expand All @@ -1652,22 +1622,6 @@ encoder_listencode_dict(PyEncoderObject *s, PyUnicodeWriter *writer,
return PyUnicodeWriter_WriteASCII(writer, "{}", 2);
}

if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(dct);
if (ident == NULL)
goto bail;
has_key = PyDict_Contains(s->markers, ident);
if (has_key) {
if (has_key != -1)
PyErr_SetString(PyExc_ValueError, "Circular reference detected");
goto bail;
}
if (PyDict_SetItem(s->markers, ident, dct)) {
goto bail;
}
}

if (PyUnicodeWriter_WriteChar(writer, '{')) {
goto bail;
}
Expand Down Expand Up @@ -1715,11 +1669,6 @@ encoder_listencode_dict(PyEncoderObject *s, PyUnicodeWriter *writer,
}
}

if (ident != NULL) {
if (PyDict_DelItem(s->markers, ident))
goto bail;
Py_CLEAR(ident);
}
if (s->indent != Py_None) {
indent_level--;
if (write_newline_indent(writer, indent_level, indent_cache) < 0) {
Expand All @@ -1734,7 +1683,6 @@ encoder_listencode_dict(PyEncoderObject *s, PyUnicodeWriter *writer,

bail:
Py_XDECREF(items);
Py_XDECREF(ident);
return -1;
}

Expand All @@ -1743,11 +1691,9 @@ encoder_listencode_list(PyEncoderObject *s, PyUnicodeWriter *writer,
PyObject *seq,
Py_ssize_t indent_level, PyObject *indent_cache)
{
PyObject *ident = NULL;
PyObject *s_fast = NULL;
Py_ssize_t i;

ident = NULL;
s_fast = PySequence_Fast(seq, "_iterencode_list needs a sequence");
if (s_fast == NULL)
return -1;
Expand All @@ -1756,22 +1702,6 @@ encoder_listencode_list(PyEncoderObject *s, PyUnicodeWriter *writer,
return PyUnicodeWriter_WriteASCII(writer, "[]", 2);
}

if (s->markers != Py_None) {
int has_key;
ident = PyLong_FromVoidPtr(seq);
if (ident == NULL)
goto bail;
has_key = PyDict_Contains(s->markers, ident);
if (has_key) {
if (has_key != -1)
PyErr_SetString(PyExc_ValueError, "Circular reference detected");
goto bail;
}
if (PyDict_SetItem(s->markers, ident, seq)) {
goto bail;
}
}

if (PyUnicodeWriter_WriteChar(writer, '[')) {
goto bail;
}
Expand All @@ -1797,11 +1727,6 @@ encoder_listencode_list(PyEncoderObject *s, PyUnicodeWriter *writer,
goto bail;
}
}
if (ident != NULL) {
if (PyDict_DelItem(s->markers, ident))
goto bail;
Py_CLEAR(ident);
}

if (s->indent != Py_None) {
indent_level--;
Expand All @@ -1817,7 +1742,6 @@ encoder_listencode_list(PyEncoderObject *s, PyUnicodeWriter *writer,
return 0;

bail:
Py_XDECREF(ident);
Py_DECREF(s_fast);
return -1;
}
Expand All @@ -1838,7 +1762,6 @@ encoder_traverse(PyObject *op, visitproc visit, void *arg)
{
PyEncoderObject *self = PyEncoderObject_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->markers);
Py_VISIT(self->defaultfn);
Py_VISIT(self->encoder);
Py_VISIT(self->indent);
Expand All @@ -1852,7 +1775,6 @@ encoder_clear(PyObject *op)
{
PyEncoderObject *self = PyEncoderObject_CAST(op);
/* Deallocate Encoder */
Py_CLEAR(self->markers);
Py_CLEAR(self->defaultfn);
Py_CLEAR(self->encoder);
Py_CLEAR(self->indent);
Expand Down
Loading