Skip to content

Commit a1f2d58

Browse files
authored
Merge pull request #28361 from eendebakpt/nonzero_unit_tests
BUG: Make np.nonzero threading safe
2 parents dc8f46d + d1c7b4a commit a1f2d58

File tree

4 files changed

+41
-6
lines changed

4 files changed

+41
-6
lines changed

.github/workflows/compiler_sanitizers.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
- name: Test
6969
run: |
7070
# pass -s to pytest to see ASAN errors and warnings, otherwise pytest captures them
71-
ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true:allocator_may_return_null=1:halt_on_error=1 \
71+
ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true:allocator_may_return_null=1 \
7272
python -m spin test -- -v -s --timeout=600 --durations=10
7373
7474
clang_TSAN:
@@ -121,7 +121,7 @@ jobs:
121121
- name: Test
122122
run: |
123123
# These tests are slow, so only run tests in files that do "import threading" to make them count
124-
TSAN_OPTIONS=allocator_may_return_null=1:halt_on_error=1 \
124+
TSAN_OPTIONS="allocator_may_return_null=1:suppressions=$GITHUB_WORKSPACE/tools/ci/tsan_suppressions.txt" \
125125
python -m spin test \
126126
`find numpy -name "test*.py" | xargs grep -l "import threading" | tr '\n' ' '` \
127127
-- -v -s --timeout=600 --durations=10

numpy/_core/src/multiarray/item_selection.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2893,10 +2893,11 @@ PyArray_Nonzero(PyArrayObject *self)
28932893
* the fast bool count is followed by this sparse path is faster
28942894
* than combining the two loops, even for larger arrays
28952895
*/
2896+
npy_intp * multi_index_end = multi_index + nonzero_count;
28962897
if (((double)nonzero_count / count) <= 0.1) {
28972898
npy_intp subsize;
28982899
npy_intp j = 0;
2899-
while (1) {
2900+
while (multi_index < multi_index_end) {
29002901
npy_memchr(data + j * stride, 0, stride, count - j,
29012902
&subsize, 1);
29022903
j += subsize;
@@ -2911,11 +2912,10 @@ PyArray_Nonzero(PyArrayObject *self)
29112912
* stalls that are very expensive on most modern processors.
29122913
*/
29132914
else {
2914-
npy_intp *multi_index_end = multi_index + nonzero_count;
29152915
npy_intp j = 0;
29162916

29172917
/* Manually unroll for GCC and maybe other compilers */
2918-
while (multi_index + 4 < multi_index_end) {
2918+
while (multi_index + 4 < multi_index_end && (j < count - 4) ) {
29192919
*multi_index = j;
29202920
multi_index += data[0] != 0;
29212921
*multi_index = j + 1;
@@ -2928,7 +2928,7 @@ PyArray_Nonzero(PyArrayObject *self)
29282928
j += 4;
29292929
}
29302930

2931-
while (multi_index < multi_index_end) {
2931+
while (multi_index < multi_index_end && (j < count) ) {
29322932
*multi_index = j;
29332933
multi_index += *data != 0;
29342934
data += stride;

numpy/_core/tests/test_multithreading.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,27 @@ def closure(b):
271271
# Reducing the number of threads means the test doesn't trigger the
272272
# bug. Better to skip on some platforms than add a useless test.
273273
pytest.skip("Couldn't spawn enough threads to run the test")
274+
275+
@pytest.mark.parametrize("dtype", [bool, int, float])
276+
def test_nonzero(dtype):
277+
# See: gh-28361
278+
#
279+
# np.nonzero uses np.count_nonzero to determine the size of the output array
280+
# In a second pass the indices of the non-zero elements are determined, but they can have changed
281+
#
282+
# This test triggers a data race which is suppressed in the TSAN CI. The test is to ensure
283+
# np.nonzero does not generate a segmentation fault
284+
x = np.random.randint(4, size=10_000).astype(dtype)
285+
286+
def func(index):
287+
for _ in range(10):
288+
if index == 0:
289+
x[::2] = np.random.randint(2)
290+
else:
291+
try:
292+
_ = np.nonzero(x)
293+
except RuntimeError as ex:
294+
assert 'number of non-zero array elements changed during function execution' in str(ex)
295+
296+
run_threaded(func, max_workers=10, pass_count=True, outer_iterations=50)
297+

tools/ci/tsan_suppressions.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# This file contains suppressions for the TSAN tool
2+
#
3+
# Reference: https://github.com/google/sanitizers/wiki/ThreadSanitizerSuppressions
4+
5+
# For np.nonzero, see gh-28361
6+
race:PyArray_Nonzero
7+
race:count_nonzero_int
8+
race:count_nonzero_bool
9+
race:count_nonzero_float
10+
race:DOUBLE_nonzero
11+

0 commit comments

Comments
 (0)