Skip to content

Commit a58026a

Browse files
yoneympage
andauthored
gh-116738: Make _heapq module thread-safe (#135036)
Use critical sections to make heapq methods that update the heap thread-safe when the GIL is disabled. --------- Co-authored-by: mpage <mpage@meta.com>
1 parent cc8e6d2 commit a58026a

File tree

4 files changed

+303
-15
lines changed

4 files changed

+303
-15
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import unittest
2+
3+
import heapq
4+
5+
from enum import Enum
6+
from threading import Thread, Barrier
7+
from random import shuffle, randint
8+
9+
from test.support import threading_helper
10+
from test import test_heapq
11+
12+
13+
NTHREADS = 10
14+
OBJECT_COUNT = 5_000
15+
16+
17+
class Heap(Enum):
18+
MIN = 1
19+
MAX = 2
20+
21+
22+
@threading_helper.requires_working_threading()
23+
class TestHeapq(unittest.TestCase):
24+
def setUp(self):
25+
self.test_heapq = test_heapq.TestHeapPython()
26+
27+
def test_racing_heapify(self):
28+
heap = list(range(OBJECT_COUNT))
29+
shuffle(heap)
30+
31+
self.run_concurrently(
32+
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
33+
)
34+
self.test_heapq.check_invariant(heap)
35+
36+
def test_racing_heappush(self):
37+
heap = []
38+
39+
def heappush_func(heap):
40+
for item in reversed(range(OBJECT_COUNT)):
41+
heapq.heappush(heap, item)
42+
43+
self.run_concurrently(
44+
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
45+
)
46+
self.test_heapq.check_invariant(heap)
47+
48+
def test_racing_heappop(self):
49+
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
50+
51+
# Each thread pops (OBJECT_COUNT / NTHREADS) items
52+
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
53+
per_thread_pop_count = OBJECT_COUNT // NTHREADS
54+
55+
def heappop_func(heap, pop_count):
56+
local_list = []
57+
for _ in range(pop_count):
58+
item = heapq.heappop(heap)
59+
local_list.append(item)
60+
61+
# Each local list should be sorted
62+
self.assertTrue(self.is_sorted_ascending(local_list))
63+
64+
self.run_concurrently(
65+
worker_func=heappop_func,
66+
args=(heap, per_thread_pop_count),
67+
nthreads=NTHREADS,
68+
)
69+
self.assertEqual(len(heap), 0)
70+
71+
def test_racing_heappushpop(self):
72+
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
73+
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
74+
75+
def heappushpop_func(heap, pushpop_items):
76+
for item in pushpop_items:
77+
popped_item = heapq.heappushpop(heap, item)
78+
self.assertTrue(popped_item <= item)
79+
80+
self.run_concurrently(
81+
worker_func=heappushpop_func,
82+
args=(heap, pushpop_items),
83+
nthreads=NTHREADS,
84+
)
85+
self.assertEqual(len(heap), OBJECT_COUNT)
86+
self.test_heapq.check_invariant(heap)
87+
88+
def test_racing_heapreplace(self):
89+
heap = self.create_heap(OBJECT_COUNT, Heap.MIN)
90+
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
91+
92+
def heapreplace_func(heap, replace_items):
93+
for item in replace_items:
94+
heapq.heapreplace(heap, item)
95+
96+
self.run_concurrently(
97+
worker_func=heapreplace_func,
98+
args=(heap, replace_items),
99+
nthreads=NTHREADS,
100+
)
101+
self.assertEqual(len(heap), OBJECT_COUNT)
102+
self.test_heapq.check_invariant(heap)
103+
104+
def test_racing_heapify_max(self):
105+
max_heap = list(range(OBJECT_COUNT))
106+
shuffle(max_heap)
107+
108+
self.run_concurrently(
109+
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
110+
)
111+
self.test_heapq.check_max_invariant(max_heap)
112+
113+
def test_racing_heappush_max(self):
114+
max_heap = []
115+
116+
def heappush_max_func(max_heap):
117+
for item in range(OBJECT_COUNT):
118+
heapq.heappush_max(max_heap, item)
119+
120+
self.run_concurrently(
121+
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
122+
)
123+
self.test_heapq.check_max_invariant(max_heap)
124+
125+
def test_racing_heappop_max(self):
126+
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
127+
128+
# Each thread pops (OBJECT_COUNT / NTHREADS) items
129+
self.assertEqual(OBJECT_COUNT % NTHREADS, 0)
130+
per_thread_pop_count = OBJECT_COUNT // NTHREADS
131+
132+
def heappop_max_func(max_heap, pop_count):
133+
local_list = []
134+
for _ in range(pop_count):
135+
item = heapq.heappop_max(max_heap)
136+
local_list.append(item)
137+
138+
# Each local list should be sorted
139+
self.assertTrue(self.is_sorted_descending(local_list))
140+
141+
self.run_concurrently(
142+
worker_func=heappop_max_func,
143+
args=(max_heap, per_thread_pop_count),
144+
nthreads=NTHREADS,
145+
)
146+
self.assertEqual(len(max_heap), 0)
147+
148+
def test_racing_heappushpop_max(self):
149+
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
150+
pushpop_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
151+
152+
def heappushpop_max_func(max_heap, pushpop_items):
153+
for item in pushpop_items:
154+
popped_item = heapq.heappushpop_max(max_heap, item)
155+
self.assertTrue(popped_item >= item)
156+
157+
self.run_concurrently(
158+
worker_func=heappushpop_max_func,
159+
args=(max_heap, pushpop_items),
160+
nthreads=NTHREADS,
161+
)
162+
self.assertEqual(len(max_heap), OBJECT_COUNT)
163+
self.test_heapq.check_max_invariant(max_heap)
164+
165+
def test_racing_heapreplace_max(self):
166+
max_heap = self.create_heap(OBJECT_COUNT, Heap.MAX)
167+
replace_items = self.create_random_list(-5_000, 10_000, OBJECT_COUNT)
168+
169+
def heapreplace_max_func(max_heap, replace_items):
170+
for item in replace_items:
171+
heapq.heapreplace_max(max_heap, item)
172+
173+
self.run_concurrently(
174+
worker_func=heapreplace_max_func,
175+
args=(max_heap, replace_items),
176+
nthreads=NTHREADS,
177+
)
178+
self.assertEqual(len(max_heap), OBJECT_COUNT)
179+
self.test_heapq.check_max_invariant(max_heap)
180+
181+
@staticmethod
182+
def is_sorted_ascending(lst):
183+
"""
184+
Check if the list is sorted in ascending order (non-decreasing).
185+
"""
186+
return all(lst[i - 1] <= lst[i] for i in range(1, len(lst)))
187+
188+
@staticmethod
189+
def is_sorted_descending(lst):
190+
"""
191+
Check if the list is sorted in descending order (non-increasing).
192+
"""
193+
return all(lst[i - 1] >= lst[i] for i in range(1, len(lst)))
194+
195+
@staticmethod
196+
def create_heap(size, heap_kind):
197+
"""
198+
Create a min/max heap where elements are in the range (0, size - 1) and
199+
shuffled before heapify.
200+
"""
201+
heap = list(range(OBJECT_COUNT))
202+
shuffle(heap)
203+
if heap_kind == Heap.MIN:
204+
heapq.heapify(heap)
205+
else:
206+
heapq.heapify_max(heap)
207+
208+
return heap
209+
210+
@staticmethod
211+
def create_random_list(a, b, size):
212+
"""
213+
Create a list of random numbers between a and b (inclusive).
214+
"""
215+
return [randint(-a, b) for _ in range(size)]
216+
217+
def run_concurrently(self, worker_func, args, nthreads):
218+
"""
219+
Run the worker function concurrently in multiple threads.
220+
"""
221+
barrier = Barrier(nthreads)
222+
223+
def wrapper_func(*args):
224+
# Wait for all threads to reach this point before proceeding.
225+
barrier.wait()
226+
worker_func(*args)
227+
228+
with threading_helper.catch_threading_exception() as cm:
229+
workers = (
230+
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
231+
)
232+
with threading_helper.start_threads(workers):
233+
pass
234+
235+
# Worker threads should not raise any exceptions
236+
self.assertIsNone(cm.exc_value)
237+
238+
239+
if __name__ == "__main__":
240+
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make methods in :mod:`heapq` thread-safe on the :term:`free threaded <free threading>` build.

0 commit comments

Comments
 (0)