Skip to content

Enable config setting sparse_interface to control sparray and spmatrix creation #31177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
"sparse_interface": "spmatrix",
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -68,6 +69,7 @@ def set_config(
transform_output=None,
enable_metadata_routing=None,
skip_parameter_validation=None,
sparse_interface=None,
):
"""Set global scikit-learn configuration.

Expand Down Expand Up @@ -174,6 +176,16 @@ def set_config(

.. versionadded:: 1.3

sparse_interface : str, default="spmatrix"

The sparse interface used for every sparse object that scikit-learn produces,
e.g., function returns, estimator attributes, estimator properties, etc.

- `"sparray"`: Return sparse as SciPy sparse array
- `"spmatrix"`: Return sparse as SciPy sparse matrix

.. versionadded:: 1.7

See Also
--------
config_context : Context manager for global scikit-learn configuration.
Expand Down Expand Up @@ -209,6 +221,8 @@ def set_config(
local_config["enable_metadata_routing"] = enable_metadata_routing
if skip_parameter_validation is not None:
local_config["skip_parameter_validation"] = skip_parameter_validation
if sparse_interface is not None:
local_config["sparse_interface"] = sparse_interface


@contextmanager
Expand All @@ -224,6 +238,7 @@ def config_context(
transform_output=None,
enable_metadata_routing=None,
skip_parameter_validation=None,
sparse_interface=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -329,6 +344,16 @@ def config_context(

.. versionadded:: 1.3

sparse_interface : str, default="spmatrix"

The sparse interface used for every sparse object that scikit-learn produces,
e.g., function returns, estimator attributes, estimator properties, etc.

- `"sparray"`: Return sparse as SciPy sparse array
- `"spmatrix"`: Return sparse as SciPy sparse matrix

.. versionadded:: 1.7

Yields
------
None.
Expand Down Expand Up @@ -368,6 +393,7 @@ def config_context(
transform_output=transform_output,
enable_metadata_routing=enable_metadata_routing,
skip_parameter_validation=skip_parameter_validation,
sparse_interface=sparse_interface,
)

try:
Expand Down
8 changes: 7 additions & 1 deletion sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.utils import _as_sparse
from sklearn.utils._testing import (
assert_allclose_dense_sparse,
assert_almost_equal,
Expand Down Expand Up @@ -1611,7 +1612,12 @@ def test_tfidf_transformer_copy(csr_container):
assert X_transform is not X_csr

X_transform = transformer.transform(X_csr, copy=False)
assert X_transform is X_csr
# allow for config["sparse_interface"] to change output type
# there should be no data copied, but the `id` will change.
if _as_sparse(X_csr) is X_csr:
assert X_transform is X_csr
else:
assert X_transform is not X_csr
with pytest.raises(AssertionError):
assert_allclose_dense_sparse(X_csr, X_csr_original)

Expand Down
10 changes: 5 additions & 5 deletions sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import scipy.sparse as sp

from sklearn.utils import metadata_routing
from sklearn.utils import _as_sparse, metadata_routing

from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context
from ..exceptions import NotFittedError
Expand Down Expand Up @@ -880,7 +880,7 @@ def transform(self, X):
X.data.fill(1)
if self.norm is not None:
X = normalize(X, norm=self.norm, copy=False)
return X
return _as_sparse(X)

def fit_transform(self, X, y=None):
"""Transform a sequence of documents to a document-term matrix.
Expand Down Expand Up @@ -1393,7 +1393,7 @@ def fit_transform(self, raw_documents, y=None):
X = self._sort_features(X, vocabulary)
self.vocabulary_ = vocabulary

return X
return _as_sparse(X)

def transform(self, raw_documents):
"""Transform documents to document-term matrix.
Expand Down Expand Up @@ -1421,7 +1421,7 @@ def transform(self, raw_documents):
_, X = self._count_vocab(raw_documents, fixed_vocab=True)
if self.binary:
X.data.fill(1)
return X
return _as_sparse(X)

def inverse_transform(self, X):
"""Return terms per document with nonzero entries in X.
Expand Down Expand Up @@ -1720,7 +1720,7 @@ def transform(self, X, copy=True):
if self.norm is not None:
X = normalize(X, norm=self.norm, copy=False)

return X
return _as_sparse(X)

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
Expand Down
4 changes: 2 additions & 2 deletions sklearn/linear_model/_coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from joblib import effective_n_jobs
from scipy import sparse

