From a6c488d2c2596eacf0ba1943655e13396a83c634 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 12 Oct 2023 11:18:23 +0200 Subject: [PATCH 01/22] separate checks for sparse array and sparse matrix input --- sklearn/cluster/_bicluster.py | 3 +- sklearn/utils/_testing.py | 6 +- sklearn/utils/estimator_checks.py | 71 +++++++++++++++++--- sklearn/utils/tests/test_estimator_checks.py | 26 ++++++- 4 files changed, 91 insertions(+), 15 deletions(-) diff --git a/sklearn/cluster/_bicluster.py b/sklearn/cluster/_bicluster.py index 65280c06319d9..18c98ad5348b5 100644 --- a/sklearn/cluster/_bicluster.py +++ b/sklearn/cluster/_bicluster.py @@ -198,7 +198,8 @@ def _more_tags(self): "check_estimators_dtypes": "raises nan error", "check_fit2d_1sample": "_scale_normalize fails", "check_fit2d_1feature": "raises apply_along_axis error", - "check_estimator_sparse_data": "does not fail gracefully", + "check_estimator_sparse_matrix": "does not fail gracefully", + "check_estimator_sparse_array": "does not fail gracefully", "check_methods_subset_invariance": "empty array passed inside", "check_dont_overwrite_parameters": "empty array passed inside", "check_fit2d_predict1d": "empty array passed inside", diff --git a/sklearn/utils/_testing.py b/sklearn/utils/_testing.py index cb43a5f46a236..d5c0e9220b35f 100644 --- a/sklearn/utils/_testing.py +++ b/sklearn/utils/_testing.py @@ -938,7 +938,7 @@ def __exit__(self, exc_type, exc_value, _): class MinimalClassifier: - """Minimal classifier implementation with inheriting from BaseEstimator. + """Minimal classifier implementation without inheriting from BaseEstimator. This estimator should be tested with: @@ -987,7 +987,7 @@ def score(self, X, y): class MinimalRegressor: - """Minimal regressor implementation with inheriting from BaseEstimator. + """Minimal regressor implementation without inheriting from BaseEstimator. This estimator should be tested with: @@ -1027,7 +1027,7 @@ def score(self, X, y): class MinimalTransformer: - """Minimal transformer implementation with inheriting from + """Minimal transformer implementation without inheriting from BaseEstimator. This estimator should be tested with: diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 317a789771ccc..092c75ec0d1a7 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -133,7 +133,8 @@ def _yield_checks(estimator): if hasattr(estimator, "sparsify"): yield check_sparsify_coefficients - yield check_estimator_sparse_data + yield check_estimator_sparse_array + yield check_estimator_sparse_matrix # Test that estimators can be pickled, and once pickled # give the same answer as before. @@ -821,17 +822,17 @@ def _is_pairwise_metric(estimator): return bool(metric == "precomputed") -def _generate_sparse_matrix(X_csr): - """Generate sparse matrices with {32,64}bit indices of diverse format. +def _generate_sparse_data(X_csr): + """Generate sparse matrices or arrays with {32,64}bit indices of diverse format. Parameters ---------- - X_csr: CSR Matrix - Input matrix in CSR format. + X_csr: scipy.sparse.csr_matrix or scipy.sparse.csr_array + Input in CSR format. Returns ------- - out: iter(Matrices) + out: iter(Matrices) or iter(Arrays) In format['dok', 'lil', 'dia', 'bsr', 'csr', 'csc', 'coo', 'coo_64', 'csc_64', 'csr_64'] """ @@ -1014,19 +1015,71 @@ def check_array_api_input_and_values( ) -def check_estimator_sparse_data(name, estimator_orig): +def check_estimator_sparse_matrix(name, estimator_orig): rng = np.random.RandomState(0) X = rng.uniform(size=(40, 3)) X[X < 0.8] = 0 X = _enforce_estimator_tags_X(estimator_orig, X) - X_csr = sparse.csr_matrix(X) y = (4 * rng.uniform(size=40)).astype(int) # catch deprecation warnings with ignore_warnings(category=FutureWarning): estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) tags = _safe_tags(estimator_orig) - for matrix_format, X in _generate_sparse_matrix(X_csr): + for matrix_format, X in _generate_sparse_data(sparse.csr_matrix(X)): + # catch deprecation warnings + with ignore_warnings(category=FutureWarning): + estimator = clone(estimator_orig) + if name in ["Scaler", "StandardScaler"]: + estimator.set_params(with_mean=False) + # fit and predict + if "64" in matrix_format: + err_msg = ( + f"Estimator {name} doesn't seem to support {matrix_format} " + "matrix, and is not failing gracefully, e.g. by using " + "check_array(X, accept_large_sparse=False)" + ) + else: + err_msg = ( + f"Estimator {name} doesn't seem to fail gracefully on sparse " + "data: error message should state explicitly that sparse " + "input is not supported if this is not the case." + ) + with raises( + (TypeError, ValueError), + match=["sparse", "Sparse"], + may_pass=True, + err_msg=err_msg, + ): + with ignore_warnings(category=FutureWarning): + estimator.fit(X, y) + if hasattr(estimator, "predict"): + pred = estimator.predict(X) + if tags["multioutput_only"]: + assert pred.shape == (X.shape[0], 1) + else: + assert pred.shape == (X.shape[0],) + if hasattr(estimator, "predict_proba"): + probs = estimator.predict_proba(X) + if tags["binary_only"]: + expected_probs_shape = (X.shape[0], 2) + else: + expected_probs_shape = (X.shape[0], 4) + assert probs.shape == expected_probs_shape + + +def check_estimator_sparse_array(name, estimator_orig): + rng = np.random.RandomState(0) + X = rng.uniform(size=(40, 3)) + X[X < 0.8] = 0 + X = _enforce_estimator_tags_X(estimator_orig, X) + y = (4 * rng.uniform(size=40)).astype(int) + # catch deprecation warnings + with ignore_warnings(category=FutureWarning): + estimator = clone(estimator_orig) + y = _enforce_estimator_tags_y(estimator, y) + tags = _safe_tags(estimator_orig) + for matrix_format, X in _generate_sparse_data(sparse.csr_array(X)): # catch deprecation warnings with ignore_warnings(category=FutureWarning): estimator = clone(estimator_orig) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 120233293542e..3550a6c4e566f 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -356,6 +356,13 @@ def predict(self, X): class LargeSparseNotSupportedClassifier(BaseEstimator): + """Estimator that claims to support large sparse data + (accept_large_sparse=True), but doesn't""" + + def __init__(self, raise_for_type=None): + # raise_for_type : str, expects "sparse_array" or "sparse_matrix" + self.raise_for_type = raise_for_type + def fit(self, X, y): X, y = self._validate_data( X, @@ -365,7 +372,15 @@ def fit(self, X, y): multi_output=True, y_numeric=True, ) - if sp.issparse(X): + # the following is only here since sp.csr_array is an instance of + # sp.csr_matrix, but not the other way around + if self.raise_for_type == "sparse_array": + correct_type = isinstance(X, sp.csr_array) and isinstance(X, sp.csr_matrix) + elif self.raise_for_type == "sparse_matrix": + correct_type = not isinstance(X, sp.csr_array) and isinstance( + X, sp.csr_matrix + ) + if correct_type: if X.getformat() == "coo": if X.row.dtype == "int64" or X.col.dtype == "int64": raise ValueError("Estimator doesn't support 64-bit indices") @@ -652,7 +667,14 @@ def test_check_estimator(): r"support \S{3}_64 matrix, and is not failing gracefully.*" ) with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier()) + check_estimator(LargeSparseNotSupportedClassifier(sp.csr_matrix)) + + msg = ( + "Estimator LargeSparseNotSupportedClassifier doesn't seem to " + r"support \S{3}_64 matrix, and is not failing gracefully.*" + ) + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier(sp.csr_array)) # does error on binary_only untagged estimator msg = "Only 2 classes are supported" From c60e2323a7f08b311937aeff7d850a6e11e9c411 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 12 Oct 2023 13:52:24 +0200 Subject: [PATCH 02/22] fix input for raise_for_type --- sklearn/utils/tests/test_estimator_checks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 3550a6c4e566f..72fae3368c7df 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -380,6 +380,8 @@ def fit(self, X, y): correct_type = not isinstance(X, sp.csr_array) and isinstance( X, sp.csr_matrix ) + else: + raise ValueError("Invalid value for `raise_for_type`.") if correct_type: if X.getformat() == "coo": if X.row.dtype == "int64" or X.col.dtype == "int64": @@ -667,14 +669,14 @@ def test_check_estimator(): r"support \S{3}_64 matrix, and is not failing gracefully.*" ) with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier(sp.csr_matrix)) + check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) msg = ( "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier(sp.csr_array)) + check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) # does error on binary_only untagged estimator msg = "Only 2 classes are supported" From 0071c10baa420d3229812ce9260970abbfc166e5 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 12 Oct 2023 21:15:57 +0200 Subject: [PATCH 03/22] correcter error message --- sklearn/utils/tests/test_estimator_checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 72fae3368c7df..c8f68eb52e9d1 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -673,7 +673,7 @@ def test_check_estimator(): msg = ( "Estimator LargeSparseNotSupportedClassifier doesn't seem to " - r"support \S{3}_64 matrix, and is not failing gracefully.*" + r"support \S{3}_64 array, and is not failing gracefully.*" ) with raises(AssertionError, match=msg): check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) From 8b1eb77718d2154fd8465e754ea8bb798717815f Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Fri, 13 Oct 2023 00:53:35 +0200 Subject: [PATCH 04/22] repaired datatype selection and CI failures --- sklearn/utils/estimator_checks.py | 7 ++++--- sklearn/utils/tests/test_estimator_checks.py | 12 ++---------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 092c75ec0d1a7..5ff8a590996a4 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -839,7 +839,8 @@ def _generate_sparse_data(X_csr): assert X_csr.format == "csr" yield "csr", X_csr.copy() - for sparse_format in ["dok", "lil", "dia", "bsr", "csc", "coo"]: + # TODO re-insert "dia" when PR #27372 is merged + for sparse_format in ["dok", "lil", "bsr", "csc", "coo"]: yield sparse_format, X_csr.asformat(sparse_format) # Generate large indices matrix only if its supported by scipy @@ -1037,7 +1038,7 @@ def check_estimator_sparse_matrix(name, estimator_orig): err_msg = ( f"Estimator {name} doesn't seem to support {matrix_format} " "matrix, and is not failing gracefully, e.g. by using " - "check_array(X, accept_large_sparse=False)" + "check_array(X, accept_large_sparse=False)." ) else: err_msg = ( @@ -1090,7 +1091,7 @@ def check_estimator_sparse_array(name, estimator_orig): err_msg = ( f"Estimator {name} doesn't seem to support {matrix_format} " "matrix, and is not failing gracefully, e.g. by using " - "check_array(X, accept_large_sparse=False)" + "check_array(X, accept_large_sparse=False)." ) else: err_msg = ( diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index c8f68eb52e9d1..0db2c154a0996 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -372,14 +372,10 @@ def fit(self, X, y): multi_output=True, y_numeric=True, ) - # the following is only here since sp.csr_array is an instance of - # sp.csr_matrix, but not the other way around if self.raise_for_type == "sparse_array": - correct_type = isinstance(X, sp.csr_array) and isinstance(X, sp.csr_matrix) + correct_type = isinstance(X, sp.sparray) elif self.raise_for_type == "sparse_matrix": - correct_type = not isinstance(X, sp.csr_array) and isinstance( - X, sp.csr_matrix - ) + correct_type = isinstance(X, sp.spmatrix) else: raise ValueError("Invalid value for `raise_for_type`.") if correct_type: @@ -671,10 +667,6 @@ def test_check_estimator(): with raises(AssertionError, match=msg): check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) - msg = ( - "Estimator LargeSparseNotSupportedClassifier doesn't seem to " - r"support \S{3}_64 array, and is not failing gracefully.*" - ) with raises(AssertionError, match=msg): check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) From 028d580afc334ee6062747bb6b51c583150f927c Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 18 Oct 2023 13:08:48 +0200 Subject: [PATCH 05/22] Lasso and ElasticNet don't support large sparse data --- sklearn/linear_model/_coordinate_descent.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index a988438f95653..7a1d316632bfe 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -896,7 +896,8 @@ def fit(self, X, y, sample_weight=None, check_input=True): Parameters ---------- - X : {ndarray, sparse matrix} of (n_samples, n_features) + X : {ndarray, sparse matrix with a bit width of up to 32} of + (n_samples, n_features) Data. y : ndarray of shape (n_samples,) or (n_samples, n_targets) @@ -948,6 +949,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): accept_sparse="csc", order="F", dtype=[np.float64, np.float32], + accept_large_sparse=False, copy=X_copied, multi_output=True, y_numeric=True, @@ -1524,7 +1526,7 @@ def fit(self, X, y, sample_weight=None): X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. Pass directly as Fortran-contiguous data to avoid unnecessary memory duplication. If y is mono-output, - X can be sparse. + X can be sparse with a bit width of up to 32. y : array-like of shape (n_samples,) or (n_samples, n_targets) Target values. @@ -1563,7 +1565,10 @@ def fit(self, X, y, sample_weight=None): # csr. We also want to allow y to be 64 or 32 but check_X_y only # allows to convert for 64. check_X_params = dict( - accept_sparse="csc", dtype=[np.float64, np.float32], copy=False + accept_sparse="csc", + dtype=[np.float64, np.float32], + copy=False, + accept_large_sparse=False, ) X, y = self._validate_data( X, y, validate_separately=(check_X_params, check_y_params) From 8e9feaacc693b6dc6c3dc4746429cbecd24d839d Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 18 Oct 2023 19:48:52 +0200 Subject: [PATCH 06/22] refactored code of almost identical checks --- sklearn/utils/estimator_checks.py | 62 +++----------------- sklearn/utils/tests/test_estimator_checks.py | 18 ++++-- 2 files changed, 22 insertions(+), 58 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index a44d6b7f24271..e4d3c648897e5 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -839,7 +839,8 @@ def _generate_sparse_data(X_csr): assert X_csr.format == "csr" yield "csr", X_csr.copy() - # TODO re-insert "dia" when PR #27372 is merged + # re-insert "dia" when PR #27372 is merged + # only merge the present PR afterwards for sparse_format in ["dok", "lil", "bsr", "csc", "coo"]: yield sparse_format, X_csr.asformat(sparse_format) @@ -1016,7 +1017,7 @@ def check_array_api_input_and_values( ) -def check_estimator_sparse_matrix(name, estimator_orig): +def _check_estimator_sparse_container(name, estimator_orig, sparse_type): rng = np.random.RandomState(0) X = rng.uniform(size=(40, 3)) X[X < 0.8] = 0 @@ -1027,7 +1028,7 @@ def check_estimator_sparse_matrix(name, estimator_orig): estimator = clone(estimator_orig) y = _enforce_estimator_tags_y(estimator, y) tags = _safe_tags(estimator_orig) - for matrix_format, X in _generate_sparse_data(sparse.csr_matrix(X)): + for matrix_format, X in _generate_sparse_data(sparse_type(X)): # catch deprecation warnings with ignore_warnings(category=FutureWarning): estimator = clone(estimator_orig) @@ -1069,57 +1070,12 @@ def check_estimator_sparse_matrix(name, estimator_orig): assert probs.shape == expected_probs_shape +def check_estimator_sparse_matrix(name, estimator_orig): + _check_estimator_sparse_container(name, estimator_orig, sparse.csr_matrix) + + def check_estimator_sparse_array(name, estimator_orig): - rng = np.random.RandomState(0) - X = rng.uniform(size=(40, 3)) - X[X < 0.8] = 0 - X = _enforce_estimator_tags_X(estimator_orig, X) - y = (4 * rng.uniform(size=40)).astype(int) - # catch deprecation warnings - with ignore_warnings(category=FutureWarning): - estimator = clone(estimator_orig) - y = _enforce_estimator_tags_y(estimator, y) - tags = _safe_tags(estimator_orig) - for matrix_format, X in _generate_sparse_data(sparse.csr_array(X)): - # catch deprecation warnings - with ignore_warnings(category=FutureWarning): - estimator = clone(estimator_orig) - if name in ["Scaler", "StandardScaler"]: - estimator.set_params(with_mean=False) - # fit and predict - if "64" in matrix_format: - err_msg = ( - f"Estimator {name} doesn't seem to support {matrix_format} " - "matrix, and is not failing gracefully, e.g. by using " - "check_array(X, accept_large_sparse=False)." - ) - else: - err_msg = ( - f"Estimator {name} doesn't seem to fail gracefully on sparse " - "data: error message should state explicitly that sparse " - "input is not supported if this is not the case." - ) - with raises( - (TypeError, ValueError), - match=["sparse", "Sparse"], - may_pass=True, - err_msg=err_msg, - ): - with ignore_warnings(category=FutureWarning): - estimator.fit(X, y) - if hasattr(estimator, "predict"): - pred = estimator.predict(X) - if tags["multioutput_only"]: - assert pred.shape == (X.shape[0], 1) - else: - assert pred.shape == (X.shape[0],) - if hasattr(estimator, "predict_proba"): - probs = estimator.predict_proba(X) - if tags["binary_only"]: - expected_probs_shape = (X.shape[0], 2) - else: - expected_probs_shape = (X.shape[0], 4) - assert probs.shape == expected_probs_shape + _check_estimator_sparse_container(name, estimator_orig, sparse.csr_array) @ignore_warnings(category=FutureWarning) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 0db2c154a0996..3746dd0c1538a 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -206,10 +206,17 @@ def fit(self, X, y): class NoSparseClassifier(BaseBadClassifier): + def __init__(self, raise_for_type=None): + # raise_for_type : str, expects "sparse_array" or "sparse_matrix" + self.raise_for_type = raise_for_type + def fit(self, X, y): X, y = self._validate_data(X, y, accept_sparse=["csr", "csc"]) if sp.issparse(X): - raise ValueError("Nonsensical Error") + if self.raise_for_type == "sparse_array": + raise ValueError("Nonsensical Error") + elif self.raise_for_type == "sparse_matrix": + raise ValueError("Nonsensical Error") return self def predict(self, X): @@ -376,8 +383,6 @@ def fit(self, X, y): correct_type = isinstance(X, sp.sparray) elif self.raise_for_type == "sparse_matrix": correct_type = isinstance(X, sp.spmatrix) - else: - raise ValueError("Invalid value for `raise_for_type`.") if correct_type: if X.getformat() == "coo": if X.row.dtype == "int64" or X.col.dtype == "int64": @@ -643,11 +648,14 @@ def test_check_estimator(): ) with raises(AssertionError, match=msg): check_estimator(NotInvariantPredict()) - # check for sparse matrix input handling + # check for sparse data input handling name = NoSparseClassifier.__name__ msg = "Estimator %s doesn't seem to fail gracefully on sparse data" % name with raises(AssertionError, match=msg): - check_estimator(NoSparseClassifier()) + check_estimator(NoSparseClassifier("sparse_matrix")) + + with raises(AssertionError, match=msg): + check_estimator(NoSparseClassifier("sparse_array")) # check for classifiers reducing to less than two classes via sample weights name = OneClassSampleErrorClassifier.__name__ From 22b399857c2b1b47d9fa471a07e173379f417aab Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Wed, 18 Oct 2023 20:14:14 +0200 Subject: [PATCH 07/22] changelog --- doc/whats_new/v1.4.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 8cd4498f53cf0..dc995d12419a2 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -334,6 +334,11 @@ Changelog proportional to the number of coefficients (`n_features * n_classes`). :pr:`27417` by :user:`Christian Lorentzen `. +- |Fix| :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`, + :class:`linear_model.Lasso` and :class:`linear_model.LassoCV` now explicitly don't + accept large sparse data formats. :pr:`27576` by :user:`Stefanie Senger + `. + :mod:`sklearn.metrics` ...................... From b156ef659e2edb22d8157d227ddab415265bd624 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 19 Oct 2023 12:49:37 +0200 Subject: [PATCH 08/22] test should now run with lower scipy version --- sklearn/utils/estimator_checks.py | 8 +++++--- sklearn/utils/fixes.py | 19 +++++++++++++++++++ sklearn/utils/tests/test_estimator_checks.py | 10 ++++++++-- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index e4d3c648897e5..03c5d59c3ba7d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -61,7 +61,7 @@ generate_invalid_param_val, make_constraint, ) -from ..utils.fixes import parse_version, sp_version +from ..utils.fixes import SPARSE_ARRAY_PRESENT, parse_version, sp_version from ..utils.validation import check_is_fitted from . import IS_PYPY, is_scalar_nan, shuffle from ._param_validation import Interval @@ -1045,7 +1045,8 @@ def _check_estimator_sparse_container(name, estimator_orig, sparse_type): err_msg = ( f"Estimator {name} doesn't seem to fail gracefully on sparse " "data: error message should state explicitly that sparse " - "input is not supported if this is not the case." + "input is not supported if this is not the case, e.g. by using " + "check_array(X, accept_sparse=False)." ) with raises( (TypeError, ValueError), @@ -1075,7 +1076,8 @@ def check_estimator_sparse_matrix(name, estimator_orig): def check_estimator_sparse_array(name, estimator_orig): - _check_estimator_sparse_container(name, estimator_orig, sparse.csr_array) + if SPARSE_ARRAY_PRESENT: + _check_estimator_sparse_container(name, estimator_orig, sparse.csr_array) @ignore_warnings(category=FutureWarning) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index ab2dc207137fd..85c2002ce2938 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -52,6 +52,25 @@ BSR_CONTAINERS.append(scipy.sparse.bsr_array) DIA_CONTAINERS.append(scipy.sparse.dia_array) + +# Remove when minimum scipy version is 1.11.0 +try: + from scipy.sparse import sparray # noqa + + SPARRAY_PRESENT = True +except ImportError: + SPARRAY_PRESENT = False + + +# Remove when minimum scipy version is 1.8 +try: + from scipy.sparse import csr_array # noqa + + SPARSE_ARRAY_PRESENT = True +except ImportError: + SPARSE_ARRAY_PRESENT = False + + try: from scipy.optimize._linesearch import line_search_wolfe1, line_search_wolfe2 except ImportError: # SciPy < 1.8 diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 3746dd0c1538a..56cffe9290a03 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -64,6 +64,7 @@ check_requires_y_none, set_random_state, ) +from sklearn.utils.fixes import SPARRAY_PRESENT from sklearn.utils.metaestimators import available_if from sklearn.utils.validation import check_array, check_is_fitted, check_X_y @@ -224,6 +225,10 @@ def predict(self, X): return np.ones(X.shape[0]) +"""from sklearn.utils.estimator_checks import check_estimator +check_estimator(NoSparseClassifier("sparse_matrix"))""" + + class CorrectNotFittedErrorClassifier(BaseBadClassifier): def fit(self, X, y): X, y = self._validate_data(X, y) @@ -672,8 +677,9 @@ def test_check_estimator(): "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) + if SPARRAY_PRESENT: + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) with raises(AssertionError, match=msg): check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) From bba2616f9a3b42676bf828b6bc26ffb28d56feaa Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 19 Oct 2023 13:02:47 +0200 Subject: [PATCH 09/22] fix --- sklearn/utils/tests/test_estimator_checks.py | 28 +++++++++----------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 56cffe9290a03..10e2485af54c4 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -213,11 +213,12 @@ def __init__(self, raise_for_type=None): def fit(self, X, y): X, y = self._validate_data(X, y, accept_sparse=["csr", "csc"]) - if sp.issparse(X): - if self.raise_for_type == "sparse_array": - raise ValueError("Nonsensical Error") - elif self.raise_for_type == "sparse_matrix": - raise ValueError("Nonsensical Error") + if self.raise_for_type == "sparse_array": + correct_type = isinstance(X, sp.sparray) + elif self.raise_for_type == "sparse_matrix": + correct_type = isinstance(X, sp.spmatrix) + if correct_type: + raise ValueError("Nonsensical Error") return self def predict(self, X): @@ -225,10 +226,6 @@ def predict(self, X): return np.ones(X.shape[0]) -"""from sklearn.utils.estimator_checks import check_estimator -check_estimator(NoSparseClassifier("sparse_matrix"))""" - - class CorrectNotFittedErrorClassifier(BaseBadClassifier): def fit(self, X, y): X, y = self._validate_data(X, y) @@ -659,8 +656,9 @@ def test_check_estimator(): with raises(AssertionError, match=msg): check_estimator(NoSparseClassifier("sparse_matrix")) - with raises(AssertionError, match=msg): - check_estimator(NoSparseClassifier("sparse_array")) + if SPARRAY_PRESENT: + with raises(AssertionError, match=msg): + check_estimator(NoSparseClassifier("sparse_array")) # check for classifiers reducing to less than two classes via sample weights name = OneClassSampleErrorClassifier.__name__ @@ -677,12 +675,12 @@ def test_check_estimator(): "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) + if SPARRAY_PRESENT: with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) - - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) + check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) # does error on binary_only untagged estimator msg = "Only 2 classes are supported" From 2c7b59a1b6ffb25bc04a03251059d4a34dfc7715 Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Thu, 19 Oct 2023 13:07:54 +0200 Subject: [PATCH 10/22] Update sklearn/linear_model/_coordinate_descent.py Co-authored-by: Adrin Jalali --- sklearn/linear_model/_coordinate_descent.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 5f8a14bf31215..57a5c3988a90c 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -903,9 +903,11 @@ def fit(self, X, y, sample_weight=None, check_input=True): Parameters ---------- - X : {ndarray, sparse matrix with a bit width of up to 32} of - (n_samples, n_features) + X : {ndarray, sparse matrix, sparse array} of (n_samples, n_features) Data. + + Note that large sparse matrices and arrays requiring `int64` + indices are not accepted. y : ndarray of shape (n_samples,) or (n_samples, n_targets) Target. Will be cast to X's dtype if necessary. From 3314d79d7b63b0895c2fffc3dfb626f96946a1a6 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 19 Oct 2023 13:09:48 +0200 Subject: [PATCH 11/22] doc for LinearModelCV --- sklearn/linear_model/_coordinate_descent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 57a5c3988a90c..ff6553f7245c8 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -905,7 +905,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): ---------- X : {ndarray, sparse matrix, sparse array} of (n_samples, n_features) Data. - + Note that large sparse matrices and arrays requiring `int64` indices are not accepted. @@ -1535,7 +1535,8 @@ def fit(self, X, y, sample_weight=None, **params): X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data. Pass directly as Fortran-contiguous data to avoid unnecessary memory duplication. If y is mono-output, - X can be sparse with a bit width of up to 32. + X can be sparse. Note that large sparse matrices and arrays + requiring `int64` indices are not accepted. y : array-like of shape (n_samples,) or (n_samples, n_targets) Target values. From 22a0f1dd70a0ce64aeada8a60b5d5fa511e21a32 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Sat, 11 Nov 2023 21:53:44 +0100 Subject: [PATCH 12/22] 'dia' part of data generation and handled DeprecationWarning --- sklearn/utils/estimator_checks.py | 4 +--- sklearn/utils/tests/test_estimator_checks.py | 12 +++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 65f628131886b..f7cd002bbe517 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -839,9 +839,7 @@ def _generate_sparse_data(X_csr): assert X_csr.format == "csr" yield "csr", X_csr.copy() - # re-insert "dia" when PR #27372 is merged - # only merge the present PR afterwards - for sparse_format in ["dok", "lil", "bsr", "csc", "coo"]: + for sparse_format in ["dok", "lil", "dia", "bsr", "csc", "coo"]: yield sparse_format, X_csr.asformat(sparse_format) # Generate large indices matrix only if its supported by scipy diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ff7a6af6f189c..914b221c861ef 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -678,12 +678,14 @@ def test_check_estimator(): "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) - - if SPARRAY_PRESENT: + with warnings.catch_warnings(record=True) as records: + warnings.filterwarnings("ignore", category=DeprecationWarning) with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) + check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) + + if SPARRAY_PRESENT: + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) # does error on binary_only untagged estimator msg = "Only 2 classes are supported" From 8792ce83536f80944fa781508666e4462544d0a6 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Sat, 11 Nov 2023 23:47:57 +0100 Subject: [PATCH 13/22] ignore_warnings instead of filterwarnings --- sklearn/utils/tests/test_estimator_checks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 914b221c861ef..43e398fea527c 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -678,8 +678,7 @@ def test_check_estimator(): "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) - with warnings.catch_warnings(record=True) as records: - warnings.filterwarnings("ignore", category=DeprecationWarning) + with ignore_warnings(category=DeprecationWarning): with raises(AssertionError, match=msg): check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) From 28a141701ea390bbef890b25a39a35f14dd689b3 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 7 Dec 2023 01:53:46 +0100 Subject: [PATCH 14/22] use warnings.filterwarnings --- sklearn/utils/tests/test_estimator_checks.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 43e398fea527c..9432c29e30134 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -678,7 +678,14 @@ def test_check_estimator(): "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) - with ignore_warnings(category=DeprecationWarning): + with warnings.catch_warnings(record=True) as records: + warnings.filterwarnings( + "error", + category=DeprecationWarning, + message=( + "is deprecated and will be removed in v1.13.0; use `X.format` instead." + ), + ) with raises(AssertionError, match=msg): check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) From 2c6e49a904aedfbc4066a1c929df6dd166ad79b8 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 7 Dec 2023 09:45:43 +0100 Subject: [PATCH 15/22] match warning message from start and collect it --- sklearn/utils/tests/test_estimator_checks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 9432c29e30134..5776c0b275ff8 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -680,10 +680,11 @@ def test_check_estimator(): ) with warnings.catch_warnings(record=True) as records: warnings.filterwarnings( - "error", + "always", category=DeprecationWarning, message=( - "is deprecated and will be removed in v1.13.0; use `X.format` instead." + "`getformat` is deprecated and will be removed in v1.13.0; use" + " `X.format` instead." ), ) with raises(AssertionError, match=msg): From be3ea35bf9a2ff1a16860448fa9bd60260e0294d Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Fri, 5 Jan 2024 11:09:41 +0100 Subject: [PATCH 16/22] moved change log entry --- doc/whats_new/v1.4.rst | 5 ----- doc/whats_new/v1.5.rst | 8 ++++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 56ca6c4783f80..d2de5ee433f94 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -541,11 +541,6 @@ Changelog proportional to the number of coefficients (`n_features * n_classes`). :pr:`27417` by :user:`Christian Lorentzen `. -- |Fix| :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`, - :class:`linear_model.Lasso` and :class:`linear_model.LassoCV` now explicitly don't - accept large sparse data formats. :pr:`27576` by :user:`Stefanie Senger - `. - - |Fix| Ensure that the `sigma_` attribute of :class:`linear_model.ARDRegression` and :class:`linear_model.BayesianRidge` always has a `float32` dtype when fitted on `float32` data, even with the diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index fbd8a3f83b1dd..0babc56f7bd2a 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -38,3 +38,11 @@ TODO: update at the time of the release. - |Feature| A fitted :class:`compose.ColumnTransformer` now implements `__getitem__` which returns the fitted transformers by name. :pr:`27990` by `Thomas Fan`_. + +:mod:`sklearn.linear_model` +........................... + +- |Fix| :class:`linear_model.ElasticNet`, :class:`linear_model.ElasticNetCV`, + :class:`linear_model.Lasso` and :class:`linear_model.LassoCV` now explicitly don't + accept large sparse data formats. :pr:`27576` by :user:`Stefanie Senger + `. From 4c008515431350cc7ad9ccf94d593a3c977790cb Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 6 Feb 2024 13:01:56 +0100 Subject: [PATCH 17/22] swapped deprecated X.getformat() for the newer X.format --- sklearn/utils/tests/test_estimator_checks.py | 23 ++++++-------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 5776c0b275ff8..1e0a083a9c989 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -386,10 +386,10 @@ def fit(self, X, y): elif self.raise_for_type == "sparse_matrix": correct_type = isinstance(X, sp.spmatrix) if correct_type: - if X.getformat() == "coo": + if X.format == "coo": if X.row.dtype == "int64" or X.col.dtype == "int64": raise ValueError("Estimator doesn't support 64-bit indices") - elif X.getformat() in ["csc", "csr"]: + elif X.format in ["csc", "csr"]: assert "int64" not in ( X.indices.dtype, X.indptr.dtype, @@ -678,21 +678,12 @@ def test_check_estimator(): "Estimator LargeSparseNotSupportedClassifier doesn't seem to " r"support \S{3}_64 matrix, and is not failing gracefully.*" ) - with warnings.catch_warnings(record=True) as records: - warnings.filterwarnings( - "always", - category=DeprecationWarning, - message=( - "`getformat` is deprecated and will be removed in v1.13.0; use" - " `X.format` instead." - ), - ) - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier("sparse_matrix")) - if SPARRAY_PRESENT: - with raises(AssertionError, match=msg): - check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) + if SPARRAY_PRESENT: + with raises(AssertionError, match=msg): + check_estimator(LargeSparseNotSupportedClassifier("sparse_array")) # does error on binary_only untagged estimator msg = "Only 2 classes are supported" From 7a2fc5023204af123ccbef3fbbd7140d559490cc Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Fri, 16 Feb 2024 13:16:51 +0100 Subject: [PATCH 18/22] convert sparse dok_array into sparse coo_array before hstack --- sklearn/multioutput.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index bfb83884399ef..ecfc4640c2dd3 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -698,6 +698,11 @@ def fit(self, X, Y, **fit_params): X_aug = np.hstack((X, Y_pred_chain)) elif sp.issparse(X): + # to prevent scipy.sparse.hstack from breaking, we convert the sparse + # dok_array to a coo_array format, it's also faster; see scipy issue + # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039 + if isinstance(X, sp._dok.dok_array): + X = sp.coo_array(X) Y_pred_chain = sp.lil_matrix((X.shape[0], Y.shape[1])) X_aug = sp.hstack((X, Y_pred_chain), format="lil") From a0bd979b0854f2d3321077e348fd14ea1b676fd4 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Fri, 16 Feb 2024 14:23:52 +0100 Subject: [PATCH 19/22] compatibility with scipy versions < 1.11 --- sklearn/multioutput.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index ecfc4640c2dd3..b8dbabe2d3b2e 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -701,8 +701,10 @@ def fit(self, X, Y, **fit_params): # to prevent scipy.sparse.hstack from breaking, we convert the sparse # dok_array to a coo_array format, it's also faster; see scipy issue # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039 - if isinstance(X, sp._dok.dok_array): - X = sp.coo_array(X) + # convert type of X only in case it a scipy.sparray: + if not sp.isspmatrix(X): + if isinstance(X, sp._dok.dok_array): + X = sp.coo_array(X) Y_pred_chain = sp.lil_matrix((X.shape[0], Y.shape[1])) X_aug = sp.hstack((X, Y_pred_chain), format="lil") From 83439c3999a1b9cadf860c94eb790d0ea12a71d1 Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Tue, 20 Feb 2024 09:33:23 +0100 Subject: [PATCH 20/22] public path for isinstance check --- sklearn/multioutput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index b8dbabe2d3b2e..d9be43b96c990 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -703,7 +703,7 @@ def fit(self, X, Y, **fit_params): # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039 # convert type of X only in case it a scipy.sparray: if not sp.isspmatrix(X): - if isinstance(X, sp._dok.dok_array): + if isinstance(X, sp.dok_array): X = sp.coo_array(X) Y_pred_chain = sp.lil_matrix((X.shape[0], Y.shape[1])) X_aug = sp.hstack((X, Y_pred_chain), format="lil") From 3cfc057a8b9e6d410ee0c8d7c613faf5cabb8ece Mon Sep 17 00:00:00 2001 From: Stefanie Senger <91849487+StefanieSenger@users.noreply.github.com> Date: Tue, 20 Feb 2024 16:53:52 +0100 Subject: [PATCH 21/22] Apply suggestions from code review Co-authored-by: Adrin Jalali --- sklearn/multioutput.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index d9be43b96c990..44babe462e7e0 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -702,10 +702,9 @@ def fit(self, X, Y, **fit_params): # dok_array to a coo_array format, it's also faster; see scipy issue # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039 # convert type of X only in case it a scipy.sparray: - if not sp.isspmatrix(X): - if isinstance(X, sp.dok_array): - X = sp.coo_array(X) - Y_pred_chain = sp.lil_matrix((X.shape[0], Y.shape[1])) + if not sp.isspmatrix(X) and X.format == "dok": + X = sp.coo_array(X) + Y_pred_chain = sp.coo_matrix((X.shape[0], Y.shape[1])) X_aug = sp.hstack((X, Y_pred_chain), format="lil") else: From 2bf114a1574040ef00c0fd468f07289f87c38d0d Mon Sep 17 00:00:00 2001 From: Stefanie Senger Date: Thu, 22 Feb 2024 11:24:15 +0100 Subject: [PATCH 22/22] added conversion into sparse array --- sklearn/multioutput.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index 44babe462e7e0..c4b4f2bd3dd27 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -698,13 +698,19 @@ def fit(self, X, Y, **fit_params): X_aug = np.hstack((X, Y_pred_chain)) elif sp.issparse(X): - # to prevent scipy.sparse.hstack from breaking, we convert the sparse - # dok_array to a coo_array format, it's also faster; see scipy issue - # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039 - # convert type of X only in case it a scipy.sparray: - if not sp.isspmatrix(X) and X.format == "dok": - X = sp.coo_array(X) - Y_pred_chain = sp.coo_matrix((X.shape[0], Y.shape[1])) + # TODO: remove this condition check when the minimum supported scipy version + # doesn't support sparse matrices anymore + if not sp.isspmatrix(X): + # if `X` is a scipy sparse dok_array, we convert it to a sparse + # coo_array format before hstacking, it's faster; see + # https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039: + if X.format == "dok": + X = sp.coo_array(X) + # in case that `X` is a sparse array we create `Y_pred_chain` as a + # sparse array format: + Y_pred_chain = sp.coo_array((X.shape[0], Y.shape[1])) + else: + Y_pred_chain = sp.coo_matrix((X.shape[0], Y.shape[1])) X_aug = sp.hstack((X, Y_pred_chain), format="lil") else: