Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- :class:`decomposition.SparseCoder` now follows the transformer API of scikit-learn.
In addition, the :meth:`fit` method now validates the input and parameters.
Copy link
Member

@jeremiedbb jeremiedbb Sep 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reference yourself as author of the PR: "By :user:`your name <your handle>`".

By :user:`François Paugam <FrancoisPgm>`.
58 changes: 33 additions & 25 deletions sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,11 @@ def sparse_encode(
[ 0., 1., 1., 0., 0.]])
"""
if check_input:
if algorithm == "lasso_cd":
dictionary = check_array(
dictionary, order="C", dtype=[np.float64, np.float32]
)
X = check_array(X, order="C", dtype=[np.float64, np.float32])
else:
dictionary = check_array(dictionary)
X = check_array(X)
order = "C" if algorithm == "lasso_cd" else None
dictionary = check_array(
dictionary, order=order, dtype=[np.float64, np.float32]
)
X = check_array(X, order=order, dtype=[np.float64, np.float32])

if dictionary.shape[1] != X.shape[1]:
raise ValueError(
Expand Down Expand Up @@ -421,7 +418,7 @@ def _sparse_encode(
regularization = 1.0

if gram is None and algorithm != "threshold":
gram = np.dot(dictionary, dictionary.T)
gram = np.dot(dictionary, dictionary.T).astype(X.dtype, copy=False)

if cov is None and algorithm != "lasso_cd":
copy_cov = False
Expand Down Expand Up @@ -1301,6 +1298,19 @@ class SparseCoder(_BaseSparseCoding, BaseEstimator):
[ 0., 1., 1., 0., 0.]])
"""

_parameter_constraints: dict = {
"dictionary": ["array-like"],
"transform_algorithm": [
StrOptions({"lasso_lars", "lasso_cd", "lars", "omp", "threshold"})
],
"transform_n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
"transform_alpha": [Interval(Real, 0, None, closed="left"), None],
"split_sign": ["boolean"],
"n_jobs": [Integral, None],
"positive_code": ["boolean"],
"transform_max_iter": [Interval(Integral, 0, None, closed="left")],
}