from sklearn.utils import metadata_routing
from sklearn.utils import _as_sparse, metadata_routing

from ..base import MultiOutputMixin, RegressorMixin, _fit_context
from ..model_selection import check_cv
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
@property
def sparse_coef_(self):
"""Sparse representation of the fitted `coef_`."""
return sparse.csr_matrix(self.coef_)
return _as_sparse(sparse.csr_array([self.coef_]))

def _decision_function(self, X):
"""Decision function of the linear model.
Expand Down
10 changes: 6 additions & 4 deletions sklearn/manifold/_locally_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
from scipy.linalg import eigh, qr, solve, svd
from scipy.sparse import csr_matrix, eye, lil_matrix
from scipy.sparse import csr_array, eye, lil_array
from scipy.sparse.linalg import eigsh

from ..base import (
Expand All @@ -21,6 +21,7 @@
from ..utils import check_array, check_random_state
from ..utils._arpack import _init_arpack_v0
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils._sparse import _as_sparse
from ..utils.extmath import stable_cumsum
from ..utils.validation import FLOAT_DTYPES, check_is_fitted, validate_data

Expand Down Expand Up @@ -118,7 +119,8 @@ def barycenter_kneighbors_graph(X, n_neighbors, reg=1e-3, n_jobs=None):
ind = knn.kneighbors(X, return_distance=False)[:, 1:]
data = barycenter_weights(X, X, ind, reg=reg)
indptr = np.arange(0, n_samples * n_neighbors + 1, n_neighbors)
return csr_matrix((data.ravel(), ind.ravel(), indptr), shape=(n_samples, n_samples))
csr = csr_array((data.ravel(), ind.ravel(), indptr), shape=(n_samples, n_samples))
return _as_sparse(csr)


def null_space(
Expand Down Expand Up @@ -229,7 +231,7 @@ def _locally_linear_embedding(
)

M_sparse = eigen_solver != "dense"
M_container_constructor = lil_matrix if M_sparse else np.zeros
M_container_constructor = lil_array if M_sparse else np.zeros

if method == "standard":
W = barycenter_kneighbors_graph(
Expand Down Expand Up @@ -432,7 +434,7 @@ def _locally_linear_embedding(
M[neighbors[i], neighbors[i]] += np.ones(shape=n_neighbors)

if M_sparse:
M = M.tocsr()
M = _as_sparse(M.tocsr())

return null_space(
M,
Expand Down
3 changes: 3 additions & 0 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_config_context():
"transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
"sparse_interface": "spmatrix",
}

# Not using as a context manager affects nothing
Expand All @@ -39,6 +40,7 @@ def test_config_context():
"transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
"sparse_interface": "spmatrix",
}
assert get_config()["assume_finite"] is False

Expand Down Expand Up @@ -74,6 +76,7 @@ def test_config_context():
"transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
"sparse_interface": "spmatrix",
}

# No positional arguments
Expand Down
3 changes: 3 additions & 0 deletions sklearn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
shuffle,
)
from ._mask import safe_mask
from ._sparse import _as_sparse, _select_interface_if_sparse
from ._tags import (
ClassifierTags,
InputTags,
Expand Down Expand Up @@ -56,6 +57,8 @@
"Tags",
"TargetTags",
"TransformerTags",
"_as_sparse",
"_select_interface_if_sparse",
"all_estimators",
"as_float_array",
"assert_all_finite",
Expand Down
62 changes: 62 additions & 0 deletions sklearn/utils/_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Control sparse interface based on config"""

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import scipy as sp

from .._config import get_config


def _as_sparse(X_sparse):
"""
Convert to sparse interface as set in config. Input must be sparse.
If you know the input is sparse, use e.g. `return _as_sparse(X)`.
Otherwise you should check if sparse before calling.

if sp.sparse.issparse(X):
X = _as_sparse(X)
"""
if not sp.sparse.issparse(X_sparse):
raise TypeError("Input should be a SciPy sparse container")

return _convert_sparse_to_config_chosen_interface(X_sparse)


def _select_interface_if_sparse(X):
"""
Convert to sparse interface as set in config. Input X can be dense or sparse.
If sparse, convert to sparse_interface indicated by get_config.
Otherwise, return X unchanged.

X = _select_interface_if_sparse(X)
"""
if not sp.sparse.issparse(X):
return X

return _convert_sparse_to_config_chosen_interface(X)


