Skip to content

Commit e9e94e8

Browse files
committed
test our_rand_r
1 parent 5916742 commit e9e94e8

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

build_tools/azure/test_script.cmd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ if "%COVERAGE%" == "true" (
1313
set PYTEST_ARGS=%PYTEST_ARGS% --cov sklearn
1414
)
1515

16-
pytest -s -k "test_seq_dataset_shuffle or test_sgd_proba" --pyargs sklearn
16+
pytest -s -k test_our_rand_r --pyargs sklearn

build_tools/azure/test_script.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ except ImportError:
2121
python -c "import multiprocessing as mp; print('%d CPUs' % mp.cpu_count())"
2222
pip list
2323

24-
TEST_CMD="python -m pytest -s -k \"test_seq_dataset_shuffle or test_sgd_proba\" sklearn"
24+
TEST_CMD="python -m pytest -s -k test_our_rand_r sklearn"
2525

2626
if [[ "$COVERAGE" == "true" ]]; then
2727
TEST_CMD="$TEST_CMD"

sklearn/utils/seq_dataset.pyx.tp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,8 @@ cdef inline np.uint32_t our_rand_r(np.uint32_t* seed) nogil:
372372
seed[0] ^= <np.uint32_t>(seed[0] << 5)
373373

374374
return seed[0] % (<np.uint32_t>RAND_R_MAX + 1)
375+
376+
377+
def our_rand_r_py(seed):
378+
cdef np.uint32_t my_seed = <np.uint32_t>seed
379+
return our_rand_r(&my_seed)

sklearn/utils/tests/test_seq_dataset.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@
33
#
44
# License: BSD 3 clause
55

6-
import pytest
76
import numpy as np
8-
from numpy.testing import assert_array_equal
9-
from sklearn.utils.testing import assert_allclose
7+
import pytest
108
import scipy.sparse as sp
11-
12-
from sklearn.utils.seq_dataset import ArrayDataset64
13-
from sklearn.utils.seq_dataset import ArrayDataset32
14-
from sklearn.utils.seq_dataset import CSRDataset64
15-
from sklearn.utils.seq_dataset import CSRDataset32
9+
from numpy.testing import assert_array_equal
10+
from sklearn.utils.seq_dataset import ArrayDataset32, ArrayDataset64, \
11+
CSRDataset32, CSRDataset64, our_rand_r_py
1612

1713
from sklearn.datasets import load_iris
14+
from sklearn.utils.testing import assert_allclose
1815

1916
iris = load_iris()
2017
X64 = iris.data.astype(np.float64)
@@ -154,3 +151,10 @@ def test_buffer_dtype_mismatch_error():
154151
with pytest.raises(ValueError, match='Buffer dtype mismatch'):
155152
CSRDataset32(X_csr64.data, X_csr64.indptr, X_csr64.indices, y64,
156153
sample_weight64, seed=42),
154+
155+
156+
def test_our_rand_r():
157+
seed = 1273642419
158+
assert seed <= np.iinfo(np.int32).max
159+
my_random_int = our_rand_r_py(seed)
160+
assert my_random_int == 131541053

0 commit comments

Comments
 (0)