Skip to content

Commit 05c0e08

Browse files
cozekKaushik Amar Dasglemaitre
authored
ENH BaseLabelPropagation to accept sparse matrices (#19664)
Co-authored-by: Kaushik Amar Das <kaushik.amar.das@accenture.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent baefe83 commit 05c0e08

File tree

3 files changed

+39
-6
lines changed

3 files changed

+39
-6
lines changed

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,13 @@ Changelog
214214
during `transform` with no prior call to `fit` or `fit_transform`.
215215
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
216216

217+
:mod:`sklearn.semi_supervised`
218+
..............................
219+
220+
- |Enhancement| :meth:`LabelSpreading.fit` and :meth:`LabelPropagation.fit` now
221+
accepts sparse metrics.
222+
:pr:`19664` by :user:`Kaushik Amar Das <cozek>`.
223+
217224
Code and Documentation Contributors
218225
-----------------------------------
219226

sklearn/semi_supervised/_label_propagation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def fit(self, X, y):
241241
242242
Parameters
243243
----------
244-
X : array-like of shape (n_samples, n_features)
244+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
245245
Training data, where `n_samples` is the number of samples
246246
and `n_features` is the number of features.
247247
@@ -256,7 +256,12 @@ def fit(self, X, y):
256256
Returns the instance itself.
257257
"""
258258
self._validate_params()
259-
X, y = self._validate_data(X, y)
259+
X, y = self._validate_data(
260+
X,
261+
y,
262+
accept_sparse=["csr", "csc"],
263+
reset=True,
264+
)
260265
self.X_ = X
261266
check_classification_targets(y)
262267

@@ -365,7 +370,7 @@ class LabelPropagation(BaseLabelPropagation):
365370
366371
Attributes
367372
----------
368-
X_ : ndarray of shape (n_samples, n_features)
373+
X_ : {array-like, sparse matrix} of shape (n_samples, n_features)
369374
Input array.
370375
371376
classes_ : ndarray of shape (n_classes,)
@@ -463,7 +468,7 @@ def fit(self, X, y):
463468
464469
Parameters
465470
----------
466-
X : array-like of shape (n_samples, n_features)
471+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
467472
Training data, where `n_samples` is the number of samples
468473
and `n_features` is the number of features.
469474

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
assert_allclose,
1616
assert_array_equal,
1717
)
18+
from sklearn.utils._testing import _convert_container
19+
20+
CONSTRUCTOR_TYPES = ("array", "sparse_csr", "sparse_csc")
1821

1922
ESTIMATORS = [
2023
(label_propagation.LabelPropagation, {"kernel": "rbf"}),
@@ -122,9 +125,27 @@ def test_label_propagation_closed_form(global_dtype):
122125
assert_allclose(expected, clf.label_distributions_, atol=1e-4)
123126

124127

125-
def test_convergence_speed():
128+
@pytest.mark.parametrize("accepted_sparse_type", ["sparse_csr", "sparse_csc"])
129+
@pytest.mark.parametrize("index_dtype", [np.int32, np.int64])
130+
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
131+
@pytest.mark.parametrize("Estimator, parameters", ESTIMATORS)
132+
def test_sparse_input_types(
133+
accepted_sparse_type, index_dtype, dtype, Estimator, parameters
134+
):
135+
# This is non-regression test for #17085
136+
X = _convert_container([[1.0, 0.0], [0.0, 2.0], [1.0, 3.0]], accepted_sparse_type)
137+
X.data = X.data.astype(dtype, copy=False)
138+
X.indices = X.indices.astype(index_dtype, copy=False)
139+
X.indptr = X.indptr.astype(index_dtype, copy=False)
140+
labels = [0, 1, -1]
141+
clf = Estimator(**parameters).fit(X, labels)
142+
assert_array_equal(clf.predict([[0.5, 2.5]]), np.array([1]))
143+
144+
145+
@pytest.mark.parametrize("constructor_type", CONSTRUCTOR_TYPES)
146+
def test_convergence_speed(constructor_type):
126147
# This is a non-regression test for #5774
127-
X = np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 2.5]])
148+
X = _convert_container([[1.0, 0.0], [0.0, 1.0], [1.0, 2.5]], constructor_type)
128149
y = np.array([0, 1, -1])
129150
mdl = label_propagation.LabelSpreading(kernel="rbf", max_iter=5000)
130151
mdl.fit(X, y)

0 commit comments

Comments
 (0)