def _convert_sparse_to_config_chosen_interface(X_sparse):
# assume there are only two sparse interfaces: sparray and spmatrix
X_is_sparray = not sp.sparse.isspmatrix(X_sparse)
config_sparse_interface = get_config()["sparse_interface"]

if config_sparse_interface == "sparray":
if X_is_sparray:
return X_sparse
return _convert_from_spmatrix_to_sparray(X_sparse)
else: # global is spmatrix
if not X_is_sparray:
return X_sparse
return _convert_from_sparray_to_spmatrix(X_sparse)


def _convert_from_spmatrix_to_sparray(X_sparse):
"""Fundamental code to switch to sparray in any format"""
return getattr(sp.sparse, X_sparse.format + "_array")(X_sparse)


def _convert_from_sparray_to_spmatrix(X_sparse):
"""Fundamental code to switch to spmatrix in any format"""
return getattr(sp.sparse, X_sparse.format + "_matrix")(X_sparse)
86 changes: 86 additions & 0 deletions sklearn/utils/tests/test_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import numpy as np
import pytest
from scipy.sparse import csc_array, csc_matrix, csr_array, csr_matrix

import sklearn


@pytest.mark.parametrize(
["sparse_interface", "x", "result_type"],
[
("sparray", csr_array([[1, 2, 3]]), csr_array),
("sparray", csr_matrix([[1, 2, 3]]), csr_array),
("sparray", csc_array([[1, 2, 3]]), csc_array),
("sparray", csc_matrix([[1, 2, 3]]), csc_array),
("spmatrix", csr_array([[1, 2, 3]]), csr_matrix),
("spmatrix", csr_matrix([[1, 2, 3]]), csr_matrix),
("spmatrix", csc_array([[1, 2, 3]]), csc_matrix),
("spmatrix", csc_matrix([[1, 2, 3]]), csc_matrix),
],
)
def test_interface_if_sparse(sparse_interface, x, result_type):
with sklearn.config_context(sparse_interface=sparse_interface):
result = sklearn.utils._select_interface_if_sparse(x)
result = sklearn.utils._as_sparse(x)
assert isinstance(result, result_type)


@pytest.mark.parametrize(
["sparse_interface", "x", "result_type"],
[
("sparray", np.array([[1, 2, 3]]), np.ndarray),
("spmatrix", np.array([[1, 2, 3]]), np.ndarray),
],
)
def test_ndarray_select_interface_if_sparse(sparse_interface, x, result_type):
with sklearn.config_context(sparse_interface=sparse_interface):
result = sklearn.utils._select_interface_if_sparse(x)
assert isinstance(result, result_type)


@pytest.mark.parametrize("sparse_interface", ["sparray", "spmatrix"])
def test_ndarray_as_sparse_raises(sparse_interface):
with sklearn.config_context(sparse_interface=sparse_interface):
with pytest.raises(TypeError, match="sparse container"):
sklearn.utils._as_sparse(np.array([[1, 2, 3]]))


@pytest.mark.parametrize(
["sparse_interface", "result_type"],
[("sparray", csr_array), ("spmatrix", csr_matrix)],
)
def test_transform_returns_sparse(sparse_interface, result_type):
corpus = [
"This is the first document.",
"This document is the second document.",
"And this is the third one.",
"Is this the first document?",
]
with sklearn.config_context(sparse_interface=sparse_interface):
vectorizer = sklearn.feature_extraction.text.CountVectorizer()
X = vectorizer.fit_transform(corpus)
assert isinstance(X, result_type)


@pytest.mark.parametrize(
["sparse_interface", "result_type"],
[("sparray", csr_array), ("spmatrix", csr_matrix)],
)
def test_function_returns_sparse(sparse_interface, result_type):
with sklearn.config_context(sparse_interface=sparse_interface):
X, y = sklearn.datasets.make_regression(n_features=2, random_state=0)
X = sklearn.manifold._locally_linear.barycenter_kneighbors_graph(X, 1)
assert isinstance(X, result_type)


@pytest.mark.parametrize(
["sparse_interface", "result_type"],
[("sparray", csr_array), ("spmatrix", csr_matrix)],
)
def test_estimator_property_sparse(sparse_interface, result_type):
with sklearn.config_context(sparse_interface=sparse_interface):
X, y = sklearn.datasets.make_regression(n_features=2, random_state=0)
regr = sklearn.linear_model.ElasticNet(random_state=0)
regr.fit(X, y)
# check spec_coeff property
assert isinstance(regr.sparse_coef_, result_type)
Loading