Skip to content

gh-129107: make bytearrayiter free-threading safe #130096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 19, 2025
41 changes: 38 additions & 3 deletions Lib/test/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2455,9 +2455,6 @@ def check(funcs, a=None, *args):
with threading_helper.start_threads(threads):
pass

for thread in threads:
threading_helper.join_thread(thread)

# hard errors

check([clear] + [reduce] * 10)
Expand Down Expand Up @@ -2519,6 +2516,44 @@ def check(funcs, a=None, *args):
check([clear] + [upper] * 10, bytearray(b'a' * 0x400000))
check([clear] + [zfill] * 10, bytearray(b'1' * 0x200000))

@unittest.skipUnless(support.Py_GIL_DISABLED, 'this test can only possibly fail with GIL disabled')
@threading_helper.reap_threads
@threading_helper.requires_working_threading()
def test_free_threading_bytearrayiter(self):
# Non-deterministic but good chance to fail if bytearrayiter is not free-threading safe.
# We are fishing for a "Assertion failed: object has negative ref count" and tsan races.

def iter_next(b, it):
b.wait()
list(it)

def iter_reduce(b, it):
b.wait()
it.__reduce__()

def iter_setstate(b, it):
b.wait()
it.__setstate__(0)

def check(funcs, it):
barrier = threading.Barrier(len(funcs))
threads = []

for func in funcs:
thread = threading.Thread(target=func, args=(barrier, it))

threads.append(thread)

with threading_helper.start_threads(threads):
pass

for _ in range(10):
ba = bytearray(b'0' * 0x4000) # this is a load-bearing variable, do not remove

check([iter_next] * 10, iter(ba))
check([iter_next] + [iter_reduce] * 10, iter(ba)) # for tsan
check([iter_next] + [iter_setstate] * 10, iter(ba)) # for tsan


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`bytearray` iterator safe under :term:`free threading`.
65 changes: 42 additions & 23 deletions Objects/bytearrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -2856,31 +2856,44 @@ static PyObject *
bytearrayiter_next(PyObject *self)
{
bytesiterobject *it = _bytesiterobject_CAST(self);
PyByteArrayObject *seq;
int val;

assert(it != NULL);
seq = it->it_seq;
if (seq == NULL)
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
if (index < 0) {
return NULL;
}
PyByteArrayObject *seq = it->it_seq;
assert(PyByteArray_Check(seq));

if (it->it_index < PyByteArray_GET_SIZE(seq)) {
return _PyLong_FromUnsignedChar(
(unsigned char)PyByteArray_AS_STRING(seq)[it->it_index++]);
Py_BEGIN_CRITICAL_SECTION(seq);
if (index < Py_SIZE(seq)) {
val = (unsigned char)PyByteArray_AS_STRING(seq)[index];
}
else {
val = -1;
}
Py_END_CRITICAL_SECTION();

it->it_seq = NULL;
Py_DECREF(seq);
return NULL;
if (val == -1) {
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, -1);
#ifndef Py_GIL_DISABLED
Py_CLEAR(it->it_seq);
#endif
return NULL;
}
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index + 1);
return _PyLong_FromUnsignedChar((unsigned char)val);
}

static PyObject *
bytearrayiter_length_hint(PyObject *self, PyObject *Py_UNUSED(ignored))
{
bytesiterobject *it = _bytesiterobject_CAST(self);
Py_ssize_t len = 0;
if (it->it_seq) {
len = PyByteArray_GET_SIZE(it->it_seq) - it->it_index;
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
if (index >= 0) {
len = PyByteArray_GET_SIZE(it->it_seq) - index;
if (len < 0) {
len = 0;
}
Expand All @@ -2900,27 +2913,33 @@ bytearrayiter_reduce(PyObject *self, PyObject *Py_UNUSED(ignored))
* call must be before access of iterator pointers.
* see issue #101765 */
bytesiterobject *it = _bytesiterobject_CAST(self);
if (it->it_seq != NULL) {
return Py_BuildValue("N(O)n", iter, it->it_seq, it->it_index);
} else {
return Py_BuildValue("N(())", iter);
Py_ssize_t index = FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index);
if (index >= 0) {
return Py_BuildValue("N(O)n", iter, it->it_seq, index);
}
return Py_BuildValue("N(())", iter);
}

static PyObject *
bytearrayiter_setstate(PyObject *self, PyObject *state)
{
Py_ssize_t index = PyLong_AsSsize_t(state);
if (index == -1 && PyErr_Occurred())
if (index == -1 && PyErr_Occurred()) {
return NULL;
}

bytesiterobject *it = _bytesiterobject_CAST(self);
if (it->it_seq != NULL) {
if (index < 0)
index = 0;
else if (index > PyByteArray_GET_SIZE(it->it_seq))
index = PyByteArray_GET_SIZE(it->it_seq); /* iterator exhausted */
it->it_index = index;
if (FT_ATOMIC_LOAD_SSIZE_RELAXED(it->it_index) >= 0) {
if (index < -1) {
index = -1;
}
else {
Py_ssize_t size = PyByteArray_GET_SIZE(it->it_seq);
if (index > size) {
index = size; /* iterator at end */
}
}
FT_ATOMIC_STORE_SSIZE_RELAXED(it->it_index, index);
}
Py_RETURN_NONE;
}
Expand Down Expand Up @@ -2982,7 +3001,7 @@ bytearray_iter(PyObject *seq)
it = PyObject_GC_New(bytesiterobject, &PyByteArrayIter_Type);
if (it == NULL)
return NULL;
it->it_index = 0;
it->it_index = 0; // -1 indicates exhausted
it->it_seq = (PyByteArrayObject *)Py_NewRef(seq);
_PyObject_GC_TRACK(it);
return (PyObject *)it;
Expand Down
Loading