diff --git a/sklearn/_config.py b/sklearn/_config.py index 05549c88a9ddc..a8096c112a317 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -20,6 +20,7 @@ "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, + "sparse_interface": "spmatrix", } _threadlocal = threading.local() @@ -68,6 +69,7 @@ def set_config( transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None, + sparse_interface=None, ): """Set global scikit-learn configuration. @@ -174,6 +176,16 @@ def set_config( .. versionadded:: 1.3 + sparse_interface : str, default="spmatrix" + + The sparse interface used for every sparse object that scikit-learn produces, + e.g., function returns, estimator attributes, estimator properties, etc. + + - `"sparray"`: Return sparse as SciPy sparse array + - `"spmatrix"`: Return sparse as SciPy sparse matrix + + .. versionadded:: 1.7 + See Also -------- config_context : Context manager for global scikit-learn configuration. @@ -209,6 +221,8 @@ def set_config( local_config["enable_metadata_routing"] = enable_metadata_routing if skip_parameter_validation is not None: local_config["skip_parameter_validation"] = skip_parameter_validation + if sparse_interface is not None: + local_config["sparse_interface"] = sparse_interface @contextmanager @@ -224,6 +238,7 @@ def config_context( transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None, + sparse_interface=None, ): """Context manager for global scikit-learn configuration. @@ -329,6 +344,16 @@ def config_context( .. versionadded:: 1.3 + sparse_interface : str, default="spmatrix" + + The sparse interface used for every sparse object that scikit-learn produces, + e.g., function returns, estimator attributes, estimator properties, etc. + + - `"sparray"`: Return sparse as SciPy sparse array + - `"spmatrix"`: Return sparse as SciPy sparse matrix + + .. versionadded:: 1.7 + Yields ------ None. @@ -368,6 +393,7 @@ def config_context( transform_output=transform_output, enable_metadata_routing=enable_metadata_routing, skip_parameter_validation=skip_parameter_validation, + sparse_interface=sparse_interface, ) try: diff --git a/sklearn/feature_extraction/tests/test_text.py b/sklearn/feature_extraction/tests/test_text.py index ab3f84668fd2d..e38992752b6b8 100644 --- a/sklearn/feature_extraction/tests/test_text.py +++ b/sklearn/feature_extraction/tests/test_text.py @@ -27,6 +27,7 @@ from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC +from sklearn.utils import _as_sparse from sklearn.utils._testing import ( assert_allclose_dense_sparse, assert_almost_equal, @@ -1611,7 +1612,12 @@ def test_tfidf_transformer_copy(csr_container): assert X_transform is not X_csr X_transform = transformer.transform(X_csr, copy=False) - assert X_transform is X_csr + # allow for config["sparse_interface"] to change output type + # there should be no data copied, but the `id` will change. + if _as_sparse(X_csr) is X_csr: + assert X_transform is X_csr + else: + assert X_transform is not X_csr with pytest.raises(AssertionError): assert_allclose_dense_sparse(X_csr, X_csr_original) diff --git a/sklearn/feature_extraction/text.py b/sklearn/feature_extraction/text.py index eb3226b01c79e..67508291a29d4 100644 --- a/sklearn/feature_extraction/text.py +++ b/sklearn/feature_extraction/text.py @@ -16,7 +16,7 @@ import numpy as np import scipy.sparse as sp -from sklearn.utils import metadata_routing +from sklearn.utils import _as_sparse, metadata_routing from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context from ..exceptions import NotFittedError @@ -880,7 +880,7 @@ def transform(self, X): X.data.fill(1) if self.norm is not None: X = normalize(X, norm=self.norm, copy=False) - return X + return _as_sparse(X) def fit_transform(self, X, y=None): """Transform a sequence of documents to a document-term matrix. @@ -1393,7 +1393,7 @@ def fit_transform(self, raw_documents, y=None): X = self._sort_features(X, vocabulary) self.vocabulary_ = vocabulary - return X + return _as_sparse(X) def transform(self, raw_documents): """Transform documents to document-term matrix. @@ -1421,7 +1421,7 @@ def transform(self, raw_documents): _, X = self._count_vocab(raw_documents, fixed_vocab=True) if self.binary: X.data.fill(1) - return X + return _as_sparse(X) def inverse_transform(self, X): """Return terms per document with nonzero entries in X. @@ -1720,7 +1720,7 @@ def transform(self, X, copy=True): if self.norm is not None: X = normalize(X, norm=self.norm, copy=False) - return X + return _as_sparse(X) def __sklearn_tags__(self): tags = super().__sklearn_tags__() diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index c0c14cbb12f32..d4366988ed379 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -12,7 +12,7 @@ from joblib import effective_n_jobs from scipy import sparse -from sklearn.utils import metadata_routing +from sklearn.utils import _as_sparse, metadata_routing from ..base import MultiOutputMixin, RegressorMixin, _fit_context from ..model_selection import check_cv @@ -1129,7 +1129,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): @property def sparse_coef_(self): """Sparse representation of the fitted `coef_`.""" - return sparse.csr_matrix(self.coef_) + return _as_sparse(sparse.csr_array([self.coef_])) def _decision_function(self, X): """Decision function of the linear model. diff --git a/sklearn/manifold/_locally_linear.py b/sklearn/manifold/_locally_linear.py index e6967446274ad..3864f6db12a11 100644 --- a/sklearn/manifold/_locally_linear.py +++ b/sklearn/manifold/_locally_linear.py @@ -7,7 +7,7 @@ import numpy as np from scipy.linalg import eigh, qr, solve, svd -from scipy.sparse import csr_matrix, eye, lil_matrix +from scipy.sparse import csr_array, eye, lil_array from scipy.sparse.linalg import eigsh from ..base import ( @@ -21,6 +21,7 @@ from ..utils import check_array, check_random_state from ..utils._arpack import _init_arpack_v0 from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils._sparse import _as_sparse from ..utils.extmath import stable_cumsum from ..utils.validation import FLOAT_DTYPES, check_is_fitted, validate_data @@ -118,7 +119,8 @@ def barycenter_kneighbors_graph(X, n_neighbors, reg=1e-3, n_jobs=None): ind = knn.kneighbors(X, return_distance=False)[:, 1:] data = barycenter_weights(X, X, ind, reg=reg) indptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors) - return csr_matrix((data.ravel(), ind.ravel(), indptr), shape=(n_samples, n_samples)) + csr = csr_array((data.ravel(), ind.ravel(), indptr), shape=(n_samples, n_samples)) + return _as_sparse(csr) def null_space( @@ -229,7 +231,7 @@ def _locally_linear_embedding( ) M_sparse = eigen_solver != "dense" - M_container_constructor = lil_matrix if M_sparse else np.zeros + M_container_constructor = lil_array if M_sparse else np.zeros if method == "standard": W = barycenter_kneighbors_graph( @@ -432,7 +434,7 @@ def _locally_linear_embedding( M[neighbors[i], neighbors[i]] += np.ones(shape=n_neighbors) if M_sparse: - M = M.tocsr() + M = _as_sparse(M.tocsr()) return null_space( M, diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index bf35eee623c18..b7beda31c2fd5 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -21,6 +21,7 @@ def test_config_context(): "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, + "sparse_interface": "spmatrix", } # Not using as a context manager affects nothing @@ -39,6 +40,7 @@ def test_config_context(): "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, + "sparse_interface": "spmatrix", } assert get_config()["assume_finite"] is False @@ -74,6 +76,7 @@ def test_config_context(): "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, + "sparse_interface": "spmatrix", } # No positional arguments diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 941126c6b083f..99d6fad7d3535 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -20,6 +20,7 @@ shuffle, ) from ._mask import safe_mask +from ._sparse import _as_sparse, _select_interface_if_sparse from ._tags import ( ClassifierTags, InputTags, @@ -56,6 +57,8 @@ "Tags", "TargetTags", "TransformerTags", + "_as_sparse", + "_select_interface_if_sparse", "all_estimators", "as_float_array", "assert_all_finite", diff --git a/sklearn/utils/_sparse.py b/sklearn/utils/_sparse.py new file mode 100644 index 0000000000000..590581d81c436 --- /dev/null +++ b/sklearn/utils/_sparse.py @@ -0,0 +1,62 @@ +"""Control sparse interface based on config""" + +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +import scipy as sp + +from .._config import get_config + + +def _as_sparse(X_sparse): + """ + Convert to sparse interface as set in config. Input must be sparse. + If you know the input is sparse, use e.g. `return _as_sparse(X)`. + Otherwise you should check if sparse before calling. + + if sp.sparse.issparse(X): + X = _as_sparse(X) + """ + if not sp.sparse.issparse(X_sparse): + raise TypeError("Input should be a SciPy sparse container") + + return _convert_sparse_to_config_chosen_interface(X_sparse) + + +def _select_interface_if_sparse(X): + """ + Convert to sparse interface as set in config. Input X can be dense or sparse. + If sparse, convert to sparse_interface indicated by get_config. + Otherwise, return X unchanged. + + X = _select_interface_if_sparse(X) + """ + if not sp.sparse.issparse(X): + return X + + return _convert_sparse_to_config_chosen_interface(X) + + +def _convert_sparse_to_config_chosen_interface(X_sparse): + # assume there are only two sparse interfaces: sparray and spmatrix + X_is_sparray = not sp.sparse.isspmatrix(X_sparse) + config_sparse_interface = get_config()["sparse_interface"] + + if config_sparse_interface == "sparray": + if X_is_sparray: + return X_sparse + return _convert_from_spmatrix_to_sparray(X_sparse) + else: # global is spmatrix + if not X_is_sparray: + return X_sparse + return _convert_from_sparray_to_spmatrix(X_sparse) + + +def _convert_from_spmatrix_to_sparray(X_sparse): + """Fundamental code to switch to sparray in any format""" + return getattr(sp.sparse, X_sparse.format + "_array")(X_sparse) + + +def _convert_from_sparray_to_spmatrix(X_sparse): + """Fundamental code to switch to spmatrix in any format""" + return getattr(sp.sparse, X_sparse.format + "_matrix")(X_sparse) diff --git a/sklearn/utils/tests/test_sparse.py b/sklearn/utils/tests/test_sparse.py new file mode 100644 index 0000000000000..65c3c78d2277d --- /dev/null +++ b/sklearn/utils/tests/test_sparse.py @@ -0,0 +1,86 @@ +import numpy as np +import pytest +from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix + +import sklearn + + +@pytest.mark.parametrize( + ["sparse_interface", "x", "result_type"], + [ + ("sparray", csr_array([[1, 2, 3]]), csr_array), + ("sparray", csr_matrix([[1, 2, 3]]), csr_array), + ("sparray", csc_array([[1, 2, 3]]), csc_array), + ("sparray", csc_matrix([[1, 2, 3]]), csc_array), + ("spmatrix", csr_array([[1, 2, 3]]), csr_matrix), + ("spmatrix", csr_matrix([[1, 2, 3]]), csr_matrix), + ("spmatrix", csc_array([[1, 2, 3]]), csc_matrix), + ("spmatrix", csc_matrix([[1, 2, 3]]), csc_matrix), + ], +) +def test_interface_if_sparse(sparse_interface, x, result_type): + with sklearn.config_context(sparse_interface=sparse_interface): + result = sklearn.utils._select_interface_if_sparse(x) + result = sklearn.utils._as_sparse(x) + assert isinstance(result, result_type) + + +@pytest.mark.parametrize( + ["sparse_interface", "x", "result_type"], + [ + ("sparray", np.array([[1, 2, 3]]), np.ndarray), + ("spmatrix", np.array([[1, 2, 3]]), np.ndarray), + ], +) +def test_ndarray_select_interface_if_sparse(sparse_interface, x, result_type): + with sklearn.config_context(sparse_interface=sparse_interface): + result = sklearn.utils._select_interface_if_sparse(x) + assert isinstance(result, result_type) + + +@pytest.mark.parametrize("sparse_interface", ["sparray", "spmatrix"]) +def test_ndarray_as_sparse_raises(sparse_interface): + with sklearn.config_context(sparse_interface=sparse_interface): + with pytest.raises(TypeError, match="sparse container"): + sklearn.utils._as_sparse(np.array([[1, 2, 3]])) + + +@pytest.mark.parametrize( + ["sparse_interface", "result_type"], + [("sparray", csr_array), ("spmatrix", csr_matrix)], +) +def test_transform_returns_sparse(sparse_interface, result_type): + corpus = [ + "This is the first document.", + "This document is the second document.", + "And this is the third one.", + "Is this the first document?", + ] + with sklearn.config_context(sparse_interface=sparse_interface): + vectorizer = sklearn.feature_extraction.text.CountVectorizer() + X = vectorizer.fit_transform(corpus) + assert isinstance(X, result_type) + + +@pytest.mark.parametrize( + ["sparse_interface", "result_type"], + [("sparray", csr_array), ("spmatrix", csr_matrix)], +) +def test_function_returns_sparse(sparse_interface, result_type): + with sklearn.config_context(sparse_interface=sparse_interface): + X, y = sklearn.datasets.make_regression(n_features=2, random_state=0) + X = sklearn.manifold._locally_linear.barycenter_kneighbors_graph(X, 1) + assert isinstance(X, result_type) + + +@pytest.mark.parametrize( + ["sparse_interface", "result_type"], + [("sparray", csr_array), ("spmatrix", csr_matrix)], +) +def test_estimator_property_sparse(sparse_interface, result_type): + with sklearn.config_context(sparse_interface=sparse_interface): + X, y = sklearn.datasets.make_regression(n_features=2, random_state=0) + regr = sklearn.linear_model.ElasticNet(random_state=0) + regr.fit(X, y) + # check spec_coeff property + assert isinstance(regr.sparse_coef_, result_type)