Skip to content

Commit 5576e64

Browse files
wiggin15colesbury
authored andcommitted
[3.13] gh-117657: Fix itertools.count thread safety (GH-119268)
Fix itertools.count in free-threading mode (cherry picked from commit 87939bd) Co-authored-by: Arnon Yaari <wiggin15@yahoo.com>
1 parent ca37034 commit 5576e64

File tree

3 files changed

+54
-11
lines changed

3 files changed

+54
-11
lines changed

Lib/test/test_itertools.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ def test_count(self):
644644
count(1, maxsize+5); sys.exc_info()
645645

646646
@pickle_deprecated
647-
def test_count_with_stride(self):
647+
def test_count_with_step(self):
648648
self.assertEqual(lzip('abc',count(2,3)), [('a', 2), ('b', 5), ('c', 8)])
649649
self.assertEqual(lzip('abc',count(start=2,step=3)),
650650
[('a', 2), ('b', 5), ('c', 8)])
@@ -699,6 +699,28 @@ def test_count_with_stride(self):
699699
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
700700
self.pickletest(proto, count(i, j))
701701

702+
@threading_helper.requires_working_threading()
703+
def test_count_threading(self, step=1):
704+
# this test verifies multithreading consistency, which is
705+
# mostly for testing builds without GIL, but nice to test anyway
706+
count_to = 10_000
707+
num_threads = 10
708+
c = count(step=step)
709+
def counting_thread():
710+
for i in range(count_to):
711+
next(c)
712+
threads = []
713+
for i in range(num_threads):
714+
thread = threading.Thread(target=counting_thread)
715+
thread.start()
716+
threads.append(thread)
717+
for thread in threads:
718+
thread.join()
719+
self.assertEqual(next(c), count_to * num_threads * step)
720+
721+
def test_count_with_step_threading(self):
722+
self.test_count_threading(step=5)
723+
702724
def test_cycle(self):
703725
self.assertEqual(take(10, cycle('abc')), list('abcabcabca'))
704726
self.assertEqual(list(cycle('')), [])

Modules/itertoolsmodule.c

+31-9
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include "Python.h"
2-
#include "pycore_call.h" // _PyObject_CallNoArgs()
3-
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
4-
#include "pycore_long.h" // _PyLong_GetZero()
5-
#include "pycore_moduleobject.h" // _PyModule_GetState()
6-
#include "pycore_typeobject.h" // _PyType_GetModuleState()
7-
#include "pycore_object.h" // _PyObject_GC_TRACK()
8-
#include "pycore_tuple.h" // _PyTuple_ITEMS()
2+
#include "pycore_call.h" // _PyObject_CallNoArgs()
3+
#include "pycore_ceval.h" // _PyEval_GetBuiltin()
4+
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
5+
#include "pycore_long.h" // _PyLong_GetZero()
6+
#include "pycore_moduleobject.h" // _PyModule_GetState()
7+
#include "pycore_typeobject.h" // _PyType_GetModuleState()
8+
#include "pycore_object.h" // _PyObject_GC_TRACK()
9+
#include "pycore_tuple.h" // _PyTuple_ITEMS()
910

10-
#include <stddef.h> // offsetof()
11+
#include <stddef.h> // offsetof()
1112

1213
/* Itertools module written and maintained
1314
by Raymond D. Hettinger <python@rcn.com>
@@ -4037,7 +4038,7 @@ fast_mode: when cnt an integer < PY_SSIZE_T_MAX and no step is specified.
40374038
40384039
assert(cnt != PY_SSIZE_T_MAX && long_cnt == NULL && long_step==PyLong(1));
40394040
Advances with: cnt += 1
4040-
When count hits Y_SSIZE_T_MAX, switch to slow_mode.
4041+
When count hits PY_SSIZE_T_MAX, switch to slow_mode.
40414042
40424043
slow_mode: when cnt == PY_SSIZE_T_MAX, step is not int(1), or cnt is a float.
40434044
@@ -4186,9 +4187,30 @@ count_nextlong(countobject *lz)
41864187
static PyObject *
41874188
count_next(countobject *lz)
41884189
{
4190+
#ifndef Py_GIL_DISABLED
41894191
if (lz->cnt == PY_SSIZE_T_MAX)
41904192
return count_nextlong(lz);
41914193
return PyLong_FromSsize_t(lz->cnt++);
4194+
#else
4195+
// free-threading version
4196+
// fast mode uses compare-exchange loop
4197+
// slow mode uses a critical section
4198+
PyObject *returned;
4199+
Py_ssize_t cnt;
4200+
4201+
cnt = _Py_atomic_load_ssize_relaxed(&lz->cnt);
4202+
for (;;) {
4203+
if (cnt == PY_SSIZE_T_MAX) {
4204+
Py_BEGIN_CRITICAL_SECTION(lz);
4205+
returned = count_nextlong(lz);
4206+
Py_END_CRITICAL_SECTION();
4207+
return returned;
4208+
}
4209+
if (_Py_atomic_compare_exchange_ssize(&lz->cnt, &cnt, cnt + 1)) {
4210+
return PyLong_FromSsize_t(cnt);
4211+
}
4212+
}
4213+
#endif
41924214
}
41934215

41944216
static PyObject *

Tools/tsan/suppressions_free_threading.txt

-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ race_top:_Py_dict_lookup_threadsafe
4848
race_top:_imp_release_lock
4949
race_top:_multiprocessing_SemLock_acquire_impl
5050
race_top:builtin_compile_impl
51-
race_top:count_next
5251
race_top:dictiter_new
5352
race_top:dictresize
5453
race_top:insert_to_emptydict

0 commit comments

Comments
 (0)