Skip to content

Commit b5827cb

Browse files
authored
CI Fix scipy-dev build (scikit-learn#28047)
1 parent 4ce8e19 commit b5827cb

File tree

8 files changed

+27
-15
lines changed

8 files changed

+27
-15
lines changed

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def test_invalid_filename():
261261
def test_dump(csr_container):
262262
X_sparse, y_dense = _load_svmlight_local_test_file(datafile)
263263
X_dense = X_sparse.toarray()
264-
y_sparse = csr_container(y_dense)
264+
y_sparse = csr_container(np.atleast_2d(y_dense))
265265

266266
# slicing a csr_matrix can unsort its .indices, so test that we sort
267267
# those correctly

sklearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1595,7 +1595,7 @@ def test_max_samples_boundary_classifiers(name):
15951595
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
15961596
def test_forest_y_sparse(csr_container):
15971597
X = [[1, 2, 3]]
1598-
y = csr_container([4, 5, 6])
1598+
y = csr_container([[4, 5, 6]])
15991599
est = RandomForestClassifier()
16001600
msg = "sparse multilabel-indicator for y is not supported."
16011601
with pytest.raises(ValueError, match=msg):

sklearn/metrics/tests/test_dist_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def test_readonly_kwargs():
366366
(np.array([1, 1.5, np.nan]), ValueError, "w contains NaN"),
367367
*[
368368
(
369-
csr_container([1, 1.5, 1]),
369+
csr_container([[1, 1.5, 1]]),
370370
TypeError,
371371
"Sparse data was passed for w, but dense data is required",
372372
)

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,8 +476,8 @@ def test_is_sorted_by_data(csr_container):
476476
# _is_sorted_by_data should return True when entries are sorted by data,
477477
# and False in all other cases.
478478

479-
# Test with sorted 1D array
480-
X = csr_container(np.arange(10))
479+
# Test with sorted single row sparse array
480+
X = csr_container(np.arange(10).reshape(1, 10))
481481
assert _is_sorted_by_data(X)
482482
# Test with unsorted 1D array
483483
X[0, 2] = 5

sklearn/utils/_testing.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def _convert_container(
765765
elif constructor_name == "array":
766766
return np.asarray(container, dtype=dtype)
767767
elif constructor_name == "sparse":
768-
return sp.sparse.csr_matrix(container, dtype=dtype)
768+
return sp.sparse.csr_matrix(np.atleast_2d(container), dtype=dtype)
769769
elif constructor_name in ("pandas", "dataframe"):
770770
pd = pytest.importorskip("pandas", minversion=minversion)
771771
result = pd.DataFrame(container, columns=columns_name, dtype=dtype, copy=False)
@@ -803,18 +803,18 @@ def _convert_container(
803803
elif constructor_name == "slice":
804804
return slice(container[0], container[1])
805805
elif constructor_name == "sparse_csr":
806-
return sp.sparse.csr_matrix(container, dtype=dtype)
806+
return sp.sparse.csr_matrix(np.atleast_2d(container), dtype=dtype)
807807
elif constructor_name == "sparse_csr_array":
808808
if sp_version >= parse_version("1.8"):
809-
return sp.sparse.csr_array(container, dtype=dtype)
809+
return sp.sparse.csr_array(np.atleast_2d(container), dtype=dtype)
810810
raise ValueError(
811811
f"sparse_csr_array is only available with scipy>=1.8.0, got {sp_version}"
812812
)
813813
elif constructor_name == "sparse_csc":
814-
return sp.sparse.csc_matrix(container, dtype=dtype)
814+
return sp.sparse.csc_matrix(np.atleast_2d(container), dtype=dtype)
815815
elif constructor_name == "sparse_csc_array":
816816
if sp_version >= parse_version("1.8"):
817-
return sp.sparse.csc_array(container, dtype=dtype)
817+
return sp.sparse.csc_array(np.atleast_2d(container), dtype=dtype)
818818
raise ValueError(
819819
f"sparse_csc_array is only available with scipy>=1.8.0, got {sp_version}"
820820
)

sklearn/utils/tests/test_class_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,6 @@ def test_class_weight_does_not_contains_more_classes():
311311
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
312312
def test_compute_sample_weight_sparse(csc_container):
313313
"""Check that we can compute weight for sparse `y`."""
314-
y = csc_container(np.asarray([0, 1, 1])).T
314+
y = csc_container(np.asarray([[0], [1], [1]]))
315315
sample_weight = compute_sample_weight("balanced", y)
316316
assert_allclose(sample_weight, [1.5, 0.75, 0.75])

sklearn/utils/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_resample_stratify_sparse_error(csr_container):
168168
n_samples = 100
169169
X = rng.normal(size=(n_samples, 2))
170170
y = rng.randint(0, 2, size=n_samples)
171-
stratify = csr_container(y)
171+
stratify = csr_container(y.reshape(-1, 1))
172172
with pytest.raises(TypeError, match="Sparse data was passed"):
173173
X, y = resample(X, y, n_samples=50, random_state=rng, stratify=stratify)
174174

sklearn/utils/tests/test_validation.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -639,9 +639,21 @@ def test_check_array_accept_sparse_no_exception():
639639
@pytest.fixture(params=["csr", "csc", "coo", "bsr"])
640640
def X_64bit(request):
641641
X = sp.rand(20, 10, format=request.param)
642-
for attr in ["indices", "indptr", "row", "col"]:
643-
if hasattr(X, attr):
644-
setattr(X, attr, getattr(X, attr).astype("int64"))
642+
643+
if request.param == "coo":
644+
if hasattr(X, "indices"):
645+
# for scipy >= 1.13 .indices is a new attribute and is a tuple. The
646+
# .col and .row attributes do not seem to be able to change the
647+
# dtype, for more details see https://github.com/scipy/scipy/pull/18530/
648+
X.indices = tuple(v.astype("int64") for v in X.indices)
649+
else:
650+
# scipy < 1.13
651+
X.row = X.row.astype("int64")
652+
X.col = X.col.astype("int64")
653+
else:
654+
X.indices = X.indices.astype("int64")
655+
X.indptr = X.indptr.astype("int64")
656+
645657
yield X
646658

647659

0 commit comments

Comments
 (0)