def __init__(
self,
dictionary,
Expand All @@ -1324,16 +1334,17 @@ def __init__(
)
self.dictionary = dictionary

@_fit_context(prefer_skip_nested_validation=True)
def fit(self, X, y=None):
"""Do nothing and return the estimator unchanged.
"""Only validate the parameters of the estimator.

This method is just there to implement the usual API and hence
work in pipelines.
This method allows to: (i) validate the parameters of the estimator and
(ii) be consistent with the scikit-learn transformer API.

Parameters
----------
X : Ignored
Not used, present for API consistency by convention.
X : array-like of shape (n_samples, n_features)
Training data. Only used for input validation.

y : Ignored
Not used, present for API consistency by convention.
Expand All @@ -1343,6 +1354,13 @@ def fit(self, X, y=None):
self : object
Returns the instance itself.
"""
X = validate_data(self, X)
self.n_components_ = self.dictionary.shape[0]
if X.shape[1] != self.dictionary.shape[1]:
raise ValueError(
"Dictionary and X have different numbers of features:"
f"dictionary.shape: {self.dictionary.shape} X.shape{X.shape}"
)
return self

def transform(self, X, y=None):
Expand All @@ -1353,7 +1371,7 @@ def transform(self, X, y=None):

Parameters
----------
X : ndarray of shape (n_samples, n_features)
X : array-like of shape (n_samples, n_features)
Training vector, where `n_samples` is the number of samples
and `n_features` is the number of features.

Expand Down Expand Up @@ -1389,16 +1407,6 @@ def __sklearn_tags__(self):
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
return tags

@property
def n_components_(self):
"""Number of atoms."""
return self.dictionary.shape[0]

@property
def n_features_in_(self):
"""Number of features seen during `fit`."""
return self.dictionary.shape[1]

Comment on lines -1397 to -1401
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: With this modification, n_features_in_ is now only set if fit is called. It makes this estimator follow our API and in line with this discussion #27724 (comment)

@property
def _n_features_out(self):
"""Number of transformed output features."""
Expand Down
20 changes: 16 additions & 4 deletions sklearn/decomposition/tests/test_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def test_sparse_coder_estimator():
def test_sparse_coder_estimator_clone():
n_components = 12
rng = np.random.RandomState(0)
V = rng.randn(n_components, n_features) # random init
V = rng.normal(size=(n_components, n_features)) # random init
V /= np.sum(V**2, axis=1)[:, np.newaxis]
coder = SparseCoder(
dictionary=V, transform_algorithm="lasso_lars", transform_alpha=0.001
Expand All @@ -631,8 +631,6 @@ def test_sparse_coder_estimator_clone():
assert id(cloned) != id(coder)
np.testing.assert_allclose(cloned.dictionary, coder.dictionary)
assert id(cloned.dictionary) != id(coder.dictionary)
assert cloned.n_components_ == coder.n_components_
assert cloned.n_features_in_ == coder.n_features_in_
data = np.random.rand(n_samples, n_features).astype(np.float32)
np.testing.assert_allclose(cloned.transform(data), coder.transform(data))

Expand Down Expand Up @@ -677,10 +675,24 @@ def test_sparse_coder_common_transformer():

def test_sparse_coder_n_features_in():
d = np.array([[1, 2, 3], [1, 2, 3]])
X = np.array([[1, 2, 3]])
sc = SparseCoder(d)
sc.fit(X)
assert sc.n_features_in_ == d.shape[1]


def test_sparse_encoder_feature_number_error():
n_components = 10
rng = np.random.RandomState(0)
D = rng.uniform(size=(n_components, n_features))
X = rng.uniform(size=(n_samples, n_features + 1))
coder = SparseCoder(D)
with pytest.raises(
ValueError, match="Dictionary and X have different numbers of features"
):
coder.fit(X)


def test_update_dict():
# Check the dict update in batch mode vs online mode
# Non-regression test for #4866
Expand Down Expand Up @@ -958,7 +970,7 @@ def test_dict_learning_online_numerical_consistency(method):
@pytest.mark.parametrize(
"estimator",
[
SparseCoder(X.T),
SparseCoder(rng_global.uniform(size=(n_features, n_features))),
DictionaryLearning(),
MiniBatchDictionaryLearning(batch_size=4, max_iter=10),
],
Expand Down
54 changes: 31 additions & 23 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
_get_check_estimator_ids,
_get_expected_failed_checks,
_tested_estimators,
_yield_instances_for_check,
)
from sklearn.utils._testing import (
SkipTest,
Expand Down Expand Up @@ -256,24 +257,27 @@ def _estimators_that_predict_in_fit():


@pytest.mark.parametrize(
"estimator", column_name_estimators, ids=_get_check_estimator_ids
"estimator_orig", column_name_estimators, ids=_get_check_estimator_ids
)
def test_pandas_column_name_consistency(estimator):
if isinstance(estimator, ColumnTransformer):
def test_pandas_column_name_consistency(estimator_orig):
if isinstance(estimator_orig, ColumnTransformer):
pytest.skip("ColumnTransformer is not tested here")
if "check_dataframe_column_names_consistency" in _get_expected_failed_checks(
estimator
estimator_orig
):
pytest.skip(
"Estimator does not support check_dataframe_column_names_consistency"
)
with ignore_warnings(category=(FutureWarning)):
with warnings.catch_warnings(record=True) as record:
check_dataframe_column_names_consistency(
estimator.__class__.__name__, estimator
)
for warning in record:
assert "was fitted without feature names" not in str(warning.message)
for estimator in _yield_instances_for_check(
check_dataframe_column_names_consistency, estimator_orig
):
with ignore_warnings(category=(FutureWarning)):
with warnings.catch_warnings(record=True) as record:
check_dataframe_column_names_consistency(
estimator.__class__.__name__, estimator
)
for warning in record:
assert "was fitted without feature names" not in str(warning.message)


# TODO: As more modules support get_feature_names_out they should be removed
Expand Down Expand Up @@ -347,21 +351,24 @@ def test_check_param_validation(estimator):


@pytest.mark.parametrize(
"estimator", SET_OUTPUT_ESTIMATORS, ids=_get_check_estimator_ids
"estimator_orig", SET_OUTPUT_ESTIMATORS, ids=_get_check_estimator_ids
)
def test_set_output_transform(estimator):
name = estimator.__class__.__name__
if not hasattr(estimator, "set_output"):
def test_set_output_transform(estimator_orig):
name = estimator_orig.__class__.__name__
if not hasattr(estimator_orig, "set_output"):
pytest.skip(
f"Skipping check_set_output_transform for {name}: Does not support"
" set_output API"
)
with ignore_warnings(category=(FutureWarning)):
check_set_output_transform(estimator.__class__.__name__, estimator)
for estimator in _yield_instances_for_check(
check_set_output_transform, estimator_orig
):
with ignore_warnings(category=(FutureWarning)):
check_set_output_transform(estimator.__class__.__name__, estimator)


@pytest.mark.parametrize(
"estimator", SET_OUTPUT_ESTIMATORS, ids=_get_check_estimator_ids
"estimator_orig", SET_OUTPUT_ESTIMATORS, ids=_get_check_estimator_ids
)
@pytest.mark.parametrize(
"check_func",
Expand All @@ -372,15 +379,16 @@ def test_set_output_transform(estimator):
check_global_set_output_transform_polars,
],
)
def test_set_output_transform_configured(estimator, check_func):
name = estimator.__class__.__name__
if not hasattr(estimator, "set_output"):
def test_set_output_transform_configured(estimator_orig, check_func):
name = estimator_orig.__class__.__name__
if not hasattr(estimator_orig, "set_output"):
pytest.skip(
f"Skipping {check_func.__name__} for {name}: Does not support"
" set_output API yet"
)
with ignore_warnings(category=(FutureWarning)):
check_func(estimator.__class__.__name__, estimator)
for estimator in _yield_instances_for_check(check_func, estimator_orig):
with ignore_warnings(category=(FutureWarning)):
check_func(estimator.__class__.__name__, estimator)


@pytest.mark.parametrize(
Expand Down
39 changes: 38 additions & 1 deletion sklearn/utils/_test_common/instance_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from functools import partial
from inspect import isfunction

import numpy as np

from sklearn import clone, config_context
from sklearn.calibration import CalibratedClassifierCV
from sklearn.cluster import (
Expand Down Expand Up @@ -177,6 +179,8 @@

CROSS_DECOMPOSITION = ["PLSCanonical", "PLSRegression", "CCA", "PLSSVD"]

rng = np.random.RandomState(0)

# The following dictionary is to indicate constructor arguments suitable for the test
# suite, which uses very small datasets, and is intended to run rather quickly.
INIT_PARAMS = {
Expand Down Expand Up @@ -441,6 +445,7 @@
SGDClassifier: dict(max_iter=5),
SGDOneClassSVM: dict(max_iter=5),
SGDRegressor: dict(max_iter=5),
SparseCoder: dict(dictionary=rng.normal(size=(5, 3))),
SparsePCA: dict(max_iter=5),
# Due to the jl lemma and often very few samples, the number
# of components of the random matrix projection will be probably
Expand Down Expand Up @@ -711,6 +716,38 @@
],
},
SkewedChi2Sampler: {"check_dict_unchanged": dict(n_components=1)},
SparseCoder: {
"check_estimators_dtypes": dict(dictionary=rng.normal(size=(5, 5))),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a single case where we can set to have it pass all the tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dictionary is not a very friendly parameter because it needs to have a shape compatible with X, but all the checks have different Xs

"check_dtype_object": dict(dictionary=rng.normal(size=(5, 10))),
"check_transformers_unfitted_stateless": dict(
dictionary=rng.normal(size=(5, 5))
),
"check_fit_idempotent": dict(dictionary=rng.normal(size=(5, 2))),
"check_transformer_preserve_dtypes": dict(
dictionary=rng.normal(size=(5, 3)).astype(np.float32)
),
"check_set_output_transform": dict(dictionary=rng.normal(size=(5, 5))),
"check_global_output_transform_pandas": dict(
dictionary=rng.normal(size=(5, 5))
),
"check_set_output_transform_pandas": dict(dictionary=rng.normal(size=(5, 5))),
"check_set_output_transform_polars": dict(dictionary=rng.normal(size=(5, 5))),
"check_global_set_output_transform_polars": dict(
dictionary=rng.normal(size=(5, 5))
),
"check_dataframe_column_names_consistency": dict(
dictionary=rng.normal(size=(5, 8))
),
"check_estimators_overwrite_params": dict(dictionary=rng.normal(size=(5, 2))),
"check_estimators_fit_returns_self": dict(dictionary=rng.normal(size=(5, 2))),
"check_readonly_memmap_input": dict(dictionary=rng.normal(size=(5, 2))),
"check_n_features_in_after_fitting": dict(dictionary=rng.normal(size=(5, 4))),
"check_fit_check_is_fitted": dict(dictionary=rng.normal(size=(5, 2))),
"check_n_features_in": dict(dictionary=rng.normal(size=(5, 2))),
"check_positive_only_tag_during_fit": dict(dictionary=rng.normal(size=(5, 4))),
"check_fit2d_1sample": dict(dictionary=rng.normal(size=(5, 10))),
"check_fit2d_1feature": dict(dictionary=rng.normal(size=(5, 1))),
},
SparsePCA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
SparseRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
SpectralBiclustering: {
Expand Down Expand Up @@ -748,7 +785,7 @@ def _tested_estimators(type_filter=None):
yield estimator


SKIPPED_ESTIMATORS = [SparseCoder, FrozenEstimator]
SKIPPED_ESTIMATORS = [FrozenEstimator]


def _construct_instances(Estimator):
Expand Down
Loading