Skip to content

[3.13] gh-117657: Fix itertools.count thread safety (GH-119268) #120007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion Lib/test/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ def test_count(self):
count(1, maxsize+5); sys.exc_info()

@pickle_deprecated
def test_count_with_stride(self):
def test_count_with_step(self):
self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
self.assertEqual(lzip('abc',count(start=2,step=3)),
[('a', 2), ('b', 5), ('c', 8)])
Expand Down Expand Up @@ -699,6 +699,28 @@ def test_count_with_stride(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
self.pickletest(proto, count(i, j))

@threading_helper.requires_working_threading()
def test_count_threading(self, step=1):
# this test verifies multithreading consistency, which is
# mostly for testing builds without GIL, but nice to test anyway
count_to = 10_000
num_threads = 10
c = count(step=step)
def counting_thread():
for i in range(count_to):
next(c)
threads = []
for i in range(num_threads):
thread = threading.Thread(target=counting_thread)
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
self.assertEqual(next(c), count_to * num_threads * step)

def test_count_with_step_threading(self):
self.test_count_threading(step=5)

def test_cycle(self):
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
self.assertEqual(list(cycle('')), [])
Expand Down
40 changes: 31 additions & 9 deletions Modules/itertoolsmodule.c
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
#include "Python.h"
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()
#include "pycore_call.h" // _PyObject_CallNoArgs()
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_long.h" // _PyLong_GetZero()
#include "pycore_moduleobject.h" // _PyModule_GetState()
#include "pycore_typeobject.h" // _PyType_GetModuleState()
#include "pycore_object.h" // _PyObject_GC_TRACK()
#include "pycore_tuple.h" // _PyTuple_ITEMS()

#include <stddef.h> // offsetof()
#include <stddef.h> // offsetof()

/* Itertools module written and maintained
by Raymond D. Hettinger <python@rcn.com>
Expand Down Expand Up @@ -4037,7 +4038,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.

assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
Advances with: cnt += 1
When count hits Y_SSIZE_T_MAX, switch to slow_mode.
When count hits PY_SSIZE_T_MAX, switch to slow_mode.

slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.

Expand Down Expand Up @@ -4186,9 +4187,30 @@ count_nextlong(countobject *lz)
static PyObject *
count_next(countobject *lz)
{
#ifndef Py_GIL_DISABLED
if (lz->cnt == PY_SSIZE_T_MAX)
return count_nextlong(lz);
return PyLong_FromSsize_t(lz->cnt++);
#else
// free-threading version
// fast mode uses compare-exchange loop
// slow mode uses a critical section
PyObject *returned;
Py_ssize_t cnt;

cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
for (;;) {
if (cnt == PY_SSIZE_T_MAX) {
Py_BEGIN_CRITICAL_SECTION(lz);
returned = count_nextlong(lz);
Py_END_CRITICAL_SECTION();
return returned;
}
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
return PyLong_FromSsize_t(cnt);
}
}
#endif
}

static PyObject *
Expand Down
1 change: 0 additions & 1 deletion Tools/tsan/suppressions_free_threading.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ race_top:_PyImport_AcquireLock
race_top:_Py_dict_lookup_threadsafe
race_top:_imp_release_lock
race_top:_multiprocessing_SemLock_acquire_impl
race_top:count_next
race_top:dictiter_new
race_top:dictresize
race_top:insert_to_emptydict
Expand Down
Loading